tagging release
[dasher.git] / Src / Test / LanguageModelling / pjc51_main.cpp
blobef1451c07f363e3e6ab0ed82e691da21152858f8
1 // main.cpp
2 //
3 /////////////////////////////////////////////////////////////////////////////
4 //
5 // Copyright (c) 2005 David Ward
6 //
7 /////////////////////////////////////////////////////////////////////////////
9 // LanguageModel test application
11 #include "../../Common/Common.h"
12 #include "../../DasherCore/LanguageModelling/PPMLanguageModel.h"
13 #include "../../DasherCore/LanguageModelling/WordLanguageModel.h"
14 #include "../../DasherCore/LanguageModelling/MixtureLanguageModel.h"
15 #include "../../DasherCore/LanguageModelling/LanguageModelParams.h"
17 #include "../../DasherCore/Alphabet/AlphIO.h"
18 #include "../../DasherCore/Alphabet/Alphabet.h"
20 #include <fstream>
21 #include <iostream>
22 #include <cmath>
23 #include <sstream>
25 #include "lib_expt.h"
27 #include <gsl/gsl_vector.h>
28 #include <gsl/gsl_multimin.h>
30 #include <TextHandler.hpp>
31 #include <TrecParser.hpp>
33 using namespace Dasher;
34 using namespace std;
36 double f(const gsl_vector * x, void *params);
38 class cCompressionExperiment:public cExperiment {
39 public:
40 cCompressionExperiment(const std::string & oPrefix):cExperiment(oPrefix) {
42 double Execute();
45 class dummy_handler:public TextHandler {
47 // A dummy text handler to allow us to use the trec parser
48 public:
50 dummy_handler(std::stringstream * _ss) {
51 ss = _ss;
54 virtual char *handleWord(char *word, const char *original, PropertyList * list) {
55 (*ss) << word << " ";
56 return word;
59 virtual char *handleEndDoc(char *token, const char *original, PropertyList * list) {
60 return token;
63 protected:
64 std::stringstream * ss;
67 int main(int argc, char *argv[]) {
68 int iNumDimensions(3);
70 gsl_multimin_function oMinFunction;
72 oMinFunction.f = f;
73 oMinFunction.n = iNumDimensions;
74 oMinFunction.params = NULL;
76 gsl_vector *pXInit(gsl_vector_alloc(iNumDimensions));
78 gsl_vector_set(pXInit, 0, -0.56636);
79 gsl_vector_set(pXInit, 1, 0.60203);
80 gsl_vector_set(pXInit, 2, 3.89152);
82 gsl_vector *pXStep(gsl_vector_alloc(iNumDimensions));
83 gsl_vector_set_all(pXStep, 1.0);
85 f(pXInit, NULL);
87 // gsl_multimin_fminimizer *pMinimizer( gsl_multimin_fminimizer_alloc( gsl_multimin_fminimizer_nmsimplex, iNumDimensions ));
88 // gsl_multimin_fminimizer_set( pMinimizer, &oMinFunction, pXInit, pXStep );
90 // std::ofstream oMinOutputFile( "fmin.op" );
92 // int iNumIterations( 100 );
93 // for( int i(0); i < iNumIterations; ++i ) {
94 // gsl_multimin_fminimizer_iterate( pMinimizer );
96 // gsl_vector *pXCurrent( gsl_multimin_fminimizer_x( pMinimizer ));
97 // double dMin( gsl_multimin_fminimizer_minimum( pMinimizer ));
99 // for( int j(0); j < iNumDimensions; ++j ) {
100 // oMinOutputFile << gsl_vector_get( pXCurrent, j ) << " ";
101 // }
103 // oMinOutputFile << dMin << std::endl;
105 // }
107 return 0;
110 double f(const gsl_vector *x, void *params) {
111 cCompressionExperiment oExpt("Experiment5");
113 oExpt.SetParameterInt("LMAlpha", exp(gsl_vector_get(x, 0)) * 100);
114 oExpt.SetParameterInt("LMBeta", (tanh(gsl_vector_get(x, 1)) + 1) * 50);
115 oExpt.SetParameterInt("LMWordAlpha", exp(gsl_vector_get(x, 2)) * 100);
117 oExpt.SetParameterInt("LetterOrder", 200);
119 oExpt.SetParameterInt("ModelType", 1);
120 oExpt.SetParameterInt("Dictionary", 1);
121 oExpt.SetParameterInt("LetterExclusion", 1);
123 return oExpt.Run();
126 double cCompressionExperiment::Execute() {
128 std::cerr << "Setting up language model ... " << std::flush;
130 string userlocation = "/usr/local/share/dasher/";
131 string filename = "alphabet.english.xml";
133 vector < string > vFileNames;
134 vFileNames.push_back(filename);
136 // Set up the CAlphIO
137 std::auto_ptr < CAlphIO > ptrAlphIO(new CAlphIO("", userlocation, vFileNames));
139 string strID = "English alphabet with lots of punctuation";
140 const CAlphIO::AlphInfo & AlphInfo = ptrAlphIO->GetInfo(strID);
142 // Create the Alphabet that converts plain text to symbols
143 std::auto_ptr < CAlphabet > ptrAlphabet(new CAlphabet(AlphInfo));
145 string strFileCompress = userlocation + "training_english_GB.txt";
146 // string strFileCompress = "testfile.txt";
148 std::cerr << "done." << std::endl;
150 std::cerr << "Loading data file ... " << std::flush;
154 std::stringstream strCompress;
156 TrecParser tp;
157 dummy_handler dh(&strCompress);
159 tp.setTextHandler(&dh);
161 // tp.parseFile( "/mnt/data2/pjc51/enron/enron_short_trec.txt" );
163 tp.parseFile("/data/tiree2/pjc51/enron/enron_short_trec.txt");
167 std::cerr << "done." << std::endl;
169 std::cerr << "Converting to symbols ... " << std::flush;
171 int iLength(strCompress.str().size());
172 int iTestSize(10000);
174 std::vector < symbol > vSymbols;
175 ptrAlphabet->GetSymbols(&vSymbols, &(strCompress.str().substr(0, iLength - iTestSize)), false /*IsMore */ );
177 std::vector < symbol > vSymbolsTest;
178 ptrAlphabet->GetSymbols(&vSymbolsTest, &(strCompress.str().substr(iLength - iTestSize, iTestSize)), false /*IsMore */ );
180 std::cerr << "(" << vSymbolsTest.size() << " test symbols) ";
182 std::cerr << "done." << std::endl;
184 // Set up the language model for compression test
186 CSymbolAlphabet alphabet(ptrAlphabet->GetNumberSymbols());
187 alphabet.SetSpaceSymbol(ptrAlphabet->GetSpaceSymbol());
188 alphabet.SetAlphabetPointer(&*ptrAlphabet);
190 int order(200);
191 int exclusion(0);
192 int update_exclusion(1);
194 CLanguageModelParams settings;
196 settings.SetValue("LMMaxOrder", GetParameterInt("LetterOrder"));
198 settings.SetValue("LMExclusion", exclusion);
199 settings.SetValue("LMUpdateExclusion", update_exclusion);
201 settings.SetValue("LMAlpha", GetParameterInt("LMAlpha")); // 49
202 settings.SetValue("LMBeta", GetParameterInt("LMBeta")); // 77
204 settings.SetValue("LMWordAlpha", GetParameterInt("LMWordAlpha"));
205 settings.SetValue("LMDictionary", GetParameterInt("Dictionary"));
207 settings.SetValue("LMLetterExclusion", GetParameterInt("LetterExclusion"));
209 CLanguageModel *lm;
211 switch (GetParameterInt("ModelType")) {
212 case 0:
213 lm = new CPPMLanguageModel(alphabet, &settings);
214 break;
215 case 1:
216 lm = new CWordLanguageModel(alphabet, &settings);
217 break;
220 std::cerr << "Calculating compression ... " << std::flush;
222 CLanguageModel::Context context;
223 context = lm->CreateEmptyContext();
225 std::vector < unsigned int >Probs;
226 int iNormTot = 1 << 16;
227 int iNorm = iNormTot;
229 int iASize = alphabet.GetSize();
231 int iExtra = (iNormTot - iNorm) / (iASize - 1);
232 double dSumLogP = 0;
234 // Loop over symbols
236 int iPcOld(-1);
238 for(int i = 0; i < vSymbols.size(); i++) {
240 int iPc(i * 100 / vSymbols.size());
242 if(iPc > iPcOld) {
244 // Do a test...
246 CLanguageModel::Context oTestContext;
247 oTestContext = lm->CreateEmptyContext();
249 double dSumLogPTest(0.0);
251 int iTestCount(0);
253 for(int k(0); k < vSymbolsTest.size(); ++k) {
255 lm->GetProbs(oTestContext, Probs, iNorm);
257 ++iTestCount;
259 symbol s = vSymbolsTest[k];
261 // std::cerr << s << std::endl;
263 int j = Probs[s];
265 double p = static_cast < double >(j + 1) / static_cast < double >(iNormTot + iASize - 1);
267 int iTot(0);
269 for(int l(1); l < iASize; ++l) {
271 double dPTemp(static_cast < double >(Probs[l] + 1) / static_cast < double >(iNormTot + iASize - 1));
273 iTot += Probs[l];
275 if((dPTemp <= 0.0) || (dPTemp >= 1.0))
276 std::cout << "warning - prob of " << dPTemp << std::endl;
279 if(iTot > iNorm)
280 std::cout << "Total: " << iTot << ", " << iNorm << std::endl;
282 DASHER_ASSERT(p != 0);
284 // std::cout << "p: " << p << std::endl;
286 dSumLogPTest += log(p);
288 lm->EnterSymbol(oTestContext, s);
291 std::cerr << iPc << " " << iTestCount << "% " << dSumLogPTest << " " << i << " " << -dSumLogPTest / log(2.0) / static_cast < double >(vSymbolsTest.size()) << std::endl;
294 iPcOld = iPc;
296 lm->GetProbs(context, Probs, iNorm);
298 symbol sbl = vSymbols[i];
300 lm->LearnSymbol(context, sbl);
303 std::cerr << "done." << std::endl;
305 delete lm;
307 return -dSumLogP / log(2.0) / vSymbols.size();