quest: Fix formatting of --help output
[xapian.git] / xapian-core / cluster / cluster.cc
blob990174a77724a80f329e565b544f96861425be4f
1 /** @file
2 * @brief Cluster API
3 */
4 /* Copyright (C) 2010 Richard Boulton
5 * Copyright (C) 2016 Richhiey Thomas
7 * This program is free software; you can redistribute it and/or
8 * modify it under the terms of the GNU General Public License as
9 * published by the Free Software Foundation; either version 2 of the
10 * License, or (at your option) any later version.
12 * This program is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 * GNU General Public License for more details.
17 * You should have received a copy of the GNU General Public License
18 * along with this program; if not, write to the Free Software
19 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
20 * USA
23 #include <config.h>
25 #include "xapian/cluster.h"
27 #include "cluster/clusterinternal.h"
28 #include "xapian/error.h"
29 #include "api/termlist.h"
30 #include "debuglog.h"
31 #include "omassert.h"
33 #include <cmath>
34 #include <unordered_map>
35 #include <vector>
37 using namespace Xapian;
38 using namespace std;
40 FreqSource::~FreqSource()
42 LOGCALL_DTOR(API, "FreqSource");
45 Similarity::~Similarity()
47 LOGCALL_DTOR(API, "Similarity");
50 Clusterer::~Clusterer()
52 LOGCALL_DTOR(API, "Clusterer");
55 TermListGroup::TermListGroup(const MSet& docs, const Stopper* stopper)
57 LOGCALL_CTOR(API, "TermListGroup", docs | stopper);
58 for (MSetIterator it = docs.begin(); it != docs.end(); ++it)
59 add_document(it.get_document(), stopper);
60 num_of_documents = docs.size();
63 void
64 TermListGroup::add_document(const Document& document, const Stopper* stopper)
66 LOGCALL_VOID(API, "TermListGroup::add_document", document | stopper);
68 TermIterator titer(document.termlist_begin());
70 for (; titer != document.termlist_end(); ++titer) {
71 const string& term = *titer;
73 // Remove stopwords by using the Xapian::Stopper object
74 if (stopper && (*stopper)(term))
75 continue;
77 // Remove unstemmed terms since document vector should
78 // contain only stemmed terms
79 if (term[0] != 'Z')
80 continue;
82 unordered_map<string, doccount>::iterator i;
83 i = termfreq.find(term);
84 if (i == termfreq.end())
85 termfreq[term] = 1;
86 else
87 ++i->second;
91 doccount
92 TermListGroup::get_doccount() const
94 LOGCALL(API, doccount, "TermListGroup::get_doccount", NO_ARGS);
95 return num_of_documents;
98 doccount
99 TermListGroup::get_termfreq(const string& tname) const
101 LOGCALL(API, doccount, "TermListGroup::get_termfreq", tname);
102 unordered_map<string, doccount>::const_iterator it = termfreq.find(tname);
103 if (it != termfreq.end())
104 return it->second;
105 else
106 return 0;
109 DocumentSet::DocumentSet(const DocumentSet&) = default;
111 DocumentSet&
112 DocumentSet::operator=(const DocumentSet&) = default;
114 DocumentSet::DocumentSet(DocumentSet&&) = default;
116 DocumentSet&
117 DocumentSet::operator=(DocumentSet&&) = default;
119 DocumentSet::DocumentSet()
120 : internal(new Xapian::DocumentSet::Internal)
124 doccount
125 DocumentSet::size() const
127 LOGCALL(API, doccount, "DocumentSet::size", NO_ARGS);
128 return internal->size();
131 doccount
132 DocumentSet::Internal::size() const
134 return documents.size();
137 void
138 DocumentSet::add_document(const Document& document)
140 LOGCALL_VOID(API, "DocumentSet::add_document", document);
141 internal->add_document(document);
144 void
145 DocumentSet::Internal::add_document(const Document& document)
147 documents.push_back(document);
150 Document&
151 DocumentSet::operator[](doccount i)
153 return internal->get_document(i);
156 Document&
157 DocumentSet::Internal::get_document(doccount i)
159 return documents[i];
162 const Document&
163 DocumentSet::operator[](doccount i) const
165 return internal->get_document(i);
168 const Document&
169 DocumentSet::Internal::get_document(doccount i) const
171 return documents[i];
174 DocumentSet::~DocumentSet()
176 LOGCALL_DTOR(API, "DocumentSet");
179 class PointTermIterator : public TermIterator::Internal {
180 unordered_map<string, double>::const_iterator i;
181 unordered_map<string, double>::const_iterator end;
182 termcount size;
183 bool started;
184 public:
185 PointTermIterator(const unordered_map<string, double>& termlist)
186 : i(termlist.begin()), end(termlist.end()),
187 size(termlist.size()), started(false)
189 termcount get_approx_size() const { return size; }
190 termcount get_wdf() const {
191 throw UnimplementedError("PointIterator doesn't support get_wdf()");
193 string get_termname() const { return i->first; }
194 doccount get_termfreq() const {
195 throw UnimplementedError("PointIterator doesn't support "
196 "get_termfreq()");
198 Internal* next();
199 termcount positionlist_count() const {
200 throw UnimplementedError("PointTermIterator doesn't support "
201 "positionlist_count()");
203 bool at_end() const;
204 PositionList* positionlist_begin() const {
205 throw UnimplementedError("PointTermIterator doesn't support "
206 "positionlist_begin()");
208 Internal* skip_to(const string&) {
209 throw UnimplementedError("PointTermIterator doesn't support skip_to()");
213 TermIterator::Internal*
214 PointTermIterator::next()
216 if (!started) {
217 started = true;
218 return NULL;
220 Assert(i != end);
221 ++i;
222 return NULL;
225 bool
226 PointTermIterator::at_end() const
228 if (!started) return false;
229 return i == end;
232 TermIterator
233 PointType::termlist_begin() const
235 LOGCALL(API, TermIterator, "PointType::termlist_begin", NO_ARGS);
236 return TermIterator(new PointTermIterator(weights));
239 bool
240 PointType::contains(const string& term) const
242 LOGCALL(API, bool, "PointType::contains", term);
243 return weights.find(term) != weights.end();
246 double
247 PointType::get_weight(const string& term) const
249 LOGCALL(API, double, "PointType::get_weight", term);
250 unordered_map<string, double>::const_iterator it = weights.find(term);
251 return (it == weights.end()) ? 0.0 : it->second;
254 double
255 PointType::get_magnitude() const {
256 LOGCALL(API, double, "PointType::get_magnitude", NO_ARGS);
257 return magnitude;
260 void
261 PointType::add_weight(const string& term, double weight)
263 LOGCALL_VOID(API, "PointType::add_weight", term | weight);
264 unordered_map<string, double>::iterator it;
265 it = weights.find(term);
266 if (it != weights.end())
267 it->second += weight;
268 else
269 weights[term] = weight;
272 void
273 PointType::set_weight(const string& term, double weight)
275 LOGCALL_VOID(API, "PointType::set_weight", term | weight);
276 weights[term] = weight;
279 termcount
280 PointType::termlist_size() const
282 LOGCALL(API, termcount, "PointType::termlist_size", NO_ARGS);
283 return weights.size();
286 Document
287 Point::get_document() const
289 LOGCALL(API, Document, "Point::get_document", NO_ARGS);
290 return document;
293 Point::Point(const FreqSource& freqsource, const Document& document_)
295 LOGCALL_CTOR(API, "Point::initialize", freqsource | document_);
296 doccount size = freqsource.get_doccount();
297 document = document_;
298 for (TermIterator it = document.termlist_begin();
299 it != document.termlist_end();
300 ++it) {
301 doccount wdf = it.get_wdf();
302 string term = *it;
303 double termfreq = freqsource.get_termfreq(term);
305 // If the term exists in only one document, or if it exists in
306 // every document within the MSet, or if it is a filter term, then
307 // these terms are not used for document vector calculations
308 if (wdf < 1 || termfreq <= 1 || size == termfreq)
309 continue;
311 double tf = 1 + log(double(wdf));
312 double idf = log(size / termfreq);
313 double wt = tf * idf;
315 weights[term] = wt;
316 magnitude += wt * wt;
320 Centroid::Centroid(const Point& point)
322 LOGCALL_CTOR(API, "Centroid", point);
323 for (TermIterator it = point.termlist_begin();
324 it != point.termlist_end();
325 ++it) {
326 weights[*it] = point.get_weight(*it);
328 magnitude = point.get_magnitude();
331 void
332 Centroid::divide(double cluster_size)
334 LOGCALL_VOID(API, "Centroid::divide", cluster_size);
335 magnitude = 0;
336 unordered_map<string, double>::iterator it;
337 for (it = weights.begin(); it != weights.end(); ++it) {
338 double new_weight = it->second / cluster_size;
339 it->second = new_weight;
340 magnitude += new_weight * new_weight;
344 void
345 Centroid::clear()
347 LOGCALL_VOID(API, "Centroid::clear", NO_ARGS);
348 weights.clear();
351 Cluster&
352 Cluster::operator=(const Cluster&) = default;
354 Cluster::Cluster(const Cluster&) = default;
356 Cluster::Cluster(Cluster&&) = default;
358 Cluster&
359 Cluster::operator=(Cluster&&) = default;
361 Cluster::Cluster() : internal(new Xapian::Cluster::Internal)
363 LOGCALL_CTOR(API, "Cluster", NO_ARGS);
366 Cluster::Cluster(const Centroid& centroid)
367 : internal(new Xapian::Cluster::Internal(centroid))
369 LOGCALL_CTOR(API, "Cluster", centroid);
372 Cluster::~Cluster()
374 LOGCALL_DTOR(API, "Cluster");
377 Centroid::Centroid()
379 LOGCALL_CTOR(API, "Centroid", NO_ARGS);
382 DocumentSet
383 Cluster::get_documents() const
385 LOGCALL(API, DocumentSet, "Cluster::get_documents", NO_ARGS);
386 return internal->get_documents();
389 DocumentSet
390 Cluster::Internal::get_documents() const
392 DocumentSet docs;
393 for (auto&& point : cluster_docs)
394 docs.add_document(point.get_document());
395 return docs;
398 Point&
399 Cluster::operator[](Xapian::doccount i)
401 return internal->get_point(i);
404 Point&
405 Cluster::Internal::get_point(Xapian::doccount i)
407 return cluster_docs[i];
410 const Point&
411 Cluster::operator[](Xapian::doccount i) const
413 return internal->get_point(i);
416 const Point&
417 Cluster::Internal::get_point(Xapian::doccount i) const
419 return cluster_docs[i];
422 ClusterSet&
423 ClusterSet::operator=(const ClusterSet&) = default;
425 ClusterSet::ClusterSet(const ClusterSet&) = default;
427 ClusterSet&
428 ClusterSet::operator=(ClusterSet&&) = default;
430 ClusterSet::ClusterSet(ClusterSet&&) = default;
432 ClusterSet::ClusterSet() : internal(new Xapian::ClusterSet::Internal)
436 ClusterSet::~ClusterSet()
440 doccount
441 ClusterSet::Internal::size() const
443 return clusters.size();
446 doccount
447 ClusterSet::size() const
449 LOGCALL(API, doccount, "ClusterSet::size", NO_ARGS);
450 return internal->size();
453 void
454 ClusterSet::Internal::add_cluster(const Cluster& cluster)
456 clusters.push_back(cluster);
459 void
460 ClusterSet::add_cluster(const Cluster& cluster)
462 LOGCALL_VOID(API, "ClusterSet::add_cluster", cluster);
463 internal->add_cluster(cluster);
466 Cluster&
467 ClusterSet::Internal::get_cluster(doccount i)
469 return clusters[i];
472 Cluster&
473 ClusterSet::operator[](doccount i)
475 return internal->get_cluster(i);
478 const Cluster&
479 ClusterSet::Internal::get_cluster(doccount i) const
481 return clusters[i];
484 const Cluster&
485 ClusterSet::operator[](doccount i) const
487 return internal->get_cluster(i);
490 void
491 ClusterSet::Internal::add_to_cluster(const Point& point, unsigned int index)
493 clusters[index].add_point(point);
496 void
497 ClusterSet::add_to_cluster(const Point& point, unsigned int index)
499 LOGCALL_VOID(API, "ClusterSet::add_to_cluster", point | index);
500 internal->add_to_cluster(point, index);
503 void
504 ClusterSet::Internal::recalculate_centroids()
506 for (auto&& cluster : clusters)
507 cluster.recalculate();
510 void
511 ClusterSet::recalculate_centroids()
513 LOGCALL_VOID(API, "ClusterSet::recalculate_centroids", NO_ARGS);
514 internal->recalculate_centroids();
517 void
518 ClusterSet::clear_clusters()
520 LOGCALL_VOID(API, "ClusterSet::clear_clusters", NO_ARGS);
521 internal->clear_clusters();
524 void
525 ClusterSet::Internal::clear_clusters()
527 for (auto&& cluster : clusters)
528 cluster.clear();
531 doccount
532 Cluster::size() const
534 LOGCALL(API, doccount, "Cluster::size", NO_ARGS);
535 return internal->size();
538 doccount
539 Cluster::Internal::size() const
541 return (cluster_docs.size());
544 void
545 Cluster::add_point(const Point& point)
547 LOGCALL_VOID(API, "Cluster::add_point", point);
548 internal->add_point(point);
551 void
552 Cluster::Internal::add_point(const Point& point)
554 cluster_docs.push_back(point);
557 void
558 Cluster::clear()
560 LOGCALL_VOID(API, "Cluster::clear", NO_ARGS);
561 internal->clear();
564 void
565 Cluster::Internal::clear()
567 cluster_docs.clear();
570 const Centroid&
571 Cluster::get_centroid() const
573 LOGCALL(API, Centroid, "Cluster::get_centroid", NO_ARGS);
574 return internal->get_centroid();
577 const Centroid&
578 Cluster::Internal::get_centroid() const
580 return centroid;
583 void
584 Cluster::set_centroid(const Centroid& centroid_)
586 LOGCALL_VOID(API, "Cluster::set_centroid", centroid_);
587 internal->set_centroid(centroid_);
590 void
591 Cluster::Internal::set_centroid(const Centroid& centroid_)
593 centroid = centroid_;
596 void
597 Cluster::recalculate()
599 LOGCALL_VOID(API, "Cluster::recalculate", NO_ARGS);
600 internal->recalculate();
603 void
604 Cluster::Internal::recalculate()
606 centroid.clear();
607 for (const Point& temp : cluster_docs) {
608 for (TermIterator titer = temp.termlist_begin();
609 titer != temp.termlist_end();
610 ++titer) {
611 centroid.add_weight(*titer, temp.get_weight(*titer));
614 centroid.divide(size());
617 StemStopper::StemStopper(const Stem& stemmer_, stem_strategy strategy)
618 : stem_action(strategy), stemmer(stemmer_)
620 LOGCALL_CTOR(API, "StemStopper", stemmer_ | strategy);
623 string
624 StemStopper::get_description() const
626 string desc("Xapian::StemStopper(");
627 unordered_set<string>::const_iterator i;
628 for (i = stop_words.begin(); i != stop_words.end(); ++i) {
629 if (i != stop_words.begin()) desc += ' ';
630 desc += *i;
632 desc += ')';
633 return desc;
636 void
637 StemStopper::add(const string& term)
639 LOGCALL_VOID(API, "StemStopper::add", term);
640 switch (stem_action) {
641 case STEM_NONE:
642 stop_words.insert(term);
643 break;
644 case STEM_ALL_Z:
645 stop_words.insert('Z' + stemmer(term));
646 break;
647 case STEM_ALL:
648 stop_words.insert(stemmer(term));
649 break;
650 case STEM_SOME:
651 case STEM_SOME_FULL_POS:
652 stop_words.insert(term);
653 stop_words.insert('Z' + stemmer(term));
654 break;