HowManyAreAnalyzed(): use status_user_agent to report progress
[linguistica.git] / cMTModel2Norm.cpp
blob6178e2a5631f137d75d48b060a44b5e8f1598f15
1 // Implementation of the cMTModel2Norm class
2 // Copyright © 2009 The University of Chicago
3 #include "cMTModel2Norm.h"
5 #include <QMessageBox>
6 #include "cMTModel1.h"
7 #include "mTVolca.h"
8 #include "cMT.h"
10 //////////////////////////////////////////////////////////////////////
11 // Construction/Destruction
12 //////////////////////////////////////////////////////////////////////
14 cMTModel2Norm::cMTModel2Norm(cMT* myMT, int Iterations, bool fromModel1)
15 : m_myMT(myMT),
16 m_Iterations(Iterations),
17 m_sentenceNormChunk(10),
18 m_getTfromModel1(fromModel1),
19 m_T(), m_A(),
20 m_softCountOfT(), m_softCountOfA(),
21 m_lamdaT(), m_lamdaA() { }
23 cMTModel2Norm::~cMTModel2Norm()
25 // XXX. not clear who owns what, so there are definitely some
26 // leaks here.
29 void cMTModel2Norm::initTandA()
31 mTVolca* myVolca;
32 int totalAssociatedLanguage2Words;
33 IntToIntToDouble::iterator IntToIntToDoubleIt;
34 IntToDouble::iterator IntToDoubleIt;
35 IntToDouble* oneListForOneLanguage1Word;
36 double uniformProb;
37 int key;
38 IntToDouble* oneListForOneLanuage2Chunk;
39 IntToDouble* oneSoftCountListForOneLanuage2Chunk;
40 int i,j;
44 myVolca = m_myMT ->m_Volca;
47 // Init m_T
48 if ( !m_getTfromModel1)
50 m_T = myVolca ->m_fastWordsPairs;
52 for ( IntToIntToDoubleIt = m_T.begin(); IntToIntToDoubleIt != m_T.end();IntToIntToDoubleIt++)
54 oneListForOneLanguage1Word = IntToIntToDoubleIt.data();
56 totalAssociatedLanguage2Words = oneListForOneLanguage1Word ->size();
58 uniformProb = 1.0 / totalAssociatedLanguage2Words;
60 for ( IntToDoubleIt = oneListForOneLanguage1Word ->begin(); IntToDoubleIt != oneListForOneLanguage1Word ->end(); IntToDoubleIt++)
62 key = IntToDoubleIt.key();
63 (*oneListForOneLanguage1Word)[key] = uniformProb;
67 else
69 m_T = m_myMT ->m_Model1 ->m_T;
73 // init m_A and m_softcountofA
74 for ( j=0; j< m_sentenceNormChunk; j++)
76 oneListForOneLanuage2Chunk = new IntToDouble();
77 oneSoftCountListForOneLanuage2Chunk = new IntToDouble();
78 m_A.insert(j, oneListForOneLanuage2Chunk);
79 m_softCountOfA.insert(j, oneSoftCountListForOneLanuage2Chunk);
81 uniformProb = 1.0/ m_sentenceNormChunk;
83 for ( i=0; i< m_sentenceNormChunk; i++)
85 oneListForOneLanuage2Chunk ->insert(i, uniformProb);
86 oneSoftCountListForOneLanuage2Chunk ->insert(i, 0.0);
91 // init m_softcountofT
92 m_softCountOfT = myVolca ->m_fastWordsSoftCounts;
94 for ( IntToIntToDoubleIt = m_softCountOfT.begin(); IntToIntToDoubleIt != m_softCountOfT.end();IntToIntToDoubleIt++)
96 oneListForOneLanguage1Word = IntToIntToDoubleIt.data();
98 for ( IntToDoubleIt = oneListForOneLanguage1Word ->begin(); IntToDoubleIt != oneListForOneLanguage1Word ->end(); IntToDoubleIt++)
100 IntToDoubleIt.data() = 0.0;
106 QMessageBox::information ( NULL, "Linguistica : MT Model2Norm", "Finished InitTandA", "OK" );
111 void cMTModel2Norm::EMLoops(int numberOfIterations)
113 int loopI;
115 for (loopI =0; loopI < numberOfIterations; loopI++)
117 // E step
118 EStep();
120 //QMessageBox::information ( NULL, "Linguistica : MT Model1", "Finished E-Step", "OK" );
122 // M step
123 MStep();
125 //QMessageBox::information ( NULL, "Linguistica : MT Model1", "Finished M-Step", "OK" );
127 // Clear softcouts for T
128 clearSoftCounts();
131 // Release softcountT memory
132 releaseSoftCounts();
135 void cMTModel2Norm::clearSoftCounts()
137 IntToIntToDouble::iterator IntToIntToDoubleIt;
138 IntToDouble* oneList;
139 IntToDouble* oneListForOneLanuage2Chunk;
140 IntToDouble::iterator IntToDoubleIt;
141 int key;
142 int i, j;
144 for ( IntToIntToDoubleIt = m_softCountOfT.begin(); IntToIntToDoubleIt != m_softCountOfT.end(); IntToIntToDoubleIt++)
146 oneList = IntToIntToDoubleIt.data();
148 for (IntToDoubleIt = oneList ->begin(); IntToDoubleIt != oneList ->end(); IntToDoubleIt++)
150 key = IntToDoubleIt.key();
151 (*oneList)[key] = 0.0;
155 for ( j=0; j< m_sentenceNormChunk; j++)
157 oneListForOneLanuage2Chunk = m_softCountOfA[j];
158 for ( i=0; i< m_sentenceNormChunk; i++)
160 (*oneListForOneLanuage2Chunk)[i] = 0.0;
167 void cMTModel2Norm::EStep()
169 mTVolca* myVolca;
170 int i;
172 myVolca = m_myMT ->m_Volca;
174 m_lamdaT.clear();
175 m_lamdaA.clear();
177 for ( i=0; i < myVolca ->m_countOfSentences; i++)
179 addSoftCountOfTandA(i);
184 void cMTModel2Norm::MStep()
186 IntToIntToDouble::iterator IntToIntToDoubleIt;
187 IntToDouble* oneList;
188 IntToDouble::iterator IntToDoubleIt;
189 int language1Id;
190 int language2Id;
191 int language1chunkId;
192 int language2chunkId;
193 double oneTotalSoftCount;
194 double oneLamda;
197 // Mstep for T
198 for ( IntToIntToDoubleIt = m_softCountOfT.begin(); IntToIntToDoubleIt != m_softCountOfT.end(); IntToIntToDoubleIt++)
200 language1Id = IntToIntToDoubleIt.key();
201 oneLamda = m_lamdaT[language1Id] ;
202 oneList = IntToIntToDoubleIt.data();
205 for (IntToDoubleIt = oneList ->begin(); IntToDoubleIt != oneList ->end(); IntToDoubleIt++)
207 language2Id = IntToDoubleIt.key();
208 oneTotalSoftCount = IntToDoubleIt.data();
209 (*(m_T[language1Id]))[language2Id] = oneTotalSoftCount / oneLamda ;
213 // Mstep for A
214 for ( IntToIntToDoubleIt = m_softCountOfA.begin(); IntToIntToDoubleIt != m_softCountOfA.end(); IntToIntToDoubleIt++)
216 language2chunkId = IntToIntToDoubleIt.key();
217 oneLamda = m_lamdaA[language2chunkId] ;
218 oneList = IntToIntToDoubleIt.data();
220 for (IntToDoubleIt = oneList ->begin(); IntToDoubleIt != oneList ->end(); IntToDoubleIt++)
222 language1chunkId = IntToDoubleIt.key();
223 oneTotalSoftCount = IntToDoubleIt.data();
224 (*(m_A[language2chunkId]))[language1chunkId] = oneTotalSoftCount / oneLamda ;
230 void cMTModel2Norm::addSoftCountOfTandA(int sentenceId)
232 mTVolca* myVolca;
233 double oneT;
234 double oneA;
235 double oneCount;
236 int language1SentenceLen;
237 int language2SentenceLen;
238 int language1ChunkId;
239 int language2ChunkId;
240 double deNumerator;
241 int l,m;
242 int language1WordId;
243 int language2WordId;
244 IntToInt* oneLan1Sentence;
245 IntToInt* oneLan2Sentence;
248 // add softcount of T
249 myVolca = m_myMT ->m_Volca;
251 oneLan1Sentence = myVolca ->m_language1Sentences[sentenceId];
252 oneLan2Sentence = myVolca ->m_language2Sentences[sentenceId];
254 language1SentenceLen = oneLan1Sentence ->size();
255 language2SentenceLen = oneLan2Sentence ->size();
257 for (m = 0; m < oneLan2Sentence->size(); ++m) {
258 language2WordId = (*oneLan2Sentence)[m];
259 language2ChunkId = static_cast<int>(
260 double(m) / language2SentenceLen * m_sentenceNormChunk);
262 Q_ASSERT(language2ChunkId >=0 && language2ChunkId <m_sentenceNormChunk);
265 deNumerator =0;
267 for (l = 0; l < oneLan1Sentence->size(); ++l) {
268 language1WordId = (*oneLan1Sentence)[l];
269 language1ChunkId = static_cast<int>(
270 double(l) / language1SentenceLen *
271 m_sentenceNormChunk);
273 Q_ASSERT(language1ChunkId >=0 && language1ChunkId <m_sentenceNormChunk);
275 oneT = (*(m_T[language1WordId]))[language2WordId];
276 oneA = (*(m_A[language2ChunkId]))[language1ChunkId];
277 deNumerator += oneT*oneA;
280 for (l = 0; l < oneLan1Sentence->size(); ++l) {
281 language1WordId = (*oneLan1Sentence)[l];
282 language1ChunkId = static_cast<int>(
283 double(l) / language1SentenceLen *
284 m_sentenceNormChunk);
286 oneA = (*(m_A[language2ChunkId]))[language1ChunkId];
287 oneT = (*(m_T[language1WordId]))[language2WordId] ;
288 oneCount = oneT*oneA/ deNumerator;
290 (*(m_softCountOfT[language1WordId]))[language2WordId] += oneCount;
292 if ( m_lamdaT.contains(language1WordId))
294 m_lamdaT[language1WordId] += oneCount;
296 else
298 m_lamdaT.insert(language1WordId, oneCount);
301 // add softcount of A
302 (*(m_softCountOfA[language2ChunkId]))[language1ChunkId] += oneCount;
304 if ( m_lamdaA.contains(language2ChunkId))
306 m_lamdaA[language2ChunkId] += oneCount;
308 else
310 m_lamdaA.insert(language2ChunkId, oneCount);
316 void cMTModel2Norm::releaseSoftCounts()
318 IntToIntToDouble::iterator IntToIntToDoubleIt;
319 IntToDouble* oneList;
321 for ( IntToIntToDoubleIt = m_softCountOfT.begin(); IntToIntToDoubleIt != m_softCountOfT.end(); IntToIntToDoubleIt++)
323 oneList = IntToIntToDoubleIt.data();
325 delete oneList;
329 for ( IntToIntToDoubleIt = m_softCountOfA.begin(); IntToIntToDoubleIt != m_softCountOfA.end(); IntToIntToDoubleIt++)
331 oneList = IntToIntToDoubleIt.data();
333 delete oneList;
338 void cMTModel2Norm::viterbiAll()
340 mTVolca* myVolca;
341 int i;
343 myVolca = m_myMT ->m_Volca;
345 myVolca ->clearSentenceViterbiAlignment();
347 for ( i=0; i < myVolca ->m_countOfSentences; i++)
349 viterbi(i);
354 void cMTModel2Norm::viterbi(int sentenceId)
356 mTVolca* myVolca;
357 double oneT;
358 double oneA;
359 double oneCount;
360 double bestCount;
361 int bestAlignmentId;
362 int language1SentenceLen;
363 int language2SentenceLen;
364 int language1ChunkId;
365 int language2ChunkId;
366 int l,m;
367 int language1WordId;
368 int language2WordId;
369 IntToInt* oneLan1Sentence;
370 IntToInt* oneLan2Sentence;
371 IntToInt* oneAlignment;
374 // add softcount of T
375 myVolca = m_myMT ->m_Volca;
377 oneAlignment = new IntToInt();
378 myVolca ->m_sentenceAlignments.insert(sentenceId, oneAlignment);
380 oneLan1Sentence = myVolca ->m_language1Sentences[sentenceId];
381 oneLan2Sentence = myVolca ->m_language2Sentences[sentenceId];
383 language1SentenceLen = oneLan1Sentence ->size();
384 language2SentenceLen = oneLan2Sentence ->size();
386 for (m = 0; m < language2SentenceLen; ++m) {
387 language2WordId = (*oneLan2Sentence)[m];
388 language2ChunkId = static_cast<int>(
389 double(m) / language2SentenceLen *
390 m_sentenceNormChunk);
392 Q_ASSERT(language2ChunkId >=0 && language2ChunkId <m_sentenceNormChunk);
394 bestAlignmentId = -1;
395 bestCount = 0.0;
397 for (l = 0; l < language1SentenceLen; ++l) {
398 language1WordId = (*oneLan1Sentence)[l];
399 language1ChunkId = static_cast<int>(
400 double(l) / language1SentenceLen *
401 m_sentenceNormChunk);
403 oneA = (*(m_A[language2ChunkId]))[language1ChunkId];
404 oneT = (*(m_T[language1WordId]))[language2WordId] ;
405 oneCount = oneT*oneA;
407 if ( oneCount > bestCount)
409 bestAlignmentId = l;
410 bestCount = oneCount;
414 oneAlignment->insert(m, bestAlignmentId);