Add the implementation of SDBN clickmodel
[xapian.git] / xapian-applications / omega / clickmodel / simplifieddbn.cc
blob0606f65c6f4945cc412eb3ca26879a4f4d277596
1 /** @file simplifieddbn.cc
2 * @brief SimplifiedDBN class - the Simplified DBN click model.
3 */
4 /* Copyright (C) 2017 Vivek Pal
6 * This program is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU General Public License as
8 * published by the Free Software Foundation; either version 2 of the
9 * License, or (at your option) any later version.
11 * This program is distributed in the hope that it will be useful
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
21 #include <config.h>
23 #include "simplifieddbn.h"
25 #include <algorithm>
26 #include <iomanip>
27 #include <fstream>
28 #include <sstream>
29 #include <string>
30 #include <utility>
31 #include <vector>
33 using namespace std;
35 #define QID 0
36 #define DOCIDS 1
37 #define CLICKS 2
39 string
40 SimplifiedDBN::name()
42 return "SimplifiedDBN";
45 static vector<string>
46 get_docid_list(const string &str_docids)
48 vector<string> docids;
49 string docid;
50 for (size_t j = 0; j <= str_docids.length(); ++j) {
51 char ch = str_docids[j];
52 if (ch != ',' && ch != '\0') {
53 docid += ch;
54 } else {
55 docids.push_back(docid);
56 docid.clear();
59 return docids;
62 static vector<int>
63 get_click_list(const string &str_clicks)
65 vector<int> clicks;
66 string clickstring, clickcount;
67 int click = 0;
69 for (size_t j = 0; j <= str_clicks.length(); ++j) {
70 char ch = str_clicks[j];
71 if (ch != ',' && ch != '\0') {
72 // Get clickstring of the form "docid:click_count".
73 clickstring += str_clicks[j];
75 // Get click count string.
76 size_t delimiter_pos = clickstring.find(':');
77 clickcount = clickstring.substr(delimiter_pos + 1);
79 // Convert click count string to an integer.
80 stringstream ss(clickcount);
81 ss >> click;
82 } else {
83 clicks.push_back(click);
84 clickstring.clear();
87 return clicks;
90 vector<Session>
91 SimplifiedDBN::build_sessions(const string &logfile)
93 ifstream file;
94 file.open(logfile, ios::in);
96 if (!file) {
97 throw runtime_error("ERROR: Specified file does not exist.");
100 string line;
102 // Skip the first line.
103 getline(file, line);
105 vector<Session> sessions;
107 // Start reading file from the second line.
108 while (getline(file, line)) {
110 istringstream ss(line);
112 vector<string> row_data;
114 while (ss >> std::ws) {
115 string column_element;
116 if (ss.peek() == '"') {
117 int pos = ss.tellg();
118 ss.seekg(pos + 1);
119 char ch;
120 while (ss.get(ch)) {
121 if (ch == '"')
122 break;
123 column_element += ch;
125 } else {
126 if (ss.peek() == ',') {
127 int pos = ss.tellg();
128 ss.seekg(pos + 1);
130 getline(ss, column_element, ',');
132 row_data.push_back(column_element);
135 string qid = row_data[0];
136 string query = row_data[1];
137 string docids = row_data[2];
138 string clicks = row_data[4];
140 Session s;
141 s.create_session(qid, docids, clicks);
142 sessions.push_back(s);
145 file.close();
146 return sessions;
149 void
150 SimplifiedDBN::train(const vector<Session> &sessions)
152 map<string, map<string, map<int, map<int, double>>>> doc_rel_fractions;
154 for (auto&& session : sessions) {
155 string qid = session.get_qid();
157 vector<string> docids = get_docid_list(session.get_docids());
159 vector<int> clicks = get_click_list(session.get_clicks());
161 int last_clicked_pos = clicks.size() - 1;
163 for (size_t j = 0; j < clicks.size(); ++j)
164 if (clicks[j] != 0)
165 last_clicked_pos = j;
167 // Initialise some values.
168 for (int k = 0; k <= last_clicked_pos; ++k) {
169 doc_rel_fractions[qid][docids[k]][PARAM_ATTR_PROB][0] = 1.0;
170 doc_rel_fractions[qid][docids[k]][PARAM_ATTR_PROB][1] = 1.0;
171 doc_rel_fractions[qid][docids[k]][PARAM_SAT_PROB][0] = 1.0;
172 doc_rel_fractions[qid][docids[k]][PARAM_SAT_PROB][1] = 1.0;
175 for (int k = 0; k <= last_clicked_pos; ++k) {
176 if (clicks[k] != 0) {
177 doc_rel_fractions[qid][docids[k]][PARAM_ATTR_PROB][1] += 1;
178 if (int(k) == last_clicked_pos)
179 doc_rel_fractions[qid][docids[k]][PARAM_SAT_PROB][1] += 1;
180 else
181 doc_rel_fractions[qid][docids[k]][PARAM_SAT_PROB][0] += 1;
182 } else {
183 doc_rel_fractions[qid][docids[k]][PARAM_ATTR_PROB][0] += 1;
188 for (auto i = doc_rel_fractions.begin(); i != doc_rel_fractions.end();
189 ++i) {
190 string qid = i->first;
191 for (auto&& j : i->second) {
192 doc_relevances[qid][j.first][PARAM_ATTR_PROB] =
193 j.second[PARAM_ATTR_PROB][1] /
194 (j.second[PARAM_ATTR_PROB][1] +
195 j.second[PARAM_ATTR_PROB][0]);
196 doc_relevances[qid][j.first][PARAM_SAT_PROB] =
197 j.second[PARAM_SAT_PROB][1] /
198 (j.second[PARAM_SAT_PROB][1] +
199 j.second[PARAM_SAT_PROB][0]);
204 vector<pair<string, double>>
205 SimplifiedDBN::get_predicted_relevances(const Session &session)
207 vector<pair<string, double>> docid_relevances;
209 vector<string> docids = get_docid_list(session.get_docids());
211 for (size_t i = 0; i < docids.size(); ++i) {
212 double attr_prob =
213 doc_relevances[session.get_qid()][docids[i]][PARAM_ATTR_PROB];
214 double sat_prob =
215 doc_relevances[session.get_qid()][docids[i]][PARAM_SAT_PROB];
216 docid_relevances.push_back(make_pair(docids[i], attr_prob * sat_prob));
218 return docid_relevances;