tagging release
[dasher.git] / Src / DasherCore / LanguageModelling / JapaneseLanguageModel.cpp
blob4c473b065332a70635c2d69da15edcd06eae5f3d
1 // JapaneseLanguageModel.cpp
2 //
3 /////////////////////////////////////////////////////////////////////////////
4 //
5 // Copyright (c) 1999-2005 David Ward
6 // 2005 Takashi Kaburagi
7 //
8 /////////////////////////////////////////////////////////////////////////////
10 #include "../../Common/Common.h"
12 #include "JapaneseLanguageModel.h"
14 #include "KanjiConversion.h"
15 #ifdef HAVE_LIBCANNA
16 #include "KanjiConversionCanna.h"
17 #elif WIN32
18 #include "KanjiConversionIME.h"
19 #endif
21 #include <math.h>
22 #include <stack>
23 #include <sstream>
24 #include <iostream>
26 using namespace Dasher;
27 using namespace std;
29 // Track memory leaks on Windows to the line that new'd the memory
30 #ifdef _WIN32
31 #ifdef _DEBUG
32 #define DEBUG_NEW new( _NORMAL_BLOCK, THIS_FILE, __LINE__ )
33 #define new DEBUG_NEW
34 #undef THIS_FILE
35 static char THIS_FILE[] = __FILE__;
36 #endif
37 #endif
39 /////////////////////////////////////////////////////////////////////
41 CJapaneseLanguageModel::CJapaneseLanguageModel(Dasher::CEventHandler *pEventHandler, CSettingsStore *pSettingsStore, const CSymbolAlphabet &SymbolAlphabet)
42 :CLanguageModel(pEventHandler, pSettingsStore, SymbolAlphabet), m_iMaxOrder(5), NodesAllocated(0), m_NodeAlloc(8192), m_ContextAlloc(1024) {
43 m_pRoot = m_NodeAlloc.Alloc();
44 m_pRoot->symbol = -1;
46 m_pRootContext = m_ContextAlloc.Alloc();
47 m_pRootContext->head = m_pRoot;
48 m_pRootContext->order = 0;
50 // Cache the result of update exclusion - otherwise we have to look up a lot when training, which is slow
52 bUpdateExclusion = ( GetLongParameter(LP_LM_UPDATE_EXCLUSION) !=0 );
56 /////////////////////////////////////////////////////////////////////
58 CJapaneseLanguageModel::~CJapaneseLanguageModel() {
61 /////////////////////////////////////////////////////////////////////
62 // Get the probability distribution at the context
64 void CJapaneseLanguageModel::GetProbs(Context context, std::vector<unsigned int> &probs, int norm) const {
65 CPPMContext *ppmcontext = (CPPMContext *) (context);
67 int iNumSymbols = GetSize();
69 probs.resize(iNumSymbols);
71 std::vector < bool > exclusions(iNumSymbols);
73 int i;
74 unsigned int ui;
76 for(i = 0; i < iNumSymbols; i++) {
77 probs[i] = 0;
78 exclusions[i] = false;
81 bool doExclusion = 0; //FIXME
83 int alpha = GetLongParameter( LP_LM_ALPHA );
84 int beta = GetLongParameter( LP_LM_BETA );
87 unsigned int iToSpend = norm;
89 CPPMnode *pTemp = ppmcontext->head;
91 bool has_convert_symbol = 0; // Flag to show if a conversion symbol appears in the history
92 std::vector < symbol > hiragana; // Hiragana sequence to be converted
93 std::vector < std::vector < symbol > >candidate; // Temp list for candidates
94 unsigned int kanji_pos = 0;
96 // Initialize all probabilities to 0
97 for(ui = 0; ui < probs.size(); ui++) {
98 probs[ui] = 0;
101 // Look for a "Start Conversion" character in history
102 for(ui = 0; ui < ppmcontext->history.size(); ui++) {
103 if(ppmcontext->history[ui] == GetEndConversionSymbol()) {
104 // If "End Conversion" character was found, clear the history
105 ppmcontext->history.clear();
106 break;
108 else if(ppmcontext->history[ui] == GetStartConversionSymbol()) {
109 // We found the "Start Conversion" symbol!
111 //TODO: Write a conversion code here
112 #ifdef HAVE_LIBCANNA
113 CKanjiConversionCanna canna;
114 #elif WIN32
115 CKanjiConversionIME canna;
116 #else
117 CKanjiConversion canna;
118 #endif
119 //CKanjiConversion canna;
120 std::vector < std::string > cand_list;
122 // convert symbols to string
123 string str;
124 for(unsigned int j(0); j < hiragana.size(); j++) {
125 str += GetText(hiragana[j]);
128 // do kanji conversion
129 canna.ConvertKanji(str);
131 // create candidate list
132 for(unsigned int j(0); j < canna.phrase.size(); j++) {
133 std::vector < std::string > tmp_cand_list;
134 tmp_cand_list = cand_list;
135 int nCandList = tmp_cand_list.size();
136 cand_list.clear();
137 for(unsigned int k(0); k < canna.phrase[j].candidate_list.size(); k++) {
138 //CandidateString(canna.phrase[j].candidate_list[k]);
139 //cout << "RAW:" << canna.phrase[j].candidate_list[k] << endl;
140 if(k >= 30)
141 break;
143 if(nCandList) {
144 for(int n = 0; n < nCandList; n++) {
145 cand_list.push_back(tmp_cand_list[n] + canna.phrase[j].candidate_list[k]);
148 else {
149 cand_list.push_back(canna.phrase[j].candidate_list[k]);
154 // std::cout << "Candidate list has " << cand_list.size() << " entries" << std::endl;
156 // convert strings to symbols
157 for(unsigned int j(0); j < cand_list.size(); j++) {
158 //cout << "[" << j << "]" << cand_list[j] << endl;
159 std::vector < symbol > new_cand;
160 //SetCandidateString(cand_list[j]);
161 GetSymbols(&new_cand, &cand_list[j], false);
162 /*for( int k=0; k<new_cand.size(); k++ )
163 cout << GetText(new_cand[k]) << "[" << new_cand[k] << "] ";
164 cout << endl; */
165 candidate.push_back(new_cand);
167 candidate.push_back(hiragana); // escape to hiragana
168 has_convert_symbol = 1;
171 else if(has_convert_symbol && candidate.size()) {
172 // disable the candidate if the symbol does not match the present symbol
173 for(unsigned int j(0); j < candidate.size(); j++) {
174 if(kanji_pos < candidate[j].size()) {
175 if(ppmcontext->history[ui] != candidate[j][kanji_pos]) {
176 for(unsigned int k(0); k < candidate[j].size(); k++) {
177 candidate[j][k] = GetStartConversionSymbol();
182 // Count up the kanji symbols we have seen so far
183 kanji_pos++;
184 //cout << "Kanji Pos:" << kanji_pos << endl;
186 else {
187 hiragana.push_back(ppmcontext->history[ui]);
192 //== Kanji Candidates
193 if(has_convert_symbol && candidate.size()) {
194 for(ui = 0; ui < probs.size(); ui++)
195 exclusions[ui] = 0;
197 // assign a large probability for the candidate
198 int candidate_rank = 1;
199 for(ui = 0; ui < candidate.size(); ui++) {
200 /*cout << "Cand" << ui << ":";
201 for( int j=0; j<candidate[ui].size(); j++ ){
202 cout << GetText( candidate[ui][j] );
204 cout << endl; */
206 if(kanji_pos < candidate[ui].size()) { // check if kanji_pos is valid in present candidate
207 //cout << candidate_rank << ":" << GetText(candidate[ui][kanji_pos]) << "[" << kanji_pos << "]" << endl;
208 if(candidate[ui][kanji_pos] != GetStartConversionSymbol() && !exclusions[candidate[ui][kanji_pos]]) { // check if present candidate is enabled
209 /*cout << "Selected:";
210 for( int hoge = 0; hoge<candidate[ui].size(); hoge++ ){
211 cout << GetText(candidate[ui][hoge]) << "(" << candidate[ui][hoge] << ")";
213 cout << endl; */
215 uint32 p = (uint32) ((double)iToSpend / ((candidate_rank + 15) * (candidate_rank + 16))); // a large probability
216 probs[candidate[ui][kanji_pos]] += p;
217 iToSpend -= p;
218 exclusions[candidate[ui][kanji_pos]] = 1;
219 candidate_rank++;
224 if(candidate_rank > 1) {
225 unsigned int total = 0;
226 for(ui = 0; ui < probs.size(); ui++) {
227 total += probs[ui];
229 if(total) {
230 iToSpend = norm;
231 for(unsigned int ui(0); ui < probs.size(); ui++) {
232 if(probs[ui]) {
233 //cout << GetText(ui) << " " << probs[ui] << " -> ";
234 probs[ui] = (uint32) (((double)norm / (double)total) * (double)probs[ui]);
235 iToSpend -= probs[ui];
236 //cout << probs[ui] << endl;
240 pTemp = NULL; // Skip normal PPM
242 else {
243 probs[GetEndConversionSymbol()] += iToSpend;
244 iToSpend = 0;
245 candidate.clear(); // conversion is finished. clear the history
246 ppmcontext->history.clear();
247 pTemp = NULL; // Skip normal PPM
251 while(pTemp != 0) {
252 int iTotal = 0;
254 CPPMnode *pSymbol = pTemp->child;
255 while(pSymbol) {
256 int sym = pSymbol->symbol;
257 if(!(exclusions[sym] && doExclusion))
258 iTotal += pSymbol->count;
259 pSymbol = pSymbol->next;
262 if(iTotal) {
263 unsigned int size_of_slice = iToSpend;
264 pSymbol = pTemp->child;
265 while(pSymbol) {
266 if(!(exclusions[pSymbol->symbol] && doExclusion)) {
267 exclusions[pSymbol->symbol] = 1;
269 unsigned int p = static_cast < myint > (size_of_slice) * (100 * pSymbol->count - beta) / (100 * iTotal + alpha);
271 probs[pSymbol->symbol] += p;
272 iToSpend -= p;
274 // Usprintf(debug,TEXT("sym %u counts %d p %u tospend %u \n"),sym,s->count,p,tospend);
275 // DebugOutput(debug);
276 pSymbol = pSymbol->next;
279 pTemp = pTemp->vine;
282 unsigned int size_of_slice = iToSpend;
283 int symbolsleft = 0;
285 for(i = 0; i < iNumSymbols; i++)
286 if(!(exclusions[i] && doExclusion))
287 symbolsleft++;
289 // std::ostringstream str;
290 // for (sym=0;sym<modelchars;sym++)
291 // str << probs[sym] << " ";
292 // str << std::endl;
293 // DASHER_TRACEOUTPUT("probs %s",str.str().c_str());
295 // std::ostringstream str2;
296 // for (sym=0;sym<modelchars;sym++)
297 // str2 << valid[sym] << " ";
298 // str2 << std::endl;
299 // DASHER_TRACEOUTPUT("valid %s",str2.str().c_str());
301 for(i = 0; i < iNumSymbols; i++) {
302 if(!(exclusions[i] && doExclusion)) {
303 unsigned int p = size_of_slice / symbolsleft;
304 probs[i] += p;
305 iToSpend -= p;
309 int iLeft = iNumSymbols;
311 for(int j = 0; j < iNumSymbols; ++j) {
312 unsigned int p = iToSpend / iLeft;
313 probs[j] += p;
314 --iLeft;
315 iToSpend -= p;
318 DASHER_ASSERT(iToSpend == 0);
321 void CJapaneseLanguageModel::AddSymbol(CJapaneseLanguageModel::CPPMContext &context, int sym)
322 // add symbol to the context
323 // creates new nodes, updates counts
324 // and leaves 'context' at the new context
326 DASHER_ASSERT(sym >= 0 && sym <= GetSize());
328 CPPMnode *vineptr, *temp;
329 int updatecnt = 1;
331 temp = context.head->vine;
332 context.head = AddSymbolToNode(context.head, sym, &updatecnt);
333 vineptr = context.head;
334 context.order++;
336 while(temp != 0) {
337 vineptr->vine = AddSymbolToNode(temp, sym, &updatecnt);
338 vineptr = vineptr->vine;
339 temp = temp->vine;
341 vineptr->vine = m_pRoot;
343 // m_iMaxOrder = LanguageModelParams()->GetValue(std::string("LMMaxOrder"));
344 m_iMaxOrder = GetLongParameter( LP_LM_MAX_ORDER );
346 while(context.order > m_iMaxOrder) {
347 context.head = context.head->vine;
348 context.order--;
352 /////////////////////////////////////////////////////////////////////
353 // Update context with symbol 'Symbol'
355 void CJapaneseLanguageModel::EnterSymbol(Context c, int Symbol) {
356 DASHER_ASSERT(Symbol >= 0 && Symbol <= GetSize());
358 CJapaneseLanguageModel::CPPMContext & context = *(CPPMContext *) (c);
360 CPPMnode *find;
362 context.history.push_back(Symbol);
363 if(context.history.size() > 100) {
364 context.history.erase(context.history.begin());
367 while(context.head) {
369 if(context.order < m_iMaxOrder) { // Only try to extend the context if it's not going to make it too long
370 find = context.head->find_symbol(Symbol);
371 if(find) {
372 context.order++;
373 context.head = find;
374 // Usprintf(debug,TEXT("found context %x order %d\n"),head,order);
375 // DebugOutput(debug);
377 // std::cout << context.order << std::endl;
378 return;
382 // If we can't extend the current context, follow vine pointer to shorten it and try again
384 context.order--;
385 context.head = context.head->vine;
388 if(context.head == 0) {
389 context.head = m_pRoot;
390 context.order = 0;
393 // std::cout << context.order << std::endl;
397 /////////////////////////////////////////////////////////////////////
399 void CJapaneseLanguageModel::LearnSymbol(Context c, int Symbol) {
400 DASHER_ASSERT(Symbol >= 0 && Symbol <= GetSize());
402 CJapaneseLanguageModel::CPPMContext & context = *(CPPMContext *) (c);
403 AddSymbol(context, Symbol);
406 void CJapaneseLanguageModel::dumpSymbol(int sym) {
407 if((sym <= 32) || (sym >= 127))
408 printf("<%d>", sym);
409 else
410 printf("%c", sym);
413 void CJapaneseLanguageModel::dumpString(char *str, int pos, int len)
414 // Dump the string STR starting at position POS
416 char cc;
417 int p;
418 for(p = pos; p < pos + len; p++) {
419 cc = str[p];
420 if((cc <= 31) || (cc >= 127))
421 printf("<%d>", cc);
422 else
423 printf("%c", cc);
427 void CJapaneseLanguageModel::dumpTrie(CJapaneseLanguageModel::CPPMnode *t, int d)
428 // diagnostic display of the PPM trie from node t and deeper
430 //TODO
432 dchar debug[256];
433 int sym;
434 CPPMnode *s;
435 Usprintf( debug,TEXT("%5d %7x "), d, t );
436 //TODO: Uncomment this when headers sort out
437 //DebugOutput(debug);
438 if (t < 0) // pointer to input
439 printf( " <" );
440 else {
441 Usprintf(debug,TEXT( " %3d %5d %7x %7x %7x <"), t->symbol,t->count, t->vine, t->child, t->next );
442 //TODO: Uncomment this when headers sort out
443 //DebugOutput(debug);
446 dumpString( dumpTrieStr, 0, d );
447 Usprintf( debug,TEXT(">\n") );
448 //TODO: Uncomment this when headers sort out
449 //DebugOutput(debug);
450 if (t != 0) {
451 s = t->child;
452 while (s != 0) {
453 sym =s->symbol;
455 dumpTrieStr [d] = sym;
456 dumpTrie( s, d+1 );
457 s = s->next;
463 void CJapaneseLanguageModel::dump()
464 // diagnostic display of the whole PPM trie
466 // TODO:
468 dchar debug[256];
469 Usprintf(debug,TEXT( "Dump of Trie : \n" ));
470 //TODO: Uncomment this when headers sort out
471 //DebugOutput(debug);
472 Usprintf(debug,TEXT( "---------------\n" ));
473 //TODO: Uncomment this when headers sort out
474 //DebugOutput(debug);
475 Usprintf( debug,TEXT( "depth node symbol count vine child next context\n") );
476 //TODO: Uncomment this when headers sort out
477 //DebugOutput(debug);
478 dumpTrie( root, 0 );
479 Usprintf( debug,TEXT( "---------------\n" ));
480 //TODO: Uncomment this when headers sort out
481 //DebugOutput(debug);
482 Usprintf(debug,TEXT( "\n" ));
483 //TODO: Uncomment this when headers sort out
484 //DebugOutput(debug);
488 ////////////////////////////////////////////////////////////////////////
489 /// PPMnode definitions
490 ////////////////////////////////////////////////////////////////////////
492 CJapaneseLanguageModel::CPPMnode * CJapaneseLanguageModel::CPPMnode::find_symbol(int sym) const
493 // see if symbol is a child of node
495 // printf("finding symbol %d at node %d\n",sym,node->id);
496 CPPMnode *found = child;
498 while(found) {
499 if(found->symbol == sym) {
500 return found;
502 found = found->next;
504 return 0;
507 CJapaneseLanguageModel::CPPMnode * CJapaneseLanguageModel::AddSymbolToNode(CPPMnode *pNode, int sym, int *update) {
508 CPPMnode *pReturn = pNode->find_symbol(sym);
510 // std::cout << sym << ",";
512 if(pReturn != NULL) {
513 // std::cout << "Using existing node" << std::endl;
515 // if (*update || (LanguageModelParams()->GetValue("LMUpdateExclusion") == 0) )
516 if(*update || !bUpdateExclusion) { // perform update exclusions
517 pReturn->count++;
518 *update = 0;
520 return pReturn;
523 // std::cout << "Creating new node" << std::endl;
525 pReturn = m_NodeAlloc.Alloc(); // count is initialized to 1
526 pReturn->symbol = sym;
527 pReturn->next = pNode->child;
528 pNode->child = pReturn;
530 ++NodesAllocated;
532 return pReturn;