quest: Fix formatting of --help output
[xapian.git] / xapian-core / cluster / kmeans.cc
blob49c376f4644d3359cc86f443cce31a81d0d175f8
1 /** @file
2 * @brief KMeans clustering API
3 */
4 /* Copyright (C) 2016 Richhiey Thomas
6 * This program is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU General Public License as
8 * published by the Free Software Foundation; either version 2 of the
9 * License, or (at your option) any later version.
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
19 * USA
22 #include <config.h>
24 #include "xapian/cluster.h"
25 #include "xapian/error.h"
27 #include "debuglog.h"
29 #include <limits>
30 #include <vector>
32 // Threshold value for checking convergence in KMeans
33 #define CONVERGENCE_THRESHOLD 0.0000000001
35 /** Maximum number of times KMeans algorithm will iterate
36 * till it converges
38 #define MAX_ITERS 1000
40 using namespace Xapian;
41 using namespace std;
43 KMeans::KMeans(unsigned int k_, unsigned int max_iters_)
44 : k(k_)
46 LOGCALL_CTOR(API, "KMeans", k_ | max_iters_);
47 max_iters = (max_iters_ == 0) ? MAX_ITERS : max_iters_;
48 if (k_ == 0)
49 throw InvalidArgumentError("Number of required clusters should be "
50 "greater than zero");
53 string
54 KMeans::get_description() const
56 return "KMeans()";
59 void
60 KMeans::set_stopper(const Stopper* stopper_)
62 LOGCALL_VOID(API, "KMeans::set_stopper", stopper_);
63 stopper = stopper_;
66 void
67 KMeans::initialise_clusters(ClusterSet& cset, doccount num_of_points)
69 LOGCALL_VOID(API, "KMeans::initialise_clusters", cset | num_of_points);
70 // Initial centroids are selected by picking points at roughly even
71 // intervals within the MSet. This is cheap and helps pick diverse
72 // elements since the MSet is usually sorted by some sort of key
73 for (unsigned int i = 0; i < k; ++i) {
74 unsigned int x = (i * num_of_points) / k;
75 cset.add_cluster(Cluster(Centroid(points[x])));
79 void
80 KMeans::initialise_points(const MSet& source)
82 LOGCALL_VOID(API, "KMeans::initialise_points", source);
83 TermListGroup tlg(source, stopper.get());
84 for (MSetIterator it = source.begin(); it != source.end(); ++it)
85 points.push_back(Point(tlg, it.get_document()));
88 ClusterSet
89 KMeans::cluster(const MSet& mset)
91 LOGCALL(API, ClusterSet, "KMeans::cluster", mset);
92 doccount size = mset.size();
93 if (k >= size)
94 k = size;
95 initialise_points(mset);
96 ClusterSet cset;
97 initialise_clusters(cset, size);
98 CosineDistance distance;
99 vector<Centroid> previous_centroids;
100 for (unsigned int i = 0; i < max_iters; ++i) {
101 // Assign each point to the cluster corresponding to its
102 // closest cluster centroid
103 cset.clear_clusters();
104 for (unsigned int j = 0; j < size; ++j) {
105 double closest_cluster_distance = numeric_limits<double>::max();
106 unsigned int closest_cluster = 0;
107 for (unsigned int c = 0; c < k; ++c) {
108 const Centroid& centroid = cset[c].get_centroid();
109 double dist = distance.similarity(points[j], centroid);
110 if (closest_cluster_distance > dist) {
111 closest_cluster_distance = dist;
112 closest_cluster = c;
115 cset.add_to_cluster(points[j], closest_cluster);
118 // Remember the previous centroids
119 previous_centroids.clear();
120 for (unsigned int j = 0; j < k; ++j)
121 previous_centroids.push_back(cset[j].get_centroid());
123 // Recalculate the centroids for current iteration
124 cset.recalculate_centroids();
126 // Check whether centroids have converged
127 bool has_converged = true;
128 for (unsigned int j = 0; j < k; ++j) {
129 const Centroid& centroid = cset[j].get_centroid();
130 double dist = distance.similarity(previous_centroids[j], centroid);
131 // If distance between any two centroids has changed by
132 // more than the threshold, then KMeans hasn't converged
133 if (dist > CONVERGENCE_THRESHOLD) {
134 has_converged = false;
135 break;
138 // If converged, then break from the loop
139 if (has_converged)
140 break;
142 return cset;