Fix smoketest.lua to work with newer lua
[xapian.git] / xapian-letor / letor_internal.cc
blob1de63398dcc5ab7141442c9bfa18200265f925ae
1 /** @file letor_internal.cc
2 * @brief Internals of Xapian::Letor class
3 */
4 /* Copyright (C) 2011 Parth Gupta
6 * This program is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU General Public License as
8 * published by the Free Software Foundation; either version 2 of the
9 * License, or (at your option) any later version.
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
19 * USA
22 #include <config.h>
24 #include <xapian/letor.h>
26 #include <xapian.h>
28 #include "letor_internal.h"
29 #include "featuremanager.h"
30 #include "str.h"
31 #include "stringutils.h"
32 #include "ranker.h"
33 #include "svmranker.h"
35 #include <cstdio>
36 #include <cstdlib>
37 #include <cstring>
38 #include "safeerrno.h"
39 #include "safeunistd.h"
41 #include <algorithm>
42 #include <list>
43 #include <iostream>
44 #include <fstream>
45 #include <sstream>
46 #include <string>
47 #include <map>
48 #include <math.h>
50 #include <libsvm/svm.h>
51 #define Malloc(type, n) (type *)malloc((n) * sizeof(type))
53 using namespace std;
55 using namespace Xapian;
57 struct svm_parameter param;
58 struct svm_problem prob;
59 struct svm_model *model;
60 struct svm_node *x_space;
61 int cross_validation;
62 int nr_fold;
66 struct svm_node *x;
67 int max_nr_attr = 64;
69 int predict_probability = 0;
71 static char *line = NULL;
72 static int max_line_len;
74 int MAXPATHLEN = 200;
77 //Stop-words
78 static const char * sw[] = {
79 "a", "about", "an", "and", "are", "as", "at",
80 "be", "by",
81 "en",
82 "for", "from",
83 "how",
84 "i", "in", "is", "it",
85 "of", "on", "or",
86 "that", "the", "this", "to",
87 "was", "what", "when", "where", "which", "who", "why", "will", "with"
93 static void exit_input_error(int line_num) {
94 printf("Error at Line : %d", line_num);
95 exit(1);
98 static string convertDouble(double value) {
99 std::ostringstream o;
100 if (!(o << value))
101 return string();
102 return o.str();
105 static string get_cwd() {
106 char temp[MAXPATHLEN];
107 return (getcwd(temp, MAXPATHLEN) ? std::string(temp) : std::string());
112 /* This method will calculate the score assigned by the Letor function.
113 * It will take MSet as input then convert the documents in feature vectors
114 * then normalize them according to QueryLevelNorm
115 * and after that use the machine learned model file
116 * to assign a score to the document
118 map<Xapian::docid, double>
119 Letor::Internal::letor_score(const Xapian::MSet & mset) {
121 map<Xapian::docid, double> letor_mset;
123 Xapian::FeatureManager fm;
124 fm.set_database(letor_db);
125 fm.set_query(letor_query);
127 std::string s="1";
128 Xapian::RankList rl = fm.create_rank_list(mset, s);
129 std::list<double> scores =ranker.rank(rl);
131 /*code to convert list<double> scores to map<docid,double>*/
133 return letor_mset;
136 static char* readline(FILE *input) {
137 int len;
139 if (fgets(line, max_line_len, input) == NULL)
140 return NULL;
142 while (strrchr(line, '\n') == NULL) {
143 max_line_len *= 2;
144 line = (char *)realloc(line, max_line_len);
145 len = (int)strlen(line);
146 if (fgets(line + len, max_line_len - len, input) == NULL)
147 break;
149 return line;
152 static void read_problem(const char *filename) {
153 int elements, max_index, inst_max_index, i, j;
154 FILE *fp = fopen(filename, "r");
155 char *endptr;
156 char *idx, *val, *label;
158 if (fp == NULL) {
159 fprintf(stderr, "can't open input file %s\n", filename);
160 exit(1);
163 prob.l = 0;
164 elements = 0;
166 max_line_len = 1024;
167 line = Malloc(char, max_line_len);
169 while (readline(fp) != NULL) {
170 char *p = strtok(line, " \t"); // label
172 // features
173 while (1) {
174 p = strtok(NULL, " \t");
175 if (p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
176 break;
177 ++elements;
179 ++elements;
180 ++prob.l;
182 rewind(fp);
184 prob.y = Malloc(double, prob.l);
185 prob.x = Malloc(struct svm_node *, prob.l);
186 x_space = Malloc(struct svm_node, elements);
188 max_index = 0;
189 j = 0;
191 for (i = 0; i < prob.l; ++i) {
192 inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
193 readline(fp);
194 prob.x[i] = &x_space[j];
195 label = strtok(line, " \t\n");
196 if (label == NULL) // empty line
197 exit_input_error(i + 1);
198 prob.y[i] = strtod(label, &endptr);
199 if (endptr == label || *endptr != '\0')
200 exit_input_error(i + 1);
202 while (1) {
203 idx = strtok(NULL, ":");
204 val = strtok(NULL, " \t");
206 if (val == NULL)
207 break;
209 errno = 0;
210 x_space[j].index = (int)strtol(idx, &endptr, 10);
212 if (endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
213 exit_input_error(i + 1);
214 else
215 inst_max_index = x_space[j].index;
217 errno = 0;
218 x_space[j].value = strtod(val, &endptr);
220 if (endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
221 exit_input_error(i + 1);
223 ++j;
226 if (inst_max_index > max_index)
227 max_index = inst_max_index;
228 x_space[j++].index = -1;
231 if (param.gamma == 0 && max_index > 0)
232 param.gamma = 1.0 / max_index;
234 if (param.kernel_type == PRECOMPUTED)
235 for (i = 0; i < prob.l; ++i) {
236 if (prob.x[i][0].index != 0) {
237 fprintf(stderr, "Wrong input format: first column must be 0:sample_serial_number\n");
238 exit(1);
240 if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index) {
241 fprintf(stderr, "Wrong input format: sample_serial_number out of range\n");
242 exit(1);
245 fclose(fp);
248 void
249 Letor::Internal::letor_learn_model(int s_type, int k_type) {
250 // default values
251 param.svm_type = s_type;
252 param.kernel_type = k_type;
253 param.degree = 3;
254 param.gamma = 0; // 1/num_features
255 param.coef0 = 0;
256 param.nu = 0.5;
257 param.cache_size = 100;
258 param.C = 1;
259 param.eps = 1e-3;
260 param.p = 0.1;
261 param.shrinking = 1;
262 param.probability = 0;
263 param.nr_weight = 0;
264 param.weight_label = NULL;
265 param.weight = NULL;
266 cross_validation = 0;
268 printf("Learning the model..");
269 string input_file_name;
270 string model_file_name;
271 const char *error_msg;
273 input_file_name = get_cwd().append("/train.txt");
274 model_file_name = get_cwd().append("/model.txt");
276 read_problem(input_file_name.c_str());
277 error_msg = svm_check_parameter(&prob, &param);
278 if (error_msg) {
279 fprintf(stderr, "svm_check_parameter failed: %s\n", error_msg);
280 exit(1);
283 model = svm_train(&prob, &param);
284 if (svm_save_model(model_file_name.c_str(), model)) {
285 fprintf(stderr, "can't save model to file %s\n", model_file_name.c_str());
286 exit(1);
291 /* This is the method which prepares the train.txt file of the standard Letor Format.
292 * param 1 = Xapian Database
293 * param 2 = path to the query file
294 * param 3 = path to the qrel file
296 * output : It produces the train.txt method in /core/examples/ which is taken as input by learn_model() method
299 static void
300 write_to_file(std::list<Xapian::RankList> l) {
301 ofstream train_file;
302 train_file.open("train.txt");
303 // write it down with proper format
304 for (list<Xapian::RankList>::iterator it = l.begin(); it != l.end(); it++);
308 void
309 Letor::Internal::prepare_training_file(const string & queryfile, const string & qrel_file, Xapian::doccount msetsize) {
311 // ofstream train_file;
312 // train_file.open("train.txt");
314 Xapian::SimpleStopper mystopper(sw, sw + sizeof(sw) / sizeof(sw[0]));
315 Xapian::Stem stemmer("english");
317 Xapian::QueryParser parser;
318 parser.add_prefix("title", "S");
319 parser.add_prefix("subject", "S");
321 parser.set_database(letor_db);
322 parser.set_default_op(Xapian::Query::OP_OR);
323 parser.set_stemmer(stemmer);
324 parser.set_stemming_strategy(Xapian::QueryParser::STEM_SOME);
325 parser.set_stopper(&mystopper);
327 /* ---------------------------- store whole qrel file in a Map<> ---------------------*/
329 // typedef map<string, int> Map1; //docid and relevance judjement 0/1
330 // typedef map<string, Map1> Map2; // qid and map1
331 // Map2 qrel;
333 map<string, map<string, int> > qrel; // 1
335 Xapian::FeatureManager fm;
336 fm.set_database(letor_db);
337 fm.load_relevance(qrel_file);
338 qrel = fm.load_relevance(qrel_file);
340 list<Xapian::RankList> l;
342 string str1;
343 ifstream myfile1;
344 myfile1.open(queryfile.c_str(), ios::in);
347 while (!myfile1.eof()) { //reading all the queries line by line from the query file
349 getline(myfile1, str1);
350 if (str1.empty()) {
351 break;
354 string qid = str1.substr(0, (int)str1.find(" "));
355 string querystr = str1.substr((int)str1.find("'")+1, (str1.length() - ((int)str1.find("'") + 2)));
357 string qq = querystr;
358 istringstream iss(querystr);
359 string title = "title:";
360 while (iss) {
361 string t;
362 iss >> t;
363 if (t.empty())
364 break;
365 string temp;
366 temp.append(title);
367 temp.append(t);
368 temp.append(" ");
369 temp.append(qq);
370 qq = temp;
373 cout << "Processing Query: " << qq << "\n";
375 Xapian::Query query = parser.parse_query(qq,
376 parser.FLAG_DEFAULT|
377 parser.FLAG_SPELLING_CORRECTION);
379 Xapian::Enquire enquire(letor_db);
380 enquire.set_query(query);
382 Xapian::MSet mset = enquire.get_mset(0, msetsize);
384 fm.set_query(query);
386 Xapian::RankList rl = fm.create_rank_list(mset, qid);
387 l.push_back(rl);
388 }//while closed
389 myfile1.close();
390 write_to_file(l);
391 // train_file.close();