3 /////////////////////////////////////////////////////////////////////////////
5 // Copyright (c) 1999-2004 David Ward
7 /////////////////////////////////////////////////////////////////////////////
9 #include "../../Common/Common.h"
10 #include "WordLanguageModel.h"
11 #include "PPMLanguageModel.h"
18 using namespace Dasher
;
21 // static TCHAR debug[256];
22 typedef unsigned long ulong
;
25 #define snprintf _snprintf
28 // Track memory leaks on Windows to the line that new'd the memory
30 #ifdef _DEBUG_MEMLEAKS
31 #define DEBUG_NEW new( _NORMAL_BLOCK, THIS_FILE, __LINE__ )
34 static char THIS_FILE
[] = __FILE__
;
38 ///////////////////////////////////////////////////////////////////
40 void CWordLanguageModel::CWordContext::dump()
43 // TODO uncomment this when headers sorted out
45 //Usprintf(debug,TEXT("head %x order %d\n"),head,order);
49 ////////////////////////////////////////////////////////////////////////
50 /// Wordnode definitions
51 ////////////////////////////////////////////////////////////////////////
53 /// Return the child of a node with a given symbol, or NULL if there is no child with that symbol yet
55 CWordLanguageModel::CWordnode
* CWordLanguageModel::CWordnode::find_symbol(int sym
) const {
56 CWordnode
*found
= child
;
65 void CWordLanguageModel::CWordnode::RecursiveDump(std::ofstream
&file
) {
67 CWordnode
*pCurrentChild(child
);
69 file
<< "\"" << this << "\" [label=\"" << this->sbl
<< "\\n" << this->count
<< "\"]" << std::endl
;
71 file
<< "\"" << this << "\" -> \"" << vine
<< "\" [style=dashed]" << std::endl
;
73 while(pCurrentChild
) {
74 file
<< "\"" << this << "\" -> \"" << pCurrentChild
<< "\"" << std::endl
;
75 pCurrentChild
->RecursiveDump(file
);
76 pCurrentChild
= pCurrentChild
->next
;
80 CWordLanguageModel::CWordnode
* CWordLanguageModel::AddSymbolToNode(CWordnode
*pNode
, symbol sym
, int *update
, bool bLearn
) {
82 // FIXME - need to implement bLearn
84 CWordnode
*pReturn
= pNode
->find_symbol(sym
);
89 // std::cout << "USHRT_MAX: " << USHRT_MAX << " " << bLearn << std::endl;
91 // if( (pReturn->count < USHRT_MAX) && bLearn ) // Truncate counts at storage limit
92 if(bLearn
) // Truncate counts at storage limit
99 pReturn
= m_NodeAlloc
.Alloc(); // count is initialized to 1
101 pReturn
->next
= pNode
->child
;
102 pNode
->child
= pReturn
;
105 --(pReturn
->count
); // FIXME - in the long term, don't allocate
106 // nodes if we're not learning, but should be
110 // std::cout << pReturn->count << std::endl;
117 /////////////////////////////////////////////////////////////////////
118 // CWordLanguageModel defs
119 /////////////////////////////////////////////////////////////////////
121 CWordLanguageModel::CWordLanguageModel(Dasher::CEventHandler
*pEventHandler
, CSettingsStore
*pSettingsStore
,
122 const CSymbolAlphabet
&Alphabet
)
123 :CLanguageModel(pEventHandler
, pSettingsStore
, Alphabet
), NodesAllocated(0),
124 max_order(2), m_NodeAlloc(8192), m_ContextAlloc(1024) {
126 // Construct a root node for the trie
128 m_pRoot
= m_NodeAlloc
.Alloc();
132 // Create a spelling model
134 pSpellingModel
= new CPPMLanguageModel(m_pEventHandler
, m_pSettingsStore
, Alphabet
);
136 // Construct a root context
138 m_rootcontext
= new CWordContext(m_pRoot
, 0);
140 m_rootcontext
->m_pSpellingModel
= pSpellingModel
;
141 m_rootcontext
->oSpellingContext
= pSpellingModel
->CreateEmptyContext();
145 nextid
= iWordStart
; // Start of indices for words - may need to increase this for *really* large alphabets
148 if(GetBoolParameter(BP_LM_DICTIONARY
)) {
150 std::ifstream
DictFile("/usr/share/dict/words"); // FIXME - hardcoded paths == bad
152 std::string CurrentWord
;
154 while(!DictFile
.eof()) {
155 DictFile
>> CurrentWord
;
157 CurrentWord
= CurrentWord
+ " ";
159 // std::cout << CurrentWord << std::endl;
161 CPPMLanguageModel::Context
TempContext(pSpellingModel
->CreateEmptyContext());
163 // std::cout << SymbolAlphabet().GetAlphabetPointer() << std::endl;
165 std::vector
< symbol
> Symbols
;
166 SymbolAlphabet().GetAlphabetPointer()->GetSymbols(&Symbols
, &CurrentWord
, false);
168 for(std::vector
< symbol
>::iterator
it(Symbols
.begin()); it
!= Symbols
.end(); ++it
) {
169 pSpellingModel
->LearnSymbol(TempContext
, *it
);
172 pSpellingModel
->ReleaseContext(TempContext
);
177 // oSpellingContext = pSpellingModel->CreateEmptyContext();
183 CWordLanguageModel::~CWordLanguageModel() {
185 delete m_rootcontext
;
186 delete pSpellingModel
;
188 // A non-recursive node deletion algorithm using a stack
189 /* std::stack<CWordnode*> deletenodes;
190 deletenodes.push(m_pRoot);
191 while (!deletenodes.empty())
193 CWordnode* temp = deletenodes.top();
202 deletenodes.push(temp->child);
214 int CWordLanguageModel::lookup_word(const std::string
&w
) {
223 int CWordLanguageModel::lookup_word_const(const std::string
&w
) const {
224 std::cout
<< "Looking up (const) " << w
<< std::endl
;
226 return dict
.find(w
)->second
;
229 /////////////////////////////////////////////////////////////////////
230 // get the probability distribution at the context
232 void CWordLanguageModel::GetProbs(Context context
, std::vector
<unsigned int> &probs
, int norm
) const {
233 // Got rid of const below
235 CWordLanguageModel::CWordContext
* wordcontext
= (CWordContext
*) (context
);
237 // Make sure that the probability vector has the right length
239 int iNumSymbols
= GetSize();
240 probs
.resize(iNumSymbols
);
242 // For the prototype work with double precision to make things easier to normalise
244 std::vector
< double >dProbs(iNumSymbols
);
246 for(std::vector
< double >::iterator
it(dProbs
.begin()); it
!= dProbs
.end(); ++it
)
249 double alpha
= GetLongParameter(LP_LM_WORD_ALPHA
) / 100.0;
250 // double beta = LanguageModelParams()->GetValue( std::string( "LMBeta" ) )/100.0;
252 // Ignore beta for now - we'll need to know how many different words have been seen, not just the total count.
254 double dToSpend(1.0);
256 CWordnode
*pTmp
= wordcontext
->head
;
257 CWordnode
*pTmpWord
= wordcontext
->word_head
;
259 // We'll assume that these stay in sync for now - maybe do something more robust later.
263 // Get the total count from the word node
265 int iTotal(pTmpWord
->count
);
269 CWordnode
*pTmpChild(pTmp
->child
);
272 // make sure we only get child nodes which correspond
273 // to symbols (not words).
275 if(pTmpChild
->sbl
< iWordStart
) {
279 if(pTmpChild
->count
> 0)
280 dP
= dToSpend
* (pTmpChild
->count
) / static_cast < double >(iTotal
+ alpha
);
284 dProbs
[pTmpChild
->sbl
] += dP
;
287 pTmpChild
= pTmpChild
->next
;
292 dToSpend
*= alpha
/ static_cast < double >(iTotal
+ alpha
);
295 pTmpWord
= pTmpWord
->vine
;
299 // Get probabilities from the spelling model (note we cache these for later)
301 wordcontext
->m_iSpellingNorm
= norm
;
303 int iSpellingNorm(wordcontext
->m_iSpellingNorm
);
305 wordcontext
->m_pSpellingModel
->GetProbs(wordcontext
->oSpellingContext
, wordcontext
->oSpellingProbs
, iSpellingNorm
);
309 for(int i(0); i
< iNumSymbols
; ++i
) {
310 dProbs
[i
] += wordcontext
->m_dSpellingFactor
* wordcontext
->oSpellingProbs
[i
] / static_cast < double >(wordcontext
->m_iSpellingNorm
);
314 // Convert back to integer representation
318 for(int i(0); i
< iNumSymbols
; ++i
) {
319 probs
[i
] = (unsigned int) (norm
* dProbs
[i
] / dNorm
);
320 iToSpend
-= probs
[i
];
323 // Check that we haven't got anything left over due to rounding errors:
325 int iLeft
= iNumSymbols
;
327 for(int j
= 0; j
< iNumSymbols
; ++j
) {
328 unsigned int p
= iToSpend
/ iLeft
;
334 DASHER_ASSERT(iToSpend
== 0);
337 /// Collapse the context. This also has the effect of entering a count
338 /// for the word into the word part of the model
340 void CWordLanguageModel::CollapseContext(CWordLanguageModel::CWordContext
&context
, bool bLearn
) {
342 // Letters appear at the end of the trie:
346 // If max_order = 0 then we are not keeping track of previous
347 // words, so we just collapse the letter part of the context and
350 // FIXME - not sure this will work any more - don't use this
351 // branch without checking that it's doing the right thing
353 context
.head
= m_pRoot
;
359 std::vector
< symbol
> oSymbols
;
361 for(std::string::iterator
it(context
.current_word
.begin()); it
!= context
.current_word
.end(); it
+= 4) {
362 std::string
fragment(it
, it
+ 4);
363 oSymbols
.push_back(atoi(fragment
.c_str()));
366 if(bLearn
) { // Only do this if we are learning
367 // We need to increment all substrings - start at the current context striped back to the word level
369 bool bUpdateExclusion(false); // Whether to keep going or not
371 CWordnode
*pCurrent(context
.word_head
);
373 // Keep track of pointers to all child nodes
375 // std::vector< std::vector< CWordnode* >* > oNodeCache;
377 std::vector
< CWordnode
* >**apNodeCache
;
379 apNodeCache
= new std::vector
< CWordnode
* >*[oSymbols
.size()];
381 for(unsigned int i(0); i
< oSymbols
.size(); ++i
)
382 apNodeCache
[i
] = new std::vector
< CWordnode
* >;
384 // FIXME - remember to delete member vectors when we're done
386 // FIXME broken SymbolAlphabet().GetAlphabetPointer()->GetSymbols( &oSymbols, &(context.current_word), false );
388 // We're not storing the actual string - just a list of symbol IDs
390 while((pCurrent
!= NULL
) && !bUpdateExclusion
) {
392 // std::cout << "Incrementing" << std::endl;
398 // std::vector< CWordnode* > *pCurrentCache( new std::vector< CWordnode* > );
400 CWordnode
*pTmp(pCurrent
);
402 bUpdateExclusion
= true;
404 for(std::vector
< symbol
>::iterator
it(oSymbols
.begin()); it
!= oSymbols
.end(); ++it
) {
407 // std::cout << "Symbol " << iSymbol << std::endl;
409 CWordnode
*pTmpChild(pTmp
->find_symbol(iSymbol
));
411 // std::cout << "pTmpChild: " << pTmpChild << std::endl;
413 if(pTmpChild
== NULL
) {
414 // We don't already have this child, so add a new node
416 pTmpChild
= m_NodeAlloc
.Alloc();
417 pTmpChild
->sbl
= iSymbol
;
418 pTmpChild
->next
= pTmp
->child
;
419 pTmp
->child
= pTmpChild
;
421 bUpdateExclusion
= false;
423 // Newly allocated child already has a count of one, so no need to increment it explicitly
427 if(pTmpChild
->count
== 0)
428 bUpdateExclusion
= false;
429 ++(pTmpChild
->count
);
432 apNodeCache
[i
]->push_back(pTmpChild
);
438 pCurrent
= pCurrent
->vine
;
440 // std::cout << "foo: " << pCurrent << " " << bUpdateExclusion << std::endl;
444 // Now we need to go through and fix up the vine pointers
446 // for( std::vector< std::vector< CWordnode* >* >::iterator it( oNodeCache.begin() ); it != oNodeCache.end(); ++it ) {
447 for(unsigned int i(0); i
< oSymbols
.size(); ++i
) {
449 CWordnode
*pPreviousNode(NULL
); // Start with a NULL pointer
451 for(std::vector
< CWordnode
* >::reverse_iterator
it2(apNodeCache
[i
]->rbegin()); it2
!= apNodeCache
[i
]->rend(); ++it2
) {
452 (*it2
)->vine
= pPreviousNode
;
453 pPreviousNode
= (*it2
);
456 delete apNodeCache
[i
];
463 // Collapse down word part regardless of whether we're learning or not
465 int oldnextid(nextid
);
467 int iNewSymbol(lookup_word(context
.current_word
));
469 // Insert into the spelling model if this is a new word
471 if((nextid
> oldnextid
) || (GetBoolParameter(BP_LM_LETTER_EXCLUSION
))) {
473 context
.m_pSpellingModel
->ReleaseContext(context
.oSpellingContext
);
474 context
.oSpellingContext
= context
.m_pSpellingModel
->CreateEmptyContext();
476 for(std::vector
< int >::iterator
it(oSymbols
.begin()); it
!= oSymbols
.end(); ++it
) {
477 context
.m_pSpellingModel
->LearnSymbol(context
.oSpellingContext
, *it
);
482 CWordnode
*pTmp(context
.word_head
);
483 CWordnode
*pTmpChild
;
484 CWordnode
*pTmpVine(NULL
);
486 // std::cout << "pTmp is " << pTmp << std::endl;
488 int iUpdateExclusion(1);
492 pTmpChild
= AddSymbolToNode(pTmp
, iNewSymbol
, &iUpdateExclusion
, false); // FIXME - might have added a new node here, so fix up vine pointers.
494 // std::cout << "New node: " << pTmpChild << std::endl;
496 context
.word_head
= pTmpChild
;
497 ++context
.word_order
;
498 pTmpVine
= pTmpChild
;
502 while(pTmp
!= NULL
) {
504 // std::cout << "pTmp is " << pTmp << std::endl;
506 pTmpChild
= AddSymbolToNode(pTmp
, iNewSymbol
, &iUpdateExclusion
, false); // FIXME - might have added a new node here, so fix up vine pointers.
508 // std::cout << "New node: " << pTmpChild << std::endl;
511 pTmpVine
->vine
= pTmpChild
;
513 pTmpVine
= pTmpChild
;
517 pTmpVine
->vine
= m_pRoot
;
519 // Finally get rid of the letter part of the context
521 // std::cout << "Changed head to " << context.word_head << std::endl;
523 while(context
.word_order
> 2) {
524 context
.word_head
= context
.word_head
->vine
;
525 // std::cout << " * Followed vine to head to " << context.word_head << std::endl;
526 --(context
.word_order
);
529 context
.head
= context
.word_head
;
530 context
.order
= context
.word_order
;
531 context
.current_word
= "";
533 context
.m_pSpellingModel
->ReleaseContext(context
.oSpellingContext
);
534 context
.oSpellingContext
= context
.m_pSpellingModel
->CreateEmptyContext();
538 // if( wordidx == 1 ) {
539 // ofstream ofile( "graph.dot" );
541 // ofile << "digraph G {" << std::endl;
543 // m_pRoot->RecursiveDump( ofile );
545 // ofile << "}" << std::endl;
556 void CWordLanguageModel::LearnSymbol(Context c
, int Symbol
) {
557 CWordContext
& context
= *(CWordContext
*) (c
);
558 AddSymbol(context
, Symbol
, true);
561 /// add symbol to the context creates new nodes, updates counts and
562 /// leaves 'context' at the new context
564 void CWordLanguageModel::AddSymbol(CWordLanguageModel::CWordContext
&context
, symbol sym
, bool bLearn
) {
565 DASHER_ASSERT(sym
>= 0 && sym
< GetSize());
567 if( context
.oSpellingProbs
.size() != 0 )
568 context
.m_dSpellingFactor
*= context
.oSpellingProbs
[sym
] / static_cast < double >(context
.m_iSpellingNorm
);
570 // Update the context for the spelling model;
572 context
.m_pSpellingModel
->EnterSymbol(context
.oSpellingContext
, sym
);
574 // Add the symbol to the letter part of the context. Note that we don't do any learning at this stage
576 CWordnode
*pTmp(context
.head
); // Current node
577 CWordnode
*pTmpVine
; // Child created last time around (for vine pointers)
579 // Context head is a special case so that we can increment order etc.
583 // std::cout << "aa: " << pTmp << " " << m_pRoot << std::endl;
585 pTmpVine
= AddSymbolToNode(pTmp
, sym
, &foo2
, false); // Last parameter is whether to learn or not
587 context
.head
= pTmpVine
;
591 CWordnode
*pTmpNew
; // Child created this time around
593 while(pTmp
!= NULL
) {
597 pTmpNew
= AddSymbolToNode(pTmp
, sym
, &foo
, false);
599 // Connect up vine pointers if necessary
602 pTmpVine
->vine
= pTmpNew
;
607 // Follow vine pointers
613 pTmpVine
->vine
= NULL
; // (not sure if this is needed)
615 // Add the new symbol to the string representation too - note that
616 // string is actually a series of integers, not the actual symbols -
617 // doesn't matter as long as we're consistent and unique.
620 snprintf(sbuffer
, 5, "%04d", sym
);
621 context
.current_word
.append(sbuffer
);
623 // Collapse the context (with learning) if we've just entered a space
624 // FIXME - we need to generalise this for more languages.
626 if(sym
== SymbolAlphabet().GetSpaceSymbol()) {
627 CollapseContext(context
, bLearn
);
628 context
.m_dSpellingFactor
= 1.0;
633 /////////////////////////////////////////////////////////////////////
634 // update context with symbol 'Symbol'
636 void CWordLanguageModel::EnterSymbol(Context c
, int Symbol
) {
637 // Same as AddSymbol but without learning in CollapseContext
639 CWordContext
& context
= *(CWordContext
*) (c
);
640 AddSymbol(context
, Symbol
, false);