2 * @brief KMeans clustering API
4 /* Copyright (C) 2016 Richhiey Thomas
6 * This program is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU General Public License as
8 * published by the Free Software Foundation; either version 2 of the
9 * License, or (at your option) any later version.
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
24 #include "xapian/cluster.h"
25 #include "xapian/error.h"
32 // Threshold value for checking convergence in KMeans
33 #define CONVERGENCE_THRESHOLD 0.0000000001
35 /** Maximum number of times KMeans algorithm will iterate
38 #define MAX_ITERS 1000
40 using namespace Xapian
;
43 KMeans::KMeans(unsigned int k_
, unsigned int max_iters_
)
46 LOGCALL_CTOR(API
, "KMeans", k_
| max_iters_
);
47 max_iters
= (max_iters_
== 0) ? MAX_ITERS
: max_iters_
;
49 throw InvalidArgumentError("Number of required clusters should be "
54 KMeans::get_description() const
60 KMeans::set_stopper(const Stopper
* stopper_
)
62 LOGCALL_VOID(API
, "KMeans::set_stopper", stopper_
);
67 KMeans::initialise_clusters(ClusterSet
& cset
, doccount num_of_points
)
69 LOGCALL_VOID(API
, "KMeans::initialise_clusters", cset
| num_of_points
);
70 // Initial centroids are selected by picking points at roughly even
71 // intervals within the MSet. This is cheap and helps pick diverse
72 // elements since the MSet is usually sorted by some sort of key
73 for (unsigned int i
= 0; i
< k
; ++i
) {
74 unsigned int x
= (i
* num_of_points
) / k
;
75 cset
.add_cluster(Cluster(Centroid(points
[x
])));
80 KMeans::initialise_points(const MSet
& source
)
82 LOGCALL_VOID(API
, "KMeans::initialise_points", source
);
83 TermListGroup
tlg(source
, stopper
.get());
84 for (MSetIterator it
= source
.begin(); it
!= source
.end(); ++it
)
85 points
.push_back(Point(tlg
, it
.get_document()));
89 KMeans::cluster(const MSet
& mset
)
91 LOGCALL(API
, ClusterSet
, "KMeans::cluster", mset
);
92 doccount size
= mset
.size();
95 initialise_points(mset
);
97 initialise_clusters(cset
, size
);
98 CosineDistance distance
;
99 vector
<Centroid
> previous_centroids
;
100 for (unsigned int i
= 0; i
< max_iters
; ++i
) {
101 // Assign each point to the cluster corresponding to its
102 // closest cluster centroid
103 cset
.clear_clusters();
104 for (unsigned int j
= 0; j
< size
; ++j
) {
105 double closest_cluster_distance
= numeric_limits
<double>::max();
106 unsigned int closest_cluster
= 0;
107 for (unsigned int c
= 0; c
< k
; ++c
) {
108 const Centroid
& centroid
= cset
[c
].get_centroid();
109 double dist
= distance
.similarity(points
[j
], centroid
);
110 if (closest_cluster_distance
> dist
) {
111 closest_cluster_distance
= dist
;
115 cset
.add_to_cluster(points
[j
], closest_cluster
);
118 // Remember the previous centroids
119 previous_centroids
.clear();
120 for (unsigned int j
= 0; j
< k
; ++j
)
121 previous_centroids
.push_back(cset
[j
].get_centroid());
123 // Recalculate the centroids for current iteration
124 cset
.recalculate_centroids();
126 // Check whether centroids have converged
127 bool has_converged
= true;
128 for (unsigned int j
= 0; j
< k
; ++j
) {
129 const Centroid
& centroid
= cset
[j
].get_centroid();
130 double dist
= distance
.similarity(previous_centroids
[j
], centroid
);
131 // If distance between any two centroids has changed by
132 // more than the threshold, then KMeans hasn't converged
133 if (dist
> CONVERGENCE_THRESHOLD
) {
134 has_converged
= false;
138 // If converged, then break from the loop