Add the implementation of SDBN clickmodel
[xapian.git] / xapian-applications / omega / clickmodel / tests / sdbntest.cc
blob8d9723b3dadfca6e386c6abef45269389555af08
1 /** @file sdbntest.cc
2 * @brief tests for the Simplified DBN clickmodel class.
3 */
4 /* Copyright (C) 2017 Vivek Pal
6 * This program is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU General Public License as
8 * published by the Free Software Foundation; either version 2 of the
9 * License, or (at your option) any later version.
11 * This program is distributed in the hope that it will be useful
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software Foundation,
18 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
21 #include <config.h>
23 #include <cmath>
24 #include <fstream>
25 #include <iostream>
26 #include <sstream>
27 #include <string>
28 #include <vector>
30 #include <clickmodel/simplifieddbn.h>
32 using namespace std;
34 #define QID 0
35 #define DOCIDS 1
36 #define CLICKS 2
38 struct sessions_items {
39 const string qid;
40 const string docids;
41 const string clicks;
44 struct sessions_testcase {
45 const string logfile;
46 const sessions_items sessions;
49 static string get_srcdir() {
50 char *p = getenv("srcdir");
51 if (!p) return ".";
52 return string(p);
55 static bool
56 almost_equal(double x, double y, double epsilon = 0.001) {
57 return (abs(x - y) < epsilon);
60 int main() {
61 string srcdir = get_srcdir();
63 string sample_log1 = srcdir + "/clickmodel/testdata/test1.log";
64 string sample_log2 = srcdir + "/clickmodel/testdata/test2.log";
65 string sample_log3 = srcdir + "/clickmodel/testdata/test3.log";
68 sessions_testcase sessions_tests[] = {
69 {sample_log1, {"821f03288846297c2cf43c34766a38f7",
70 "45,36,14,54,42",
71 "45:0,36:0,14:0,54:2,42:0"}},
72 {sample_log2, {"","",""}}
75 SimplifiedDBN sdbn;
77 int failure_count = 0;
79 // Tests for SimplifiedDBN::build_sessions method.
80 for (size_t i = 0; i < sizeof(sessions_tests) / sizeof(sessions_tests[0]);
81 ++i) {
82 vector<Session> result;
83 try {
84 result = sdbn.build_sessions(sessions_tests[i].logfile);
85 } catch (std::exception &ex) {
86 cout << ex.what() << endl;
87 ++failure_count;
88 // Specified file doesn't exist. Skip subsequent tests.
89 continue;
92 for (auto&& x : result) {
93 if (x.get_qid() != sessions_tests[i].sessions.qid) {
94 cerr << "ERROR: Query ID mismatch occurred. " << endl
95 << "Expected: " << sessions_tests[i].sessions.qid
96 << " Received: "<< x.get_qid() << endl;
97 ++failure_count;
99 if (x.get_docids() != sessions_tests[i].sessions.docids) {
100 cerr << "ERROR: Doc ID(s) mismatch occurred. " << endl
101 << "Expected: " << sessions_tests[i].sessions.docids
102 << " Received: " << x.get_docids() << endl;
103 ++failure_count;
105 if (x.get_clicks() != sessions_tests[i].sessions.clicks) {
106 cerr << "ERROR: Clicks mismatch occurred. " << endl
107 << "Expected: " << sessions_tests[i].sessions.clicks
108 << " Received: " << x.get_clicks() << endl;
109 ++failure_count;
114 // Train Simplified DBN model on a dummy training file.
115 vector<Session> training_sessions;
116 try {
117 training_sessions = sdbn.build_sessions(sample_log3);
118 } catch (std::exception &ex) {
119 cout << ex.what() << endl;
120 ++failure_count;
123 sdbn.train(training_sessions);
125 pair<string, double> relevance11 ("45", 0.166);
126 pair<string, double> relevance12 ("36", 0.166);
127 pair<string, double> relevance13 ("14", 0.166);
128 pair<string, double> relevance14 ("54", 0.444);
129 pair<string, double> relevance15 ("42", 0);
131 pair<string, double> relevance21 ("35", 0.444);
132 pair<string, double> relevance22 ("47", 0);
133 pair<string, double> relevance23 ("31", 0);
134 pair<string, double> relevance24 ("14", 0);
135 pair<string, double> relevance25 ("45", 0);
137 vector<vector<pair<string, double>>>
138 expected_relevances = {{relevance11, relevance12, relevance13,
139 relevance14, relevance15},
140 {relevance21, relevance22, relevance23,
141 relevance24, relevance25}};
143 int k = 0;
144 // Tests for SimplifiedDBN::get_predicted_relevances.
145 for (auto&& session : training_sessions) {
146 vector<pair<string, double>>
147 predicted_relevances = sdbn.get_predicted_relevances(session);
149 if (predicted_relevances.size() != expected_relevances[k].size()) {
150 cerr << "ERROR: Relevance lists sizes do not match." << endl;
151 ++failure_count;
152 // Skip subsequent tests.
153 continue;
156 for (size_t j = 0; j < expected_relevances[k].size(); ++j) {
157 if (!almost_equal(predicted_relevances[j].second,
158 expected_relevances[k][j].second)) {
159 cerr << "ERROR: Relevances do not match." << endl
160 << "Expected: " << expected_relevances[k][j].second
161 << " Received: " << predicted_relevances[j].second << endl;
162 ++failure_count;
165 ++k;
168 if (failure_count != 0) {
169 cout << "Total failures occurred: " << failure_count << endl;
170 exit(1);