Improve @param docs for Database::compact()
[xapian.git] / xapian-letor / letor_internal_refactored.cc
blobba2ca0fd19a50734400973d954e1c58aaf9bad85
1 /** @file letor_internal.cc
2 * @brief Internals of Xapian::Letor class
3 */
4 /* Copyright (C) 2011 Parth Gupta
6 * This program is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU General Public License as
8 * published by the Free Software Foundation; either version 2 of the
9 * License, or (at your option) any later version.
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
19 * USA
22 #include <config.h>
24 #include <xapian/letor.h>
26 #include <xapian.h>
28 #include "letor_internal_refactored.h"
29 #include "featurevector.h"
30 #include "str.h"
31 #include "stringutils.h"
33 #include <cstdio>
34 #include <cstdlib>
35 #include <cstring>
36 #include "safeerrno.h"
37 #include "safeunistd.h"
39 #include <algorithm>
40 #include <list>
41 #include <iostream>
42 #include <fstream>
43 #include <sstream>
44 #include <string>
45 #include <map>
46 #include <math.h>
48 #include <libsvm/svm.h>
49 #define Malloc(type, n) (type *)malloc((n) * sizeof(type))
51 using namespace std;
53 using namespace Xapian;
55 struct svm_parameter param;
56 struct svm_problem prob;
57 struct svm_model *model;
58 struct svm_node *x_space;
59 int cross_validation;
60 int nr_fold;
62 struct svm_node *x;
63 int max_nr_attr = 64;
65 int predict_probability = 0;
67 static char *line = NULL;
68 static int max_line_len;
70 int MAXPATHLEN = 200;
73 //Stop-words
74 static const char * sw[] = {
75 "a", "about", "an", "and", "are", "as", "at",
76 "be", "by",
77 "en",
78 "for", "from",
79 "how",
80 "i", "in", "is", "it",
81 "of", "on", "or",
82 "that", "the", "this", "to",
83 "was", "what", "when", "where", "which", "who", "why", "will", "with"
89 static void exit_input_error(int line_num) {
90 printf("Error at Line : %d", line_num);
91 exit(1);
94 static string convertDouble(double value) {
95 std::ostringstream o;
96 if (!(o << value))
97 return string();
98 return o.str();
101 static string get_cwd() {
102 char temp[MAXPATHLEN];
103 return (getcwd(temp, MAXPATHLEN) ? std::string(temp) : std::string());
107 /* This method will calculate the score assigned by the Letor function.
108 * It will take MSet as input then convert the documents in feature vectors
109 * then normalize them according to QueryLevelNorm
110 * and after that use the machine learned model file
111 * to assign a score to the document
113 map<Xapian::docid, double>
114 Letor::Internal::letor_score(const Xapian::MSet & mset) {
116 map<Xapian::docid, double> letor_mset;
118 map<string, long int> coll_len;
119 coll_len = collection_length(letor_db);
121 map<string, long int> coll_tf;
122 coll_tf = collection_termfreq(letor_db, letor_query);
124 map<string, double> idf;
125 idf = inverse_doc_freq(letor_db, letor_query);
127 int first = 1; //used as a flag in QueryLevelNorm module
129 typedef list<double> List1; //the values of a particular feature for MSet documents will be stored in the list
130 typedef map<int, List1> Map3; //the above list will be mapped to an integer with its feature id.
132 /* So the whole structure will look like below if there are 5 documents in MSet and 3 features to be calculated
134 * 1 -> 32.12 - 23.12 - 43.23 - 12.12 - 65.23
135 * 2 -> 31.23 - 21.43 - 33.99 - 65.23 - 22.22
136 * 3 -> 1.21 - 3.12 - 2.23 - 6.32 - 4.23
138 * And after that we divide the whole list by the maximum value for that feature in all the 5 documents
139 * So we divide the values of Feature 1 in above case by 65.23 and hence all the values of that features for that query
140 * will belongs to [0,1] and is known as Query level Norm
143 Map3 norm;
145 map<int, list<double> >::iterator norm_outer;
146 list<double>::iterator norm_inner;
148 typedef list<string> List2;
149 List2 doc_ids;
151 for (Xapian::MSetIterator i = mset.begin(); i != mset.end(); ++i) {
152 Xapian::Document doc = i.get_document();
154 map<string, long int> tf;
155 tf = termfreq(doc, letor_query);
157 map<string, long int> doclen;
158 doclen = doc_length(letor_db, doc);
160 double f[20];
162 f[1] = calculate_f1(letor_query, tf, 't'); //storing the feature values from array index 1 to sync it with feature number.
163 f[2] = calculate_f1(letor_query, tf, 'b');
164 f[3] = calculate_f1(letor_query, tf, 'w');
166 f[4] = calculate_f2(letor_query, tf, doclen, 't');
167 f[5] = calculate_f2(letor_query, tf, doclen, 'b');
168 f[6] = calculate_f2(letor_query, tf, doclen, 'w');
170 f[7] = calculate_f3(letor_query, idf, 't');
171 f[8] = calculate_f3(letor_query, idf, 'b');
172 f[9] = calculate_f3(letor_query, idf, 'w');
174 f[10] = calculate_f4(letor_query, coll_tf, coll_len, 't');
175 f[11] = calculate_f4(letor_query, coll_tf, coll_len, 'b');
176 f[12] = calculate_f4(letor_query, coll_tf, coll_len, 'w');
178 f[13] = calculate_f5(letor_query, tf, idf, doclen, 't');
179 f[14] = calculate_f5(letor_query, tf, idf, doclen, 'b');
180 f[15] = calculate_f5(letor_query, tf, idf, doclen, 'w');
182 f[16] = calculate_f6(letor_query, tf, doclen, coll_tf, coll_len, 't');
183 f[17] = calculate_f6(letor_query, tf, doclen, coll_tf, coll_len, 'b');
184 f[18] = calculate_f6(letor_query, tf, doclen, coll_tf, coll_len, 'w');
186 f[19] = i.get_weight();
188 /* This module will make the data structure to store the whole features values for
189 * all the documents for a particular query along with its relevance judgements
192 if (first == 1) {
193 for (int j = 1; j < 20; ++j) {
194 List1 l;
195 l.push_back(f[j]);
196 norm.insert(pair<int, list<double> >(j, l));
198 first = 0;
199 } else {
200 norm_outer = norm.begin();
201 int k = 1;
202 for (; norm_outer != norm.end(); ++norm_outer) {
203 norm_outer->second.push_back(f[k]);
204 ++k;
207 }//for closed
209 /* this is the place where we have to normalize the norm and after that store it in the file. */
211 if (!norm.empty()) {
212 norm_outer = norm.begin();
213 ++norm_outer;
214 int k = 0;
215 for (; norm_outer != norm.end(); ++norm_outer) {
216 k = 0;
217 double max = norm_outer->second.front();
218 for (norm_inner = norm_outer->second.begin(); norm_inner != norm_outer->second.end(); ++norm_inner) {
219 if (*norm_inner > max)
220 max = *norm_inner;
222 for (norm_inner = norm_outer->second.begin(); norm_inner != norm_outer->second.end(); ++norm_inner) {
223 if (max != 0) // sometimes value for whole feature is 0 and hence it may cause 'divide-by-zero'
224 *norm_inner /= max;
225 ++k;
229 int xx = 0, j = 0;
230 Xapian::MSetIterator mset_iter = mset.begin();
231 Xapian::Document doc;
232 while (xx<k) {
233 doc = mset_iter.get_document();
235 string test_case = "0 ";
236 j = 0;
237 norm_outer = norm.begin();
238 ++j;
239 for (; norm_outer != norm.end(); ++norm_outer) {
240 test_case.append(str(j));
241 test_case.append(":");
242 test_case.append(convertDouble(norm_outer->second.front()));
243 test_case.append(" ");
244 norm_outer->second.pop_front();
245 ++j;
247 ++xx;
249 string model_file;
250 model_file = get_cwd();
251 model_file = model_file.append("/model.txt"); // will create "model.txt" in currect working directory
253 model = svm_load_model(model_file.c_str());
254 x = (struct svm_node *)malloc(max_nr_attr * sizeof(struct svm_node));
256 int total = 0;
258 int svm_type = svm_get_svm_type(model);
259 int nr_class = svm_get_nr_class(model);
261 if (predict_probability) {
262 if (svm_type == NU_SVR || svm_type == EPSILON_SVR) {
263 printf("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n" , svm_get_svr_probability(model));
264 } else {
265 int *labels = (int *) malloc(nr_class * sizeof(int));
266 svm_get_labels(model, labels);
267 free(labels);
271 max_line_len = 1024;
272 line = (char *)malloc(max_line_len*sizeof(char));
274 line = const_cast<char *>(test_case.c_str());
276 int i = 0;
277 double predict_label;
278 char *idx, *val, *label, *endptr;
279 int inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
281 label = strtok(line, " \t\n");
282 if (label == NULL) // empty line
283 exit_input_error(total + 1);
285 if (strtod(label, &endptr)) {
286 // Ignore the result (I guess we're just syntax checking the file?)
288 if (endptr == label || *endptr != '\0')
289 exit_input_error(total + 1);
291 while (1) {
292 if (i >= max_nr_attr - 1) { // need one more for index = -1
293 max_nr_attr *= 2;
294 x = (struct svm_node *)realloc(x, max_nr_attr * sizeof(struct svm_node));
297 idx = strtok(NULL, ":");
298 val = strtok(NULL, " \t");
300 if (val == NULL)
301 break;
302 errno = 0;
303 x[i].index = (int)strtol(idx, &endptr, 10);
305 if (endptr == idx || errno != 0 || *endptr != '\0' || x[i].index <= inst_max_index)
306 exit_input_error(total + 1);
307 else
308 inst_max_index = x[i].index;
310 errno = 0;
311 x[i].value = strtod(val, &endptr);
312 if (endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
313 exit_input_error(total+1);
314 ++i;
317 x[i].index = -1;
319 predict_label = svm_predict(model, x); //this is the score for a particular document
321 letor_mset[doc.get_docid()] = predict_label;
323 ++mset_iter;
324 }//while closed
325 }//if closed
327 return letor_mset;
330 static char* readline(FILE *input) {
331 int len;
333 if (fgets(line, max_line_len, input) == NULL)
334 return NULL;
336 while (strrchr(line, '\n') == NULL) {
337 max_line_len *= 2;
338 line = (char *)realloc(line, max_line_len);
339 len = (int)strlen(line);
340 if (fgets(line + len, max_line_len - len, input) == NULL)
341 break;
343 return line;
346 static void read_problem(const char *filename) {
347 int elements, max_index, inst_max_index, i, j;
348 FILE *fp = fopen(filename, "r");
349 char *endptr;
350 char *idx, *val, *label;
352 if (fp == NULL) {
353 fprintf(stderr, "can't open input file %s\n", filename);
354 exit(1);
357 prob.l = 0;
358 elements = 0;
360 max_line_len = 1024;
361 line = Malloc(char, max_line_len);
363 while (readline(fp) != NULL) {
364 char *p = strtok(line, " \t"); // label
366 // features
367 while (1) {
368 p = strtok(NULL, " \t");
369 if (p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
370 break;
371 ++elements;
373 ++elements;
374 ++prob.l;
376 rewind(fp);
378 prob.y = Malloc(double, prob.l);
379 prob.x = Malloc(struct svm_node *, prob.l);
380 x_space = Malloc(struct svm_node, elements);
382 max_index = 0;
383 j = 0;
385 for (i = 0; i < prob.l; ++i) {
386 inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
387 readline(fp);
388 prob.x[i] = &x_space[j];
389 label = strtok(line, " \t\n");
390 if (label == NULL) // empty line
391 exit_input_error(i + 1);
392 prob.y[i] = strtod(label, &endptr);
393 if (endptr == label || *endptr != '\0')
394 exit_input_error(i + 1);
396 while (1) {
397 idx = strtok(NULL, ":");
398 val = strtok(NULL, " \t");
400 if (val == NULL)
401 break;
403 errno = 0;
404 x_space[j].index = (int)strtol(idx, &endptr, 10);
406 if (endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
407 exit_input_error(i + 1);
408 else
409 inst_max_index = x_space[j].index;
411 errno = 0;
412 x_space[j].value = strtod(val, &endptr);
414 if (endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
415 exit_input_error(i + 1);
417 ++j;
420 if (inst_max_index > max_index)
421 max_index = inst_max_index;
422 x_space[j++].index = -1;
425 if (param.gamma == 0 && max_index > 0)
426 param.gamma = 1.0 / max_index;
428 if (param.kernel_type == PRECOMPUTED)
429 for (i = 0; i < prob.l; ++i) {
430 if (prob.x[i][0].index != 0) {
431 fprintf(stderr, "Wrong input format: first column must be 0:sample_serial_number\n");
432 exit(1);
434 if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index) {
435 fprintf(stderr, "Wrong input format: sample_serial_number out of range\n");
436 exit(1);
439 fclose(fp);
442 void
443 Letor::Internal::letor_learn_model(int s_type, int k_type) {
444 // default values
445 param.svm_type = s_type;
446 param.kernel_type = k_type;
447 param.degree = 3;
448 param.gamma = 0; // 1/num_features
449 param.coef0 = 0;
450 param.nu = 0.5;
451 param.cache_size = 100;
452 param.C = 1;
453 param.eps = 1e-3;
454 param.p = 0.1;
455 param.shrinking = 1;
456 param.probability = 0;
457 param.nr_weight = 0;
458 param.weight_label = NULL;
459 param.weight = NULL;
460 cross_validation = 0;
462 printf("Learning the model..");
463 string input_file_name;
464 string model_file_name;
465 const char *error_msg;
467 input_file_name = get_cwd().append("/train.txt");
468 model_file_name = get_cwd().append("/model.txt");
470 read_problem(input_file_name.c_str());
471 error_msg = svm_check_parameter(&prob, &param);
472 if (error_msg) {
473 fprintf(stderr, "svm_check_parameter failed: %s\n", error_msg);
474 exit(1);
477 model = svm_train(&prob, &param);
478 if (svm_save_model(model_file_name.c_str(), model)) {
479 fprintf(stderr, "can't save model to file %s\n", model_file_name.c_str());
480 exit(1);
485 /* This is the method which prepares the train.txt file of the standard Letor Format.
486 * param 1 = Xapian Database
487 * param 2 = path to the query file
488 * param 3 = path to the qrel file
490 * output : It produces the train.txt method in /core/examples/ which is taken as input by learn_model() method
493 void
494 Letor::Internal::prepare_training_file(const string & queryfile, const string & qrel_file, Xapian::doccount msetsize) {
496 ofstream train_file;
497 train_file.open("train.txt");
499 Xapian::SimpleStopper mystopper(sw, sw + sizeof(sw) / sizeof(sw[0]));
500 Xapian::Stem stemmer("english");
502 Xapian::QueryParser parser;
503 parser.add_prefix("title", "S");
504 parser.add_prefix("subject", "S");
506 parser.set_database(letor_db);
507 parser.set_default_op(Xapian::Query::OP_OR);
508 parser.set_stemmer(stemmer);
509 parser.set_stemming_strategy(Xapian::QueryParser::STEM_SOME);
510 parser.set_stopper(&mystopper);
512 /* ---------------------------- store whole qrel file in a Map<> ---------------------*/
514 // typedef map<string, int> Map1; //docid and relevance judjement 0/1
515 // typedef map<string, Map1> Map2; // qid and map1
516 // Map2 qrel;
518 map<string, map<string, int> > qrel;
520 Xapian::FeatureVector fv;
521 fv.set_database(letor_db);
522 qrel = fv.load_relevance(qrel_file);
527 map<string, map<string, int> >::iterator outerit;
528 map<string, int>::iterator innerit;
530 //reading qrel in a map over.
532 map<string, long int> coll_len;
533 coll_len = collection_length(letor_db);
535 string str1;
536 ifstream myfile1;
537 myfile1.open(queryfile.c_str(), ios::in);
539 while (!myfile1.eof()) { //reading all the queries line by line from the query file
540 typedef list<double> List1; //the values of a particular feature for MSet documents will be stored in the list
541 typedef map<int, List1> Map3; //the above list will be mapped to an integer with its feature id.
543 /* So the whole structure will look like below if there are 5 documents in MSet and 3 features to be calculated
545 * 1 -> 32.12 - 23.12 - 43.23 - 12.12 - 65.23
546 * 2 -> 31.23 - 21.43 - 33.99 - 65.23 - 22.22
547 * 3 -> 1.21 - 3.12 - 2.23 - 6.32 - 4.23
549 * And after that we divide the whole list by the maximum value for that feature in all the 5 documents
550 * So we divide the values of Feature 1 in above case by 65.23 and hence all the values of that features for that query
551 * will belongs to [0,1] and is known as Query level Norm
553 Map3 norm;
555 map< int, list<double> >::iterator norm_outer;
556 list<double>::iterator norm_inner;
558 typedef list<string> List2;
559 List2 doc_ids;
561 getline(myfile1, str1);
562 if (str1.empty()) {
563 break;
566 string qid = str1.substr(0, (int)str1.find(" "));
567 string querystr = str1.substr((int)str1.find("'")+1, (str1.length() - ((int)str1.find("'") + 2)));
569 string qq = querystr;
570 istringstream iss(querystr);
571 string title = "title:";
572 while (iss) {
573 string t;
574 iss >> t;
575 if (t.empty())
576 break;
577 string temp;
578 temp.append(title);
579 temp.append(t);
580 temp.append(" ");
581 temp.append(qq);
582 qq = temp;
585 cout << "Processing Query: " << qq << "\n";
587 Xapian::Query query = parser.parse_query(qq,
588 parser.FLAG_DEFAULT|
589 parser.FLAG_SPELLING_CORRECTION);
591 Xapian::Enquire enquire(letor_db);
592 enquire.set_query(query);
594 Xapian::MSet mset = enquire.get_mset(0, msetsize);
596 Xapian::Letor ltr;
598 map<string, long int> coll_tf;
599 coll_tf = collection_termfreq(letor_db, query);
601 map<string, double> idf;
602 idf = inverse_doc_freq(letor_db, query);
604 int first = 1; //used as a flag in QueryLevelNorm and module
606 for (Xapian::MSetIterator i = mset.begin(); i != mset.end(); ++i) {
607 Xapian::Document doc = i.get_document();
609 map<string, long int> tf;
610 tf = termfreq(doc, query);
612 map<string, long int> doclen;
613 doclen = doc_length(letor_db, doc);
615 double f[20];
617 f[1] = calculate_f1(query, tf, 't');
618 f[2] = calculate_f1(query, tf, 'b');
619 f[3] = calculate_f1(query, tf, 'w');
621 f[4] = calculate_f2(query, tf, doclen, 't');
622 f[5] = calculate_f2(query, tf, doclen, 'b');
623 f[6] = calculate_f2(query, tf, doclen, 'w');
625 f[7] = calculate_f3(query, idf, 't');
626 f[8] = calculate_f3(query, idf, 'b');
627 f[9] = calculate_f3(query, idf, 'w');
629 f[10] = calculate_f4(query, coll_tf, coll_len, 't');
630 f[11] = calculate_f4(query, coll_tf, coll_len, 'b');
631 f[12] = calculate_f4(query, coll_tf, coll_len, 'w');
633 f[13] = calculate_f5(query, tf, idf, doclen, 't');
634 f[14] = calculate_f5(query, tf, idf, doclen, 'b');
635 f[15] = calculate_f5(query, tf, idf, doclen, 'w');
637 f[16] = calculate_f6(query, tf, doclen, coll_tf, coll_len, 't');
638 f[17] = calculate_f6(query, tf, doclen, coll_tf, coll_len, 'b');
639 f[18] = calculate_f6(query, tf, doclen, coll_tf, coll_len, 'w');
641 f[19] = i.get_weight();
643 string data = doc.get_data();
645 string temp_id = data.substr(data.find("url=", 0), (data.find("sample=", 0) - data.find("url=", 0)));
647 string id = temp_id.substr(temp_id.rfind('/') + 1, (temp_id.rfind('.') - temp_id.rfind('/') - 1)); //to parse the actual document name associated with the documents if any
650 outerit = qrel.find(qid);
651 if (outerit != qrel.end()) {
652 innerit = outerit->second.find(id);
653 if (innerit != outerit->second.end()) {
654 int q1 = innerit->second;
655 cout << q1 << " Qid:" << qid << " #docid:" << id << "\n";
657 /* This module will make the data structure to store the whole features values for
658 * all the documents for a particular query along with its relevance judgements
660 if (first == 1) {
661 List1 l;
662 l.push_back((double)q1);
663 norm.insert(pair<int, list<double> >(0, l));
664 doc_ids.push_back(id);
665 for (int j = 1; j < 20; ++j) {
666 List1 l1;
667 l1.push_back(f[j]);
668 norm.insert(pair<int, list<double> >(j, l1));
670 first = 0;
671 } else {
672 norm_outer = norm.begin();
673 norm_outer->second.push_back(q1);
674 ++norm_outer;
675 doc_ids.push_back(id);
676 int k = 1;
677 for (; norm_outer != norm.end(); ++norm_outer) {
678 norm_outer->second.push_back(f[k]);
679 ++k;
685 }//for closed
687 /* this is the place where we have to normalize the norm and after that store it in the file. */
689 if (norm.empty())
690 continue;
692 norm_outer = norm.begin();
693 ++norm_outer;
694 int k = 0;
695 for (; norm_outer != norm.end(); ++norm_outer) {
696 k = 0;
697 double max = norm_outer->second.front();
698 for (norm_inner = norm_outer->second.begin(); norm_inner != norm_outer->second.end(); ++norm_inner) {
699 if (*norm_inner > max)
700 max = *norm_inner;
702 for (norm_inner = norm_outer->second.begin(); norm_inner != norm_outer->second.end(); ++norm_inner) {
703 if (max != 0)
704 *norm_inner /= max;
705 ++k;
709 for (int i = 0; i < k; ++i) {
710 int j = 0;
711 norm_outer = norm.begin();
712 train_file << norm_outer->second.front();
713 norm_outer->second.pop_front();
714 ++norm_outer;
715 ++j;
716 //Uncomment the line below if you want 'Qid' in the training file
717 // train_file << " qid:" << qid;
718 for (; norm_outer != norm.end(); ++norm_outer) {
719 train_file << " " << j << ":" << norm_outer->second.front();
720 norm_outer->second.pop_front();
721 ++j;
723 //Uncomment the line below if you want 'DocID' in the training file
724 // train_file << " #docid:" << doc_ids.front();
725 train_file << "\n";
726 doc_ids.pop_front();
729 }//while closed
730 myfile1.close();
731 train_file.close();