mean and standard deviation functions in util
[neural-net.git] / src / neural_net / k_means.clj
blob0924815d66c22dd4fd40b39353e3c2a8f27fdb5e
1 (ns neural-net.k-means
2   (:use neural-net.core)
3   (:use neural-net.util)
4   (:use clojure.contrib.math))
6 (defn euclid-dist [a b] (sqrt (reduce + (map (comp #(* % %) -) a b))))
8 (defn k-means-update
9   "A single iteration of k-means clustering."  [centers samples]
10   (reduce        ; reposition each center at the center of its samples
11    (fn [a [center mine]]
12      (assoc a [(mean (map first mine)) (mean (map second mine))] mine))
13    {}
14    (reduce                 ; group each sample with the nearest center
15     (fn [centers sample]
16       (let [closest (first (sort-by (partial euclid-dist sample)
17                                     (keys centers)))]
18         (assoc centers closest (cons sample (get centers closest '())))))
19     (apply hash-map (mapcat (fn [c] [c '()]) centers))
20     samples)))
22 (defn k-means [n samples]
23   (loop [clst (k-means-update (pick samples n) samples)]
24     (let [new-clst (k-means-update (keys clst) (mapcat identity (vals clst)))]
25       (if (= (set (map set (vals new-clst))) (set (map set (vals clst))))
26         (keys clst) (recur new-clst)))))