3 /////////////////////////////////////////////////////////////////////////////
5 // Copyright (c) 1999-2005 David Ward
7 /////////////////////////////////////////////////////////////////////////////
9 #include "../../Common/Common.h"
10 #include "PPMLanguageModel.h"
17 using namespace Dasher
;
20 // Track memory leaks on Windows to the line that new'd the memory
23 #define DEBUG_NEW new( _NORMAL_BLOCK, THIS_FILE, __LINE__ )
26 static char THIS_FILE
[] = __FILE__
;
30 /////////////////////////////////////////////////////////////////////
32 CPPMLanguageModel::CPPMLanguageModel(Dasher::CEventHandler
*pEventHandler
, CSettingsStore
*pSettingsStore
, const CSymbolAlphabet
&SymbolAlphabet
)
33 :CLanguageModel(pEventHandler
, pSettingsStore
, SymbolAlphabet
), m_iMaxOrder(4), NodesAllocated(0), m_NodeAlloc(8192), m_ContextAlloc(1024) {
34 m_pRoot
= m_NodeAlloc
.Alloc();
37 m_pRootContext
= m_ContextAlloc
.Alloc();
38 m_pRootContext
->head
= m_pRoot
;
39 m_pRootContext
->order
= 0;
41 // Cache the result of update exclusion - otherwise we have to look up a lot when training, which is slow
43 // FIXME - this should be a boolean parameter
45 bUpdateExclusion
= ( GetLongParameter(LP_LM_UPDATE_EXCLUSION
) !=0 );
49 /////////////////////////////////////////////////////////////////////
51 CPPMLanguageModel::~CPPMLanguageModel() {
54 /////////////////////////////////////////////////////////////////////
55 // Get the probability distribution at the context
57 void CPPMLanguageModel::GetProbs(Context context
, std::vector
<unsigned int> &probs
, int norm
) const {
58 const CPPMContext
*ppmcontext
= (const CPPMContext
*)(context
);
60 DASHER_ASSERT(m_setContexts
.count(ppmcontext
) > 0);
62 int iNumSymbols
= GetSize();
64 probs
.resize(iNumSymbols
);
66 std::vector
< bool > exclusions(iNumSymbols
);
69 for(i
= 0; i
< iNumSymbols
; i
++) {
71 exclusions
[i
] = false;
74 // bool doExclusion = GetLongParameter( LP_LM_ALPHA );
75 bool doExclusion
= 0; //FIXME
77 int alpha
= GetLongParameter( LP_LM_ALPHA
);
78 int beta
= GetLongParameter( LP_LM_BETA
);
80 unsigned int iToSpend
= norm
;
82 CPPMnode
*pTemp
= ppmcontext
->head
;
87 CPPMnode
*pSymbol
= pTemp
->child
;
89 int sym
= pSymbol
->symbol
;
90 if(!(exclusions
[sym
] && doExclusion
))
91 iTotal
+= pSymbol
->count
;
92 pSymbol
= pSymbol
->next
;
96 unsigned int size_of_slice
= iToSpend
;
97 pSymbol
= pTemp
->child
;
99 if(!(exclusions
[pSymbol
->symbol
] && doExclusion
)) {
100 exclusions
[pSymbol
->symbol
] = 1;
102 unsigned int p
= static_cast < myint
> (size_of_slice
) * (100 * pSymbol
->count
- beta
) / (100 * iTotal
+ alpha
);
104 probs
[pSymbol
->symbol
] += p
;
107 // Usprintf(debug,TEXT("sym %u counts %d p %u tospend %u \n"),sym,s->count,p,tospend);
108 // DebugOutput(debug);
109 pSymbol
= pSymbol
->next
;
115 unsigned int size_of_slice
= iToSpend
;
118 for(i
= 1; i
< iNumSymbols
; i
++)
119 if(!(exclusions
[i
] && doExclusion
))
122 // std::ostringstream str;
123 // for (sym=0;sym<modelchars;sym++)
124 // str << probs[sym] << " ";
126 // DASHER_TRACEOUTPUT("probs %s",str.str().c_str());
128 // std::ostringstream str2;
129 // for (sym=0;sym<modelchars;sym++)
130 // str2 << valid[sym] << " ";
131 // str2 << std::endl;
132 // DASHER_TRACEOUTPUT("valid %s",str2.str().c_str());
134 for(i
= 1; i
< iNumSymbols
; i
++) {
135 if(!(exclusions
[i
] && doExclusion
)) {
136 unsigned int p
= size_of_slice
/ symbolsleft
;
142 int iLeft
= iNumSymbols
-1;
144 for(int j
= 1; j
< iNumSymbols
; ++j
) {
145 unsigned int p
= iToSpend
/ iLeft
;
151 DASHER_ASSERT(iToSpend
== 0);
154 void CPPMLanguageModel::AddSymbol(CPPMLanguageModel::CPPMContext
&context
, int sym
)
155 // add symbol to the context
156 // creates new nodes, updates counts
157 // and leaves 'context' at the new context
159 // Ignore attempts to add the root symbol
164 DASHER_ASSERT(sym
>= 0 && sym
< GetSize());
166 CPPMnode
*vineptr
, *temp
;
169 temp
= context
.head
->vine
;
170 context
.head
= AddSymbolToNode(context
.head
, sym
, &updatecnt
);
171 vineptr
= context
.head
;
175 vineptr
->vine
= AddSymbolToNode(temp
, sym
, &updatecnt
);
176 vineptr
= vineptr
->vine
;
179 vineptr
->vine
= m_pRoot
;
181 //m_iMaxOrder = LanguageModelParams()->GetValue(std::string("LMMaxOrder"));
182 m_iMaxOrder
= GetLongParameter( LP_LM_MAX_ORDER
);
184 while(context
.order
> m_iMaxOrder
) {
185 context
.head
= context
.head
->vine
;
190 /////////////////////////////////////////////////////////////////////
191 // Update context with symbol 'Symbol'
193 void CPPMLanguageModel::EnterSymbol(Context c
, int Symbol
) {
197 DASHER_ASSERT(Symbol
>= 0 && Symbol
< GetSize());
199 CPPMLanguageModel::CPPMContext
& context
= *(CPPMContext
*) (c
);
203 while(context
.head
) {
205 if(context
.order
< m_iMaxOrder
) { // Only try to extend the context if it's not going to make it too long
206 find
= context
.head
->find_symbol(Symbol
);
210 // Usprintf(debug,TEXT("found context %x order %d\n"),head,order);
211 // DebugOutput(debug);
213 // std::cout << context.order << std::endl;
218 // If we can't extend the current context, follow vine pointer to shorten it and try again
221 context
.head
= context
.head
->vine
;
224 if(context
.head
== 0) {
225 context
.head
= m_pRoot
;
229 // std::cout << context.order << std::endl;
233 /////////////////////////////////////////////////////////////////////
235 void CPPMLanguageModel::LearnSymbol(Context c
, int Symbol
) {
240 DASHER_ASSERT(Symbol
>= 0 && Symbol
< GetSize());
241 CPPMLanguageModel::CPPMContext
& context
= *(CPPMContext
*) (c
);
242 AddSymbol(context
, Symbol
);
245 void CPPMLanguageModel::dumpSymbol(int sym
) {
246 if((sym
<= 32) || (sym
>= 127))
252 void CPPMLanguageModel::dumpString(char *str
, int pos
, int len
)
253 // Dump the string STR starting at position POS
257 for(p
= pos
; p
< pos
+ len
; p
++) {
259 if((cc
<= 31) || (cc
>= 127))
266 void CPPMLanguageModel::dumpTrie(CPPMLanguageModel::CPPMnode
*t
, int d
)
267 // diagnostic display of the PPM trie from node t and deeper
274 Usprintf( debug,TEXT("%5d %7x "), d, t );
275 //TODO: Uncomment this when headers sort out
276 //DebugOutput(debug);
277 if (t < 0) // pointer to input
280 Usprintf(debug,TEXT( " %3d %5d %7x %7x %7x <"), t->symbol,t->count, t->vine, t->child, t->next );
281 //TODO: Uncomment this when headers sort out
282 //DebugOutput(debug);
285 dumpString( dumpTrieStr, 0, d );
286 Usprintf( debug,TEXT(">\n") );
287 //TODO: Uncomment this when headers sort out
288 //DebugOutput(debug);
294 dumpTrieStr [d] = sym;
302 void CPPMLanguageModel::dump()
303 // diagnostic display of the whole PPM trie
308 Usprintf(debug,TEXT( "Dump of Trie : \n" ));
309 //TODO: Uncomment this when headers sort out
310 //DebugOutput(debug);
311 Usprintf(debug,TEXT( "---------------\n" ));
312 //TODO: Uncomment this when headers sort out
313 //DebugOutput(debug);
314 Usprintf( debug,TEXT( "depth node symbol count vine child next context\n") );
315 //TODO: Uncomment this when headers sort out
316 //DebugOutput(debug);
318 Usprintf( debug,TEXT( "---------------\n" ));
319 //TODO: Uncomment this when headers sort out
320 //DebugOutput(debug);
321 Usprintf(debug,TEXT( "\n" ));
322 //TODO: Uncomment this when headers sort out
323 //DebugOutput(debug);
327 ////////////////////////////////////////////////////////////////////////
328 /// PPMnode definitions
329 ////////////////////////////////////////////////////////////////////////
331 CPPMLanguageModel::CPPMnode
* CPPMLanguageModel::CPPMnode::find_symbol(int sym
) const
332 // see if symbol is a child of node
334 // printf("finding symbol %d at node %d\n",sym,node->id);
335 CPPMnode
*found
= child
;
338 if(found
->symbol
== sym
) {
346 CPPMLanguageModel::CPPMnode
* CPPMLanguageModel::AddSymbolToNode(CPPMnode
*pNode
, int sym
, int *update
) {
347 CPPMnode
*pReturn
= pNode
->find_symbol(sym
);
349 // std::cout << sym << ",";
351 if(pReturn
!= NULL
) {
352 // std::cout << "Using existing node" << std::endl;
354 // if (*update || (LanguageModelParams()->GetValue("LMUpdateExclusion") == 0) )
355 if(*update
|| !bUpdateExclusion
) { // perform update exclusions
362 // std::cout << "Creating new node" << std::endl;
364 pReturn
= m_NodeAlloc
.Alloc(); // count is initialized to 1
365 pReturn
->symbol
= sym
;
366 pReturn
->next
= pNode
->child
;
367 pNode
->child
= pReturn
;
375 struct BinaryRecord
{
380 unsigned short int m_iCount
;
384 bool CPPMLanguageModel::WriteToFile(std::string strFilename
) {
386 std::map
<CPPMnode
*, int> mapIdx
;
387 int iNextIdx(1); // Index of 0 means NULL;
389 std::ofstream
oOutputFile(strFilename
.c_str());
391 RecursiveWrite(m_pRoot
, &mapIdx
, &iNextIdx
, &oOutputFile
);
398 bool CPPMLanguageModel::RecursiveWrite(CPPMnode
*pNode
, std::map
<CPPMnode
*, int> *pmapIdx
, int *pNextIdx
, std::ofstream
*pOutputFile
) {
404 sBR
.m_iIndex
= GetIndex(pNode
, pmapIdx
, pNextIdx
);
405 sBR
.m_iChild
= GetIndex(pNode
->child
, pmapIdx
, pNextIdx
);
406 sBR
.m_iNext
= GetIndex(pNode
->next
, pmapIdx
, pNextIdx
);
407 sBR
.m_iVine
= GetIndex(pNode
->vine
, pmapIdx
, pNextIdx
);
408 sBR
.m_iCount
= pNode
->count
;
409 sBR
.m_iSymbol
= pNode
->symbol
;
411 pOutputFile
->write(reinterpret_cast<char*>(&sBR
), sizeof(BinaryRecord
));
413 CPPMnode
*pCurrentChild(pNode
->child
);
415 while(pCurrentChild
!= NULL
) {
416 RecursiveWrite(pCurrentChild
, pmapIdx
, pNextIdx
, pOutputFile
);
417 pCurrentChild
= pCurrentChild
->next
;
423 int CPPMLanguageModel::GetIndex(CPPMnode
*pAddr
, std::map
<CPPMnode
*, int> *pmapIdx
, int *pNextIdx
) {
429 std::map
<CPPMnode
*, int>::iterator
it(pmapIdx
->find(pAddr
));
431 if(it
== pmapIdx
->end()) {
433 pmapIdx
->insert(std::pair
<CPPMnode
*, int>(pAddr
, iIndex
));
443 bool CPPMLanguageModel::ReadFromFile(std::string strFilename
) {
445 std::ifstream
oInputFile(strFilename
.c_str());
446 std::map
<int, CPPMnode
*> oMap
;
448 bool bStarted(false);
450 while(!oInputFile
.eof()) {
451 oInputFile
.read(reinterpret_cast<char *>(&sBR
), sizeof(BinaryRecord
));
453 CPPMnode
*pCurrent(GetAddress(sBR
.m_iIndex
, &oMap
));
455 pCurrent
->child
= GetAddress(sBR
.m_iChild
, &oMap
);
456 pCurrent
->next
= GetAddress(sBR
.m_iNext
, &oMap
);
457 pCurrent
->vine
= GetAddress(sBR
.m_iVine
, &oMap
);
458 pCurrent
->count
= sBR
.m_iCount
;
459 pCurrent
->symbol
= sBR
.m_iSymbol
;
472 CPPMLanguageModel::CPPMnode
*CPPMLanguageModel::GetAddress(int iIndex
, std::map
<int, CPPMnode
*> *pMap
) {
473 std::map
<int, CPPMnode
*>::iterator
it(pMap
->find(iIndex
));
475 if(it
== pMap
->end()) {
477 pNewNode
= m_NodeAlloc
.Alloc();
478 pMap
->insert(std::pair
<int, CPPMnode
*>(iIndex
, pNewNode
));