Add another
[apertium.git] / apertium-tagger-training-tools / src / apertium-tagger-supervised.C
blob8875be925fc6754a18fd608634baeec6aae284b9
1 /*
2  * Copyright (C) 2004-2006 Felipe Sánchez-Martínez
3  * Copyright (C) 2006 Universitat d'Alacant
4  *
5  * This program is free software; you can redistribute it and/or
6  * modify it under the terms of the GNU General Public License as
7  * published by the Free Software Foundation; either version 2 of the
8  * License, or (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program; if not, write to the Free Software
17  * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
18  * 02111-1307, USA.
19  */
21 #include <fstream>
22 #include <iostream>
23 #include <string>
24 #include <cstring>
25 #include <cstdlib>
26 #include <float.h>
27 #include <sys/types.h>
28 #include <unistd.h>
29 #include <getopt.h>
30 #include <clocale>
32 #include <apertium/HMM.H>
33 #include <apertium/MorphoStream.H>
34 #include <apertium/TTag.H>
35 #include <apertium/TaggerWord.H>
36 #include <apertium/Collection.H>
37 #include <apertium/TaggerData.H>
38 #include <apertium/TSXReader.H>
39 #include "configure.H"
40 #include "SmoothUtils.H"
42 #define ZERO 1e-10
45 using namespace std;
48 //Global vars
49 TaggerData tagger_data;
50 TTag eos; //End-of-sentence tag
52 void check_file(FILE *f, const string& path) {
53   if (!f) {
54     cerr<<"Error: cannot open file '"<<path<<"'\n";
55     exit(EXIT_FAILURE);
56   }
59 void 
60 supervised(FILE *ftagged, FILE *funtagged) {
61   int i, j, k, nw=0;
63   map<int, map<int, double> > tags_pair; //NxN
64   map<int, map<int, double> > emis;     //NxM
65   map<int, double> ambclass_count;      //M
66   map<int, double> tags_count;          //N
67   map<int, double> tags_count_for_emis; //N
68   
69   MorphoStream stream_tagged(ftagged, true, &tagger_data);
70   MorphoStream stream_untagged(funtagged, true, &tagger_data);
71   
72   TaggerWord *word_tagged=NULL, *word_untagged=NULL;
73   
74   set<TTag> tags;
76   // Init counters
77   for(i=0; i<tagger_data.getN(); i++) {
78     tags_count[i]=0;
79     tags_count_for_emis[i]=0;
80     for(j=0; j<tagger_data.getN(); j++)
81       tags_pair[i][j]=0;
82   }
83   for(k=0; k<tagger_data.getM(); k++) {
84     ambclass_count[k]=0;
85     for(i=0; i<tagger_data.getN(); i++) {
86       if (tagger_data.getOutput()[k].find(i)!=tagger_data.getOutput()[k].end())
87         emis[i][k] = 0;
88     }  
89   }
91   TTag tag1, tag2;  
92   tag1 = eos; // The first seen tag is the end-of-sentence tag
93   
94   word_tagged = stream_tagged.get_next_word();
95   word_untagged = stream_untagged.get_next_word();
96   while(word_tagged) {
97     cerr<<*word_tagged<<" -- "<<*word_untagged<<"\n"; 
99     if (word_tagged->get_superficial_form()!=word_untagged->get_superficial_form()) {              
100       cerr<<"\nTagged text (.tagged) and analyzed text (.untagged) streams are not aligned.\n";
101       cerr<<"Take a look at tagged text (.tagged).\n";
102       cerr<<"Perhaps this is caused by a multiword unit that is not a multiword unit in one of the two files.\n";
103       cerr<<*word_tagged<<" -- "<<*word_untagged<<"\n"; 
104       exit(1);
105     }
107     if (++nw%100==0) cerr<<'.'<<flush; 
108     
109     tag2 = tag1;
110    
111     if (word_untagged==NULL) {
112       cerr<<"word_untagged==NULL\n";
113       exit(1);
114     }
116     if (word_tagged->get_tags().size()==0) // Unknown word
117       tag1 = -1;
118     else if (word_tagged->get_tags().size()>1) // Ambiguous word
119       cerr<<"Error in tagged text. An ambiguous word was found: "<<word_tagged->get_superficial_form()<<"\n";
120     else
121       tag1 = *(word_tagged->get_tags()).begin();
124     if ((tag1>=0) && (tag2>=0)) {
125       tags_pair[tag2][tag1]++;
126       tags_count[tag2]++;
127     }
129     if (word_untagged->get_tags().size()==0) { // Unknown word
130       tags = tagger_data.getOpenClass();
131     } else if (tagger_data.getOutput().has_not(word_untagged->get_tags())) { //We are training, there is no problem
132       string errors;
133       errors = "A new ambiguity class was found. I cannot continue.\n";
134       errors+= "Word '"+word_untagged->get_superficial_form()+"' not found in the dictionary.\n";
135       errors+= "New ambiguity class: "+word_untagged->get_string_tags()+"\n";
136       errors+= "Take a look at the dictionary, then retrain.";
137       fatal_error(errors);      
138     } else {
139       tags = word_untagged->get_tags();
140     }
142     k=tagger_data.getOutput()[tags];
143     if(tag1>=0) {
144       if (tagger_data.getOutput()[k].find(tag1)!=tagger_data.getOutput()[k].end()) {     
145         emis[tag1][k]++;
146         tags_count_for_emis[tag1]++;
147         ambclass_count[k]++;
148       } else {
149         cerr<<"Warning: Ambiguity class "<<k<<" is emmited from tag "<<tag1<<" but it should not\n";
150       }
151     }
152   
153     delete word_tagged;
154     word_tagged=stream_tagged.get_next_word();
155     delete word_untagged;
156     word_untagged=stream_untagged.get_next_word();       
157   }
158   
159   SmoothUtils::calculate_smoothed_parameters(tagger_data, tags_count, tags_pair, ambclass_count, emis, tags_count_for_emis, nw);  
160   
161   cerr<<"Number of words processed: "<<nw<<"\n";  
164 void apply_rules() {
165   bool found;
167   for(size_t i=0; i<tagger_data.getForbidRules().size(); i++) {
168     tagger_data.getA()[tagger_data.getForbidRules()[i].tagi][tagger_data.getForbidRules()[i].tagj] = ZERO;
169   }
171   for(size_t i=0; i<tagger_data.getEnforceRules().size(); i++) {
172     for(int j=0; j<tagger_data.getN(); j++) {
173       found = false;
174       for (size_t j2=0; j2<tagger_data.getEnforceRules()[i].tagsj.size(); j2++) {
175         if (tagger_data.getEnforceRules()[i].tagsj[j2]==j) {
176           found = true;
177           break;
178         }
179       }
180       if (!found) {
181         tagger_data.getA()[tagger_data.getEnforceRules()[i].tagi][j] = ZERO;
182       }
183     }
184   }
186   // Normalize probabilities
187   for(int i=0; i<tagger_data.getN(); i++) {
188     double sum=0;
189     for(int j=0; j<tagger_data.getN(); j++)
190       sum += tagger_data.getA()[i][j];
191     for(int j=0; j<tagger_data.getN(); j++) {
192       if(sum>0)
193         tagger_data.getA()[i][j] = tagger_data.getA()[i][j]/sum;
194       else
195         tagger_data.getA()[i][j] = 0;
196     }
197   }
200 void help(char *name) {
201   cerr<<"USAGE:\n";
202   cerr<<name<<" --tsxfile tsxfile --dicfile file.dic --tagged file.tagger --untagged file.untagged --outfile fileout.prob [--norules]\n\n";
204   cerr<<"ARGUMENTS: \n"
205       <<"   --tsxfile|-x: To provide the tagger specification file in XML\n"
206       <<"   --dicfile|-d: To specify the expanded dictionary that will be used to extract the\n"
207       <<"            ambiguity classes\n"
208       <<"   --tagged|-t: To specify the file with the tagged corpus to be used for training\n"
209       <<"   --untagged|-u: To specify the file with the untagged corpus to be used for training\n"
210       <<"   --outfile|-o: To specify the file in which the new parameters will be stored\n"
211       <<"   --norules|-n: Do not use forbid and enforce rules\n";
215 int main(int argc, char* argv[]) {
216   string tsxfile="";
217   string filedic="";
218   string filetagged="";
219   string fileuntagged="";
220   string fileout="";
222   bool use_forbid_enforce_rules=true;
224   int c;
225   int option_index=0;
227   cerr<<"LOCALE: "<<setlocale(LC_ALL,"")<<"\n";
229   cerr<<"Command line: ";
230   for(int i=0; i<argc; i++)
231     cerr<<argv[i]<<" ";
232   cerr<<"\n";
234   while (true) {
235     static struct option long_options[] =
236       {
237         {"tsxfile",  required_argument, 0, 'x'},
238         {"dicfile",  required_argument, 0, 'd'},
239         {"tagged",   required_argument, 0, 't'},
240         {"untagged", required_argument, 0, 'u'},
241         {"outfile",  required_argument, 0, 'o'},
242         {"norules",    no_argument,     0, 'n'},
243         {"help",       no_argument,     0, 'h'},
244         {"version",    no_argument,     0, 'v'},
245         {0, 0, 0, 0}
246       };
248     c=getopt_long(argc, argv, "x:d:t:u:o:nhv",long_options, &option_index);
249     if (c==-1)
250       break;
251       
252     switch (c) {
253     case 'x': 
254       tsxfile=optarg;
255       break;
256     case 'd': 
257       filedic=optarg;
258       break;
259     case 't': 
260       filetagged=optarg;
261       break;
262     case 'u': 
263       fileuntagged=optarg;
264       break;
265     case 'o': 
266       fileout=optarg;
267       break;
268     case 'n': 
269       use_forbid_enforce_rules=false;
270       break;
271     case 'h': 
272       help(argv[0]);
273       exit(EXIT_SUCCESS);
274       break;
275     case 'v':
276       cerr<<PACKAGE_STRING<<"\n";
277       cerr<<"LICENSE:\n\n"
278           <<"   Copyright (C) 2006 Felipe Sánchez Martínez\n\n"
279           <<"   This program is free software; you can redistribute it and/or\n"
280           <<"   modify it under the terms of the GNU General Public License as\n"
281           <<"   published by the Free Software Foundation; either version 2 of the\n"
282           <<"   License, or (at your option) any later version.\n"
283           <<"   This program is distributed in the hope that it will be useful, but\n"
284           <<"   WITHOUT ANY WARRANTY; without even the implied warranty of\n"
285           <<"   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU\n"
286           <<"   General Public License for more details.\n"
287           <<"\n"
288           <<"   You should have received a copy of the GNU General Public License\n"
289           <<"   along with this program; if not, write to the Free Software\n"
290           <<"   Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA\n"
291           <<"   02111-1307, USA.\n";
292       exit(EXIT_SUCCESS);
293       break;    
294     default:
295       help(argv[0]);
296       exit(EXIT_FAILURE);
297       break;
298     }
299   }
301   if (tsxfile=="") {
302     cerr<<"Error: You did not provide a tagger specification file (.tsx). Use --tsxfile to do that\n";
303     help(argv[0]);
304     exit(EXIT_FAILURE);
305   }
307   if (filedic=="") {
308     cerr<<"Error: You did not provide an expanded dictionary file. Use --dicfile to do that\n";
309     help(argv[0]);
310     exit(EXIT_FAILURE);
311   }
313   if (filetagged=="") {
314     cerr<<"Error: You did not provide a tagged corpus. Use --tagged to do that\n";
315     help(argv[0]);
316     exit(EXIT_FAILURE);
317   }
319   if (fileuntagged=="") {
320     cerr<<"Error: You did not provide a untagged corpus. Use --untagged to do that\n";
321     help(argv[0]);
322     exit(EXIT_FAILURE);
323   }
325   if(fileout=="") {
326     cerr<<"Error: You did not provide an output file for the tagger parameters. Use --outfile to do that\n";
327     help(argv[0]);
328     exit(EXIT_FAILURE);
329   }
333   TSXReader treader;
334   treader.read(tsxfile);
335   tagger_data=treader.getTaggerData();
337   FILE *fdic, *ftagged, *funtagged, *fout;
339   HMM hmm(&tagger_data);
340   hmm.set_debug(true);
341   hmm.set_eos((tagger_data.getTagIndex())["TAG_SENT"]);
343   TaggerWord::setArrayTags(tagger_data.getArrayTags());
344   eos=(tagger_data.getTagIndex())["TAG_SENT"];
346   fdic=fopen(filedic.c_str(), "r");
347   check_file(fdic, filedic);
348   ftagged=fopen(filetagged.c_str(), "r");
349   check_file(ftagged, filetagged);
350   funtagged=fopen(fileuntagged.c_str(), "r");
351   check_file(funtagged, fileuntagged);
353   cerr<<"Calculating ambiguity classes ... "<<flush;
354   hmm.read_dictionary(fdic);
355   cerr<<"done.\n";
356   fclose(fdic);
358   supervised(ftagged, funtagged);
359   fclose(ftagged);
360   fclose(funtagged);
362   if (use_forbid_enforce_rules) {
363     cerr<<"Applying forbid and enforce rules ... "<<flush;
364     apply_rules();
365     cerr<<"done.\n";
366   }
368   fout=fopen(fileout.c_str(), "w");
369   check_file(fout, fileout);
370   cerr<<"Writing apertium-tagger parameters to file '"<<fileout<<"' ... "<<flush;
371   tagger_data.write(fout);
372   cerr<<"done.\n";
373   fclose(fout);