00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _SIMPLEKLRETMETHOD_HPP
00013 #define _SIMPLEKLRETMETHOD_HPP
00014
00015 #include <cmath>
00016 #include "UnigramLM.hpp"
00017 #include "ScoreFunction.hpp"
00018 #include "SimpleKLDocModel.hpp"
00019 #include "TextQueryRep.hpp"
00020 #include "TextQueryRetMethod.hpp"
00021 #include "Counter.hpp"
00022 #include "DocUnigramCounter.hpp"
00023
00024 namespace lemur
00025 {
00026 namespace retrieval
00027 {
00028
00030
00031 class SimpleKLQueryModel : public ArrayQueryRep {
00032 public:
00034 SimpleKLQueryModel(const lemur::api::TermQuery &qry,
00035 const lemur::api::Index &dbIndex) :
00036 ArrayQueryRep(dbIndex.termCountUnique()+1, qry, dbIndex), qm(NULL),
00037 ind(dbIndex), colKLComputed(false) {
00038 colQLikelihood = 0;
00039 colQueryLikelihood();
00040 }
00041
00043 SimpleKLQueryModel(const lemur::api::Index &dbIndex) :
00044 ArrayQueryRep(dbIndex.termCountUnique()+1), qm(NULL), ind(dbIndex),
00045 colKLComputed(false) {
00046 colQLikelihood = 0;
00047 startIteration();
00048 while (hasMore()) {
00049 lemur::api::QueryTerm *qt = nextTerm();
00050 setCount(qt->id(), 0);
00051 delete qt;
00052 }
00053 }
00054
00055
00056 virtual ~SimpleKLQueryModel(){ if (qm) delete qm;}
00057
00058
00060
00067 virtual void interpolateWith(const lemur::langmod::UnigramLM &qModel,
00068 double origModCoeff,
00069 int howManyWord, double prSumThresh=1,
00070 double prThresh=0);
00071 virtual double scoreConstant() const {
00072 return totalCount();
00073 }
00074
00076 virtual void load(istream &is);
00077
00079 virtual void save(ostream &os);
00080
00082 virtual void clarity(ostream &os);
00084 virtual double clarity() const;
00085
00087 double colDivergence() const {
00088 if (colKLComputed) {
00089 return colKL;
00090 } else {
00091 colKLComputed = true;
00092 double d=0;
00093 startIteration();
00094 while (hasMore()) {
00095 lemur::api::QueryTerm *qt=nextTerm();
00096 double pr = qt->weight()/(double)totalCount();
00097 double colPr = ((double)ind.termCount(qt->id()) /
00098 (double)(ind.termCount()));
00099 d += pr*log(pr/colPr);
00100 delete qt;
00101 }
00102 colKL=d;
00103 return d;
00104 }
00105 }
00106
00108 double KLDivergence(const lemur::langmod::UnigramLM &refMod) {
00109 double d=0;
00110 startIteration();
00111 while (hasMore()) {
00112 lemur::api::QueryTerm *qt=nextTerm();
00113 double pr = qt->weight()/(double)totalCount();
00114 d += pr*log(pr/refMod.prob(qt->id()));
00115 delete qt;
00116 }
00117 return d;
00118 }
00119
00120 double colQueryLikelihood() const {
00121 if (colQLikelihood == 0) {
00122
00123 lemur::api::COUNT_T tc = ind.termCount();
00124 startIteration();
00125 while (hasMore()) {
00126 lemur::api::QueryTerm *qt = nextTerm();
00127 lemur::api::TERMID_T id = qt->id();
00128 double qtf = qt->weight();
00129 lemur::api::COUNT_T qtcf = ind.termCount(id);
00130 double s = qtf * log((double)qtcf/(double)tc);
00131 colQLikelihood += s;
00132 delete qt;
00133 }
00134 }
00135 return colQLikelihood;
00136 }
00137
00138
00139 protected:
00140
00141 mutable double colQLikelihood;
00142 mutable double colKL;
00143 mutable bool colKLComputed;
00144
00145 lemur::api::IndexedRealVector *qm;
00146 const lemur::api::Index &ind;
00147 };
00148
00149
00150
00152
00167 class SimpleKLScoreFunc : public lemur::api::ScoreFunction {
00168 public:
00169 enum SimpleKLParameter::adjustedScoreMethods adjScoreMethod;
00170 void setScoreMethod(enum SimpleKLParameter::adjustedScoreMethods adj) {
00171 adjScoreMethod = adj;
00172 }
00173 virtual double matchedTermWeight(const lemur::api::QueryTerm *qTerm,
00174 const lemur::api::TextQueryRep *qRep,
00175 const lemur::api::DocInfo *info,
00176 const lemur::api::DocumentRep *dRep) const {
00177 double w = qTerm->weight();
00178 double d = dRep->termWeight(qTerm->id(),info);
00179 double l = log(d);
00180 double score = w*l;
00181
00182
00183
00184
00185 return score;
00186
00187 }
00189 virtual double adjustedScore(double origScore,
00190 const lemur::api::TextQueryRep *qRep,
00191 const lemur::api::DocumentRep *dRep) const {
00192 const SimpleKLQueryModel *qm = dynamic_cast<const SimpleKLQueryModel *>(qRep);
00193
00194
00195
00196
00197 double qsc = qm->scoreConstant();
00198 double dsc = log(dRep->scoreConstant());
00199 double cql = qm->colQueryLikelihood();
00200
00201 double s = dsc * qsc + origScore + cql;
00202 double qsNorm = origScore/qsc;
00203 double qmD = qm->colDivergence();
00204
00205
00206
00207
00209 switch (adjScoreMethod) {
00210 case SimpleKLParameter::QUERYLIKELIHOOD:
00212
00213
00214
00215 return s;
00216
00217 case SimpleKLParameter::CROSSENTROPY:
00219
00220 assert(qm->scoreConstant()!=0);
00221
00222
00223 s = qsNorm + dsc + cql/qsc;
00224 return (s);
00225 case SimpleKLParameter::NEGATIVEKLD:
00227
00228 assert(qm->scoreConstant()!=0);
00229 s = qsNorm + dsc - qmD;
00230
00231
00232
00233 return s;
00234
00235
00236 default:
00237 cerr << "unknown adjusted score method" << endl;
00238 return origScore;
00239 }
00240 }
00241
00242 };
00243
00245 class SimpleKLRetMethod : public lemur::api::TextQueryRetMethod {
00246 public:
00247
00249 SimpleKLRetMethod(const lemur::api::Index &dbIndex,
00250 const string &supportFileName,
00251 lemur::api::ScoreAccumulator &accumulator);
00252 virtual ~SimpleKLRetMethod();
00253
00254 virtual lemur::api::TextQueryRep *computeTextQueryRep(const lemur::api::TermQuery &qry) {
00255 return (new SimpleKLQueryModel(qry, ind));
00256 }
00257
00258 virtual lemur::api::DocumentRep *computeDocRep(lemur::api::DOCID_T docID);
00259
00260
00261 virtual lemur::api::ScoreFunction *scoreFunc() {
00262 return (scFunc);
00263 }
00264
00265 virtual void updateTextQuery(lemur::api::TextQueryRep &origRep,
00266 const lemur::api::DocIDSet &relDocs);
00267
00268 void setDocSmoothParam(SimpleKLParameter::DocSmoothParam &docSmthParam);
00269 void setQueryModelParam(SimpleKLParameter::QueryModelParam &queryModParam);
00270
00271 protected:
00272
00274 double *mcNorm;
00275
00277 double *docProbMass;
00279 lemur::api::COUNT_T *uniqueTermCount;
00281 lemur::langmod::UnigramLM *collectLM;
00283 lemur::langmod::DocUnigramCounter *collectLMCounter;
00285 SimpleKLScoreFunc *scFunc;
00286
00288
00289
00290 void computeMixtureFBModel(SimpleKLQueryModel &origRep,
00291 const lemur::api::DocIDSet & relDocs);
00293 void computeDivMinFBModel(SimpleKLQueryModel &origRep,
00294 const lemur::api::DocIDSet &relDocs);
00296 void computeMarkovChainFBModel(SimpleKLQueryModel &origRep,
00297 const lemur::api::DocIDSet &relDocs) ;
00299 void computeRM1FBModel(SimpleKLQueryModel &origRep,
00300 const lemur::api::DocIDSet & relDocs);
00302 void computeRM2FBModel(SimpleKLQueryModel &origRep,
00303 const lemur::api::DocIDSet & relDocs);
00305
00306 SimpleKLParameter::DocSmoothParam docParam;
00307 SimpleKLParameter::QueryModelParam qryParam;
00308
00310 void loadSupportFile();
00311 const string supportFile;
00312 };
00313
00314
00315 inline void SimpleKLRetMethod::setDocSmoothParam(SimpleKLParameter::DocSmoothParam &docSmthParam)
00316 {
00317 docParam = docSmthParam;
00318 loadSupportFile();
00319 }
00320
00321 inline void SimpleKLRetMethod::setQueryModelParam(SimpleKLParameter::QueryModelParam &queryModParam)
00322 {
00323 qryParam = queryModParam;
00324
00325
00326 scFunc->setScoreMethod(qryParam.adjScoreMethod);
00327 loadSupportFile();
00328 }
00329 }
00330 }
00331
00332 #endif