1 // Implementation of StateEmitHMM methods
2 // Copyright © 2009 The University of Chicago
3 #include "StateEmitHMM.h"
7 #include <Q3TextStream>
10 #include <Q3SortedList>
11 #include <QMessageBox>
12 #include "linguisticamainwindow.h"
13 #include "ui/Status.h"
15 #include "GraphicView.h"
17 #include "WordCollection.h"
18 #include "StringSurrogate.h"
20 #include "CompareFunc.h"
25 //////////////////////////////////////////////////////////////////////
26 // Construction/Destruction
27 //////////////////////////////////////////////////////////////////////
29 StateEmitHMM::StateEmitHMM(LinguisticaMainWindow
* parent
)
38 m_PI_SoftCounts
= NULL
;
39 m_A_SoftCounts
= NULL
;
40 m_B_SoftCounts
= NULL
;
41 m_trainingDatum
= NULL
;
42 m_trainingDatumSource
= NULL
;
43 m_symbolStateList
.clear();
47 m_lengthOfObservation
= 0;
48 m_NumberOfIterations
= 0;
49 m_countOfDataItems
= 0;
50 m_maxLengthOfObservation
=0;
51 m_HMMLog
= parent
->GetLogFileName();
52 m_HMMLogDirectory
= parent
->GetLogFileName();
54 // m_HMMLog = QString("Linguistica_HMM_Log.txt");
56 m_WordProbabilities
= NULL
;
59 m_encodedTrainingData
.clear();
63 StateEmitHMM::StateEmitHMM(CLexicon
* pLexicon
)
72 m_PI_SoftCounts
= NULL
;
73 m_A_SoftCounts
= NULL
;
74 m_B_SoftCounts
= NULL
;
75 m_trainingDatum
= NULL
;
76 m_trainingDatumSource
= NULL
;
77 m_symbolStateList
.clear();
81 m_lengthOfObservation
= 0;
82 m_NumberOfIterations
= 0;
83 m_countOfDataItems
= 0;
84 m_maxLengthOfObservation
=0;
85 m_WordProbabilities
= NULL
;
86 m_HMMLog
= QString("Linguistica_HMM_Log.txt");
90 m_encodedTrainingData
.clear();
93 StateEmitHMM::~StateEmitHMM()
98 /// Figure out what kind of phonological segments are being used as the symbols of this HMM
101 /// - type - can be PHONE_TIER1, PHONE_TIER2, or PARSE (which is essentially the spelling)
104 bool StateEmitHMM::preprocessData(eHmmDataType type
, void* dataCollection
)
109 // depending on the type, process the data
110 if ( type
== PHONE_TIER1
) // take the wordCollection, and get PhoneTier 1
112 CWordCollection
* PhoneData
;
118 CStringSurrogate
* poundCSS
;
119 QString poundStr
= QString("#");
123 poundCSS
= new CStringSurrogate(poundStr
);
125 // The data is phoneCollection
126 PhoneData
= (CWordCollection
*)dataCollection
;
131 for (i
= 0; i
< PhoneData
->GetCount(); i
++)
133 pWord
= PhoneData
->GetAt(i
);
134 pOriginData
= (void*)pWord
;
135 pCount
= pWord
->GetCorpusCount();
136 pDataItem
= pWord
->GetPhonology_Tier1();
138 if ( pDataItem
->Size() ==0)
145 pDataItem
->RemovePiecesThatBegin((*poundCSS
));
147 if ( pDataItem
->Size() ==0)
152 m_trainingDataSource
. insert(Index
, pOriginData
);
153 m_trainingData
. insert(Index
, pDataItem
);
154 m_trainingDataFrequency
.insert(Index
, pCount
);
155 m_trainingDataSizes
. insert(Index
, pDataItem
->Size());
158 if ( pDataItem
->Size() > m_maxLengthOfObservation
)
160 m_maxLengthOfObservation
= pDataItem
->Size() ;
164 m_countOfDataItems
= Index
;
170 if ( type
== PHONE_TIER2
) // take the wordCollection, and get PhoneTier 2
172 CWordCollection
* PhoneData
;
180 // The data is phoneCollection
181 PhoneData
= (CWordCollection
*)dataCollection
;
184 for (i
= 0; i
< PhoneData
->GetCount(); i
++)
186 pWord
= PhoneData
->GetAt(i
);
187 pOriginData
= (void*)pWord
;
188 pCount
= pWord
->GetCorpusCount();
189 pDataItem
= pWord
->GetPhonology_Tier2();
190 if ( pDataItem
->Size() ==0)
194 m_trainingDataSource
. insert(Index
, pOriginData
);
195 m_trainingData
. insert(Index
, pDataItem
);
196 m_trainingDataFrequency
.insert(Index
, pCount
);
197 m_trainingDataSizes
. insert(Index
, pDataItem
->Size());
200 if ( pDataItem
->Size() > m_maxLengthOfObservation
)
202 m_maxLengthOfObservation
= pDataItem
->Size() ;
206 m_countOfDataItems
= Index
;
212 if ( type
== PARSE
) // take the wordCollection, and get CParse
214 CWordCollection
* WordData
;
222 // The data is WordCollection
223 WordData
= (CWordCollection
*)dataCollection
;
226 for (i
= 0; i
< WordData
->GetCount(); i
++)
228 pWord
= WordData
->GetAt(i
);
229 pOriginData
= (void*)pWord
;
230 pCount
= pWord
->GetCorpusCount();
231 pDataItem
= (CParse
*)pWord
;
233 m_trainingDataSource
. insert(Index
, pOriginData
);
234 m_trainingData
. insert(Index
, pDataItem
);
235 m_trainingDataFrequency
.insert(Index
, pCount
);
236 m_trainingDataSizes
. insert(Index
, pDataItem
->Size());
239 if ( pDataItem
->Size() > m_maxLengthOfObservation
)
241 m_maxLengthOfObservation
= pDataItem
->Size() ;
246 m_countOfDataItems
= Index
;
254 void StateEmitHMM::init(int countOfStates
, int loops
)
256 StringToInt::iterator StringToIntIt
;
257 StringToInt tempSymbolList
;
268 if ( countOfStates
< 2 )
270 QMessageBox::information(NULL
, "error", "The number of states of one HMM should > 1", "OK");
274 m_countOfStates
= countOfStates
;
277 if ( m_countOfDataItems
== 0)
279 QMessageBox::information(NULL
, "error", "The training data is empty", "OK");
284 // m_Video = new Video ( countOfStates, m_countOfDataItems );
285 m_Video
= new Video ( countOfStates
);
287 for ( i
=0; i
< m_countOfDataItems
; i
++)
289 oneDataItem
= m_trainingData
[i
];
290 for ( j
=1; j
<= oneDataItem
->Size(); j
++)
292 oneSymbol
= (oneDataItem
->GetPiece(j
)).Display();
293 tempSymbolList
.insert(oneSymbol
, 1);
299 for ( StringToIntIt
= tempSymbolList
.begin(); StringToIntIt
!= tempSymbolList
.end(); StringToIntIt
++)
301 oneSymbol
= StringToIntIt
.key();
302 m_symbolList
. insert(symbolIndex
, oneSymbol
);
303 m_symbolIndex
. insert(oneSymbol
, symbolIndex
);
307 m_countOfSymbols
= symbolIndex
;
309 // After we've got symbols, we encode training data
310 for ( i
=0; i
< m_countOfDataItems
; i
++)
312 oneDataItem
= m_trainingData
[i
];
313 itemSize
= m_trainingDataSizes
[i
];
314 oneIntArray
= new int[1+itemSize
]; // init an int array, whose size is (1+itemSize). The zero location is not used. we use from 1 to itemSize in this array
315 m_encodedTrainingData
.insert(i
, oneIntArray
);
316 for ( j
=1; j
<= itemSize
; j
++)
318 oneSymbol
= (oneDataItem
->GetPiece(j
)).Display();
319 encodeId
= m_symbolIndex
[oneSymbol
];
320 oneIntArray
[j
] = encodeId
;
323 ///////////////////////////////////////////////////////////////////////////////
324 // Output alphabet members to file //
325 ///////////////////////////////////////////////////////////////////////////////
326 QFile
file( m_HMMLogDirectory
+ "segments.txt");
327 if ( !file
.open( QIODevice::WriteOnly
| QIODevice::Append
) )
329 QMessageBox::information(NULL
, "Error", "Can't Open the segments file!", "OK");
333 Q3TextStream
outf( &file
);
334 for ( i
=0; i
< m_countOfSymbols
; i
++)
336 outf
<< m_symbolList
[i
] << endl
;
340 ///////////////////////////////////////////////////////////////////////////////
343 // allocate memory for prob parameters
345 m_PI
= new double[m_countOfStates
];
348 m_A
= new double*[m_countOfStates
];
349 for ( i
=0; i
< m_countOfStates
; i
++)
351 m_A
[i
] = new double [m_countOfStates
];
355 m_B
= new double*[m_countOfStates
];
356 for ( i
=0; i
< m_countOfStates
; i
++)
358 m_B
[i
] = new double [m_countOfSymbols
];
362 m_PI_SoftCounts
= new double[m_countOfStates
];
365 m_A_SoftCounts
= new double*[m_countOfStates
];
366 for ( i
=0; i
< m_countOfStates
; i
++)
368 m_A_SoftCounts
[i
] = new double [m_countOfStates
];
372 m_B_SoftCounts
= new double*[m_countOfStates
];
373 for ( i
=0; i
< m_countOfStates
; i
++)
375 m_B_SoftCounts
[i
] = new double [m_countOfSymbols
];
379 m_Alpha
= new double*[m_maxLengthOfObservation
+ 1];
380 for ( i
=0; i
< m_maxLengthOfObservation
+ 1; i
++)
382 m_Alpha
[i
] = new double [m_countOfStates
];
386 m_Beta
= new double*[m_maxLengthOfObservation
+ 1];
387 for ( i
=0; i
< m_maxLengthOfObservation
+ 1; i
++)
389 m_Beta
[i
] = new double [m_countOfStates
];
393 // m_P from 0 to T - 1
394 m_P
= new double**[m_maxLengthOfObservation
];
395 for ( i
=0; i
< m_maxLengthOfObservation
; i
++)
397 m_P
[i
] = new double* [m_countOfStates
];
399 for ( j
=0; j
< m_countOfStates
; j
++)
401 m_P
[i
][j
] = new double[m_countOfStates
];
405 // Word probabilities:
406 m_WordProbabilities
= new double [m_countOfDataItems
];
410 m_NumberOfIterations
= loops
;
412 ////////////////////////////////////////////////////////////////////////////////////////////////
413 ////////////////////////////////////////////////////////////////////////////////////////////////
414 QFile
file2( m_HMMLogDirectory
+ "config.txt");
415 if ( !file2
.open( QIODevice::WriteOnly
| QIODevice::Append
) )
417 QMessageBox::information(NULL
, "Error", "Can't Open the HMM Log config file!", "OK");
421 Q3TextStream
outf2( &file2
);
423 outf2
<< m_countOfStates
<< " = Number of states " << endl
;
424 outf2
<< m_countOfSymbols
<< " = Number of symbols" << endl
;
425 outf2
<< m_NumberOfIterations
<< " = Number of iterations";
428 ////////////////////////////////////////////////////////////////////////////////////////////////
429 ////////////////////////////////////////////////////////////////////////////////////////////////
434 void StateEmitHMM::initPiAndAB()
440 double oneUniformProb
;
443 double random_weight
= 0.25; //was 0.25
447 srand ( time(NULL
) );
450 oneUniformProb
= 1.0 / m_countOfStates
;
451 for ( i
=0; i
< m_countOfStates
; i
++)
453 m_PI
[i
] = oneUniformProb
;
457 //------------------------------------------------//
458 /* initialize transition probabilities */
459 for (i
= 0; i
< m_countOfStates
; i
++)
462 for ( j
=0; j
< m_countOfStates
; j
++)
464 m_A
[i
][j
] = (1 / (double) m_countOfStates
) * (1-random_weight
) + (random_weight
) * rand()/(double)RAND_MAX
;
468 for ( j
=0; j
< m_countOfStates
; j
++)
470 m_A
[i
][j
] = m_A
[i
][j
] / total
;
474 /* initialize emission probabilities */
476 for (i
= 0; i
< m_countOfStates
; i
++)
479 for ( j
=0; j
< m_countOfSymbols
; j
++)
481 m_B
[i
][j
] = (1 / (double) m_countOfSymbols
) * (1-random_weight
-0.1) + (random_weight
+0.1) * rand()/(double)RAND_MAX
;
485 for ( j
=0; j
< m_countOfSymbols
; j
++)
487 m_B
[i
][j
] = m_B
[i
][j
] / total
;
493 /* initialize m_PI_SoftCounts as zero */
494 for ( i
=0; i
< m_countOfStates
; i
++)
496 m_PI_SoftCounts
[i
] = 0.0;
499 /* initialize m_A_SoftCounts as zero */
500 for (i
= 0; i
< m_countOfStates
; i
++)
502 for ( j
=0; j
< m_countOfStates
; j
++)
504 m_A_SoftCounts
[i
][j
] = 0.0;
509 /* initialize m_B_SoftCounts as zero */
510 for (i
= 0; i
< m_countOfStates
; i
++)
512 for ( j
=0; j
< m_countOfSymbols
; j
++)
514 m_B_SoftCounts
[i
][j
] = 0.0;
523 void StateEmitHMM::forward(int dataIndex
, double& result
)
533 // Get the current dataItem as CParse
534 m_trainingDatum
= m_encodedTrainingData
[ dataIndex
];
535 m_trainingDatumSource
= m_trainingDataSource
[ dataIndex
];
536 m_lengthOfObservation
= m_trainingDataSizes
[ dataIndex
];
538 // init the alpha as PI;
539 for (i
=0; i
< m_countOfStates
; i
++)
541 m_Alpha
[0][i
]= m_PI
[i
];
545 for ( i
=1; i
<= m_lengthOfObservation
; i
++)
549 // figure out which symbol is emitted for this
550 symbolIndex
= m_trainingDatum
[i
];
552 for ( j
=0; j
< m_countOfStates
; j
++)
555 for (k
=0; k
< m_countOfStates
; k
++)
557 sum
+= m_Alpha
[i
-1][k
] * m_A
[k
][j
] * m_B
[k
][symbolIndex
];
561 // normSum += sum; // add them up for normalization
564 // Norm the m_Alpha to avoid too small values
566 for ( j=0; j< m_countOfStates; j++)
568 m_Alpha[i][j] = m_Alpha[i][j] / normSum;
573 double probOfTheData
;
576 for ( i
=0; i
<m_countOfStates
; i
++)
578 probOfTheData
+= m_Alpha
[m_lengthOfObservation
][i
];
581 result
= probOfTheData
;
584 // Find pointer to original word or whatever
585 if ( m_dataType
== PHONE_TIER1
) {
586 pWord
= static_cast<CStem
*>(m_trainingDatumSource
);
587 pWord
->SetHMM_LogProbability(-1 * base2log(result
));
592 void StateEmitHMM::backward(int dataIndex
, double& result
)
599 // Get the current dataItem as CParse
600 m_trainingDatum
= m_encodedTrainingData
[dataIndex
];
601 m_trainingDatumSource
= m_trainingDataSource
[dataIndex
];
602 m_lengthOfObservation
= m_trainingDataSizes
[dataIndex
];
605 // init the beta as 1;
606 for (i
=0; i
< m_countOfStates
; i
++)
608 m_Beta
[m_lengthOfObservation
][i
]= 1;
612 for ( i
=m_lengthOfObservation
-1; i
>=0; i
--)
615 for ( j
=0; j
< m_countOfStates
; j
++)
617 // figure out which symbol is emitted for this
618 symbolIndex
= m_trainingDatum
[i
+1];
621 for (k
=0; k
< m_countOfStates
; k
++)
623 sum
+= m_Beta
[i
+1][k
] * m_A
[j
][k
] * m_B
[j
][symbolIndex
];
630 // Norm the m_Beta to avoid too small values
632 for ( j=0; j< m_countOfStates; j++)
634 m_Beta[i][j] = m_Beta[i][j] / normSum;
640 double probOfTheData
;
643 for ( i
=0; i
<m_countOfStates
; i
++)
645 probOfTheData
+= m_PI
[i
]*m_Beta
[0][i
];
648 result
= probOfTheData
;
650 //QMessageBox::information(NULL, "Debug", QString("The String Prob Computed by forward as %1").arg(probOfTheData), "OK");
654 void StateEmitHMM::Expectation(int dataIndex
)
657 /* Old version, with John's efforts
666 // Get the current dataItem as CParse
668 m_trainingDatum = m_encodedTrainingData [dataIndex];
669 m_trainingDatumSource = m_trainingDataSource [dataIndex];
670 m_lengthOfObservation = m_trainingDataSizes [dataIndex];
672 pCount = m_trainingDataFrequency[dataIndex];
674 Q_ASSERT(pCount > 0);
679 QFile file( m_HMMLogDirectory + "expectation.txt");
680 if ( !file.open( IO_WriteOnly | IO_Append ) )
682 QMessageBox::information(NULL, "Error", "Can't Open the HMM Log file!", "OK");
685 QTextStream outf( &file );
689 outf << endl <<"----------------------------------------------------------------------" << endl;
693 outf << endl << endl;
695 outf << dataIndex << "\tProbs:\nL\tFrom\tTo";
699 // compute the m_P: // (Position in string, from-state, to-state);
700 for ( i=0; i< m_lengthOfObservation; i++)
702 // figure out which symbol is emitted for this
703 symbolIndex = m_trainingDatum[i+1];
705 // compute the denominator
707 for ( j=0; j< m_countOfStates; j++)
709 for (k=0; k< m_countOfStates; k++)
711 denominator += m_Alpha[i][j] * m_A[j][k] * m_B[j][symbolIndex]* m_Beta[i+1][k];
715 Q_ASSERT (denominator);
717 for ( j=0; j< m_countOfStates; j++)
719 for (k=0; k< m_countOfStates; k++)
721 numerator = m_Alpha[i][j] * m_A[j][k] * m_B[j][symbolIndex]* m_Beta[i+1][k];
722 m_P[i][j][k] = numerator / denominator;
723 outf << endl << i << m_symbolList[symbolIndex] << "\t" << j << "\t" << k << "\t" << m_P[i][j][k];
729 // Next, compute softcount and add them to softcount data
733 outf << "\n\n" << "Soft counts for Pi:" << "\t";
734 for ( i=0; i< m_countOfStates; i++)
737 outf << "\nFrom:\t" << i;
738 for (j=0; j< m_countOfStates; j++)
740 oneGamma += m_P[0][i][j];
741 outf << "\tTo:\t"<< j <<" \t" << oneGamma << "\t";
745 m_PI_SoftCounts[i] += oneValue * pCount;
746 outf << "\tTotal soft counts (Pi) initially to state :"<< i << "\t"<<m_PI_SoftCounts[i];
750 for ( i=0; i< m_countOfStates; i++)
753 for ( k=0; k< m_lengthOfObservation; k++)
756 for ( j=0; j< m_countOfStates; j++)
758 oneGamma += m_P[k][i][j];
761 denominator += oneGamma;
763 Q_ASSERT (denominator);
766 for ( j=0; j< m_countOfStates; j++)
769 for ( k=0; k< m_lengthOfObservation; k++)
771 numerator += m_P[k][i][j];
774 oneValue = numerator / denominator;
775 m_A_SoftCounts[i][j] += oneValue * pCount;
776 outf << "\nTotal soft counts from state: "<<i<< " to state " << j << "\t"<<m_A_SoftCounts[i][j];
781 outf << "\n\nSoft counts of symbol emissions "<< endl;
782 for (k = 0; k < m_lengthOfObservation; k++)
784 // figure out which symbol is emitted for this
785 symbolIndex = m_trainingDatum[k+1];
787 outf << endl <<"Dealing with symbol: " << m_symbolList[symbolIndex];
789 // Compute denominator
792 for (i = 0; i < m_countOfStates; i++)
794 for (j = 0; j < m_countOfStates; j++)
796 denominator += m_P[k][i][j];
799 Q_ASSERT (denominator);
801 // Compute numerator and m_B
802 for (i = 0; i < m_countOfStates; i++)
804 outf << "\n\tState: " << i << " ";
805 for (j = 0; j < m_countOfStates; j++)
807 outf << " to state " << j << " ";
808 m_B_SoftCounts[i][symbolIndex] += ( m_P[k][i][j] / denominator) * pCount;
809 outf << m_P[k][i][j] / denominator;
814 for (i = 0; i < m_countOfStates; i++)
816 outf << endl << "\nIn State: "<< i;
817 for (j= 0; j < m_countOfSymbols; j++)
819 outf << "\tEmit symbol: "<< j<< " "<< m_symbolList[j] << " "<<m_B_SoftCounts[i][j];
827 */ //end of OLD VERSION
837 // Get the current dataItem as CParse
838 m_trainingDatum
= m_encodedTrainingData
[dataIndex
];
839 m_trainingDatumSource
= m_trainingDataSource
[dataIndex
];
840 m_lengthOfObservation
= m_trainingDataSizes
[dataIndex
];
841 pCount
= m_trainingDataFrequency
[dataIndex
];
843 Q_ASSERT(pCount
> 0);
847 for ( j
=0; j
< m_countOfStates
; j
++)
850 for (k=0; k< m_countOfStates; k++)
852 denominator += m_Alpha[i][j] * m_A[j][k] * m_B[j][symbolIndex]* m_Beta[i+1][k];
855 denominator
+= m_Alpha
[m_lengthOfObservation
][j
];
860 for ( i
=0; i
< m_lengthOfObservation
; i
++)
862 // figure out which symbol is emitted for this
863 symbolIndex
= m_trainingDatum
[i
+1];
865 // compute the denominator
868 Q_ASSERT (denominator
);
870 for ( j
=0; j
< m_countOfStates
; j
++)
872 for (k
=0; k
< m_countOfStates
; k
++)
874 numerator
= m_Alpha
[i
][j
] * m_A
[j
][k
] * m_B
[j
][symbolIndex
]* m_Beta
[i
+1][k
];
875 m_P
[i
][j
][k
] = numerator
/ denominator
;
881 // Next, compute softcount and add them to softcount data
884 for ( i
=0; i
< m_countOfStates
; i
++)
888 for (j
=0; j
< m_countOfStates
; j
++)
890 oneGamma
+= m_P
[0][i
][j
];
894 m_PI_SoftCounts
[i
] += oneValue
* pCount
;
898 for ( i
=0; i
< m_countOfStates
; i
++)
900 for ( j
=0; j
< m_countOfStates
; j
++)
903 for ( k
=0; k
< m_lengthOfObservation
; k
++)
905 denominator
+= m_P
[k
][i
][j
];
908 Q_ASSERT (denominator
);
909 m_A_SoftCounts
[i
][j
] += denominator
* pCount
;
916 for (i
=0; i
< m_countOfStates
; i
++)
919 // Compute numerator and m_B
920 for (k
=0; k
< m_lengthOfObservation
; k
++)
922 // figure out which symbol is emitted for this
923 symbolIndex
= m_trainingDatum
[k
+1];
926 for ( j
=0; j
< m_countOfStates
; j
++)
928 numerator
+= m_P
[k
][i
][j
];
930 oneValue
= numerator
;
931 m_B_SoftCounts
[i
][symbolIndex
] += oneValue
* pCount
;
939 /// Calculate the new probability parameters, based on soft counts, and emission entropy
944 void StateEmitHMM::Maximization()
949 QFile
file( m_HMMLogDirectory
+ "maximization.txt");
950 if ( !file
.open( QIODevice::WriteOnly
| QIODevice::Append
) )
952 QMessageBox::information(NULL
, "Error", "Can't Open the maximization file!", "OK");
955 Q3TextStream
outf( &file
);
957 outf
<< endl
<<"----------------------------------------------------------------------" << endl
;
959 outf
<< "\nPi softcounts:\t";
960 // Re-Estimate the m_PI
962 for ( i
=0; i
< m_countOfStates
; i
++)
964 total
+= m_PI_SoftCounts
[i
];
965 outf
<< "State: "<<i
<<"\t" << m_PI_SoftCounts
[i
] << "\t";
970 for ( i
=0; i
< m_countOfStates
; i
++)
972 m_PI
[i
] = m_PI_SoftCounts
[i
] / total
;
973 m_PI_SoftCounts
[i
] = 0.0;
975 outf
<< endl
<< "Pi for State: "<<"\t" << m_PI
[i
];
979 outf
<< "\n\nMatrix A";
980 // Re-Estimate the m_A
981 for ( i
=0; i
< m_countOfStates
; i
++)
983 outf
<< "\n\nFrom State: "<< i
<< "\t";
985 for (j
=0; j
< m_countOfStates
; j
++)
987 total
+= m_A_SoftCounts
[i
][j
];
988 outf
<< " To State: "<< j
<< "\t" << m_A_SoftCounts
[i
][j
];
990 outf
<< " -- Total: "<< total
;
993 outf
<< "\nTransition probabilities: ";
994 for (j
=0; j
< m_countOfStates
; j
++)
996 m_A
[i
][j
] = m_A_SoftCounts
[i
][j
] /total
;
997 m_A_SoftCounts
[i
][j
] = 0.0;
998 outf
<< "\tto: "<< j
<< "\t"<< m_A
[i
][j
];
1002 outf
<< "\n\nMatrix B";
1003 // Re-Estimate the m_B
1004 for ( i
=0; i
< m_countOfStates
; i
++)
1006 outf
<< "\nFrom State: " << i
<< endl
;
1008 for (j
=0; j
< m_countOfSymbols
; j
++)
1010 total
+= m_B_SoftCounts
[i
][j
];
1011 outf
<< "\tCount of symbol "<< j
<< " "<<m_B_SoftCounts
[i
][j
];
1019 // ************** bug found JG Nov 30 2006 **********
1022 for (j
=0; j
< m_countOfSymbols
; j
++)
1025 m_B_SoftCounts
[i
][j
] = 0.0;
1026 outf
<< "\tSymb: "<< j
<< "\t" << m_B
[i
][j
];
1031 for (j
=0; j
< m_countOfSymbols
; j
++)
1033 m_B
[i
][j
] = m_B_SoftCounts
[i
][j
] /total
;
1034 m_B_SoftCounts
[i
][j
] = 0.0;
1035 outf
<< "Symbol: "<< j
<< "\t" << m_B
[i
][j
];
1040 // calculate each state's emission entropy:
1041 if (m_Entropy
) delete m_Entropy
;
1042 m_Entropy
= new double [m_countOfStates
];
1044 for ( i
=0; i
< m_countOfStates
; i
++)
1047 for (j
=0; j
< m_countOfSymbols
; j
++)
1049 Q_ASSERT (m_B
[i
][j
] );
1050 if (base2log (m_B
[i
][j
]) > -1* 1000000) // i.e., if m_B[i][j] ins't too small...
1052 total
+= -1 * m_B
[i
][j
] * base2log ( m_B
[i
][j
] );
1055 m_Entropy
[i
] = total
;
1060 /// The highest level iteration
1061 double StateEmitHMM::trainParameters()
1063 CLexicon
& lex
= *m_Lexicon
;
1064 linguistica::ui::status_user_agent
& status
= lex
.status_display();
1067 double stringProbWithForward
;
1068 double stringProbWithBackward
;
1069 int HowFrequentlyToOutputStatistics
;
1071 VideoFrame
* pVideoFrame
;
1073 HowFrequentlyToOutputStatistics
= 1;
1075 status
.major_operation
= "Training HMM";
1076 status
.progress
.clear();
1077 status
.progress
.set_denominator(m_NumberOfIterations
);
1078 for (int i
= 0; i
<m_NumberOfIterations
; i
++) {
1079 m_LogProbabilityOfData
= 0;
1080 status
.progress
= i
;
1082 //OutputTransitions(i, m_HMMLog);
1083 for ( dataIndex
=0; dataIndex
< m_countOfDataItems
; dataIndex
++)
1086 forward (dataIndex
, stringProbWithForward
);
1087 backward(dataIndex
, stringProbWithBackward
);
1089 Q_ASSERT (stringProbWithForward
== stringProbWithBackward
);
1090 Q_ASSERT (stringProbWithForward
);
1092 Expectation(dataIndex
);
1094 m_LogProbabilityOfData
+= -1 * base2log (stringProbWithForward
);
1099 OutputTransitionsToLogFile(i
, m_HMMLog
);
1101 pVideoFrame
= new VideoFrame(m_countOfStates
);
1102 InsertValues (pVideoFrame
);
1103 m_Video
->insert (i
, pVideoFrame
);
1106 if (dummy
== HowFrequentlyToOutputStatistics
)
1109 OutputEmissions (i
, m_HMMLog
);
1110 OutputTransitions(i
, m_HMMLog
);
1111 OutputInitials (i
, m_HMMLog
);
1117 status
.progress
.clear();
1118 getStateListForASymbol();
1119 status
.major_operation
.clear();
1120 return m_LogProbabilityOfData
;
1123 void StateEmitHMM::clear()
1126 double** oneTwoDimensionArray
;
1128 IntToIntArray::iterator IntToIntArrayIt
;
1132 m_trainingData
.clear();
1133 m_trainingDataFrequency
.clear();
1134 m_trainingDataSizes
.clear();
1136 // Clear Origin data
1137 m_trainingDataSource
.clear();
1140 // Clear m_encodedTrainingData
1141 for ( IntToIntArrayIt
= m_encodedTrainingData
.begin(); IntToIntArrayIt
!= m_encodedTrainingData
.end(); IntToIntArrayIt
++)
1143 oneIntArray
= IntToIntArrayIt
.data();
1144 if ( oneIntArray
!= NULL
)
1146 delete [] oneIntArray
;
1149 m_encodedTrainingData
.clear();
1152 // Clear m_WordProbabilities
1153 if (m_WordProbabilities
) delete m_WordProbabilities
;
1164 for ( i
=0; i
< m_countOfStates
; i
++)
1175 for ( i
=0; i
< m_countOfStates
; i
++)
1184 // clear m_PI_SoftCounts
1185 if ( m_PI_SoftCounts
!= NULL
)
1187 delete [] m_PI_SoftCounts
;
1188 m_PI_SoftCounts
= NULL
;
1191 // clear m_A_SoftCounts
1192 if ( m_A_SoftCounts
!= NULL
)
1194 for ( i
=0; i
< m_countOfStates
; i
++)
1196 oneArray
= m_A_SoftCounts
[i
];
1199 delete [] m_A_SoftCounts
;
1200 m_A_SoftCounts
= NULL
;
1202 // clear m_B_SoftCounts
1203 if ( m_B_SoftCounts
!= NULL
)
1205 for ( i
=0; i
< m_countOfStates
; i
++)
1207 oneArray
= m_B_SoftCounts
[i
];
1210 delete [] m_B_SoftCounts
;
1211 m_B_SoftCounts
= NULL
;
1215 if ( m_Alpha
!= NULL
)
1217 for ( i
=0; i
< m_maxLengthOfObservation
+ 1; i
++)
1219 oneArray
= m_Alpha
[i
];
1226 if ( m_Beta
!= NULL
)
1228 for ( i
=0; i
< m_maxLengthOfObservation
+ 1; i
++)
1230 oneArray
= m_Beta
[i
];
1240 for ( i
=0; i
< m_maxLengthOfObservation
; i
++)
1242 oneTwoDimensionArray
= m_P
[i
];
1244 for ( j
=0; j
< m_countOfStates
; j
++)
1246 oneArray
= oneTwoDimensionArray
[j
];
1250 delete oneTwoDimensionArray
;
1257 // clear m_symbolStateList
1258 IntTohmmSortedList::iterator IntTohmmSortedListIt
;
1259 hmmSortedList
* oneStateList
;
1261 for(IntTohmmSortedListIt
= m_symbolStateList
.begin(); IntTohmmSortedListIt
!= m_symbolStateList
.end(); IntTohmmSortedListIt
++)
1263 oneStateList
= IntTohmmSortedListIt
.data();
1264 oneStateList
->setAutoDelete(true);
1265 delete oneStateList
;
1267 m_symbolStateList
.clear();
1269 // set construction parameters back to zeros
1270 m_countOfStates
= 0;
1271 m_countOfSymbols
= 0;
1272 m_maxLengthOfObservation
= 0;
1273 m_NumberOfIterations
= 0;
1274 m_countOfDataItems
= 0;
1275 m_dataType
= HMMNONE
;
1278 m_symbolList
.clear();
1279 m_symbolIndex
.clear();
1289 void StateEmitHMM::getStateListForASymbol()
1295 hmmSortedList
* oneList
;
1296 hmmForSortingItem
* oneItem
;
1297 const double MaxLogRatio
= 10;
1298 const double MinLogRatio
= -100;
1301 for ( i
=0; i
< m_countOfSymbols
; i
++)
1303 oneList
= new hmmSortedList();
1304 m_symbolStateList
.insert(i
, oneList
);
1307 for (j
=0; j
< m_countOfStates
; j
++)
1312 for (j
=0; j
< m_countOfStates
; j
++)
1314 oneValue
= m_B
[j
][i
];
1315 restValue
= total
- oneValue
;
1317 if ( restValue
== 0.0)
1319 LogRatio
= MaxLogRatio
;
1323 if (oneValue
== 0.0) {
1324 LogRatio
= MinLogRatio
;
1326 Q_ASSERT(restValue
!= 0);
1327 LogRatio
= base2log(oneValue
/ restValue
);
1330 if ( LogRatio
> MaxLogRatio
)
1332 LogRatio
= MaxLogRatio
;
1335 if ( LogRatio
< MinLogRatio
)
1337 LogRatio
= MinLogRatio
;
1342 oneItem
= new hmmForSortingItem(j
, LogRatio
);
1343 oneList
->append(oneItem
);
1351 void StateEmitHMM::OutputTransitionsToLogFile(int IterationNumber
,
1352 QString LogFileName
)
1354 QFile
file(LogFileName
);
1355 if (!file
.open(QIODevice::WriteOnly
| QIODevice::Append
)) {
1356 QMessageBox::information(0, "Error",
1357 "Can't Open the HMM Log file!", "OK");
1360 QTextStream
outf(&file
);
1362 outf
<< endl
<< IterationNumber
<< '\t';
1363 for (int i
= 0; i
< m_countOfStates
; ++i
)
1364 for (int j
= 0; j
< m_countOfStates
; ++j
)
1365 outf
<< m_A
[i
][j
] << '\t';
1366 for (int i
= 0; i
< m_countOfStates
; ++i
)
1367 outf
<< m_Entropy
[i
] << '\t';
1370 void StateEmitHMM::OutputTransitions(int /* iteration number, unused */,
1371 QString
/* log file name. XXX. unused */)
1373 QFile
file2(QString("%1transitions.txt").arg(m_HMMLogDirectory
));
1374 if (!file2
.open(QIODevice::WriteOnly
| QIODevice::Append
)) {
1375 QMessageBox::information(0, "Error",
1376 "Can't Open the HMM transitions file!", "OK");
1380 QTextStream
outf2(&file2
);
1381 for (int i
= 0; i
< m_countOfStates
; ++i
)
1382 for (int j
= 0; j
< m_countOfStates
; ++j
)
1383 outf2
<< m_A
[i
][j
] << '\t';
1387 void StateEmitHMM::OutputEmissions(int /* iteration number, unused */,
1388 QString
/* log file name. XXX. unused */)
1390 QFile
file(QString("%1emissions.txt").arg(m_HMMLogDirectory
));
1391 if (!file
.open(QIODevice::WriteOnly
| QIODevice::Append
)) {
1392 QMessageBox::information(0, "Error",
1393 "Can't Open the HMM emissions file!", "OK");
1397 QTextStream
outf(&file
);
1398 for (int i
= 0; i
< m_countOfStates
; ++i
)
1399 for (int j
= 0; j
< m_countOfSymbols
; ++j
)
1400 outf
<< m_B
[i
][j
] << '\t';
1404 void StateEmitHMM::OutputInitials(int /* iteration number, unused */,
1405 QString
/* log file name. XXX. unused */)
1407 QFile
file(QString("%1initials.txt").arg(m_HMMLogDirectory
));
1408 if (!file
.open(QIODevice::WriteOnly
| QIODevice::Append
)) {
1409 QMessageBox::information(0, "Error",
1410 "Can't Open the HMM initials file!", "OK");
1414 QTextStream
outf(&file
);
1415 for (int j
= 0; j
< m_countOfStates
; ++j
)
1416 outf
<< m_PI
[j
] << '\t';
1420 void StateEmitHMM::logInfo(double totalLogProbability
)
1422 QFile
file( m_HMMLog
);
1424 QString oneOutString
;
1426 if ( !file
.open( QIODevice::WriteOnly
| QIODevice::Append
) )
1428 QMessageBox::information(NULL
, "Error", "Can't Open the HMM Log file!", "OK");
1432 Q3TextStream
outf( &file
);
1434 outf
<< "\n*** One SnapShot ***" << endl
<<endl
;
1436 outf
<< "Total log probability: "<<totalLogProbability
<< endl
<< endl
;
1439 outf
<< "** PI Values **" << endl
;
1440 for (i
=0; i
<m_countOfStates
; i
++)
1442 oneOutString
= QString("State %1:\t %2\t").arg(i
).arg(m_PI
[i
]);
1443 outf
<<" "<< oneOutString
<< endl
;
1447 outf
<< "** A Values **" << endl
;
1448 for (i
=0; i
<m_countOfStates
; i
++)
1450 for ( j
=0; j
< m_countOfStates
; j
++)
1452 oneOutString
= QString("State %1 -> %2 :\t%3").arg(i
).arg(j
).arg(m_A
[i
][j
]);
1453 outf
<<" "<< oneOutString
<< endl
;
1458 outf
<< "** B Values **" << endl
;
1459 for (i
=0; i
<m_countOfStates
; i
++)
1461 for ( j
=0; j
< m_countOfSymbols
; j
++)
1463 oneOutString
= QString("State %1\tEmit\t%2 :\t%3").arg(i
).arg(m_symbolList
[j
]).arg(m_B
[i
][j
]);
1464 outf
<<" "<< oneOutString
<< endl
;
1468 // Output the m_symbolStateList
1470 hmmSortedList
* oneList
;
1471 hmmForSortingItem
* oneItem
;
1472 int stateNumber
, thisStateNumber
;
1475 ////////////////////////////////////////////////////////////////////////////////////////////////////
1476 // The special, but usual, case where we care about a 2-state HMM:
1477 ////////////////////////////////////////////////////////////////////////////////////////////////////
1478 if (m_countOfStates
== 2)
1480 outf
<< "** Two State system: Log ratio probability, State 0 to State 1 **" << endl
;
1481 for ( i
=0; i
< m_countOfSymbols
; i
++)
1483 thisSymbol
= m_symbolList
[i
];
1484 // oneList = m_symbolStateList[i];
1485 outf
<<endl
<< thisSymbol
<< "\t" <<
1488 base2log ( m_B
[0][i
] / m_B
[1][i
] );
1491 ////////////////////////////////////////////////////////////////////////////////////////////////////
1493 ////////////////////////////////////////////////////////////////////////////////////////////////////
1496 outf
<< endl
<< endl
<< "** Symbol StateList: Log ratio probability, this state to all others, individually sorted by best state **" << endl
;
1497 for ( i
=0; i
< m_countOfSymbols
; i
++)
1499 thisSymbol
= m_symbolList
[i
];
1500 oneList
= m_symbolStateList
[i
];
1502 outf
<<endl
<< thisSymbol
<< "\t";
1504 for ( oneItem
=oneList
->first(); oneItem
!= 0; oneItem
=oneList
->next())
1506 // stateNumber = oneItem ->m_stateNumber;
1507 // ratioValue = oneItem ->m_probRatio;
1508 // oneOutString = QString("%1\t%2\t").arg(stateNumber).arg(ratioValue );
1509 // outf << oneOutString ;
1513 ////////////////////////////////////////////////////////////////////////////////////////////////////
1514 // Symbol StateList: Log ratio probability, this state to all others, states in normal order
1515 ////////////////////////////////////////////////////////////////////////////////////////////////////
1517 outf
<< endl
<< endl
<< "\n\n** Symbol StateList: Log ratio probability, this state to all others, states in normal order **" << endl
;
1518 for ( i
=0; i
< m_countOfSymbols
; i
++)
1520 thisSymbol
= m_symbolList
[i
];
1521 oneList
= m_symbolStateList
[i
];
1523 outf
<<endl
<< thisSymbol
<< " ";
1525 for (stateNumber
= 0; stateNumber
< m_countOfStates
; stateNumber
++ )
1527 for ( oneItem
=oneList
->first(); oneItem
!= 0; oneItem
=oneList
->next()) // this "unsorts" the list
1529 thisStateNumber
= oneItem
->m_stateNumber
;
1530 if ( thisStateNumber
== stateNumber
)
1532 ratioValue
= oneItem
->m_probRatio
;
1533 oneOutString
= QString("%1\t%2\t").arg(stateNumber
).arg(ratioValue
);
1534 outf
<< oneOutString
;
1540 ////////////////////////////////////////////////////////////////////////////////////////////////////
1541 // JG: Output ordered list of symbols for each state.
1542 ////////////////////////////////////////////////////////////////////////////////////////////////////
1544 double* SymbolEmissions
; SymbolEmissions
= NULL
;
1545 int* SortedOutput
; SortedOutput
= NULL
;
1547 outf
<< "\n\n** Each state: ordered emissions **";
1548 for (i
= 0; i
< m_countOfStates
; i
++)
1550 outf
<< endl
<< "\nState " << i
;
1551 if ( SymbolEmissions
) { delete SymbolEmissions
; }
1552 if ( SortedOutput
) { delete SortedOutput
; }
1553 SymbolEmissions
= new double[ m_countOfSymbols
];
1554 SortedOutput
= new int [m_countOfSymbols
];
1556 for (j
= 0; j
< m_countOfSymbols
; j
++)
1558 SymbolEmissions
[j
] = m_B
[i
][j
];
1560 SortVector (SortedOutput
, SymbolEmissions
, m_countOfSymbols
);
1562 for (j
= 0; j
< m_countOfSymbols
; j
++)
1565 m_symbolList
[SortedOutput
[j
]] << "\t"<<
1566 m_B
[i
][SortedOutput
[j
]];
1570 ////////////////////////////////////////////////////////////////////////////////////////////////////
1572 ////////////////////////////////////////////////////////////////////////////////////////////////////
1574 outf
<<endl
<<endl
<<endl
;
1580 // --------------------------------- For graphical display --------------------------//
1581 void StateEmitHMM::InsertValues ( VideoFrame
* pVideoFrame
)
1583 LabeledDataPoint
* pDataPoint
= NULL
;
1585 for (int i
= 0; i
< m_countOfSymbols
; i
++)
1587 pDataPoint
= new LabeledDataPoint ( m_symbolList
[i
], m_countOfStates
);
1588 for (int j
= 0; j
< m_countOfStates
; j
++)
1590 *pDataPoint
<< m_B
[j
][i
];
1592 pVideoFrame
->append ( pDataPoint
);
1598 void StateEmitHMM::Display ()
1602 int NumberOfSymbols;
1605 IntToString listOfSymbols;
1607 int NumberOfIterations =100;
1609 NumberOfSymbols = m_Video->GetNumberOfDataPoints();
1610 NumberOfStates = m_Video->GetDimensionality();
1612 // Update the GraphicView to Display the HMM multidimension data
1613 ptrData = new double*[ NumberOfSymbols ];
1614 for ( i=0; i<NumberOfSymbols; i++)
1616 ptrData[i] = new double[NumberOfStates] ;
1619 // assign initial ptrData
1620 for(i=0; i<NumberOfSymbols;i++)
1622 listOfSymbols.insert(i, QString("s%1").arg(i));
1623 for(j=0; j<NumberOfStates; j++)
1625 ptrData[i][j] = 1.0 / NumberOfStates;
1631 for ( pFrame = m_Video->first(); pFrame; pFrame = m_Video->next() )
1634 // Display on Graphic View
1635 updateSmallGraphicDisplaySlotForMultiDimensionData(NumberOfStates,
1642 // Sleep for a short time : Is there any Qt sleep function ?
1644 for(j=0; j<500000; j++)
1652 for(i=0; i<NumberOfSymbols; i++)
1654 delete [] ptrData[i];