Add new stemming mode STEM_SOME_FULL_POS
[xapian.git] / xapian-core / cluster / cluster.cc
blobfdd6fed40856fcbe68604071ac6ca60f3fd09619
1 /** @file cluster.cc
2 * @brief Cluster API
3 */
4 /* Copyright (C) 2010 Richard Boulton
5 * Copyright (C) 2016 Richhiey Thomas
7 * This program is free software; you can redistribute it and/or
8 * modify it under the terms of the GNU General Public License as
9 * published by the Free Software Foundation; either version 2 of the
10 * License, or (at your option) any later version.
12 * This program is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 * GNU General Public License for more details.
17 * You should have received a copy of the GNU General Public License
18 * along with this program; if not, write to the Free Software
19 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
20 * USA
23 #include <config.h>
25 #include "xapian/cluster.h"
27 #include "cluster/clusterinternal.h"
28 #include "xapian/error.h"
29 #include "api/termlist.h"
30 #include "debuglog.h"
31 #include "omassert.h"
33 #include <cmath>
34 #include <unordered_map>
35 #include <vector>
37 using namespace Xapian;
38 using namespace std;
40 FreqSource::~FreqSource()
42 LOGCALL_DTOR(API, "FreqSource");
45 Similarity::~Similarity()
47 LOGCALL_DTOR(API, "Similarity");
50 Clusterer::~Clusterer()
52 LOGCALL_DTOR(API, "Clusterer");
55 TermListGroup::TermListGroup(const MSet &docs, const Stopper *stopper)
57 LOGCALL_CTOR(API, "TermListGroup", docs | stopper);
58 for (MSetIterator it = docs.begin(); it != docs.end(); ++it)
59 add_document(it.get_document(), stopper);
60 num_of_documents = docs.size();
63 doccount
64 DummyFreqSource::get_termfreq(const string &) const
66 LOGCALL(API, doccount, "DummyFreqSource::get_termfreq", NO_ARGS);
67 return 1;
70 doccount
71 DummyFreqSource::get_doccount() const
73 LOGCALL(API, doccount, "DummyFreqSource::get_doccount", NO_ARGS);
74 return 1;
77 void
78 TermListGroup::add_document(const Document &document, const Stopper *stopper)
80 LOGCALL_VOID(API, "TermListGroup::add_document", document | stopper);
82 TermIterator titer(document.termlist_begin());
84 for (; titer != document.termlist_end(); ++titer) {
85 const string &term = *titer;
87 // Remove stopwords by using the Xapian::Stopper object
88 if (stopper && (*stopper)(term))
89 continue;
91 // Remove unstemmed terms since document vector should
92 // contain only stemmed terms
93 if (term[0] != 'Z')
94 continue;
96 unordered_map<string, doccount>::iterator i;
97 i = termfreq.find(term);
98 if (i == termfreq.end())
99 termfreq[term] = 1;
100 else
101 ++i->second;
105 doccount
106 TermListGroup::get_doccount() const
108 LOGCALL(API, doccount, "TermListGroup::get_doccount", NO_ARGS);
109 return num_of_documents;
112 doccount
113 TermListGroup::get_termfreq(const string &tname) const
115 LOGCALL(API, doccount, "TermListGroup::get_termfreq", tname);
116 unordered_map<string, doccount>::const_iterator it = termfreq.find(tname);
117 if (it != termfreq.end())
118 return it->second;
119 else
120 return 0;
123 DocumentSet::DocumentSet(const DocumentSet &) = default;
125 DocumentSet&
126 DocumentSet::operator=(const DocumentSet &) = default;
128 DocumentSet::DocumentSet(DocumentSet &&) = default;
130 DocumentSet&
131 DocumentSet::operator=(DocumentSet &&) = default;
133 DocumentSet::DocumentSet()
134 : internal(new Xapian::DocumentSet::Internal)
138 doccount
139 DocumentSet::size() const
141 LOGCALL(API, doccount, "DocumentSet::size", NO_ARGS);
142 return internal->size();
145 doccount
146 DocumentSet::Internal::size() const
148 return documents.size();
151 void
152 DocumentSet::add_document(const Document &document)
154 LOGCALL_VOID(API, "DocumentSet::add_document", document);
155 internal->add_document(document);
157 void
158 DocumentSet::Internal::add_document(const Document &document)
160 documents.push_back(document);
163 Document&
164 DocumentSet::operator[](doccount i)
166 return internal->get_document(i);
169 Document&
170 DocumentSet::Internal::get_document(doccount i)
172 return documents[i];
175 const Document&
176 DocumentSet::operator[](doccount i) const
178 return internal->get_document(i);
181 const Document&
182 DocumentSet::Internal::get_document(doccount i) const
184 return documents[i];
187 DocumentSet::~DocumentSet()
189 LOGCALL_DTOR(API, "DocumentSet");
192 class PointTermIterator : public TermIterator::Internal {
193 unordered_map<string, double>::const_iterator i;
194 unordered_map<string, double>::const_iterator end;
195 termcount size;
196 bool started;
197 public:
198 PointTermIterator(const unordered_map<string, double> &termlist)
199 : i(termlist.begin()), end(termlist.end()),
200 size(termlist.size()), started(false)
202 termcount get_approx_size() const { return size; }
203 termcount get_wdf() const { throw UnimplementedError("PointIterator doesn't support get_wdf()"); }
204 string get_termname() const { return i->first; }
205 doccount get_termfreq() const { throw UnimplementedError("PointIterator doesn't support get_termfreq()"); }
206 Internal * next();
207 termcount positionlist_count() const {
208 throw UnimplementedError("PointTermIterator doesn't support positionlist_count()");
210 bool at_end() const;
211 PositionList* positionlist_begin() const {
212 throw UnimplementedError("PointTermIterator doesn't support positionlist_begin()");
214 Internal * skip_to(const string &) {
215 throw UnimplementedError("PointTermIterator doesn't support skip_to()");
219 TermIterator::Internal *
220 PointTermIterator::next()
222 if (!started) {
223 started = true;
224 return NULL;
226 Assert(i != end);
227 ++i;
228 return NULL;
231 bool
232 PointTermIterator::at_end() const
234 if (!started) return false;
235 return i == end;
238 TermIterator
239 PointType::termlist_begin() const
241 LOGCALL(API, TermIterator, "PointType::termlist_begin", NO_ARGS);
242 return TermIterator(new PointTermIterator(weights));
245 bool
246 PointType::contains(const string &term) const
248 LOGCALL(API, bool, "PointType::contains", term);
249 return weights.find(term) != weights.end();
252 double
253 PointType::get_weight(const string &term) const
255 LOGCALL(API, double, "PointType::get_weight", term);
256 unordered_map<string, double>::const_iterator it = weights.find(term);
257 return (it == weights.end()) ? 0.0 : it->second;
260 double
261 PointType::get_magnitude() const {
262 LOGCALL(API, double, "PointType::get_magnitude", NO_ARGS);
263 return magnitude;
266 void
267 PointType::add_weight(const string &term, double weight)
269 LOGCALL_VOID(API, "PointType::add_weight", term | weight);
270 unordered_map<string, double>::iterator it;
271 it = weights.find(term);
272 if (it != weights.end())
273 it->second += weight;
274 else
275 weights[term] = weight;
278 void
279 PointType::set_weight(const string &term, double weight)
281 LOGCALL_VOID(API, "PointType::set_weight", term | weight);
282 weights[term] = weight;
285 termcount
286 PointType::termlist_size() const
288 LOGCALL(API, termcount, "PointType::termlist_size", NO_ARGS);
289 return weights.size();
292 Document
293 Point::get_document() const
295 LOGCALL(API, Document, "Point::get_document", NO_ARGS);
296 return document;
299 Point::Point(const TermListGroup &tlg, const Document &document_)
301 LOGCALL_CTOR(API, "Point::initialize", tlg | document_);
302 doccount size = tlg.get_doccount();
303 document = document_;
304 for (TermIterator it = document.termlist_begin(); it != document.termlist_end(); ++it) {
305 doccount wdf = it.get_wdf();
306 string term = *it;
307 double termfreq = tlg.get_termfreq(term);
309 // If the term exists in only one document, or if it exists in
310 // every document within the MSet, or if it is a filter term, then
311 // these terms are not used for document vector calculations
312 if (wdf < 1 || termfreq <= 1 || size == termfreq)
313 continue;
315 double tf = 1 + log((double)wdf);
316 double idf = log(size / termfreq);
317 double wt = tf * idf;
319 weights[term] = wt;
320 magnitude += wt * wt;
324 Centroid::Centroid(const Point &point) {
325 LOGCALL_CTOR(API, "Centroid", point);
326 for (TermIterator it = point.termlist_begin(); it != point.termlist_end(); ++it)
327 weights[*it] = point.get_weight(*it);
328 magnitude = point.get_magnitude();
331 void
332 Centroid::divide(double cluster_size)
334 LOGCALL_VOID(API, "Centroid::divide", cluster_size);
335 magnitude = 0;
336 unordered_map<string, double>::iterator it;
337 for (it = weights.begin(); it != weights.end(); ++it) {
338 double new_weight = it->second / cluster_size;
339 it->second = new_weight;
340 magnitude += new_weight * new_weight;
344 void
345 Centroid::clear()
347 LOGCALL_VOID(API, "Centroid::clear", NO_ARGS);
348 weights.clear();
351 Cluster&
352 Cluster::operator=(const Cluster &) = default;
354 Cluster::Cluster(const Cluster &) = default;
356 Cluster::Cluster(Cluster &&) = default;
358 Cluster&
359 Cluster::operator=(Cluster &&) = default;
361 Cluster::Cluster() : internal(new Xapian::Cluster::Internal)
363 LOGCALL_CTOR(API, "Cluster", NO_ARGS);
366 Cluster::Cluster(const Centroid &centroid)
367 : internal(new Xapian::Cluster::Internal(centroid))
369 LOGCALL_CTOR(API, "Cluster", centroid);
372 Cluster::~Cluster()
374 LOGCALL_DTOR(API, "Cluster");
377 Centroid::Centroid()
379 LOGCALL_CTOR(API, "Centroid", NO_ARGS);
382 DocumentSet
383 Cluster::get_documents() const
385 LOGCALL(API, DocumentSet, "Cluster::get_documents", NO_ARGS);
386 return internal->get_documents();
389 DocumentSet
390 Cluster::Internal::get_documents() const
392 DocumentSet docs;
393 for (auto&& point : cluster_docs)
394 docs.add_document(point.get_document());
395 return docs;
398 Point&
399 Cluster::operator[](Xapian::doccount i)
401 return internal->get_point(i);
404 Point&
405 Cluster::Internal::get_point(Xapian::doccount i)
407 return cluster_docs[i];
410 const Point&
411 Cluster::operator[](Xapian::doccount i) const
413 return internal->get_point(i);
416 const Point&
417 Cluster::Internal::get_point(Xapian::doccount i) const
419 return cluster_docs[i];
422 ClusterSet&
423 ClusterSet::operator=(const ClusterSet &) = default;
425 ClusterSet::ClusterSet(const ClusterSet &) = default;
427 ClusterSet&
428 ClusterSet::operator=(ClusterSet &&) = default;
430 ClusterSet::ClusterSet(ClusterSet &&) = default;
432 ClusterSet::ClusterSet() : internal(new Xapian::ClusterSet::Internal)
436 ClusterSet::~ClusterSet()
440 doccount
441 ClusterSet::Internal::size() const
443 return clusters.size();
446 doccount
447 ClusterSet::size() const
449 LOGCALL(API, doccount, "ClusterSet::size", NO_ARGS);
450 return internal->size();
453 void
454 ClusterSet::Internal::add_cluster(const Cluster &cluster)
456 clusters.push_back(cluster);
459 void
460 ClusterSet::add_cluster(const Cluster &cluster)
462 LOGCALL_VOID(API, "ClusterSet::add_cluster", cluster);
463 internal->add_cluster(cluster);
466 Cluster&
467 ClusterSet::Internal::get_cluster(doccount i)
469 return clusters[i];
472 Cluster&
473 ClusterSet::operator[](doccount i)
475 return internal->get_cluster(i);
478 const Cluster&
479 ClusterSet::Internal::get_cluster(doccount i) const
481 return clusters[i];
484 const Cluster&
485 ClusterSet::operator[](doccount i) const
487 return internal->get_cluster(i);
490 void
491 ClusterSet::Internal::add_to_cluster(const Point &point, unsigned int index)
493 clusters[index].add_point(point);
496 void
497 ClusterSet::add_to_cluster(const Point &point, unsigned int index)
499 LOGCALL_VOID(API, "ClusterSet::add_to_cluster", point | index);
500 internal->add_to_cluster(point, index);
503 void
504 ClusterSet::Internal::recalculate_centroids()
506 for (auto&& cluster : clusters)
507 cluster.recalculate();
510 void
511 ClusterSet::recalculate_centroids()
513 LOGCALL_VOID(API, "ClusterSet::recalculate_centroids", NO_ARGS);
514 internal->recalculate_centroids();
517 void
518 ClusterSet::clear_clusters()
520 LOGCALL_VOID(API, "ClusterSet::clear_clusters", NO_ARGS);
521 internal->clear_clusters();
524 void
525 ClusterSet::Internal::clear_clusters()
527 for (auto&& cluster : clusters)
528 cluster.clear();
531 doccount
532 Cluster::size() const
534 LOGCALL(API, doccount, "Cluster::size", NO_ARGS);
535 return internal->size();
538 doccount
539 Cluster::Internal::size() const
541 return (cluster_docs.size());
544 void
545 Cluster::add_point(const Point &point)
547 LOGCALL_VOID(API, "Cluster::add_point", point);
548 internal->add_point(point);
551 void
552 Cluster::Internal::add_point(const Point &point)
554 cluster_docs.push_back(point);
557 void
558 Cluster::clear()
560 LOGCALL_VOID(API, "Cluster::clear", NO_ARGS);
561 internal->clear();
564 void
565 Cluster::Internal::clear()
567 cluster_docs.clear();
570 const Centroid&
571 Cluster::get_centroid() const
573 LOGCALL(API, Centroid, "Cluster::get_centroid", NO_ARGS);
574 return internal->get_centroid();
577 const Centroid&
578 Cluster::Internal::get_centroid() const
580 return centroid;
583 void
584 Cluster::set_centroid(const Centroid &centroid_)
586 LOGCALL_VOID(API, "Cluster::set_centroid", centroid_);
587 internal->set_centroid(centroid_);
590 void
591 Cluster::Internal::set_centroid(const Centroid &centroid_)
593 centroid = centroid_;
596 void
597 Cluster::recalculate()
599 LOGCALL_VOID(API, "Cluster::recalculate", NO_ARGS);
600 internal->recalculate();
603 void
604 Cluster::Internal::recalculate()
606 centroid.clear();
607 for (const Point& temp : cluster_docs) {
608 for (TermIterator titer = temp.termlist_begin(); titer != temp.termlist_end(); ++titer)
609 centroid.add_weight(*titer, temp.get_weight(*titer));
611 centroid.divide(size());
614 StemStopper::StemStopper(const Stem &stemmer_, stem_strategy strategy)
615 : stem_action(strategy), stemmer(stemmer_)
617 LOGCALL_CTOR(API, "StemStopper", stemmer_ | strategy);
620 string
621 StemStopper::get_description() const
623 string desc("Xapian::StemStopper(");
624 unordered_set<string>::const_iterator i;
625 for (i = stop_words.begin(); i != stop_words.end(); ++i) {
626 if (i != stop_words.begin()) desc += ' ';
627 desc += *i;
629 desc += ')';
630 return desc;
633 void
634 StemStopper::add(const string &term)
636 LOGCALL_VOID(API, "StemStopper::add", term);
637 switch (stem_action) {
638 case STEM_NONE:
639 stop_words.insert(term);
640 break;
641 case STEM_ALL_Z:
642 stop_words.insert('Z' + stemmer(term));
643 break;
644 case STEM_ALL:
645 stop_words.insert(stemmer(term));
646 break;
647 case STEM_SOME:
648 case STEM_SOME_FULL_POS:
649 stop_words.insert(term);
650 stop_words.insert('Z' + stemmer(term));
651 break;