HowManyAreAnalyzed(): use status_user_agent to report progress
[linguistica.git] / StateEmitHMM.cpp
blob4e9441a673fb4e6807022d0e0d35c45dfaad87d5
1 // Implementation of StateEmitHMM methods
2 // Copyright © 2009 The University of Chicago
3 #include "StateEmitHMM.h"
5 #include <cstdlib>
6 #include <ctime>
7 #include <Q3TextStream>
8 #include <QIODevice>
9 #include <QFile>
10 #include <Q3SortedList>
11 #include <QMessageBox>
12 #include "linguisticamainwindow.h"
13 #include "ui/Status.h"
14 #include "Lexicon.h"
15 #include "GraphicView.h"
16 #include "Stem.h"
17 #include "WordCollection.h"
18 #include "StringSurrogate.h"
19 #include "Parse.h"
20 #include "CompareFunc.h"
21 #include "log2.h"
23 class CLexicon;
25 //////////////////////////////////////////////////////////////////////
26 // Construction/Destruction
27 //////////////////////////////////////////////////////////////////////
29 StateEmitHMM::StateEmitHMM(LinguisticaMainWindow* parent)
31 m_parent = parent;
32 m_PI = NULL;
33 m_A = NULL;
34 m_B = NULL;
35 m_Alpha = NULL;
36 m_Beta = NULL;
37 m_P = NULL;
38 m_PI_SoftCounts = NULL;
39 m_A_SoftCounts = NULL;
40 m_B_SoftCounts = NULL;
41 m_trainingDatum = NULL;
42 m_trainingDatumSource = NULL;
43 m_symbolStateList.clear();
44 m_Entropy = NULL;
45 m_countOfStates = 0;
46 m_countOfSymbols = 0;
47 m_lengthOfObservation = 0;
48 m_NumberOfIterations = 0;
49 m_countOfDataItems = 0;
50 m_maxLengthOfObservation =0;
51 m_HMMLog = parent->GetLogFileName();
52 m_HMMLogDirectory = parent->GetLogFileName();
54 // m_HMMLog = QString("Linguistica_HMM_Log.txt");
56 m_WordProbabilities = NULL;
57 m_dataType = HMMNONE;
59 m_encodedTrainingData.clear();
63 StateEmitHMM::StateEmitHMM(CLexicon * pLexicon )
65 m_parent = NULL;
66 m_PI = NULL;
67 m_A = NULL;
68 m_B = NULL;
69 m_Alpha = NULL;
70 m_Beta = NULL;
71 m_P = NULL;
72 m_PI_SoftCounts = NULL;
73 m_A_SoftCounts = NULL;
74 m_B_SoftCounts = NULL;
75 m_trainingDatum = NULL;
76 m_trainingDatumSource = NULL;
77 m_symbolStateList .clear();
78 m_Entropy = NULL;
79 m_countOfStates = 0;
80 m_countOfSymbols = 0;
81 m_lengthOfObservation = 0;
82 m_NumberOfIterations = 0;
83 m_countOfDataItems = 0;
84 m_maxLengthOfObservation =0;
85 m_WordProbabilities = NULL;
86 m_HMMLog = QString("Linguistica_HMM_Log.txt");
87 m_Lexicon = pLexicon;
88 m_dataType = HMMNONE;
90 m_encodedTrainingData.clear();
93 StateEmitHMM::~StateEmitHMM()
95 clear();
98 /// Figure out what kind of phonological segments are being used as the symbols of this HMM
99 ///
100 /// Parameters:
101 /// - type - can be PHONE_TIER1, PHONE_TIER2, or PARSE (which is essentially the spelling)
104 bool StateEmitHMM::preprocessData(eHmmDataType type, void* dataCollection)
106 // clear the HMM
107 clear();
109 // depending on the type, process the data
110 if ( type == PHONE_TIER1) // take the wordCollection, and get PhoneTier 1
112 CWordCollection* PhoneData;
113 int i;
114 CStem* pWord;
115 CParse* pDataItem;
116 int pCount;
117 int Index;
118 CStringSurrogate* poundCSS;
119 QString poundStr = QString("#");
120 void* pOriginData;
123 poundCSS = new CStringSurrogate(poundStr);
125 // The data is phoneCollection
126 PhoneData = (CWordCollection*)dataCollection;
130 Index = 0;
131 for (i = 0; i < PhoneData ->GetCount(); i++)
133 pWord = PhoneData ->GetAt(i);
134 pOriginData = (void*)pWord;
135 pCount = pWord->GetCorpusCount();
136 pDataItem = pWord ->GetPhonology_Tier1();
138 if ( pDataItem ->Size() ==0)
140 return false;
145 pDataItem ->RemovePiecesThatBegin((*poundCSS));
147 if ( pDataItem ->Size() ==0)
149 continue;
152 m_trainingDataSource. insert(Index, pOriginData);
153 m_trainingData. insert(Index, pDataItem);
154 m_trainingDataFrequency.insert(Index, pCount);
155 m_trainingDataSizes. insert(Index, pDataItem ->Size());
156 Index++;
158 if ( pDataItem ->Size() > m_maxLengthOfObservation)
160 m_maxLengthOfObservation = pDataItem ->Size() ;
164 m_countOfDataItems = Index;
165 delete poundCSS;
166 m_dataType = type;
167 return true;
170 if ( type == PHONE_TIER2) // take the wordCollection, and get PhoneTier 2
172 CWordCollection* PhoneData;
173 int i;
174 CStem* pWord;
175 CParse* pDataItem;
176 int pCount;
177 int Index;
178 void* pOriginData;
180 // The data is phoneCollection
181 PhoneData = (CWordCollection*)dataCollection;
183 Index = 0;
184 for (i = 0; i < PhoneData ->GetCount(); i++)
186 pWord = PhoneData ->GetAt(i);
187 pOriginData = (void*)pWord;
188 pCount = pWord->GetCorpusCount();
189 pDataItem = pWord ->GetPhonology_Tier2();
190 if ( pDataItem ->Size() ==0)
192 return false;
194 m_trainingDataSource. insert(Index, pOriginData);
195 m_trainingData. insert(Index, pDataItem);
196 m_trainingDataFrequency.insert(Index, pCount);
197 m_trainingDataSizes. insert(Index, pDataItem ->Size());
198 Index++;
200 if ( pDataItem ->Size() > m_maxLengthOfObservation)
202 m_maxLengthOfObservation = pDataItem ->Size() ;
206 m_countOfDataItems = Index;
207 m_dataType = type;
208 return true;
212 if ( type == PARSE) // take the wordCollection, and get CParse
214 CWordCollection* WordData;
215 int i;
216 CStem* pWord;
217 CParse* pDataItem;
218 int pCount;
219 int Index;
220 void* pOriginData;
222 // The data is WordCollection
223 WordData = (CWordCollection*)dataCollection;
225 Index = 0;
226 for (i = 0; i < WordData ->GetCount(); i++)
228 pWord = WordData ->GetAt(i);
229 pOriginData = (void*)pWord;
230 pCount = pWord->GetCorpusCount();
231 pDataItem = (CParse*)pWord;
233 m_trainingDataSource. insert(Index, pOriginData);
234 m_trainingData. insert(Index, pDataItem);
235 m_trainingDataFrequency.insert(Index, pCount);
236 m_trainingDataSizes. insert(Index, pDataItem ->Size());
237 Index++;
239 if ( pDataItem ->Size() > m_maxLengthOfObservation)
241 m_maxLengthOfObservation = pDataItem ->Size() ;
246 m_countOfDataItems = Index;
247 m_dataType = type;
248 return true;
251 return false;
254 void StateEmitHMM::init(int countOfStates, int loops)
256 StringToInt::iterator StringToIntIt;
257 StringToInt tempSymbolList;
258 int i,j;
259 QString oneSymbol;
260 int symbolIndex;
261 CParse* oneDataItem;
262 int itemSize;
263 int* oneIntArray;
264 int encodeId;
267 // count of states
268 if ( countOfStates < 2 )
270 QMessageBox::information(NULL, "error", "The number of states of one HMM should > 1", "OK");
271 return;
274 m_countOfStates = countOfStates;
276 // training data
277 if ( m_countOfDataItems == 0)
279 QMessageBox::information(NULL, "error", "The training data is empty", "OK");
280 clear();
281 return;
284 // m_Video = new Video ( countOfStates, m_countOfDataItems );
285 m_Video = new Video ( countOfStates );
286 // symbols
287 for ( i=0; i< m_countOfDataItems; i++)
289 oneDataItem = m_trainingData[i];
290 for ( j=1; j<= oneDataItem ->Size(); j++)
292 oneSymbol = (oneDataItem ->GetPiece(j)).Display();
293 tempSymbolList.insert(oneSymbol, 1);
298 symbolIndex = 0;
299 for ( StringToIntIt = tempSymbolList.begin(); StringToIntIt != tempSymbolList.end(); StringToIntIt++)
301 oneSymbol = StringToIntIt.key();
302 m_symbolList. insert(symbolIndex, oneSymbol);
303 m_symbolIndex. insert(oneSymbol, symbolIndex);
304 symbolIndex++;
307 m_countOfSymbols = symbolIndex;
309 // After we've got symbols, we encode training data
310 for ( i=0; i< m_countOfDataItems; i++)
312 oneDataItem = m_trainingData[i];
313 itemSize = m_trainingDataSizes[i];
314 oneIntArray = new int[1+itemSize]; // init an int array, whose size is (1+itemSize). The zero location is not used. we use from 1 to itemSize in this array
315 m_encodedTrainingData.insert(i, oneIntArray);
316 for ( j=1; j<= itemSize; j++)
318 oneSymbol = (oneDataItem ->GetPiece(j)).Display();
319 encodeId = m_symbolIndex[oneSymbol];
320 oneIntArray[j] = encodeId;
323 ///////////////////////////////////////////////////////////////////////////////
324 // Output alphabet members to file //
325 ///////////////////////////////////////////////////////////////////////////////
326 QFile file( m_HMMLogDirectory + "segments.txt");
327 if ( !file.open( QIODevice::WriteOnly | QIODevice::Append ) )
329 QMessageBox::information(NULL, "Error", "Can't Open the segments file!", "OK");
330 return;
333 Q3TextStream outf( &file );
334 for ( i=0; i< m_countOfSymbols; i++)
336 outf << m_symbolList[i] << endl;
339 file.close();
340 ///////////////////////////////////////////////////////////////////////////////
343 // allocate memory for prob parameters
344 // PI
345 m_PI = new double[m_countOfStates];
347 // A
348 m_A = new double*[m_countOfStates];
349 for ( i=0; i< m_countOfStates; i++)
351 m_A[i] = new double [m_countOfStates];
354 // B
355 m_B = new double*[m_countOfStates];
356 for ( i=0; i< m_countOfStates; i++)
358 m_B[i] = new double [m_countOfSymbols];
361 // m_PI_SoftCounts
362 m_PI_SoftCounts = new double[m_countOfStates];
364 // m_A_SoftCounts
365 m_A_SoftCounts = new double*[m_countOfStates];
366 for ( i=0; i< m_countOfStates; i++)
368 m_A_SoftCounts[i] = new double [m_countOfStates];
371 // m_B_SoftCounts
372 m_B_SoftCounts = new double*[m_countOfStates];
373 for ( i=0; i< m_countOfStates; i++)
375 m_B_SoftCounts[i] = new double [m_countOfSymbols];
378 // Alpha from 0 to T
379 m_Alpha = new double*[m_maxLengthOfObservation + 1];
380 for ( i=0; i< m_maxLengthOfObservation + 1; i++)
382 m_Alpha[i] = new double [m_countOfStates];
385 // Beta from 0 to T
386 m_Beta = new double*[m_maxLengthOfObservation + 1];
387 for ( i=0; i< m_maxLengthOfObservation + 1; i++)
389 m_Beta[i] = new double [m_countOfStates];
393 // m_P from 0 to T - 1
394 m_P = new double**[m_maxLengthOfObservation];
395 for ( i=0; i< m_maxLengthOfObservation; i++)
397 m_P[i] = new double* [m_countOfStates];
399 for ( j=0; j< m_countOfStates; j++)
401 m_P[i][j] = new double[m_countOfStates];
405 // Word probabilities:
406 m_WordProbabilities = new double [m_countOfDataItems];
409 // The Loop times
410 m_NumberOfIterations = loops;
412 ////////////////////////////////////////////////////////////////////////////////////////////////
413 ////////////////////////////////////////////////////////////////////////////////////////////////
414 QFile file2( m_HMMLogDirectory + "config.txt");
415 if ( !file2.open( QIODevice::WriteOnly | QIODevice::Append ) )
417 QMessageBox::information(NULL, "Error", "Can't Open the HMM Log config file!", "OK");
418 return;
421 Q3TextStream outf2( &file2 );
423 outf2 << m_countOfStates << " = Number of states " << endl;
424 outf2 << m_countOfSymbols << " = Number of symbols" << endl;
425 outf2 << m_NumberOfIterations << " = Number of iterations";
427 file2.close();
428 ////////////////////////////////////////////////////////////////////////////////////////////////
429 ////////////////////////////////////////////////////////////////////////////////////////////////
434 void StateEmitHMM::initPiAndAB()
436 using std::srand;
437 using std::time;
438 using std::rand;
440 double oneUniformProb;
441 int i,j;
442 double total = 0;
443 double random_weight = 0.25; //was 0.25
446 // srand the seed
447 srand ( time(NULL) );
449 // init Pi
450 oneUniformProb = 1.0 / m_countOfStates;
451 for ( i=0; i< m_countOfStates; i++)
453 m_PI[i] = oneUniformProb;
456 // init A
457 //------------------------------------------------//
458 /* initialize transition probabilities */
459 for (i = 0; i < m_countOfStates; i++)
461 total = 0;
462 for ( j=0; j< m_countOfStates; j++)
464 m_A[i][j] = (1 / (double) m_countOfStates) * (1-random_weight) + (random_weight) * rand()/(double)RAND_MAX;
465 total += m_A[i][j];
468 for ( j=0; j< m_countOfStates; j++)
470 m_A[i][j] = m_A[i][j] / total;
474 /* initialize emission probabilities */
475 total = 0;
476 for (i = 0; i < m_countOfStates; i++)
478 total = 0;
479 for ( j=0; j< m_countOfSymbols; j++)
481 m_B[i][j] = (1 / (double) m_countOfSymbols) * (1-random_weight-0.1) + (random_weight+0.1) * rand()/(double)RAND_MAX;
482 total += m_B[i][j];
485 for ( j=0; j< m_countOfSymbols; j++)
487 m_B[i][j] = m_B[i][j] / total;
493 /* initialize m_PI_SoftCounts as zero */
494 for ( i=0; i< m_countOfStates; i++)
496 m_PI_SoftCounts[i] = 0.0;
499 /* initialize m_A_SoftCounts as zero */
500 for (i = 0; i < m_countOfStates; i++)
502 for ( j=0; j< m_countOfStates; j++)
504 m_A_SoftCounts[i][j] = 0.0;
509 /* initialize m_B_SoftCounts as zero */
510 for (i = 0; i < m_countOfStates; i++)
512 for ( j=0; j< m_countOfSymbols; j++)
514 m_B_SoftCounts[i][j] = 0.0;
523 void StateEmitHMM::forward(int dataIndex, double& result)
525 int i,j,k;
526 double sum;
527 int symbolIndex;
528 double normSum;
529 CStem* pWord;
533 // Get the current dataItem as CParse
534 m_trainingDatum = m_encodedTrainingData[ dataIndex ];
535 m_trainingDatumSource = m_trainingDataSource [ dataIndex ];
536 m_lengthOfObservation = m_trainingDataSizes [ dataIndex ];
538 // init the alpha as PI;
539 for (i=0; i< m_countOfStates; i++)
541 m_Alpha[0][i]= m_PI[i];
544 // Induction step
545 for ( i=1; i <= m_lengthOfObservation; i++)
547 normSum = 0.0;
549 // figure out which symbol is emitted for this
550 symbolIndex = m_trainingDatum[i];
552 for ( j=0; j< m_countOfStates; j++)
554 sum = 0.0;
555 for (k=0; k< m_countOfStates; k++)
557 sum += m_Alpha[i-1][k] * m_A[k][j] * m_B[k][symbolIndex];
560 m_Alpha[i][j] = sum;
561 // normSum += sum; // add them up for normalization
564 // Norm the m_Alpha to avoid too small values
566 for ( j=0; j< m_countOfStates; j++)
568 m_Alpha[i][j] = m_Alpha[i][j] / normSum;
573 double probOfTheData;
575 probOfTheData = 0.0;
576 for ( i=0; i<m_countOfStates; i++)
578 probOfTheData += m_Alpha[m_lengthOfObservation][i];
581 result = probOfTheData;
582 Q_ASSERT (result);
584 // Find pointer to original word or whatever
585 if ( m_dataType == PHONE_TIER1) {
586 pWord = static_cast<CStem*>(m_trainingDatumSource);
587 pWord->SetHMM_LogProbability(-1 * base2log(result));
592 void StateEmitHMM::backward(int dataIndex, double& result)
594 int i,j,k;
595 double sum;
596 int symbolIndex;
597 double normSum;
599 // Get the current dataItem as CParse
600 m_trainingDatum = m_encodedTrainingData[dataIndex];
601 m_trainingDatumSource = m_trainingDataSource [dataIndex];
602 m_lengthOfObservation = m_trainingDataSizes [dataIndex];
605 // init the beta as 1;
606 for (i=0; i< m_countOfStates; i++)
608 m_Beta[m_lengthOfObservation][i]= 1;
611 // Induction step
612 for ( i=m_lengthOfObservation -1; i>=0; i--)
614 normSum = 0.0;
615 for ( j=0; j< m_countOfStates; j++)
617 // figure out which symbol is emitted for this
618 symbolIndex = m_trainingDatum[i+1];
620 sum = 0.0;
621 for (k=0; k< m_countOfStates; k++)
623 sum += m_Beta[i+1][k] * m_A[j][k] * m_B[j][symbolIndex];
626 m_Beta[i][j] = sum;
627 // normSum += sum;
630 // Norm the m_Beta to avoid too small values
632 for ( j=0; j< m_countOfStates; j++)
634 m_Beta[i][j] = m_Beta[i][j] / normSum;
640 double probOfTheData;
642 probOfTheData = 0.0;
643 for ( i=0; i<m_countOfStates; i++)
645 probOfTheData += m_PI[i]*m_Beta[0][i];
648 result = probOfTheData;
650 //QMessageBox::information(NULL, "Debug", QString("The String Prob Computed by forward as %1").arg(probOfTheData), "OK");
654 void StateEmitHMM::Expectation(int dataIndex)
657 /* Old version, with John's efforts
658 int i, j, k;
659 int symbolIndex;
660 double denominator;
661 double numerator;
662 double oneGamma;
663 int pCount;
664 double oneValue;
666 // Get the current dataItem as CParse
668 m_trainingDatum = m_encodedTrainingData [dataIndex];
669 m_trainingDatumSource = m_trainingDataSource [dataIndex];
670 m_lengthOfObservation = m_trainingDataSizes [dataIndex];
672 pCount = m_trainingDataFrequency[dataIndex];
674 Q_ASSERT(pCount > 0);
679 QFile file( m_HMMLogDirectory + "expectation.txt");
680 if ( !file.open( IO_WriteOnly | IO_Append ) )
682 QMessageBox::information(NULL, "Error", "Can't Open the HMM Log file!", "OK");
683 return;
685 QTextStream outf( &file );
687 if (dataIndex == 0)
689 outf << endl <<"----------------------------------------------------------------------" << endl;
691 else
693 outf << endl << endl;
695 outf << dataIndex << "\tProbs:\nL\tFrom\tTo";
699 // compute the m_P: // (Position in string, from-state, to-state);
700 for ( i=0; i< m_lengthOfObservation; i++)
702 // figure out which symbol is emitted for this
703 symbolIndex = m_trainingDatum[i+1];
705 // compute the denominator
706 denominator = 0.0;
707 for ( j=0; j< m_countOfStates; j++)
709 for (k=0; k< m_countOfStates; k++)
711 denominator += m_Alpha[i][j] * m_A[j][k] * m_B[j][symbolIndex]* m_Beta[i+1][k];
715 Q_ASSERT (denominator);
716 // compute numerator
717 for ( j=0; j< m_countOfStates; j++)
719 for (k=0; k< m_countOfStates; k++)
721 numerator = m_Alpha[i][j] * m_A[j][k] * m_B[j][symbolIndex]* m_Beta[i+1][k];
722 m_P[i][j][k] = numerator / denominator;
723 outf << endl << i << m_symbolList[symbolIndex] << "\t" << j << "\t" << k << "\t" << m_P[i][j][k];
729 // Next, compute softcount and add them to softcount data
731 // Estimate Pi
733 outf << "\n\n" << "Soft counts for Pi:" << "\t";
734 for ( i=0; i< m_countOfStates; i++)
736 oneGamma =0.0;
737 outf << "\nFrom:\t" << i;
738 for (j=0; j< m_countOfStates; j++)
740 oneGamma += m_P[0][i][j];
741 outf << "\tTo:\t"<< j <<" \t" << oneGamma << "\t";
744 oneValue = oneGamma;
745 m_PI_SoftCounts[i] += oneValue * pCount;
746 outf << "\tTotal soft counts (Pi) initially to state :"<< i << "\t"<<m_PI_SoftCounts[i];
749 // Estimate m_A
750 for ( i=0; i< m_countOfStates; i++)
752 denominator = 0.0;
753 for ( k=0; k< m_lengthOfObservation; k++)
755 oneGamma =0.0;
756 for ( j=0; j< m_countOfStates; j++)
758 oneGamma += m_P[k][i][j];
761 denominator += oneGamma;
763 Q_ASSERT (denominator);
765 // compute m_A[i][j]
766 for ( j=0; j< m_countOfStates; j++)
768 numerator =0.0;
769 for ( k=0; k< m_lengthOfObservation; k++)
771 numerator += m_P[k][i][j];
774 oneValue = numerator / denominator;
775 m_A_SoftCounts[i][j] += oneValue * pCount;
776 outf << "\nTotal soft counts from state: "<<i<< " to state " << j << "\t"<<m_A_SoftCounts[i][j];
781 outf << "\n\nSoft counts of symbol emissions "<< endl;
782 for (k = 0; k < m_lengthOfObservation; k++)
784 // figure out which symbol is emitted for this
785 symbolIndex = m_trainingDatum[k+1];
787 outf << endl <<"Dealing with symbol: " << m_symbolList[symbolIndex];
789 // Compute denominator
790 denominator =0.0;
792 for (i = 0; i < m_countOfStates; i++)
794 for (j = 0; j < m_countOfStates; j++)
796 denominator += m_P[k][i][j];
799 Q_ASSERT (denominator);
801 // Compute numerator and m_B
802 for (i = 0; i < m_countOfStates; i++)
804 outf << "\n\tState: " << i << " ";
805 for (j = 0; j < m_countOfStates; j++)
807 outf << " to state " << j << " ";
808 m_B_SoftCounts[i][symbolIndex] += ( m_P[k][i][j] / denominator) * pCount;
809 outf << m_P[k][i][j] / denominator;
814 for (i = 0; i < m_countOfStates; i++)
816 outf << endl << "\nIn State: "<< i;
817 for (j= 0; j < m_countOfSymbols; j++)
819 outf << "\tEmit symbol: "<< j<< " "<< m_symbolList[j] << " "<<m_B_SoftCounts[i][j];
825 file.close();
827 */ //end of OLD VERSION
829 int i, j, k;
830 int symbolIndex;
831 double denominator;
832 double numerator;
833 double oneGamma;
834 int pCount;
835 double oneValue;
837 // Get the current dataItem as CParse
838 m_trainingDatum = m_encodedTrainingData[dataIndex];
839 m_trainingDatumSource = m_trainingDataSource[dataIndex];
840 m_lengthOfObservation = m_trainingDataSizes[dataIndex];
841 pCount = m_trainingDataFrequency[dataIndex];
843 Q_ASSERT(pCount > 0);
846 denominator = 0.0;
847 for ( j=0; j< m_countOfStates; j++)
850 for (k=0; k< m_countOfStates; k++)
852 denominator += m_Alpha[i][j] * m_A[j][k] * m_B[j][symbolIndex]* m_Beta[i+1][k];
855 denominator += m_Alpha[m_lengthOfObservation][j];
859 // compute the m_P
860 for ( i=0; i< m_lengthOfObservation; i++)
862 // figure out which symbol is emitted for this
863 symbolIndex = m_trainingDatum[i+1];
865 // compute the denominator
868 Q_ASSERT (denominator);
869 // compute numerator
870 for ( j=0; j< m_countOfStates; j++)
872 for (k=0; k< m_countOfStates; k++)
874 numerator = m_Alpha[i][j] * m_A[j][k] * m_B[j][symbolIndex]* m_Beta[i+1][k];
875 m_P[i][j][k] = numerator / denominator;
881 // Next, compute softcount and add them to softcount data
883 // Estimate Pi
884 for ( i=0; i< m_countOfStates; i++)
886 oneGamma =0.0;
888 for (j=0; j< m_countOfStates; j++)
890 oneGamma += m_P[0][i][j];
893 oneValue = oneGamma;
894 m_PI_SoftCounts[i] += oneValue * pCount;
897 // Estimate m_A
898 for ( i=0; i< m_countOfStates; i++)
900 for ( j=0; j< m_countOfStates; j++)
902 denominator = 0.0;
903 for ( k=0; k< m_lengthOfObservation; k++)
905 denominator += m_P[k][i][j];
908 Q_ASSERT (denominator);
909 m_A_SoftCounts[i][j] += denominator * pCount;
914 // Estimate m_B
916 for (i=0; i< m_countOfStates; i++)
919 // Compute numerator and m_B
920 for (k=0; k< m_lengthOfObservation; k++)
922 // figure out which symbol is emitted for this
923 symbolIndex = m_trainingDatum[k+1];
925 numerator =0.0;
926 for ( j=0; j< m_countOfStates; j++)
928 numerator += m_P[k][i][j];
930 oneValue = numerator ;
931 m_B_SoftCounts[i][symbolIndex] += oneValue * pCount;
939 /// Calculate the new probability parameters, based on soft counts, and emission entropy
941 /// Parameters:
942 /// none
944 void StateEmitHMM::Maximization()
946 int i, j;
947 double total;
949 QFile file( m_HMMLogDirectory + "maximization.txt");
950 if ( !file.open( QIODevice::WriteOnly | QIODevice::Append ) )
952 QMessageBox::information(NULL, "Error", "Can't Open the maximization file!", "OK");
953 return;
955 Q3TextStream outf( &file );
957 outf << endl <<"----------------------------------------------------------------------" << endl;
959 outf << "\nPi softcounts:\t";
960 // Re-Estimate the m_PI
961 total =0.0;
962 for ( i=0; i< m_countOfStates; i++)
964 total += m_PI_SoftCounts[i];
965 outf << "State: "<<i<<"\t" << m_PI_SoftCounts[i] << "\t";
967 Q_ASSERT (total);
969 outf << "\n";
970 for ( i=0; i< m_countOfStates; i++)
972 m_PI[i] = m_PI_SoftCounts[i] / total;
973 m_PI_SoftCounts[i] = 0.0;
975 outf << endl << "Pi for State: "<<"\t" << m_PI[i];
979 outf << "\n\nMatrix A";
980 // Re-Estimate the m_A
981 for ( i=0; i< m_countOfStates; i++)
983 outf << "\n\nFrom State: "<< i << "\t";
984 total = 0.0;
985 for (j=0; j< m_countOfStates; j++)
987 total += m_A_SoftCounts[i][j];
988 outf << " To State: "<< j << "\t" << m_A_SoftCounts[i][j];
990 outf << " -- Total: "<< total;
991 Q_ASSERT (total);
993 outf << "\nTransition probabilities: ";
994 for (j=0; j< m_countOfStates; j++)
996 m_A[i][j] = m_A_SoftCounts[i][j] /total;
997 m_A_SoftCounts[i][j] = 0.0;
998 outf << "\tto: "<< j << "\t"<< m_A[i][j];
1002 outf << "\n\nMatrix B";
1003 // Re-Estimate the m_B
1004 for ( i=0; i< m_countOfStates; i++)
1006 outf << "\nFrom State: " << i << endl;
1007 total =0.0;
1008 for (j=0; j< m_countOfSymbols; j++)
1010 total += m_B_SoftCounts[i][j];
1011 outf << "\tCount of symbol "<< j << " "<<m_B_SoftCounts[i][j];
1013 outf << endl;
1015 // if (total == 0)
1016 // {
1017 // int aaa = 0;
1018 // }
1019 // ************** bug found JG Nov 30 2006 **********
1020 if (total == 0)
1022 for (j=0; j< m_countOfSymbols; j++)
1024 m_B[i][j] = 0;;
1025 m_B_SoftCounts[i][j] = 0.0;
1026 outf << "\tSymb: "<< j << "\t" << m_B[i][j];
1029 else
1031 for (j=0; j< m_countOfSymbols; j++)
1033 m_B[i][j] = m_B_SoftCounts[i][j] /total;
1034 m_B_SoftCounts[i][j] = 0.0;
1035 outf << "Symbol: "<< j << "\t" << m_B[i][j];
1040 // calculate each state's emission entropy:
1041 if (m_Entropy) delete m_Entropy;
1042 m_Entropy = new double [m_countOfStates];
1044 for ( i=0; i< m_countOfStates; i++)
1046 total =0.0;
1047 for (j=0; j< m_countOfSymbols; j++)
1049 Q_ASSERT (m_B[i][j] );
1050 if (base2log (m_B[i][j]) > -1* 1000000) // i.e., if m_B[i][j] ins't too small...
1052 total += -1 * m_B[i][j] * base2log ( m_B[i][j] );
1055 m_Entropy[i] = total;
1060 /// The highest level iteration
1061 double StateEmitHMM::trainParameters()
1063 CLexicon& lex = *m_Lexicon;
1064 linguistica::ui::status_user_agent& status = lex.status_display();
1066 int dataIndex;
1067 double stringProbWithForward ;
1068 double stringProbWithBackward;
1069 int HowFrequentlyToOutputStatistics;
1070 int dummy(0);
1071 VideoFrame* pVideoFrame;
1073 HowFrequentlyToOutputStatistics = 1;
1075 status.major_operation = "Training HMM";
1076 status.progress.clear();
1077 status.progress.set_denominator(m_NumberOfIterations);
1078 for (int i = 0; i<m_NumberOfIterations; i++) {
1079 m_LogProbabilityOfData = 0;
1080 status.progress = i;
1082 //OutputTransitions(i, m_HMMLog);
1083 for ( dataIndex =0; dataIndex < m_countOfDataItems; dataIndex++)
1086 forward (dataIndex, stringProbWithForward);
1087 backward(dataIndex, stringProbWithBackward);
1089 Q_ASSERT (stringProbWithForward == stringProbWithBackward);
1090 Q_ASSERT (stringProbWithForward);
1092 Expectation(dataIndex);
1094 m_LogProbabilityOfData += -1 * base2log (stringProbWithForward);
1098 Maximization();
1099 OutputTransitionsToLogFile(i, m_HMMLog);
1101 pVideoFrame = new VideoFrame(m_countOfStates );
1102 InsertValues (pVideoFrame);
1103 m_Video->insert (i, pVideoFrame);
1105 dummy++;
1106 if (dummy == HowFrequentlyToOutputStatistics )
1108 dummy = 0;
1109 OutputEmissions (i, m_HMMLog);
1110 OutputTransitions(i, m_HMMLog);
1111 OutputInitials (i, m_HMMLog);
1117 status.progress.clear();
1118 getStateListForASymbol();
1119 status.major_operation.clear();
1120 return m_LogProbabilityOfData;
1123 void StateEmitHMM::clear()
1125 double* oneArray;
1126 double** oneTwoDimensionArray;
1127 int i,j;
1128 IntToIntArray::iterator IntToIntArrayIt;
1129 int* oneIntArray;
1131 // Clear data
1132 m_trainingData.clear();
1133 m_trainingDataFrequency.clear();
1134 m_trainingDataSizes.clear();
1136 // Clear Origin data
1137 m_trainingDataSource.clear();
1140 // Clear m_encodedTrainingData
1141 for ( IntToIntArrayIt = m_encodedTrainingData.begin(); IntToIntArrayIt != m_encodedTrainingData.end(); IntToIntArrayIt++)
1143 oneIntArray = IntToIntArrayIt.data();
1144 if ( oneIntArray != NULL)
1146 delete [] oneIntArray;
1149 m_encodedTrainingData.clear();
1152 // Clear m_WordProbabilities
1153 if (m_WordProbabilities) delete m_WordProbabilities;
1155 // clear PI
1156 if ( m_PI != NULL)
1158 delete [] m_PI;
1159 m_PI = NULL;
1161 // clear A
1162 if ( m_A != NULL)
1164 for ( i=0; i< m_countOfStates; i++)
1166 oneArray = m_A[i];
1167 delete oneArray;
1169 delete [] m_A;
1170 m_A = NULL;
1172 // clear B
1173 if ( m_B != NULL)
1175 for ( i=0; i< m_countOfStates; i++)
1177 oneArray = m_B[i];
1178 delete oneArray;
1180 delete [] m_B;
1181 m_B = NULL;
1184 // clear m_PI_SoftCounts
1185 if ( m_PI_SoftCounts != NULL)
1187 delete [] m_PI_SoftCounts;
1188 m_PI_SoftCounts = NULL;
1191 // clear m_A_SoftCounts
1192 if ( m_A_SoftCounts != NULL)
1194 for ( i=0; i< m_countOfStates; i++)
1196 oneArray = m_A_SoftCounts[i];
1197 delete oneArray;
1199 delete [] m_A_SoftCounts;
1200 m_A_SoftCounts = NULL;
1202 // clear m_B_SoftCounts
1203 if ( m_B_SoftCounts != NULL)
1205 for ( i=0; i< m_countOfStates; i++)
1207 oneArray = m_B_SoftCounts[i];
1208 delete oneArray;
1210 delete [] m_B_SoftCounts;
1211 m_B_SoftCounts = NULL;
1214 // clear Alpha
1215 if ( m_Alpha != NULL)
1217 for ( i=0; i< m_maxLengthOfObservation + 1; i++)
1219 oneArray = m_Alpha[i];
1220 delete oneArray;
1222 delete [] m_Alpha;
1223 m_Alpha = NULL;
1225 // clear Beta
1226 if ( m_Beta != NULL)
1228 for ( i=0; i< m_maxLengthOfObservation + 1; i++)
1230 oneArray = m_Beta[i];
1231 delete oneArray;
1233 delete [] m_Beta;
1234 m_Beta = NULL;
1237 // clear m_P
1238 if ( m_P != NULL)
1240 for ( i=0; i< m_maxLengthOfObservation; i++)
1242 oneTwoDimensionArray = m_P[i];
1244 for ( j=0; j< m_countOfStates; j++)
1246 oneArray = oneTwoDimensionArray[j];
1247 delete oneArray;
1250 delete oneTwoDimensionArray;
1253 delete [] m_P;
1254 m_P = NULL;
1257 // clear m_symbolStateList
1258 IntTohmmSortedList::iterator IntTohmmSortedListIt;
1259 hmmSortedList* oneStateList;
1261 for(IntTohmmSortedListIt = m_symbolStateList.begin(); IntTohmmSortedListIt != m_symbolStateList.end(); IntTohmmSortedListIt++)
1263 oneStateList = IntTohmmSortedListIt.data();
1264 oneStateList ->setAutoDelete(true);
1265 delete oneStateList;
1267 m_symbolStateList.clear();
1269 // set construction parameters back to zeros
1270 m_countOfStates = 0;
1271 m_countOfSymbols = 0;
1272 m_maxLengthOfObservation = 0;
1273 m_NumberOfIterations = 0;
1274 m_countOfDataItems = 0;
1275 m_dataType = HMMNONE;
1277 // clear Volca
1278 m_symbolList .clear();
1279 m_symbolIndex .clear();
1281 // clear entropy
1282 if ( m_Entropy )
1284 delete m_Entropy;
1289 void StateEmitHMM::getStateListForASymbol()
1291 int i, j;
1292 double total;
1293 double oneValue,
1294 restValue;
1295 hmmSortedList* oneList;
1296 hmmForSortingItem* oneItem;
1297 const double MaxLogRatio = 10;
1298 const double MinLogRatio = -100;
1299 double LogRatio;
1301 for ( i=0; i< m_countOfSymbols; i++)
1303 oneList = new hmmSortedList();
1304 m_symbolStateList.insert(i, oneList);
1306 total = 0.0;
1307 for (j=0; j< m_countOfStates; j++)
1309 total += m_B[j][i];
1312 for (j=0; j< m_countOfStates; j++)
1314 oneValue = m_B[j][i];
1315 restValue = total - oneValue;
1317 if ( restValue == 0.0)
1319 LogRatio = MaxLogRatio;
1321 else
1323 if (oneValue == 0.0) {
1324 LogRatio = MinLogRatio;
1325 } else {
1326 Q_ASSERT(restValue != 0);
1327 LogRatio = base2log(oneValue / restValue);
1330 if ( LogRatio > MaxLogRatio)
1332 LogRatio = MaxLogRatio;
1335 if ( LogRatio < MinLogRatio)
1337 LogRatio = MinLogRatio;
1342 oneItem = new hmmForSortingItem(j, LogRatio );
1343 oneList ->append(oneItem);
1346 oneList ->sort();
1351 void StateEmitHMM::OutputTransitionsToLogFile(int IterationNumber,
1352 QString LogFileName)
1354 QFile file(LogFileName);
1355 if (!file.open(QIODevice::WriteOnly | QIODevice::Append)) {
1356 QMessageBox::information(0, "Error",
1357 "Can't Open the HMM Log file!", "OK");
1358 return;
1360 QTextStream outf(&file);
1362 outf << endl << IterationNumber << '\t';
1363 for (int i = 0; i < m_countOfStates; ++i)
1364 for (int j = 0; j < m_countOfStates; ++j)
1365 outf << m_A[i][j] << '\t';
1366 for (int i = 0; i < m_countOfStates; ++i)
1367 outf << m_Entropy[i] << '\t';
1370 void StateEmitHMM::OutputTransitions(int /* iteration number, unused */,
1371 QString /* log file name. XXX. unused */)
1373 QFile file2(QString("%1transitions.txt").arg(m_HMMLogDirectory));
1374 if (!file2.open(QIODevice::WriteOnly | QIODevice::Append)) {
1375 QMessageBox::information(0, "Error",
1376 "Can't Open the HMM transitions file!", "OK");
1377 return;
1380 QTextStream outf2(&file2);
1381 for (int i = 0; i < m_countOfStates; ++i)
1382 for (int j = 0; j < m_countOfStates; ++j)
1383 outf2 << m_A[i][j] << '\t';
1384 outf2 << endl;
1387 void StateEmitHMM::OutputEmissions(int /* iteration number, unused */,
1388 QString /* log file name. XXX. unused */)
1390 QFile file(QString("%1emissions.txt").arg(m_HMMLogDirectory));
1391 if (!file.open(QIODevice::WriteOnly | QIODevice::Append)) {
1392 QMessageBox::information(0, "Error",
1393 "Can't Open the HMM emissions file!", "OK");
1394 return;
1397 QTextStream outf(&file);
1398 for (int i = 0; i < m_countOfStates; ++i)
1399 for (int j = 0; j < m_countOfSymbols; ++j)
1400 outf << m_B[i][j] << '\t';
1401 outf << endl;
1404 void StateEmitHMM::OutputInitials(int /* iteration number, unused */,
1405 QString /* log file name. XXX. unused */)
1407 QFile file(QString("%1initials.txt").arg(m_HMMLogDirectory));
1408 if (!file.open(QIODevice::WriteOnly | QIODevice::Append)) {
1409 QMessageBox::information(0, "Error",
1410 "Can't Open the HMM initials file!", "OK");
1411 return;
1414 QTextStream outf(&file);
1415 for (int j = 0; j < m_countOfStates; ++j)
1416 outf << m_PI[j] << '\t';
1417 outf << endl;
1420 void StateEmitHMM::logInfo(double totalLogProbability)
1422 QFile file( m_HMMLog );
1423 int i, j;
1424 QString oneOutString;
1426 if ( !file.open( QIODevice::WriteOnly | QIODevice::Append ) )
1428 QMessageBox::information(NULL, "Error", "Can't Open the HMM Log file!", "OK");
1429 return;
1432 Q3TextStream outf( &file );
1434 outf << "\n*** One SnapShot ***" << endl <<endl;
1436 outf << "Total log probability: "<<totalLogProbability << endl << endl;
1438 // Output PI
1439 outf << "** PI Values **" << endl;
1440 for (i=0; i<m_countOfStates; i++)
1442 oneOutString = QString("State %1:\t %2\t").arg(i).arg(m_PI[i]);
1443 outf <<" "<< oneOutString << endl;
1446 // Output A
1447 outf << "** A Values **" << endl;
1448 for (i=0; i<m_countOfStates; i++)
1450 for ( j =0; j< m_countOfStates; j++)
1452 oneOutString = QString("State %1 -> %2 :\t%3").arg(i).arg(j).arg(m_A[i][j]);
1453 outf <<" "<< oneOutString << endl;
1457 // Output B
1458 outf << "** B Values **" << endl;
1459 for (i=0; i<m_countOfStates; i++)
1461 for ( j =0; j< m_countOfSymbols; j++)
1463 oneOutString = QString("State %1\tEmit\t%2 :\t%3").arg(i).arg(m_symbolList[j]).arg(m_B[i][j]);
1464 outf <<" "<< oneOutString << endl;
1468 // Output the m_symbolStateList
1469 QString thisSymbol;
1470 hmmSortedList* oneList;
1471 hmmForSortingItem* oneItem;
1472 int stateNumber, thisStateNumber;
1473 double ratioValue;
1475 ////////////////////////////////////////////////////////////////////////////////////////////////////
1476 // The special, but usual, case where we care about a 2-state HMM:
1477 ////////////////////////////////////////////////////////////////////////////////////////////////////
1478 if (m_countOfStates == 2)
1480 outf << "** Two State system: Log ratio probability, State 0 to State 1 **" << endl;
1481 for ( i=0; i< m_countOfSymbols; i++)
1483 thisSymbol = m_symbolList[i];
1484 // oneList = m_symbolStateList[i];
1485 outf <<endl << thisSymbol<< "\t" <<
1486 m_B[0][i] <<"\t" <<
1487 m_B[1][i] <<"\t" <<
1488 base2log ( m_B[0][i] / m_B[1][i] );
1491 ////////////////////////////////////////////////////////////////////////////////////////////////////
1493 ////////////////////////////////////////////////////////////////////////////////////////////////////
1496 outf << endl << endl << "** Symbol StateList: Log ratio probability, this state to all others, individually sorted by best state **" << endl;
1497 for ( i=0; i< m_countOfSymbols; i++)
1499 thisSymbol = m_symbolList[i];
1500 oneList = m_symbolStateList[i];
1502 outf <<endl << thisSymbol<< "\t";
1504 for ( oneItem=oneList ->first(); oneItem != 0; oneItem=oneList ->next())
1506 // stateNumber = oneItem ->m_stateNumber;
1507 // ratioValue = oneItem ->m_probRatio;
1508 // oneOutString = QString("%1\t%2\t").arg(stateNumber).arg(ratioValue );
1509 // outf << oneOutString ;
1513 ////////////////////////////////////////////////////////////////////////////////////////////////////
1514 // Symbol StateList: Log ratio probability, this state to all others, states in normal order
1515 ////////////////////////////////////////////////////////////////////////////////////////////////////
1517 outf << endl << endl << "\n\n** Symbol StateList: Log ratio probability, this state to all others, states in normal order **" << endl;
1518 for ( i=0; i< m_countOfSymbols; i++)
1520 thisSymbol = m_symbolList[i];
1521 oneList = m_symbolStateList[i];
1523 outf <<endl << thisSymbol<< " ";
1525 for (stateNumber= 0; stateNumber < m_countOfStates; stateNumber++ )
1527 for ( oneItem=oneList ->first(); oneItem != 0; oneItem=oneList ->next()) // this "unsorts" the list
1529 thisStateNumber = oneItem ->m_stateNumber;
1530 if ( thisStateNumber == stateNumber )
1532 ratioValue = oneItem ->m_probRatio;
1533 oneOutString = QString("%1\t%2\t").arg(stateNumber).arg(ratioValue );
1534 outf << oneOutString ;
1540 ////////////////////////////////////////////////////////////////////////////////////////////////////
1541 // JG: Output ordered list of symbols for each state.
1542 ////////////////////////////////////////////////////////////////////////////////////////////////////
1544 double* SymbolEmissions; SymbolEmissions = NULL;
1545 int* SortedOutput; SortedOutput = NULL;
1547 outf << "\n\n** Each state: ordered emissions **";
1548 for (i = 0; i < m_countOfStates; i++)
1550 outf << endl << "\nState " << i;
1551 if ( SymbolEmissions) { delete SymbolEmissions; }
1552 if ( SortedOutput ) { delete SortedOutput; }
1553 SymbolEmissions = new double[ m_countOfSymbols];
1554 SortedOutput = new int [m_countOfSymbols];
1556 for (j = 0; j < m_countOfSymbols; j++)
1558 SymbolEmissions[j] = m_B[i][j];
1560 SortVector (SortedOutput, SymbolEmissions, m_countOfSymbols);
1562 for (j = 0; j < m_countOfSymbols; j++)
1564 outf << endl <<
1565 m_symbolList[SortedOutput[j]] << "\t"<<
1566 m_B[i][SortedOutput[j]];
1570 ////////////////////////////////////////////////////////////////////////////////////////////////////
1572 ////////////////////////////////////////////////////////////////////////////////////////////////////
1574 outf <<endl <<endl<<endl;
1575 file.close();
1580 // --------------------------------- For graphical display --------------------------//
1581 void StateEmitHMM::InsertValues ( VideoFrame* pVideoFrame)
1583 LabeledDataPoint* pDataPoint = NULL;
1585 for (int i = 0; i < m_countOfSymbols; i++)
1587 pDataPoint = new LabeledDataPoint ( m_symbolList[i], m_countOfStates );
1588 for (int j = 0; j < m_countOfStates; j++)
1590 *pDataPoint << m_B[j][i];
1592 pVideoFrame->append ( pDataPoint );
1598 void StateEmitHMM::Display ()
1600 // m_Video;
1602 int NumberOfSymbols;
1603 int NumberOfStates;
1604 double** ptrData;
1605 IntToString listOfSymbols;
1606 int i,j;
1607 int NumberOfIterations =100;
1609 NumberOfSymbols = m_Video->GetNumberOfDataPoints();
1610 NumberOfStates = m_Video->GetDimensionality();
1612 // Update the GraphicView to Display the HMM multidimension data
1613 ptrData = new double*[ NumberOfSymbols ];
1614 for ( i=0; i<NumberOfSymbols; i++)
1616 ptrData[i] = new double[NumberOfStates] ;
1619 // assign initial ptrData
1620 for(i=0; i<NumberOfSymbols;i++)
1622 listOfSymbols.insert(i, QString("s%1").arg(i));
1623 for(j=0; j<NumberOfStates; j++)
1625 ptrData[i][j] = 1.0 / NumberOfStates;
1629 // Iterations...
1631 for ( pFrame = m_Video->first(); pFrame; pFrame = m_Video->next() )
1634 // Display on Graphic View
1635 updateSmallGraphicDisplaySlotForMultiDimensionData(NumberOfStates,
1636 NumberOfSymbols,
1637 ptrData,
1638 listOfSymbols);
1642 // Sleep for a short time : Is there any Qt sleep function ?
1643 int dummy;
1644 for(j=0; j<500000; j++)
1646 dummy++;
1647 dummy--;
1651 //clean ptrData
1652 for(i=0; i<NumberOfSymbols; i++)
1654 delete [] ptrData[i];
1657 delete []ptrData;