00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef _PLSA_HPP
00018 #define _PLSA_HPP
00019 #include "common_headers.hpp"
00020 #include <cmath>
00021 #include <set>
00022 #include "Index.hpp"
00023 #include "FreqVector.hpp"
00024 namespace lemur
00025 {
00026 namespace cluster
00027 {
00028
00029
00030 class PLSA;
00031
00032
00034 typedef double (PLSA::*jointfuncType)(int, int);
00035
00036
00040
00041 class PLSA {
00042 public:
00043
00044
00046 PLSA(const lemur::api::Index &dbIndex, int numCats,
00047 lemur::utility::HashFreqVector **train,
00048 lemur::utility::HashFreqVector **validate, int numIter,
00049 int numRestarts, double betastart,
00050 double betastop, double anneal, double betaMod);
00051
00053 PLSA(const lemur::api::Index &dbIndex, int testPercentage, int numCats, int numIter,
00054 int numRestarts, double betastart,
00055 double betastop, double anneal, double betaMod);
00057 PLSA(const lemur::api::Index &dbIndex);
00058 virtual ~PLSA();
00059
00061 void iterateWithRestarts();
00062
00065 double *get_p_z() const {return p_z_best;}
00067 double **get_p_w_z() const {return p_w_z_best;}
00069 double **get_p_d_z() const {return p_d_z_best;}
00071 double getProb(int d, int w) const ;
00072
00074 int numWords() const {return sizeW;}
00076 int numDocs() const {return sizeD;}
00078 int numCats() const {return sizeZ;}
00081 bool readArrays();
00082
00083 private:
00084
00086 const lemur::api::Index &ind;
00088 int sizeZ;
00090 int sizeD;
00092 int sizeW;
00093
00095 lemur::utility::HashFreqVector **data;
00097 lemur::utility::HashFreqVector **testData;
00099 set<int, less<int> > *invIndex;
00100
00102 double startBeta, beta, betaMin;
00104 double betaModifier;
00106 double annealcue;
00108 int R;
00110 int numberOfIterations;
00112 int numberOfRestarts;
00114 double bestTestLL;
00116 double bestA;
00118 bool bestOnly;
00120 bool ownMem;
00122 double *p_z_current;
00124 double **p_w_z_current;
00126 double **p_d_z_current;
00127
00129 double *p_z_prev;
00131 double **p_w_z_prev;
00133 double **p_d_z_prev;
00134
00136 double *p_z_best;
00138 double **p_w_z_best;
00140 double **p_d_z_best;
00141
00142
00144 void setPrevToCurrent();
00146 void setCurrentToBest();
00148 void setBestToCurrent();
00150 void setBestToPrev();
00152 void setPrevToBest();
00153
00156 double getAverageLikelihood();
00159 double getAverageLikelihoodPrev();
00160
00162 double jointEstimate (int indexD, int indexW);
00164 double jointEstimateCurrent (int indexD, int indexW);
00166 double jointEstimateBest (int indexD, int indexW);
00169 double jointEstimateBeta (int indexD, int indexW);
00170
00172 void iterate();
00174 void initializeParameters();
00175
00178 double doLogLikelihood(jointfuncType, lemur::utility::HashFreqVector **&myData);
00180 double logLikelihood();
00182 double validateDataLogLikelihood();
00184 double validateCurrentLogLikelihood();
00186 double bestDataLogLikelihood();
00188 double interleavedIterationEM();
00190 void selectTestTrain(int testPercent);
00192 void init();
00194 void initR();
00196 enum pType {P_Z = 0, P_W_Z = 1, P_D_Z = 2};
00198 void writeArrays();
00200 bool readArray(ifstream& infile, enum pType which);
00202 void writeArray(ofstream& ofile, enum pType which);
00203 };
00204 }
00205 }
00206
00207 #endif