2 * @brief tests for the Simplified DBN clickmodel class.
4 /* Copyright (C) 2017 Vivek Pal
6 * This program is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU General Public License as
8 * published by the Free Software Foundation; either version 2 of the
9 * License, or (at your option) any later version.
11 * This program is distributed in the hope that it will be useful
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software Foundation,
18 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
30 #include <clickmodel/simplifieddbn.h>
38 struct sessions_items
{
44 struct sessions_testcase
{
46 const sessions_items sessions
;
49 static string
get_srcdir() {
50 char *p
= getenv("srcdir");
56 almost_equal(double x
, double y
, double epsilon
= 0.001) {
57 return (abs(x
- y
) < epsilon
);
61 string srcdir
= get_srcdir();
63 string sample_log1
= srcdir
+ "/clickmodel/testdata/test1.log";
64 string sample_log2
= srcdir
+ "/clickmodel/testdata/test2.log";
65 string sample_log3
= srcdir
+ "/clickmodel/testdata/test3.log";
68 sessions_testcase sessions_tests
[] = {
69 {sample_log1
, {"821f03288846297c2cf43c34766a38f7",
71 "45:0,36:0,14:0,54:2,42:0"}},
72 {sample_log2
, {"","",""}}
77 int failure_count
= 0;
79 // Tests for SimplifiedDBN::build_sessions method.
80 for (size_t i
= 0; i
< sizeof(sessions_tests
) / sizeof(sessions_tests
[0]);
82 vector
<Session
> result
;
84 result
= sdbn
.build_sessions(sessions_tests
[i
].logfile
);
85 } catch (std::exception
&ex
) {
86 cout
<< ex
.what() << endl
;
88 // Specified file doesn't exist. Skip subsequent tests.
92 for (auto&& x
: result
) {
93 if (x
.get_qid() != sessions_tests
[i
].sessions
.qid
) {
94 cerr
<< "ERROR: Query ID mismatch occurred. " << endl
95 << "Expected: " << sessions_tests
[i
].sessions
.qid
96 << " Received: "<< x
.get_qid() << endl
;
99 if (x
.get_docids() != sessions_tests
[i
].sessions
.docids
) {
100 cerr
<< "ERROR: Doc ID(s) mismatch occurred. " << endl
101 << "Expected: " << sessions_tests
[i
].sessions
.docids
102 << " Received: " << x
.get_docids() << endl
;
105 if (x
.get_clicks() != sessions_tests
[i
].sessions
.clicks
) {
106 cerr
<< "ERROR: Clicks mismatch occurred. " << endl
107 << "Expected: " << sessions_tests
[i
].sessions
.clicks
108 << " Received: " << x
.get_clicks() << endl
;
114 // Train Simplified DBN model on a dummy training file.
115 vector
<Session
> training_sessions
;
117 training_sessions
= sdbn
.build_sessions(sample_log3
);
118 } catch (std::exception
&ex
) {
119 cout
<< ex
.what() << endl
;
123 sdbn
.train(training_sessions
);
125 pair
<string
, double> relevance11 ("45", 0.166);
126 pair
<string
, double> relevance12 ("36", 0.166);
127 pair
<string
, double> relevance13 ("14", 0.166);
128 pair
<string
, double> relevance14 ("54", 0.444);
129 pair
<string
, double> relevance15 ("42", 0);
131 pair
<string
, double> relevance21 ("35", 0.444);
132 pair
<string
, double> relevance22 ("47", 0);
133 pair
<string
, double> relevance23 ("31", 0);
134 pair
<string
, double> relevance24 ("14", 0);
135 pair
<string
, double> relevance25 ("45", 0);
137 vector
<vector
<pair
<string
, double>>>
138 expected_relevances
= {{relevance11
, relevance12
, relevance13
,
139 relevance14
, relevance15
},
140 {relevance21
, relevance22
, relevance23
,
141 relevance24
, relevance25
}};
144 // Tests for SimplifiedDBN::get_predicted_relevances.
145 for (auto&& session
: training_sessions
) {
146 vector
<pair
<string
, double>>
147 predicted_relevances
= sdbn
.get_predicted_relevances(session
);
149 if (predicted_relevances
.size() != expected_relevances
[k
].size()) {
150 cerr
<< "ERROR: Relevance lists sizes do not match." << endl
;
152 // Skip subsequent tests.
156 for (size_t j
= 0; j
< expected_relevances
[k
].size(); ++j
) {
157 if (!almost_equal(predicted_relevances
[j
].second
,
158 expected_relevances
[k
][j
].second
)) {
159 cerr
<< "ERROR: Relevances do not match." << endl
160 << "Expected: " << expected_relevances
[k
][j
].second
161 << " Received: " << predicted_relevances
[j
].second
<< endl
;
168 if (failure_count
!= 0) {
169 cout
<< "Total failures occurred: " << failure_count
<< endl
;