tagging release
[dasher.git] / trunk / Src / DasherCore / LanguageModelling / WordLanguageModel.cpp
blobe09316210720bf03a14d04d8e1a70d1d407ffe75
1 // WordLanguageModel.h
2 //
3 /////////////////////////////////////////////////////////////////////////////
4 //
5 // Copyright (c) 1999-2004 David Ward
6 //
7 /////////////////////////////////////////////////////////////////////////////
9 #include "../../Common/Common.h"
10 #include "WordLanguageModel.h"
11 #include "PPMLanguageModel.h"
13 #include <cmath>
14 #include <stack>
15 #include <iostream>
16 #include <fstream>
18 using namespace Dasher;
19 using namespace std;
21 // static TCHAR debug[256];
22 typedef unsigned long ulong;
24 #ifdef DASHER_WIN32
25 #define snprintf _snprintf
26 #endif
28 // Track memory leaks on Windows to the line that new'd the memory
29 #ifdef _WIN32
30 #ifdef _DEBUG_MEMLEAKS
31 #define DEBUG_NEW new( _NORMAL_BLOCK, THIS_FILE, __LINE__ )
32 #define new DEBUG_NEW
33 #undef THIS_FILE
34 static char THIS_FILE[] = __FILE__;
35 #endif
36 #endif
38 ///////////////////////////////////////////////////////////////////
40 void CWordLanguageModel::CWordContext::dump()
41 // diagnostic output
43 // TODO uncomment this when headers sorted out
44 //dchar debug[128];
45 //Usprintf(debug,TEXT("head %x order %d\n"),head,order);
46 //DebugOutput(debug);
49 ////////////////////////////////////////////////////////////////////////
50 /// Wordnode definitions
51 ////////////////////////////////////////////////////////////////////////
53 /// Return the child of a node with a given symbol, or NULL if there is no child with that symbol yet
55 CWordLanguageModel::CWordnode* CWordLanguageModel::CWordnode::find_symbol(int sym) const {
56 CWordnode *found = child;
57 while(found) {
58 if(found->sbl == sym)
59 return found;
60 found = found->next;
62 return 0;
65 void CWordLanguageModel::CWordnode::RecursiveDump(std::ofstream &file) {
67 CWordnode *pCurrentChild(child);
69 file << "\"" << this << "\" [label=\"" << this->sbl << "\\n" << this->count << "\"]" << std::endl;
71 file << "\"" << this << "\" -> \"" << vine << "\" [style=dashed]" << std::endl;
73 while(pCurrentChild) {
74 file << "\"" << this << "\" -> \"" << pCurrentChild << "\"" << std::endl;
75 pCurrentChild->RecursiveDump(file);
76 pCurrentChild = pCurrentChild->next;
80 CWordLanguageModel::CWordnode * CWordLanguageModel::AddSymbolToNode(CWordnode *pNode, symbol sym, int *update, bool bLearn) {
82 // FIXME - need to implement bLearn
84 CWordnode *pReturn = pNode->find_symbol(sym);
86 if(pReturn != NULL) {
87 if(*update) {
89 // std::cout << "USHRT_MAX: " << USHRT_MAX << " " << bLearn << std::endl;
91 // if( (pReturn->count < USHRT_MAX) && bLearn ) // Truncate counts at storage limit
92 if(bLearn) // Truncate counts at storage limit
93 pReturn->count++;
94 *update = 0;
96 return pReturn;
99 pReturn = m_NodeAlloc.Alloc(); // count is initialized to 1
100 pReturn->sbl = sym;
101 pReturn->next = pNode->child;
102 pNode->child = pReturn;
104 if(!bLearn) {
105 --(pReturn->count); // FIXME - in the long term, don't allocate
106 // nodes if we're not learning, but should be
107 // okay for now
110 // std::cout << pReturn->count << std::endl;
112 ++NodesAllocated;
114 return pReturn;
117 /////////////////////////////////////////////////////////////////////
118 // CWordLanguageModel defs
119 /////////////////////////////////////////////////////////////////////
121 CWordLanguageModel::CWordLanguageModel(Dasher::CEventHandler *pEventHandler, CSettingsStore *pSettingsStore,
122 const CSymbolAlphabet &Alphabet)
123 :CLanguageModel(pEventHandler, pSettingsStore, Alphabet), NodesAllocated(0),
124 max_order(2), m_NodeAlloc(8192), m_ContextAlloc(1024) {
126 // Construct a root node for the trie
128 m_pRoot = m_NodeAlloc.Alloc();
129 m_pRoot->sbl = -1;
130 m_pRoot->count = 0;
132 // Create a spelling model
134 pSpellingModel = new CPPMLanguageModel(m_pEventHandler, m_pSettingsStore, Alphabet);
136 // Construct a root context
138 m_rootcontext = new CWordContext(m_pRoot, 0);
140 m_rootcontext->m_pSpellingModel = pSpellingModel;
141 m_rootcontext->oSpellingContext = pSpellingModel->CreateEmptyContext();
143 iWordStart = 8192;
145 nextid = iWordStart; // Start of indices for words - may need to increase this for *really* large alphabets
148 if(GetBoolParameter(BP_LM_DICTIONARY)) {
150 std::ifstream DictFile("/usr/share/dict/words"); // FIXME - hardcoded paths == bad
152 std::string CurrentWord;
154 while(!DictFile.eof()) {
155 DictFile >> CurrentWord;
157 CurrentWord = CurrentWord + " ";
159 // std::cout << CurrentWord << std::endl;
161 CPPMLanguageModel::Context TempContext(pSpellingModel->CreateEmptyContext());
163 // std::cout << SymbolAlphabet().GetAlphabetPointer() << std::endl;
165 std::vector < symbol > Symbols;
166 SymbolAlphabet().GetAlphabetPointer()->GetSymbols(&Symbols, &CurrentWord, false);
168 for(std::vector < symbol >::iterator it(Symbols.begin()); it != Symbols.end(); ++it) {
169 pSpellingModel->LearnSymbol(TempContext, *it);
172 pSpellingModel->ReleaseContext(TempContext);
177 // oSpellingContext = pSpellingModel->CreateEmptyContext();
179 wordidx = 0;
183 CWordLanguageModel::~CWordLanguageModel() {
185 delete m_rootcontext;
186 delete pSpellingModel;
188 // A non-recursive node deletion algorithm using a stack
189 /* std::stack<CWordnode*> deletenodes;
190 deletenodes.push(m_pRoot);
191 while (!deletenodes.empty())
193 CWordnode* temp = deletenodes.top();
194 deletenodes.pop();
195 CWordnode* next;
198 next = temp->next;
200 // push the child
201 if (temp->child)
202 deletenodes.push(temp->child);
204 delete temp;
206 temp=next;
208 } while (temp !=0);
214 int CWordLanguageModel::lookup_word(const std::string &w) {
215 if(dict[w] == 0) {
216 dict[w] = nextid;
217 ++nextid;
220 return dict[w];
223 int CWordLanguageModel::lookup_word_const(const std::string &w) const {
224 std::cout << "Looking up (const) " << w << std::endl;
226 return dict.find(w)->second;
229 /////////////////////////////////////////////////////////////////////
230 // get the probability distribution at the context
232 void CWordLanguageModel::GetProbs(Context context, std::vector<unsigned int> &probs, int norm) const {
233 // Got rid of const below
235 CWordLanguageModel::CWordContext * wordcontext = (CWordContext *) (context);
237 // Make sure that the probability vector has the right length
239 int iNumSymbols = GetSize();
240 probs.resize(iNumSymbols);
242 // For the prototype work with double precision to make things easier to normalise
244 std::vector < double >dProbs(iNumSymbols);
246 for(std::vector < double >::iterator it(dProbs.begin()); it != dProbs.end(); ++it)
247 *it = 0.0;
249 double alpha = GetLongParameter(LP_LM_WORD_ALPHA) / 100.0;
250 // double beta = LanguageModelParams()->GetValue( std::string( "LMBeta" ) )/100.0;
252 // Ignore beta for now - we'll need to know how many different words have been seen, not just the total count.
254 double dToSpend(1.0);
256 CWordnode *pTmp = wordcontext->head;
257 CWordnode *pTmpWord = wordcontext->word_head;
259 // We'll assume that these stay in sync for now - maybe do something more robust later.
261 while(pTmp) {
263 // Get the total count from the word node
265 int iTotal(pTmpWord->count);
267 if(iTotal) {
269 CWordnode *pTmpChild(pTmp->child);
271 while(pTmpChild) {
272 // make sure we only get child nodes which correspond
273 // to symbols (not words).
275 if(pTmpChild->sbl < iWordStart) {
277 double dP;
279 if(pTmpChild->count > 0)
280 dP = dToSpend * (pTmpChild->count) / static_cast < double >(iTotal + alpha);
281 else
282 dP = 0.0;
284 dProbs[pTmpChild->sbl] += dP;
287 pTmpChild = pTmpChild->next;
292 dToSpend *= alpha / static_cast < double >(iTotal + alpha);
294 pTmp = pTmp->vine;
295 pTmpWord = pTmpWord->vine;
299 // Get probabilities from the spelling model (note we cache these for later)
301 wordcontext->m_iSpellingNorm = norm;
303 int iSpellingNorm(wordcontext->m_iSpellingNorm);
305 wordcontext->m_pSpellingModel->GetProbs(wordcontext->oSpellingContext, wordcontext->oSpellingProbs, iSpellingNorm);
307 double dNorm(0.0);
309 for(int i(0); i < iNumSymbols; ++i) {
310 dProbs[i] += wordcontext->m_dSpellingFactor * wordcontext->oSpellingProbs[i] / static_cast < double >(wordcontext->m_iSpellingNorm);
311 dNorm += dProbs[i];
314 // Convert back to integer representation
316 int iToSpend(norm);
318 for(int i(0); i < iNumSymbols; ++i) {
319 probs[i] = (unsigned int) (norm * dProbs[i] / dNorm);
320 iToSpend -= probs[i];
323 // Check that we haven't got anything left over due to rounding errors:
325 int iLeft = iNumSymbols;
327 for(int j = 0; j < iNumSymbols; ++j) {
328 unsigned int p = iToSpend / iLeft;
329 probs[j] += p;
330 --iLeft;
331 iToSpend -= p;
334 DASHER_ASSERT(iToSpend == 0);
337 /// Collapse the context. This also has the effect of entering a count
338 /// for the word into the word part of the model
340 void CWordLanguageModel::CollapseContext(CWordLanguageModel::CWordContext &context, bool bLearn) {
342 // Letters appear at the end of the trie:
345 if(max_order == 0) {
346 // If max_order = 0 then we are not keeping track of previous
347 // words, so we just collapse the letter part of the context and
348 // return
350 // FIXME - not sure this will work any more - don't use this
351 // branch without checking that it's doing the right thing
353 context.head = m_pRoot;
354 context.order = 0;
357 else {
359 std::vector < symbol > oSymbols;
361 for(std::string::iterator it(context.current_word.begin()); it != context.current_word.end(); it += 4) {
362 std::string fragment(it, it + 4);
363 oSymbols.push_back(atoi(fragment.c_str()));
366 if(bLearn) { // Only do this if we are learning
367 // We need to increment all substrings - start at the current context striped back to the word level
369 bool bUpdateExclusion(false); // Whether to keep going or not
371 CWordnode *pCurrent(context.word_head);
373 // Keep track of pointers to all child nodes
375 // std::vector< std::vector< CWordnode* >* > oNodeCache;
377 std::vector < CWordnode * >**apNodeCache;
379 apNodeCache = new std::vector < CWordnode * >*[oSymbols.size()];
381 for(unsigned int i(0); i < oSymbols.size(); ++i)
382 apNodeCache[i] = new std::vector < CWordnode * >;
384 // FIXME - remember to delete member vectors when we're done
386 // FIXME broken SymbolAlphabet().GetAlphabetPointer()->GetSymbols( &oSymbols, &(context.current_word), false );
388 // We're not storing the actual string - just a list of symbol IDs
390 while((pCurrent != NULL) && !bUpdateExclusion) {
392 // std::cout << "Incrementing" << std::endl;
394 ++(pCurrent->count);
396 int i(0);
398 // std::vector< CWordnode* > *pCurrentCache( new std::vector< CWordnode* > );
400 CWordnode *pTmp(pCurrent);
402 bUpdateExclusion = true;
404 for(std::vector < symbol >::iterator it(oSymbols.begin()); it != oSymbols.end(); ++it) {
405 int iSymbol(*it);
407 // std::cout << "Symbol " << iSymbol << std::endl;
409 CWordnode *pTmpChild(pTmp->find_symbol(iSymbol));
411 // std::cout << "pTmpChild: " << pTmpChild << std::endl;
413 if(pTmpChild == NULL) {
414 // We don't already have this child, so add a new node
416 pTmpChild = m_NodeAlloc.Alloc();
417 pTmpChild->sbl = iSymbol;
418 pTmpChild->next = pTmp->child;
419 pTmp->child = pTmpChild;
421 bUpdateExclusion = false;
423 // Newly allocated child already has a count of one, so no need to increment it explicitly
426 else {
427 if(pTmpChild->count == 0)
428 bUpdateExclusion = false;
429 ++(pTmpChild->count);
432 apNodeCache[i]->push_back(pTmpChild);
433 ++i;
434 pTmp = pTmpChild;
438 pCurrent = pCurrent->vine;
440 // std::cout << "foo: " << pCurrent << " " << bUpdateExclusion << std::endl;
444 // Now we need to go through and fix up the vine pointers
446 // for( std::vector< std::vector< CWordnode* >* >::iterator it( oNodeCache.begin() ); it != oNodeCache.end(); ++it ) {
447 for(unsigned int i(0); i < oSymbols.size(); ++i) {
449 CWordnode *pPreviousNode(NULL); // Start with a NULL pointer
451 for(std::vector < CWordnode * >::reverse_iterator it2(apNodeCache[i]->rbegin()); it2 != apNodeCache[i]->rend(); ++it2) {
452 (*it2)->vine = pPreviousNode;
453 pPreviousNode = (*it2);
456 delete apNodeCache[i];
459 delete apNodeCache;
463 // Collapse down word part regardless of whether we're learning or not
465 int oldnextid(nextid);
467 int iNewSymbol(lookup_word(context.current_word));
469 // Insert into the spelling model if this is a new word
471 if((nextid > oldnextid) || (GetBoolParameter(BP_LM_LETTER_EXCLUSION))) {
473 context.m_pSpellingModel->ReleaseContext(context.oSpellingContext);
474 context.oSpellingContext = context.m_pSpellingModel->CreateEmptyContext();
476 for(std::vector < int >::iterator it(oSymbols.begin()); it != oSymbols.end(); ++it) {
477 context.m_pSpellingModel->LearnSymbol(context.oSpellingContext, *it);
482 CWordnode *pTmp(context.word_head);
483 CWordnode *pTmpChild;
484 CWordnode *pTmpVine(NULL);
486 // std::cout << "pTmp is " << pTmp << std::endl;
488 int iUpdateExclusion(1);
492 pTmpChild = AddSymbolToNode(pTmp, iNewSymbol, &iUpdateExclusion, false); // FIXME - might have added a new node here, so fix up vine pointers.
494 // std::cout << "New node: " << pTmpChild << std::endl;
496 context.word_head = pTmpChild;
497 ++context.word_order;
498 pTmpVine = pTmpChild;
499 pTmp = pTmp->vine;
502 while(pTmp != NULL) {
504 // std::cout << "pTmp is " << pTmp << std::endl;
506 pTmpChild = AddSymbolToNode(pTmp, iNewSymbol, &iUpdateExclusion, false); // FIXME - might have added a new node here, so fix up vine pointers.
508 // std::cout << "New node: " << pTmpChild << std::endl;
510 if(pTmpVine)
511 pTmpVine->vine = pTmpChild;
513 pTmpVine = pTmpChild;
514 pTmp = pTmp->vine;
517 pTmpVine->vine = m_pRoot;
519 // Finally get rid of the letter part of the context
521 // std::cout << "Changed head to " << context.word_head << std::endl;
523 while(context.word_order > 2) {
524 context.word_head = context.word_head->vine;
525 // std::cout << " * Followed vine to head to " << context.word_head << std::endl;
526 --(context.word_order);
529 context.head = context.word_head;
530 context.order = context.word_order;
531 context.current_word = "";
533 context.m_pSpellingModel->ReleaseContext(context.oSpellingContext);
534 context.oSpellingContext = context.m_pSpellingModel->CreateEmptyContext();
538 // if( wordidx == 1 ) {
539 // ofstream ofile( "graph.dot" );
541 // ofile << "digraph G {" << std::endl;
543 // m_pRoot->RecursiveDump( ofile );
545 // ofile << "}" << std::endl;
547 // ofile.close();
549 // exit(0);
550 // }
552 ++wordidx;
556 void CWordLanguageModel::LearnSymbol(Context c, int Symbol) {
557 CWordContext & context = *(CWordContext *) (c);
558 AddSymbol(context, Symbol, true);
561 /// add symbol to the context creates new nodes, updates counts and
562 /// leaves 'context' at the new context
564 void CWordLanguageModel::AddSymbol(CWordLanguageModel::CWordContext &context, symbol sym, bool bLearn) {
565 DASHER_ASSERT(sym >= 0 && sym < GetSize());
567 if( context.oSpellingProbs.size() != 0 )
568 context.m_dSpellingFactor *= context.oSpellingProbs[sym] / static_cast < double >(context.m_iSpellingNorm);
570 // Update the context for the spelling model;
572 context.m_pSpellingModel->EnterSymbol(context.oSpellingContext, sym);
574 // Add the symbol to the letter part of the context. Note that we don't do any learning at this stage
576 CWordnode *pTmp(context.head); // Current node
577 CWordnode *pTmpVine; // Child created last time around (for vine pointers)
579 // Context head is a special case so that we can increment order etc.
581 int foo2(1);
583 // std::cout << "aa: " << pTmp << " " << m_pRoot << std::endl;
585 pTmpVine = AddSymbolToNode(pTmp, sym, &foo2, false); // Last parameter is whether to learn or not
587 context.head = pTmpVine;
588 ++context.order;
590 pTmp = pTmp->vine;
591 CWordnode *pTmpNew; // Child created this time around
593 while(pTmp != NULL) {
595 int foo(1);
597 pTmpNew = AddSymbolToNode(pTmp, sym, &foo, false);
599 // Connect up vine pointers if necessary
601 if(pTmpVine) {
602 pTmpVine->vine = pTmpNew;
605 pTmpVine = pTmpNew;
607 // Follow vine pointers
609 pTmp = pTmp->vine;
613 pTmpVine->vine = NULL; // (not sure if this is needed)
615 // Add the new symbol to the string representation too - note that
616 // string is actually a series of integers, not the actual symbols -
617 // doesn't matter as long as we're consistent and unique.
619 char sbuffer[5];
620 snprintf(sbuffer, 5, "%04d", sym);
621 context.current_word.append(sbuffer);
623 // Collapse the context (with learning) if we've just entered a space
624 // FIXME - we need to generalise this for more languages.
626 if(sym == SymbolAlphabet().GetSpaceSymbol()) {
627 CollapseContext(context, bLearn);
628 context.m_dSpellingFactor = 1.0;
633 /////////////////////////////////////////////////////////////////////
634 // update context with symbol 'Symbol'
636 void CWordLanguageModel::EnterSymbol(Context c, int Symbol) {
637 // Same as AddSymbol but without learning in CollapseContext
639 CWordContext & context = *(CWordContext *) (c);
640 AddSymbol(context, Symbol, false);