tagging release
[dasher.git] / trunk / Src / DasherCore / LanguageModelling / PPMLanguageModel.cpp
blob0e67afb94718c49b33369cfe17d027336ebc9de6
1 // PPMLanguageModel.h
2 //
3 /////////////////////////////////////////////////////////////////////////////
4 //
5 // Copyright (c) 1999-2005 David Ward
6 //
7 /////////////////////////////////////////////////////////////////////////////
9 #include "../../Common/Common.h"
10 #include "PPMLanguageModel.h"
12 #include <math.h>
13 #include <stack>
14 #include <sstream>
15 #include <iostream>
17 using namespace Dasher;
18 using namespace std;
20 // Track memory leaks on Windows to the line that new'd the memory
21 #ifdef _WIN32
22 #ifdef _DEBUG
23 #define DEBUG_NEW new( _NORMAL_BLOCK, THIS_FILE, __LINE__ )
24 #define new DEBUG_NEW
25 #undef THIS_FILE
26 static char THIS_FILE[] = __FILE__;
27 #endif
28 #endif
30 /////////////////////////////////////////////////////////////////////
32 CPPMLanguageModel::CPPMLanguageModel(Dasher::CEventHandler *pEventHandler, CSettingsStore *pSettingsStore, const CSymbolAlphabet &SymbolAlphabet)
33 :CLanguageModel(pEventHandler, pSettingsStore, SymbolAlphabet), m_iMaxOrder(4), NodesAllocated(0), m_NodeAlloc(8192), m_ContextAlloc(1024) {
34 m_pRoot = m_NodeAlloc.Alloc();
35 m_pRoot->symbol = -1;
37 m_pRootContext = m_ContextAlloc.Alloc();
38 m_pRootContext->head = m_pRoot;
39 m_pRootContext->order = 0;
41 // Cache the result of update exclusion - otherwise we have to look up a lot when training, which is slow
43 // FIXME - this should be a boolean parameter
45 bUpdateExclusion = ( GetLongParameter(LP_LM_UPDATE_EXCLUSION) !=0 );
49 /////////////////////////////////////////////////////////////////////
51 CPPMLanguageModel::~CPPMLanguageModel() {
54 /////////////////////////////////////////////////////////////////////
55 // Get the probability distribution at the context
57 void CPPMLanguageModel::GetProbs(Context context, std::vector<unsigned int> &probs, int norm) const {
58 const CPPMContext *ppmcontext = (const CPPMContext *)(context);
60 DASHER_ASSERT(m_setContexts.count(ppmcontext) > 0);
62 int iNumSymbols = GetSize();
64 probs.resize(iNumSymbols);
66 std::vector < bool > exclusions(iNumSymbols);
68 int i;
69 for(i = 0; i < iNumSymbols; i++) {
70 probs[i] = 0;
71 exclusions[i] = false;
74 // bool doExclusion = GetLongParameter( LP_LM_ALPHA );
75 bool doExclusion = 0; //FIXME
77 int alpha = GetLongParameter( LP_LM_ALPHA );
78 int beta = GetLongParameter( LP_LM_BETA );
80 unsigned int iToSpend = norm;
82 CPPMnode *pTemp = ppmcontext->head;
84 while(pTemp != 0) {
85 int iTotal = 0;
87 CPPMnode *pSymbol = pTemp->child;
88 while(pSymbol) {
89 int sym = pSymbol->symbol;
90 if(!(exclusions[sym] && doExclusion))
91 iTotal += pSymbol->count;
92 pSymbol = pSymbol->next;
95 if(iTotal) {
96 unsigned int size_of_slice = iToSpend;
97 pSymbol = pTemp->child;
98 while(pSymbol) {
99 if(!(exclusions[pSymbol->symbol] && doExclusion)) {
100 exclusions[pSymbol->symbol] = 1;
102 unsigned int p = static_cast < myint > (size_of_slice) * (100 * pSymbol->count - beta) / (100 * iTotal + alpha);
104 probs[pSymbol->symbol] += p;
105 iToSpend -= p;
107 // Usprintf(debug,TEXT("sym %u counts %d p %u tospend %u \n"),sym,s->count,p,tospend);
108 // DebugOutput(debug);
109 pSymbol = pSymbol->next;
112 pTemp = pTemp->vine;
115 unsigned int size_of_slice = iToSpend;
116 int symbolsleft = 0;
118 for(i = 1; i < iNumSymbols; i++)
119 if(!(exclusions[i] && doExclusion))
120 symbolsleft++;
122 // std::ostringstream str;
123 // for (sym=0;sym<modelchars;sym++)
124 // str << probs[sym] << " ";
125 // str << std::endl;
126 // DASHER_TRACEOUTPUT("probs %s",str.str().c_str());
128 // std::ostringstream str2;
129 // for (sym=0;sym<modelchars;sym++)
130 // str2 << valid[sym] << " ";
131 // str2 << std::endl;
132 // DASHER_TRACEOUTPUT("valid %s",str2.str().c_str());
134 for(i = 1; i < iNumSymbols; i++) {
135 if(!(exclusions[i] && doExclusion)) {
136 unsigned int p = size_of_slice / symbolsleft;
137 probs[i] += p;
138 iToSpend -= p;
142 int iLeft = iNumSymbols-1;
144 for(int j = 1; j < iNumSymbols; ++j) {
145 unsigned int p = iToSpend / iLeft;
146 probs[j] += p;
147 --iLeft;
148 iToSpend -= p;
151 DASHER_ASSERT(iToSpend == 0);
154 void CPPMLanguageModel::AddSymbol(CPPMLanguageModel::CPPMContext &context, int sym)
155 // add symbol to the context
156 // creates new nodes, updates counts
157 // and leaves 'context' at the new context
159 // Ignore attempts to add the root symbol
161 if(sym==0)
162 return;
164 DASHER_ASSERT(sym >= 0 && sym < GetSize());
166 CPPMnode *vineptr, *temp;
167 int updatecnt = 1;
169 temp = context.head->vine;
170 context.head = AddSymbolToNode(context.head, sym, &updatecnt);
171 vineptr = context.head;
172 context.order++;
174 while(temp != 0) {
175 vineptr->vine = AddSymbolToNode(temp, sym, &updatecnt);
176 vineptr = vineptr->vine;
177 temp = temp->vine;
179 vineptr->vine = m_pRoot;
181 //m_iMaxOrder = LanguageModelParams()->GetValue(std::string("LMMaxOrder"));
182 m_iMaxOrder = GetLongParameter( LP_LM_MAX_ORDER );
184 while(context.order > m_iMaxOrder) {
185 context.head = context.head->vine;
186 context.order--;
190 /////////////////////////////////////////////////////////////////////
191 // Update context with symbol 'Symbol'
193 void CPPMLanguageModel::EnterSymbol(Context c, int Symbol) {
194 if(Symbol==0)
195 return;
197 DASHER_ASSERT(Symbol >= 0 && Symbol < GetSize());
199 CPPMLanguageModel::CPPMContext & context = *(CPPMContext *) (c);
201 CPPMnode *find;
203 while(context.head) {
205 if(context.order < m_iMaxOrder) { // Only try to extend the context if it's not going to make it too long
206 find = context.head->find_symbol(Symbol);
207 if(find) {
208 context.order++;
209 context.head = find;
210 // Usprintf(debug,TEXT("found context %x order %d\n"),head,order);
211 // DebugOutput(debug);
213 // std::cout << context.order << std::endl;
214 return;
218 // If we can't extend the current context, follow vine pointer to shorten it and try again
220 context.order--;
221 context.head = context.head->vine;
224 if(context.head == 0) {
225 context.head = m_pRoot;
226 context.order = 0;
229 // std::cout << context.order << std::endl;
233 /////////////////////////////////////////////////////////////////////
235 void CPPMLanguageModel::LearnSymbol(Context c, int Symbol) {
236 if(Symbol==0)
237 return;
240 DASHER_ASSERT(Symbol >= 0 && Symbol < GetSize());
241 CPPMLanguageModel::CPPMContext & context = *(CPPMContext *) (c);
242 AddSymbol(context, Symbol);
245 void CPPMLanguageModel::dumpSymbol(int sym) {
246 if((sym <= 32) || (sym >= 127))
247 printf("<%d>", sym);
248 else
249 printf("%c", sym);
252 void CPPMLanguageModel::dumpString(char *str, int pos, int len)
253 // Dump the string STR starting at position POS
255 char cc;
256 int p;
257 for(p = pos; p < pos + len; p++) {
258 cc = str[p];
259 if((cc <= 31) || (cc >= 127))
260 printf("<%d>", cc);
261 else
262 printf("%c", cc);
266 void CPPMLanguageModel::dumpTrie(CPPMLanguageModel::CPPMnode *t, int d)
267 // diagnostic display of the PPM trie from node t and deeper
269 //TODO
271 dchar debug[256];
272 int sym;
273 CPPMnode *s;
274 Usprintf( debug,TEXT("%5d %7x "), d, t );
275 //TODO: Uncomment this when headers sort out
276 //DebugOutput(debug);
277 if (t < 0) // pointer to input
278 printf( " <" );
279 else {
280 Usprintf(debug,TEXT( " %3d %5d %7x %7x %7x <"), t->symbol,t->count, t->vine, t->child, t->next );
281 //TODO: Uncomment this when headers sort out
282 //DebugOutput(debug);
285 dumpString( dumpTrieStr, 0, d );
286 Usprintf( debug,TEXT(">\n") );
287 //TODO: Uncomment this when headers sort out
288 //DebugOutput(debug);
289 if (t != 0) {
290 s = t->child;
291 while (s != 0) {
292 sym =s->symbol;
294 dumpTrieStr [d] = sym;
295 dumpTrie( s, d+1 );
296 s = s->next;
302 void CPPMLanguageModel::dump()
303 // diagnostic display of the whole PPM trie
305 // TODO:
307 dchar debug[256];
308 Usprintf(debug,TEXT( "Dump of Trie : \n" ));
309 //TODO: Uncomment this when headers sort out
310 //DebugOutput(debug);
311 Usprintf(debug,TEXT( "---------------\n" ));
312 //TODO: Uncomment this when headers sort out
313 //DebugOutput(debug);
314 Usprintf( debug,TEXT( "depth node symbol count vine child next context\n") );
315 //TODO: Uncomment this when headers sort out
316 //DebugOutput(debug);
317 dumpTrie( root, 0 );
318 Usprintf( debug,TEXT( "---------------\n" ));
319 //TODO: Uncomment this when headers sort out
320 //DebugOutput(debug);
321 Usprintf(debug,TEXT( "\n" ));
322 //TODO: Uncomment this when headers sort out
323 //DebugOutput(debug);
327 ////////////////////////////////////////////////////////////////////////
328 /// PPMnode definitions
329 ////////////////////////////////////////////////////////////////////////
331 CPPMLanguageModel::CPPMnode * CPPMLanguageModel::CPPMnode::find_symbol(int sym) const
332 // see if symbol is a child of node
334 // printf("finding symbol %d at node %d\n",sym,node->id);
335 CPPMnode *found = child;
337 while(found) {
338 if(found->symbol == sym) {
339 return found;
341 found = found->next;
343 return 0;
346 CPPMLanguageModel::CPPMnode * CPPMLanguageModel::AddSymbolToNode(CPPMnode *pNode, int sym, int *update) {
347 CPPMnode *pReturn = pNode->find_symbol(sym);
349 // std::cout << sym << ",";
351 if(pReturn != NULL) {
352 // std::cout << "Using existing node" << std::endl;
354 // if (*update || (LanguageModelParams()->GetValue("LMUpdateExclusion") == 0) )
355 if(*update || !bUpdateExclusion) { // perform update exclusions
356 pReturn->count++;
357 *update = 0;
359 return pReturn;
362 // std::cout << "Creating new node" << std::endl;
364 pReturn = m_NodeAlloc.Alloc(); // count is initialized to 1
365 pReturn->symbol = sym;
366 pReturn->next = pNode->child;
367 pNode->child = pReturn;
369 ++NodesAllocated;
371 return pReturn;
375 struct BinaryRecord {
376 int m_iIndex;
377 int m_iChild;
378 int m_iNext;
379 int m_iVine;
380 unsigned short int m_iCount;
381 short int m_iSymbol;
384 bool CPPMLanguageModel::WriteToFile(std::string strFilename) {
386 std::map<CPPMnode *, int> mapIdx;
387 int iNextIdx(1); // Index of 0 means NULL;
389 std::ofstream oOutputFile(strFilename.c_str());
391 RecursiveWrite(m_pRoot, &mapIdx, &iNextIdx, &oOutputFile);
393 oOutputFile.close();
395 return false;
398 bool CPPMLanguageModel::RecursiveWrite(CPPMnode *pNode, std::map<CPPMnode *, int> *pmapIdx, int *pNextIdx, std::ofstream *pOutputFile) {
400 // Dump node here
402 BinaryRecord sBR;
404 sBR.m_iIndex = GetIndex(pNode, pmapIdx, pNextIdx);
405 sBR.m_iChild = GetIndex(pNode->child, pmapIdx, pNextIdx);
406 sBR.m_iNext = GetIndex(pNode->next, pmapIdx, pNextIdx);
407 sBR.m_iVine = GetIndex(pNode->vine, pmapIdx, pNextIdx);
408 sBR.m_iCount = pNode->count;
409 sBR.m_iSymbol = pNode->symbol;
411 pOutputFile->write(reinterpret_cast<char*>(&sBR), sizeof(BinaryRecord));
413 CPPMnode *pCurrentChild(pNode->child);
415 while(pCurrentChild != NULL) {
416 RecursiveWrite(pCurrentChild, pmapIdx, pNextIdx, pOutputFile);
417 pCurrentChild = pCurrentChild->next;
420 return true;
423 int CPPMLanguageModel::GetIndex(CPPMnode *pAddr, std::map<CPPMnode *, int> *pmapIdx, int *pNextIdx) {
425 int iIndex;
426 if(pAddr == NULL)
427 iIndex = 0;
428 else {
429 std::map<CPPMnode *, int>::iterator it(pmapIdx->find(pAddr));
431 if(it == pmapIdx->end()) {
432 iIndex = *pNextIdx;
433 pmapIdx->insert(std::pair<CPPMnode *, int>(pAddr, iIndex));
434 ++(*pNextIdx);
436 else {
437 iIndex = it->second;
440 return iIndex;
443 bool CPPMLanguageModel::ReadFromFile(std::string strFilename) {
445 std::ifstream oInputFile(strFilename.c_str());
446 std::map<int, CPPMnode*> oMap;
447 BinaryRecord sBR;
448 bool bStarted(false);
450 while(!oInputFile.eof()) {
451 oInputFile.read(reinterpret_cast<char *>(&sBR), sizeof(BinaryRecord));
453 CPPMnode *pCurrent(GetAddress(sBR.m_iIndex, &oMap));
455 pCurrent->child = GetAddress(sBR.m_iChild, &oMap);
456 pCurrent->next = GetAddress(sBR.m_iNext, &oMap);
457 pCurrent->vine = GetAddress(sBR.m_iVine, &oMap);
458 pCurrent->count = sBR.m_iCount;
459 pCurrent->symbol = sBR.m_iSymbol;
461 if(!bStarted) {
462 m_pRoot = pCurrent;
463 bStarted = true;
467 oInputFile.close();
469 return false;
472 CPPMLanguageModel::CPPMnode *CPPMLanguageModel::GetAddress(int iIndex, std::map<int, CPPMnode*> *pMap) {
473 std::map<int, CPPMnode*>::iterator it(pMap->find(iIndex));
475 if(it == pMap->end()) {
476 CPPMnode *pNewNode;
477 pNewNode = m_NodeAlloc.Alloc();
478 pMap->insert(std::pair<int, CPPMnode*>(iIndex, pNewNode));
479 return pNewNode;
481 else {
482 return it->second;