som gui -- show training points and weights over time in 2D weight space
[neural-net.git] / scripts / radial-basis-exp
blob9ba8db2ebd970baffd8d5a2bd721747cfca0855e
1 #!/usr/bin/env clj-env
2 ;; -*- mode: clojure -*-
4 (ns neural-net.radial-basis-exp
5   (:use neural-net.core)
6   (:use neural-net.util)
7   (:use neural-net.data)
8   (:use neural-net.back-propagation)
9   (:use neural-net.radial-basis)
10   (:use neural-net.k-means)
11   (:use [clojure.contrib math command-line]))
13 (with-command-line *command-line-args*
14   "Run a radial basis experiment outputting generalization data."
15   [[n   "Number of hidden radial basis neurons" 4]
16    [k   "Use k-means to position radial basis neurons" false]
17    [runs "optional maximum number of runs"]
18    [eta "Value to use for eta" 0.00001]]
19   (let [n (if (string? n) (read-string n) n)
20         eta (if (string? eta) (read-string eta) eta)
21         runs (when runs (read-string runs))
22         x1s (map (fn [[[x1 x2] o]] x1) training)
23         x2s (map (fn [[[x1 x2] o]] x2) training)
24         std (sqrt (/ (+ (expt (- (apply max x1s) (apply min x1s)) 2)
25                         (expt (- (apply max x2s) (apply min x2s)) 2))
26                      (sqrt 2)))
27         rb {:std std :phi radial:run :weights [1 1]}
28         rbs (map (partial assoc rb :t)
29                  (if k
30                    (k-means n (map (fn [[in out]] in) training))
31                    (take n (repeatedly (fn [] [(pick x1s) (pick x2s)])))))
32         train-epic (map (fn [[in out]] [(vec (map (fn [el] {:y el}) in)) out])
33                         (map (fn [[in out]] [(run rbs in) out]) training))
34         test-epic (map (fn [[in out]] [(vec (map (fn [el] {:y el}) in)) out])
35                        (map (fn [[in out]] [(run rbs in) out]) testing))
36         bp {:phi back-prop:run
37             :d-phi (fn [_] 1)
38             :learn back-prop:learn
39             :train (fn [n delta] (assoc n :weights
40                                        (vec (map + (n :weights)
41                                                  (map :delta-w delta)))))
42             :eta eta
43             :weights (take n (repeatedly rand-weight))}]
44     (doseq [[key val] [[:dimensions n]
45                        [:eta eta]
46                        [:std std]
47                        [:runs runs]
48                        [:centers (map :t rbs)]
49                        [:starting-weights (bp :weights)]]]
50       (println "#" key val))
51     (loop [net bp count 0 last-train-err 1000 last-test-err 1000 runs runs]
52       (let [train-err (rms-error net train-epic :y)
53             test-err (rms-error net test-epic :y)]
54         (println (format "%S\t%S\t%S" count train-err test-err))
55         (when (and (or (< train-err last-train-err)
56                        (< test-err last-test-err))
57                    (or (not runs) (> runs 0)))
58           (recur (reduce (fn [n [in out]]
59                            (second (train n in nil {:desired out})))
60                          net train-epic)
61                  (inc count)
62                  train-err
63                  test-err
64                  (when runs (dec runs))))))
65     (println "#" :ending-weights (bp :weights))))