1 // JapaneseLanguageModel.cpp
3 /////////////////////////////////////////////////////////////////////////////
5 // Copyright (c) 1999-2005 David Ward
6 // 2005 Takashi Kaburagi
8 /////////////////////////////////////////////////////////////////////////////
10 #include "../../Common/Common.h"
12 #include "JapaneseLanguageModel.h"
14 #include "KanjiConversion.h"
16 #include "KanjiConversionCanna.h"
18 #include "KanjiConversionIME.h"
26 using namespace Dasher
;
29 // Track memory leaks on Windows to the line that new'd the memory
32 #define DEBUG_NEW new( _NORMAL_BLOCK, THIS_FILE, __LINE__ )
35 static char THIS_FILE
[] = __FILE__
;
39 /////////////////////////////////////////////////////////////////////
41 CJapaneseLanguageModel::CJapaneseLanguageModel(Dasher::CEventHandler
*pEventHandler
, CSettingsStore
*pSettingsStore
, const CSymbolAlphabet
&SymbolAlphabet
)
42 :CLanguageModel(pEventHandler
, pSettingsStore
, SymbolAlphabet
), m_iMaxOrder(5), NodesAllocated(0), m_NodeAlloc(8192), m_ContextAlloc(1024) {
43 m_pRoot
= m_NodeAlloc
.Alloc();
46 m_pRootContext
= m_ContextAlloc
.Alloc();
47 m_pRootContext
->head
= m_pRoot
;
48 m_pRootContext
->order
= 0;
50 // Cache the result of update exclusion - otherwise we have to look up a lot when training, which is slow
52 bUpdateExclusion
= ( GetLongParameter(LP_LM_UPDATE_EXCLUSION
) !=0 );
56 /////////////////////////////////////////////////////////////////////
58 CJapaneseLanguageModel::~CJapaneseLanguageModel() {
61 /////////////////////////////////////////////////////////////////////
62 // Get the probability distribution at the context
64 void CJapaneseLanguageModel::GetProbs(Context context
, std::vector
<unsigned int> &probs
, int norm
) const {
65 CPPMContext
*ppmcontext
= (CPPMContext
*) (context
);
67 int iNumSymbols
= GetSize();
69 probs
.resize(iNumSymbols
);
71 std::vector
< bool > exclusions(iNumSymbols
);
76 for(i
= 0; i
< iNumSymbols
; i
++) {
78 exclusions
[i
] = false;
81 bool doExclusion
= 0; //FIXME
83 int alpha
= GetLongParameter( LP_LM_ALPHA
);
84 int beta
= GetLongParameter( LP_LM_BETA
);
87 unsigned int iToSpend
= norm
;
89 CPPMnode
*pTemp
= ppmcontext
->head
;
91 bool has_convert_symbol
= 0; // Flag to show if a conversion symbol appears in the history
92 std::vector
< symbol
> hiragana
; // Hiragana sequence to be converted
93 std::vector
< std::vector
< symbol
> >candidate
; // Temp list for candidates
94 unsigned int kanji_pos
= 0;
96 // Initialize all probabilities to 0
97 for(ui
= 0; ui
< probs
.size(); ui
++) {
101 // Look for a "Start Conversion" character in history
102 for(ui
= 0; ui
< ppmcontext
->history
.size(); ui
++) {
103 if(ppmcontext
->history
[ui
] == GetEndConversionSymbol()) {
104 // If "End Conversion" character was found, clear the history
105 ppmcontext
->history
.clear();
108 else if(ppmcontext
->history
[ui
] == GetStartConversionSymbol()) {
109 // We found the "Start Conversion" symbol!
111 //TODO: Write a conversion code here
113 CKanjiConversionCanna canna
;
115 CKanjiConversionIME canna
;
117 CKanjiConversion canna
;
119 //CKanjiConversion canna;
120 std::vector
< std::string
> cand_list
;
122 // convert symbols to string
124 for(unsigned int j(0); j
< hiragana
.size(); j
++) {
125 str
+= GetText(hiragana
[j
]);
128 // do kanji conversion
129 canna
.ConvertKanji(str
);
131 // create candidate list
132 for(unsigned int j(0); j
< canna
.phrase
.size(); j
++) {
133 std::vector
< std::string
> tmp_cand_list
;
134 tmp_cand_list
= cand_list
;
135 int nCandList
= tmp_cand_list
.size();
137 for(unsigned int k(0); k
< canna
.phrase
[j
].candidate_list
.size(); k
++) {
138 //CandidateString(canna.phrase[j].candidate_list[k]);
139 //cout << "RAW:" << canna.phrase[j].candidate_list[k] << endl;
144 for(int n
= 0; n
< nCandList
; n
++) {
145 cand_list
.push_back(tmp_cand_list
[n
] + canna
.phrase
[j
].candidate_list
[k
]);
149 cand_list
.push_back(canna
.phrase
[j
].candidate_list
[k
]);
154 // std::cout << "Candidate list has " << cand_list.size() << " entries" << std::endl;
156 // convert strings to symbols
157 for(unsigned int j(0); j
< cand_list
.size(); j
++) {
158 //cout << "[" << j << "]" << cand_list[j] << endl;
159 std::vector
< symbol
> new_cand
;
160 //SetCandidateString(cand_list[j]);
161 GetSymbols(&new_cand
, &cand_list
[j
], false);
162 /*for( int k=0; k<new_cand.size(); k++ )
163 cout << GetText(new_cand[k]) << "[" << new_cand[k] << "] ";
165 candidate
.push_back(new_cand
);
167 candidate
.push_back(hiragana
); // escape to hiragana
168 has_convert_symbol
= 1;
171 else if(has_convert_symbol
&& candidate
.size()) {
172 // disable the candidate if the symbol does not match the present symbol
173 for(unsigned int j(0); j
< candidate
.size(); j
++) {
174 if(kanji_pos
< candidate
[j
].size()) {
175 if(ppmcontext
->history
[ui
] != candidate
[j
][kanji_pos
]) {
176 for(unsigned int k(0); k
< candidate
[j
].size(); k
++) {
177 candidate
[j
][k
] = GetStartConversionSymbol();
182 // Count up the kanji symbols we have seen so far
184 //cout << "Kanji Pos:" << kanji_pos << endl;
187 hiragana
.push_back(ppmcontext
->history
[ui
]);
192 //== Kanji Candidates
193 if(has_convert_symbol
&& candidate
.size()) {
194 for(ui
= 0; ui
< probs
.size(); ui
++)
197 // assign a large probability for the candidate
198 int candidate_rank
= 1;
199 for(ui
= 0; ui
< candidate
.size(); ui
++) {
200 /*cout << "Cand" << ui << ":";
201 for( int j=0; j<candidate[ui].size(); j++ ){
202 cout << GetText( candidate[ui][j] );
206 if(kanji_pos
< candidate
[ui
].size()) { // check if kanji_pos is valid in present candidate
207 //cout << candidate_rank << ":" << GetText(candidate[ui][kanji_pos]) << "[" << kanji_pos << "]" << endl;
208 if(candidate
[ui
][kanji_pos
] != GetStartConversionSymbol() && !exclusions
[candidate
[ui
][kanji_pos
]]) { // check if present candidate is enabled
209 /*cout << "Selected:";
210 for( int hoge = 0; hoge<candidate[ui].size(); hoge++ ){
211 cout << GetText(candidate[ui][hoge]) << "(" << candidate[ui][hoge] << ")";
215 uint32 p
= (uint32
) ((double)iToSpend
/ ((candidate_rank
+ 15) * (candidate_rank
+ 16))); // a large probability
216 probs
[candidate
[ui
][kanji_pos
]] += p
;
218 exclusions
[candidate
[ui
][kanji_pos
]] = 1;
224 if(candidate_rank
> 1) {
225 unsigned int total
= 0;
226 for(ui
= 0; ui
< probs
.size(); ui
++) {
231 for(unsigned int ui(0); ui
< probs
.size(); ui
++) {
233 //cout << GetText(ui) << " " << probs[ui] << " -> ";
234 probs
[ui
] = (uint32
) (((double)norm
/ (double)total
) * (double)probs
[ui
]);
235 iToSpend
-= probs
[ui
];
236 //cout << probs[ui] << endl;
240 pTemp
= NULL
; // Skip normal PPM
243 probs
[GetEndConversionSymbol()] += iToSpend
;
245 candidate
.clear(); // conversion is finished. clear the history
246 ppmcontext
->history
.clear();
247 pTemp
= NULL
; // Skip normal PPM
254 CPPMnode
*pSymbol
= pTemp
->child
;
256 int sym
= pSymbol
->symbol
;
257 if(!(exclusions
[sym
] && doExclusion
))
258 iTotal
+= pSymbol
->count
;
259 pSymbol
= pSymbol
->next
;
263 unsigned int size_of_slice
= iToSpend
;
264 pSymbol
= pTemp
->child
;
266 if(!(exclusions
[pSymbol
->symbol
] && doExclusion
)) {
267 exclusions
[pSymbol
->symbol
] = 1;
269 unsigned int p
= static_cast < myint
> (size_of_slice
) * (100 * pSymbol
->count
- beta
) / (100 * iTotal
+ alpha
);
271 probs
[pSymbol
->symbol
] += p
;
274 // Usprintf(debug,TEXT("sym %u counts %d p %u tospend %u \n"),sym,s->count,p,tospend);
275 // DebugOutput(debug);
276 pSymbol
= pSymbol
->next
;
282 unsigned int size_of_slice
= iToSpend
;
285 for(i
= 0; i
< iNumSymbols
; i
++)
286 if(!(exclusions
[i
] && doExclusion
))
289 // std::ostringstream str;
290 // for (sym=0;sym<modelchars;sym++)
291 // str << probs[sym] << " ";
293 // DASHER_TRACEOUTPUT("probs %s",str.str().c_str());
295 // std::ostringstream str2;
296 // for (sym=0;sym<modelchars;sym++)
297 // str2 << valid[sym] << " ";
298 // str2 << std::endl;
299 // DASHER_TRACEOUTPUT("valid %s",str2.str().c_str());
301 for(i
= 0; i
< iNumSymbols
; i
++) {
302 if(!(exclusions
[i
] && doExclusion
)) {
303 unsigned int p
= size_of_slice
/ symbolsleft
;
309 int iLeft
= iNumSymbols
;
311 for(int j
= 0; j
< iNumSymbols
; ++j
) {
312 unsigned int p
= iToSpend
/ iLeft
;
318 DASHER_ASSERT(iToSpend
== 0);
321 void CJapaneseLanguageModel::AddSymbol(CJapaneseLanguageModel::CPPMContext
&context
, int sym
)
322 // add symbol to the context
323 // creates new nodes, updates counts
324 // and leaves 'context' at the new context
326 DASHER_ASSERT(sym
>= 0 && sym
<= GetSize());
328 CPPMnode
*vineptr
, *temp
;
331 temp
= context
.head
->vine
;
332 context
.head
= AddSymbolToNode(context
.head
, sym
, &updatecnt
);
333 vineptr
= context
.head
;
337 vineptr
->vine
= AddSymbolToNode(temp
, sym
, &updatecnt
);
338 vineptr
= vineptr
->vine
;
341 vineptr
->vine
= m_pRoot
;
343 // m_iMaxOrder = LanguageModelParams()->GetValue(std::string("LMMaxOrder"));
344 m_iMaxOrder
= GetLongParameter( LP_LM_MAX_ORDER
);
346 while(context
.order
> m_iMaxOrder
) {
347 context
.head
= context
.head
->vine
;
352 /////////////////////////////////////////////////////////////////////
353 // Update context with symbol 'Symbol'
355 void CJapaneseLanguageModel::EnterSymbol(Context c
, int Symbol
) {
356 DASHER_ASSERT(Symbol
>= 0 && Symbol
<= GetSize());
358 CJapaneseLanguageModel::CPPMContext
& context
= *(CPPMContext
*) (c
);
362 context
.history
.push_back(Symbol
);
363 if(context
.history
.size() > 100) {
364 context
.history
.erase(context
.history
.begin());
367 while(context
.head
) {
369 if(context
.order
< m_iMaxOrder
) { // Only try to extend the context if it's not going to make it too long
370 find
= context
.head
->find_symbol(Symbol
);
374 // Usprintf(debug,TEXT("found context %x order %d\n"),head,order);
375 // DebugOutput(debug);
377 // std::cout << context.order << std::endl;
382 // If we can't extend the current context, follow vine pointer to shorten it and try again
385 context
.head
= context
.head
->vine
;
388 if(context
.head
== 0) {
389 context
.head
= m_pRoot
;
393 // std::cout << context.order << std::endl;
397 /////////////////////////////////////////////////////////////////////
399 void CJapaneseLanguageModel::LearnSymbol(Context c
, int Symbol
) {
400 DASHER_ASSERT(Symbol
>= 0 && Symbol
<= GetSize());
402 CJapaneseLanguageModel::CPPMContext
& context
= *(CPPMContext
*) (c
);
403 AddSymbol(context
, Symbol
);
406 void CJapaneseLanguageModel::dumpSymbol(int sym
) {
407 if((sym
<= 32) || (sym
>= 127))
413 void CJapaneseLanguageModel::dumpString(char *str
, int pos
, int len
)
414 // Dump the string STR starting at position POS
418 for(p
= pos
; p
< pos
+ len
; p
++) {
420 if((cc
<= 31) || (cc
>= 127))
427 void CJapaneseLanguageModel::dumpTrie(CJapaneseLanguageModel::CPPMnode
*t
, int d
)
428 // diagnostic display of the PPM trie from node t and deeper
435 Usprintf( debug,TEXT("%5d %7x "), d, t );
436 //TODO: Uncomment this when headers sort out
437 //DebugOutput(debug);
438 if (t < 0) // pointer to input
441 Usprintf(debug,TEXT( " %3d %5d %7x %7x %7x <"), t->symbol,t->count, t->vine, t->child, t->next );
442 //TODO: Uncomment this when headers sort out
443 //DebugOutput(debug);
446 dumpString( dumpTrieStr, 0, d );
447 Usprintf( debug,TEXT(">\n") );
448 //TODO: Uncomment this when headers sort out
449 //DebugOutput(debug);
455 dumpTrieStr [d] = sym;
463 void CJapaneseLanguageModel::dump()
464 // diagnostic display of the whole PPM trie
469 Usprintf(debug,TEXT( "Dump of Trie : \n" ));
470 //TODO: Uncomment this when headers sort out
471 //DebugOutput(debug);
472 Usprintf(debug,TEXT( "---------------\n" ));
473 //TODO: Uncomment this when headers sort out
474 //DebugOutput(debug);
475 Usprintf( debug,TEXT( "depth node symbol count vine child next context\n") );
476 //TODO: Uncomment this when headers sort out
477 //DebugOutput(debug);
479 Usprintf( debug,TEXT( "---------------\n" ));
480 //TODO: Uncomment this when headers sort out
481 //DebugOutput(debug);
482 Usprintf(debug,TEXT( "\n" ));
483 //TODO: Uncomment this when headers sort out
484 //DebugOutput(debug);
488 ////////////////////////////////////////////////////////////////////////
489 /// PPMnode definitions
490 ////////////////////////////////////////////////////////////////////////
492 CJapaneseLanguageModel::CPPMnode
* CJapaneseLanguageModel::CPPMnode::find_symbol(int sym
) const
493 // see if symbol is a child of node
495 // printf("finding symbol %d at node %d\n",sym,node->id);
496 CPPMnode
*found
= child
;
499 if(found
->symbol
== sym
) {
507 CJapaneseLanguageModel::CPPMnode
* CJapaneseLanguageModel::AddSymbolToNode(CPPMnode
*pNode
, int sym
, int *update
) {
508 CPPMnode
*pReturn
= pNode
->find_symbol(sym
);
510 // std::cout << sym << ",";
512 if(pReturn
!= NULL
) {
513 // std::cout << "Using existing node" << std::endl;
515 // if (*update || (LanguageModelParams()->GetValue("LMUpdateExclusion") == 0) )
516 if(*update
|| !bUpdateExclusion
) { // perform update exclusions
523 // std::cout << "Creating new node" << std::endl;
525 pReturn
= m_NodeAlloc
.Alloc(); // count is initialized to 1
526 pReturn
->symbol
= sym
;
527 pReturn
->next
= pNode
->child
;
528 pNode
->child
= pReturn
;