3 /////////////////////////////////////////////////////////////////////////////
5 // Copyright (c) 2005 David Ward
7 /////////////////////////////////////////////////////////////////////////////
9 // LanguageModel test application
11 #include "../../Common/Common.h"
12 #include "../../DasherCore/LanguageModelling/PPMLanguageModel.h"
13 #include "../../DasherCore/LanguageModelling/WordLanguageModel.h"
14 #include "../../DasherCore/LanguageModelling/MixtureLanguageModel.h"
15 #include "../../DasherCore/LanguageModelling/LanguageModelParams.h"
17 #include "../../DasherCore/Alphabet/AlphIO.h"
18 #include "../../DasherCore/Alphabet/Alphabet.h"
27 #include <gsl/gsl_vector.h>
28 #include <gsl/gsl_multimin.h>
30 #include <TextHandler.hpp>
31 #include <TrecParser.hpp>
33 using namespace Dasher
;
36 double f(const gsl_vector
* x
, void *params
);
38 class cCompressionExperiment
:public cExperiment
{
40 cCompressionExperiment(const std::string
& oPrefix
):cExperiment(oPrefix
) {
45 class dummy_handler
:public TextHandler
{
47 // A dummy text handler to allow us to use the trec parser
50 dummy_handler(std::stringstream
* _ss
) {
54 virtual char *handleWord(char *word
, const char *original
, PropertyList
* list
) {
59 virtual char *handleEndDoc(char *token
, const char *original
, PropertyList
* list
) {
64 std::stringstream
* ss
;
67 int main(int argc
, char *argv
[]) {
68 int iNumDimensions(3);
70 gsl_multimin_function oMinFunction
;
73 oMinFunction
.n
= iNumDimensions
;
74 oMinFunction
.params
= NULL
;
76 gsl_vector
*pXInit(gsl_vector_alloc(iNumDimensions
));
78 gsl_vector_set(pXInit
, 0, -0.56636);
79 gsl_vector_set(pXInit
, 1, 0.60203);
80 gsl_vector_set(pXInit
, 2, 3.89152);
82 gsl_vector
*pXStep(gsl_vector_alloc(iNumDimensions
));
83 gsl_vector_set_all(pXStep
, 1.0);
87 // gsl_multimin_fminimizer *pMinimizer( gsl_multimin_fminimizer_alloc( gsl_multimin_fminimizer_nmsimplex, iNumDimensions ));
88 // gsl_multimin_fminimizer_set( pMinimizer, &oMinFunction, pXInit, pXStep );
90 // std::ofstream oMinOutputFile( "fmin.op" );
92 // int iNumIterations( 100 );
93 // for( int i(0); i < iNumIterations; ++i ) {
94 // gsl_multimin_fminimizer_iterate( pMinimizer );
96 // gsl_vector *pXCurrent( gsl_multimin_fminimizer_x( pMinimizer ));
97 // double dMin( gsl_multimin_fminimizer_minimum( pMinimizer ));
99 // for( int j(0); j < iNumDimensions; ++j ) {
100 // oMinOutputFile << gsl_vector_get( pXCurrent, j ) << " ";
103 // oMinOutputFile << dMin << std::endl;
110 double f(const gsl_vector
*x
, void *params
) {
111 cCompressionExperiment
oExpt("Experiment5");
113 oExpt
.SetParameterInt("LMAlpha", exp(gsl_vector_get(x
, 0)) * 100);
114 oExpt
.SetParameterInt("LMBeta", (tanh(gsl_vector_get(x
, 1)) + 1) * 50);
115 oExpt
.SetParameterInt("LMWordAlpha", exp(gsl_vector_get(x
, 2)) * 100);
117 oExpt
.SetParameterInt("LetterOrder", 200);
119 oExpt
.SetParameterInt("ModelType", 1);
120 oExpt
.SetParameterInt("Dictionary", 1);
121 oExpt
.SetParameterInt("LetterExclusion", 1);
126 double cCompressionExperiment::Execute() {
128 std::cerr
<< "Setting up language model ... " << std::flush
;
130 string userlocation
= "/usr/local/share/dasher/";
131 string filename
= "alphabet.english.xml";
133 vector
< string
> vFileNames
;
134 vFileNames
.push_back(filename
);
136 // Set up the CAlphIO
137 std::auto_ptr
< CAlphIO
> ptrAlphIO(new CAlphIO("", userlocation
, vFileNames
));
139 string strID
= "English alphabet with lots of punctuation";
140 const CAlphIO::AlphInfo
& AlphInfo
= ptrAlphIO
->GetInfo(strID
);
142 // Create the Alphabet that converts plain text to symbols
143 std::auto_ptr
< CAlphabet
> ptrAlphabet(new CAlphabet(AlphInfo
));
145 string strFileCompress
= userlocation
+ "training_english_GB.txt";
146 // string strFileCompress = "testfile.txt";
148 std::cerr
<< "done." << std::endl
;
150 std::cerr
<< "Loading data file ... " << std::flush
;
154 std::stringstream strCompress
;
157 dummy_handler
dh(&strCompress
);
159 tp
.setTextHandler(&dh
);
161 // tp.parseFile( "/mnt/data2/pjc51/enron/enron_short_trec.txt" );
163 tp
.parseFile("/data/tiree2/pjc51/enron/enron_short_trec.txt");
167 std::cerr
<< "done." << std::endl
;
169 std::cerr
<< "Converting to symbols ... " << std::flush
;
171 int iLength(strCompress
.str().size());
172 int iTestSize(10000);
174 std::vector
< symbol
> vSymbols
;
175 ptrAlphabet
->GetSymbols(&vSymbols
, &(strCompress
.str().substr(0, iLength
- iTestSize
)), false /*IsMore */ );
177 std::vector
< symbol
> vSymbolsTest
;
178 ptrAlphabet
->GetSymbols(&vSymbolsTest
, &(strCompress
.str().substr(iLength
- iTestSize
, iTestSize
)), false /*IsMore */ );
180 std::cerr
<< "(" << vSymbolsTest
.size() << " test symbols) ";
182 std::cerr
<< "done." << std::endl
;
184 // Set up the language model for compression test
186 CSymbolAlphabet
alphabet(ptrAlphabet
->GetNumberSymbols());
187 alphabet
.SetSpaceSymbol(ptrAlphabet
->GetSpaceSymbol());
188 alphabet
.SetAlphabetPointer(&*ptrAlphabet
);
192 int update_exclusion(1);
194 CLanguageModelParams settings
;
196 settings
.SetValue("LMMaxOrder", GetParameterInt("LetterOrder"));
198 settings
.SetValue("LMExclusion", exclusion
);
199 settings
.SetValue("LMUpdateExclusion", update_exclusion
);
201 settings
.SetValue("LMAlpha", GetParameterInt("LMAlpha")); // 49
202 settings
.SetValue("LMBeta", GetParameterInt("LMBeta")); // 77
204 settings
.SetValue("LMWordAlpha", GetParameterInt("LMWordAlpha"));
205 settings
.SetValue("LMDictionary", GetParameterInt("Dictionary"));
207 settings
.SetValue("LMLetterExclusion", GetParameterInt("LetterExclusion"));
211 switch (GetParameterInt("ModelType")) {
213 lm
= new CPPMLanguageModel(alphabet
, &settings
);
216 lm
= new CWordLanguageModel(alphabet
, &settings
);
220 std::cerr
<< "Calculating compression ... " << std::flush
;
222 CLanguageModel::Context context
;
223 context
= lm
->CreateEmptyContext();
225 std::vector
< unsigned int >Probs
;
226 int iNormTot
= 1 << 16;
227 int iNorm
= iNormTot
;
229 int iASize
= alphabet
.GetSize();
231 int iExtra
= (iNormTot
- iNorm
) / (iASize
- 1);
238 for(int i
= 0; i
< vSymbols
.size(); i
++) {
240 int iPc(i
* 100 / vSymbols
.size());
246 CLanguageModel::Context oTestContext
;
247 oTestContext
= lm
->CreateEmptyContext();
249 double dSumLogPTest(0.0);
253 for(int k(0); k
< vSymbolsTest
.size(); ++k
) {
255 lm
->GetProbs(oTestContext
, Probs
, iNorm
);
259 symbol s
= vSymbolsTest
[k
];
261 // std::cerr << s << std::endl;
265 double p
= static_cast < double >(j
+ 1) / static_cast < double >(iNormTot
+ iASize
- 1);
269 for(int l(1); l
< iASize
; ++l
) {
271 double dPTemp(static_cast < double >(Probs
[l
] + 1) / static_cast < double >(iNormTot
+ iASize
- 1));
275 if((dPTemp
<= 0.0) || (dPTemp
>= 1.0))
276 std::cout
<< "warning - prob of " << dPTemp
<< std::endl
;
280 std::cout
<< "Total: " << iTot
<< ", " << iNorm
<< std::endl
;
282 DASHER_ASSERT(p
!= 0);
284 // std::cout << "p: " << p << std::endl;
286 dSumLogPTest
+= log(p
);
288 lm
->EnterSymbol(oTestContext
, s
);
291 std::cerr
<< iPc
<< " " << iTestCount
<< "% " << dSumLogPTest
<< " " << i
<< " " << -dSumLogPTest
/ log(2.0) / static_cast < double >(vSymbolsTest
.size()) << std::endl
;
296 lm
->GetProbs(context
, Probs
, iNorm
);
298 symbol sbl
= vSymbols
[i
];
300 lm
->LearnSymbol(context
, sbl
);
303 std::cerr
<< "done." << std::endl
;
307 return -dSumLogP
/ log(2.0) / vSymbols
.size();