seems to be working, if slowly...
[neural-net.git] / neural_net / backprop.clj
blob63ab0711a9f75279ce38aa4d8937119570d98195
1 (ns neural-net.backprop
2   (:use neural-net.core)
3   (:use [clojure.contrib math]))
5 (defn make-neuron "Make a simple neuron"
6   ([] (make-neuron sigmoid)) ([phi] (ref {:phi phi})))
8 (defn initial-weight "A random initial weight." [] (- (rand 0.4) 0.2))
10 (defn make-network
11   "Create a fully connected feed-forward network of neurons.  Layers
12 should specify the number of neurons in each layer."
13   [layers]
14   (let [net (cons
15              ;; input neurons
16              (map (fn [_] (ref {:y nil})) (range (first layers)))
17              ;; real neurons
18              (map (comp (partial map (fn [_] (make-neuron))) range) (rest layers)))]
19     (dorun
20      (map
21       (fn [in cur out]
22         (doseq [n cur]
23           (dosync (ref-set n
24                            (assoc @n
25                              ;; hash of weights and inputs
26                              :dendrite
27                              (when (not (empty? in))
28                                (cons
29                                 ;; bias neuron
30                                 {:n (ref {:y 1}) :w (initial-weight)}
31                                 ;; input neurons
32                                 (map #(hash-map :n % :w (initial-weight)) in)))
33                              :axon out)))))
34       (cons '() (butlast net)) net (concat (rest net) (list '()))))
35     net))
37 (defn reset-network
38   "Reset the activation of the neurons in a network." [net]
39   (doseq [row net]
40     (doseq [n row]
41       (dosync (ref-set n (dissoc @n :v :y :phi-p :e :grad)))))
42   net)
44 (defn set-inputs "Set the input values of a network." [net inputs]
45   (dorun (map (fn [n i] (dosync (ref-set n (assoc @n :y i)))) (first net) inputs))
46   net)
48 (defn set-outputs "Set the output values of a network." [net outputs]
49   (dorun
50    (map (fn [n d] (dosync (ref-set n (assoc @n :d d)))) (last net) outputs))
51   net)
53 (defn eval-neuron "Calculate v phi-p and y for the neuron." [n]
54   (if (@n :y)
55     n
56     (let [phi (@n :phi)
57           v (reduce + (map (fn [d] (* (d :w) (@(eval-neuron (d :n)) :y)))
58                            (@n :dendrite)))
59           y (phi v) phi-p (phi v 'prime)]
60       (dosync (ref-set n (assoc @n :v v :y y :phi-p phi-p))) n)))
62 (defn eval-network "Evaluate an entire neuroal network." [net]
63   (doseq [row net] (doseq [n row] (eval-neuron n)))
64   net)
66 (defn back-prop-neuron
67   "Calculate the weight change for a neruon." [n eta]
68   ;; when not an input neuron and already back-prop'd
69   (when (not (or (not (@n :phi))
70                  (reduce #(and %1 %2)
71                          (cons (@n :grad) (map (comp :delta-w deref :n)
72                                                (@n :dendrite))))))
73     (eval-neuron n)
74     (cond
75      ;; output neuron
76      (@n :d)
77      (let [e (- (@n :d) (@n :y))
78            grad (* e (@n :phi-p))
79            dendrite (map (fn [d] (assoc d :delta-w (* eta grad (@(d :n) :y))))
80                          (@n :dendrite))]
81        (dosync (ref-set n (assoc @n :e e :grad grad :dendrite dendrite))))
82      ;; hidden neuron -- input neurons have no dendrites or phi
83      (and (@n :dendrite) (@n :phi))
84      (let [grad (* (@n :phi-p)
85                    (reduce +
86                            (map (fn [a]
87                                   (back-prop-neuron a eta)
88                                   (* (@a :grad)
89                                      ((first
90                                        (filter (fn [d] (= n (d :n)))
91                                                (@a :dendrite))) :w)))
92                                 (@n :axon))))
93            dendrite (map (fn [d] (assoc d :delta-w (* eta grad (@(d :n) :y))))
94                          (@n :dendrite))]
95        (dosync (ref-set n (assoc @n :grad grad :dendrite dendrite))))))
96   n)
98 (defn back-prop-network "Back-propagate an entire neuroal network." [net eta]
99   (doseq [row (reverse net)] (doseq [n row] (back-prop-neuron n eta)))
100   net)
102 (defn apply-delta-weights "Apply delta weights across a network." [net]
103   (doseq [row net]
104     (doseq [n row]
105       (dosync (ref-set n (assoc @n
106                            :dendrite
107                            (map
108                             (fn [d] (assoc d :w (+ (d :delta-w) (d :w))))
109                             (@n :dendrite))))))))
111 (defn train "Train a network on an epic." [net eta epic]
112   (doseq [train-pt epic]
113     (reset-network net)
114     ;; clamp down inputs and outputs
115     (set-inputs net (first train-pt))
116     (set-outputs net (second train-pt))
117     ;; evaluate network
118     (eval-network net)
119     ;; back-prop and re-weight
120     (back-prop-network net eta)
121     (apply-delta-weights net)))
123 (defn epic-error
124   "Run an epic without re-weighting and report the error." [net epic]
125   (sqrt
126    (/ (reduce + (map
127                  (fn [train-pt]
128                    ;; clean out the network
129                    (reset-network net)
130                    ;; clamp down inputs and outputs
131                    (set-inputs net (first train-pt))
132                    (set-outputs net (second train-pt))
133                    ;; evaluate network
134                    (eval-network net)
135                    ;; return the error
136                    (reduce + (map (fn [n] (expt (- (@n :d) (@n :y)) 2))
137                                   (last net))))
138                  epic))
139       (* (count epic) (count (last net))))))
141 (defn classification-error
142   "Return the percent of trials in the epic classified correctly." [net epic]
143   )
145 ;; network inspection
146 (defn weights "Return the weights of each neuron in the network." [net]
147   (map (partial map (comp (partial map :w) :dendrite deref)) net))
149 (defn delta-weights
150   "Return the weight delta  of each neuron in the network." [net]
151   (map (partial map (comp (partial map :delta-w) :dendrite deref)) net))