Make PostList subclasses return PostList* not Internal*
[xapian.git] / xapian-letor / ranker / svmranker.cc
blob2e4490f07530522852507eb38c1c9b2e1721a4ce
1 /** @file svmranker.cc
2 * @brief Implementation of Ranking-SVM
3 */
4 /* Copyright (C) 2012 Parth Gupta
5 * Copyright (C) 2016 Ayush Tomar
7 * This program is free software; you can redistribute it and/or
8 * modify it under the terms of the GNU General Public License as
9 * published by the Free Software Foundation; either version 2 of the
10 * License, or (at your option) any later version.
12 * This program is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 * GNU General Public License for more details.
17 * You should have received a copy of the GNU General Public License
18 * along with this program; if not, write to the Free Software
19 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
20 * USA
24 Ranking-SVM is adapted from the paper:
25 Joachims, Thorsten. "Optimizing search engines using clickthrough data."
26 Proceedings of the eighth ACM SIGKDD international conference on Knowledge discovery and data mining. ACM, 2002.
28 #include <config.h>
30 #include "xapian-letor/ranker.h"
32 #include "debuglog.h"
33 #include "serialise-double.h"
35 #include <algorithm>
36 #include <cerrno>
37 #include <cstdlib>
38 #include <fstream>
39 #include <sstream>
40 #include <unistd.h>
41 #include <vector>
43 #include <libsvm/svm.h>
45 using namespace std;
46 using namespace Xapian;
48 static void
49 empty_function(const char*) {}
51 static void
52 clear_svm_problem(svm_problem *problem)
54 delete [] problem->y;
55 problem->y = NULL;
56 for (int i = 0; i < problem->l; ++i) {
57 delete [] problem->x[i];
59 delete [] problem->x;
60 problem->x = NULL;
63 SVMRanker::SVMRanker()
65 LOGCALL_CTOR(API, "SVMRanker", NO_ARGS);
68 SVMRanker::~SVMRanker()
70 LOGCALL_DTOR(API, "SVMRanker");
73 static int
74 get_non_zero_num(const Xapian::FeatureVector & fv_non_zero)
76 int non_zero = 0;
77 for (int i = 0; i < fv_non_zero.get_fcount(); ++i)
78 if (fv_non_zero.get_feature_value(i) != 0.0)
79 ++non_zero;
80 return non_zero;
83 void
84 SVMRanker::train(const std::vector<Xapian::FeatureVector> & training_data)
86 LOGCALL_VOID(API, "SVMRanker::train", training_data);
87 svm_set_print_string_function(&empty_function);
88 struct svm_parameter param;
89 param.svm_type = NU_SVR;
90 param.kernel_type = LINEAR;
91 param.degree = 3; // parameter for poly Kernel, default 3
92 param.gamma = 0; // parameter for poly/rbf/sigmoid Kernel, default 1/num_features
93 param.coef0 = 0; // parameter for poly/sigmoid Kernel, default 0
94 param.nu = 0.5; // parameter for nu-SVC/one-class SVM/nu-SVR, default 0.5
95 param.cache_size = 100; // default 40MB
96 param.C = 1; // penalty parameter, default 1
97 param.eps = 1e-3; // stopping criteria, default 1e-3
98 param.p = 0.1; // parameter for e -SVR, default 0.1
99 param.shrinking = 1; // use the shrinking heuristics
100 param.probability = 0; // probability estimates
101 param.nr_weight = 0; // parameter for C-SVCl
102 param.weight_label = NULL; // parameter for C-SVC
103 param.weight = NULL; // parameter for c-SVC
105 int fvv_len = training_data.size();
106 if (fvv_len == 0) {
107 throw LetorInternalError("Training data is empty. Check training file.");
109 int feature_cnt = training_data[0].get_fcount();
111 struct svm_problem prob;
112 // set the parameters for svm_problem
113 prob.l = fvv_len;
114 // feature vector
115 prob.x = new svm_node* [prob.l];
116 prob.y = new double [prob.l];
118 // Sparse storage; libsvm needs only non-zero features
119 for (int i = 0; i < fvv_len; ++i) {
120 prob.x[i] = new svm_node [get_non_zero_num(training_data[i]) + 1]; // one extra for default -1 at the end
121 int non_zero_flag = 0;
122 for (int k = 0; k < feature_cnt; ++k) {
123 double fval = training_data[i].get_feature_value(k);
124 if (fval != 0.0) {
125 prob.x[i][non_zero_flag].index = k;
126 prob.x[i][non_zero_flag].value = fval;
127 ++non_zero_flag;
130 prob.x[i][non_zero_flag].index = -1;
131 prob.x[i][non_zero_flag].value = -1;
133 prob.y[i] = training_data[i].get_label();
136 const char * error_msg;
137 error_msg = svm_check_parameter(&prob, &param);
138 if (error_msg)
139 throw LetorInternalError("svm_check_parameter failed: %s", error_msg);
141 struct svm_model * trainmodel = svm_train(&prob, &param);
142 // Generate temporary file to extract model data
143 char templ[] = "/tmp/svmtemp.XXXXXX";
144 int fd = mkstemp(templ);
145 if (fd == -1) {
146 throw LetorInternalError("Training failed: ", errno);
148 try {
149 svm_save_model(templ, trainmodel);
150 // Read content of model to string
151 std::ifstream f(templ);
152 std::ostringstream ss;
153 ss << f.rdbuf();
154 // Save model string
155 this->model_data = ss.str();
156 close(fd);
157 std::remove(templ);
158 } catch (...) {
159 close(fd);
160 std::remove(templ);
162 svm_free_and_destroy_model(&trainmodel);
163 clear_svm_problem(&prob);
164 svm_destroy_param(&param);
165 if (this->model_data.empty()) {
166 throw LetorInternalError("SVM model empty. Training failed.");
170 void
171 SVMRanker::save_model_to_metadata(const string & model_key)
173 LOGCALL_VOID(API, "SVMRanker::save_model_to_metadata", model_key);
174 Xapian::WritableDatabase letor_db(get_database_path());
175 string key = model_key;
176 if (key.empty()) {
177 key = "SVMRanker.model.default";
179 letor_db.set_metadata(key, this->model_data);
182 void
183 SVMRanker::load_model_from_metadata(const string & model_key)
185 LOGCALL_VOID(API, "SVMRanker::load_model_from_metadata", model_key);
186 Xapian::Database letor_db(get_database_path());
187 string key = model_key;
188 if (key.empty()) {
189 key = "SVMRanker.model.default";
191 string model_data_load = letor_db.get_metadata(key);
192 // Throw exception if no model data associated with key
193 if (model_data_load.empty()) {
194 throw Xapian::LetorInternalError("No model found. Check key.");
196 this->model_data = model_data_load;
199 std::vector<FeatureVector>
200 SVMRanker::rank_fvv(const std::vector<FeatureVector> & fvv) const
202 LOGCALL(API, std::vector<FeatureVector>, "SVMRanker::rank_fvv", fvv);
203 if (this->model_data.empty()) {
204 throw LetorInternalError("SVM model empty. Load correct model.");
206 // Generate temporary file containing model data for svm_load_model() method
207 struct svm_model * model = NULL;
208 char templ[] = "/tmp/svmtemp.XXXXXX";
209 int fd = mkstemp(templ);
210 if (fd == -1) {
211 throw LetorInternalError("Ranking failed: ", errno);
213 try {
214 std::ofstream f(templ);
215 f << this->model_data.c_str();
216 f.close();
217 model = svm_load_model(templ);
218 close(fd);
219 std::remove(templ);
220 } catch (...) {
221 close(fd);
222 std::remove(templ);
224 if (!model) {
225 throw LetorInternalError("Ranking failed. SVM model not usable.");
228 std::vector<FeatureVector> testfvv = fvv;
230 struct svm_node * test = NULL;
231 for (size_t i = 0; i < testfvv.size(); ++i) {
232 test = new svm_node [get_non_zero_num(testfvv[i]) + 1];
233 int feature_cnt = testfvv[i].get_fcount();
234 int non_zero_flag = 0;
235 for (int k = 0; k < feature_cnt; ++k) {
236 double fval = testfvv[i].get_feature_value(k);
237 if (fval != 0.0) {
238 test[non_zero_flag].index = k;
239 test[non_zero_flag].value = fval;
240 ++non_zero_flag;
243 test[non_zero_flag].index = -1;
244 test[non_zero_flag].value = -1;
245 testfvv[i].set_score(svm_predict(model, test));
246 delete [] test;
248 svm_free_and_destroy_model(&model);
249 return testfvv;