gostyle_old, commit old
[gostyle.git] / kohonen_map.py
blob2f52c3bb720bb953041bd8f9b0bb5ffcf5d6a105
1 #!/usr/bin/python
2 import sys
3 from gostyle import *
4 from math import sqrt
5 import random
6 import numpy
8 from data_about_players import Data
10 try:
11 import kohonen
12 from kohonen import DistanceMetric
13 except:
14 print >>sys.stderr, "Could not locate kohonen.py library. Visit http://code.google.com/p/python-kohonen/ to obtain it manually."
15 raise
17 class ComponentEuclideanMetric(DistanceMetric):
18 '''Implements the euclidean distance (L-2 norm).'''
19 def __init__(self, k=None):
20 self.k=k
21 def __call__(self, x, y):
22 d = x - y
23 if self.k:
24 print k
25 return numpy.sqrt((d * d)[self.k])
26 else:
27 return numpy.sqrt(numpy.sum(d * d, axis=-1))
29 if __name__ == '__main__':
30 main_pat_filename = Data.main_pat_filename
31 filename = 'koh_all.data2'
32 num_features = 300
33 size=16
34 players_ignore = [ 'Honinbo Shusaku', 'Kuwahara Shusaku', 'Yasuda Shusaku', 'Go Seigen', 'Cho Tae-hyeon','Rui Naiwei']
35 players_all = Data.players_all
36 players = [ p for p in players_all if p not in players_ignore ]
37 ### Object creating input vector when called
38 print >>sys.stderr, "Creating input vector generator from main pat file:", main_pat_filename
39 i = InputVectorGenerator(main_pat_filename, num_features)
41 # Create list of input vectors
42 input_vectors = []
43 for name in players:
44 input_vectors += [i(Data.pat_files_folder + name)]
46 if len(input_vectors) == 0:
47 print >>sys.stderr, "No vectors."
48 sys.exit()
50 ### PCA example usage
51 # Change this to False, if you do not want to use PCA
52 use_pca = True
53 if use_pca:
54 # Create PCA object, trained on input_vectors
55 print >>sys.stderr, "Running PCA."
56 pca = PCA(input_vectors, reduce=True)
57 # Perform a PCA on input vectors
58 input_vectors = pca.process_list_of_vectors(input_vectors)
59 dim = len(input_vectors[0])
60 #print "dim:",dim
62 lc =[]
63 num_linc=2000
64 for i in xrange(num_linc):
65 if i % (num_linc/100) == 0:
66 print >>sys.stderr, "Generating training set: %d%%\r"%((100*i)/num_linc),
68 num = random.randint(2, 20)
69 coefs = get_random_norm_coefs(num)
70 vecs = [ random.choice(input_vectors) for _ in xrange(num) ]
71 lc.append(linear_combination(vecs, coefs))
73 print >>sys.stderr
75 input_vectors = [numpy.array(vec) for vec in input_vectors]
76 input_vectors_lc = [numpy.array(vec) for vec in lc]
77 total = input_vectors + input_vectors_lc
78 #print len(total)
80 m = kohonen.Map(kohonen.Parameters(dimension=dim,
81 shape=(size,size),
82 learning_rate=kohonen.ExponentialTimeseries(-5e-4, 0.5, 0.2),
83 noise_variance=None))
84 def def_neurs(*args):
85 return random.choice(total)
87 m.reset(def_neurs)
89 print >>sys.stderr, "Training Kohonen net."
90 num_iter = 2000
91 for i in xrange(num_iter):
92 if i % (num_iter/10) == 0:
93 err = sum( [ m.distances(random.choice(total)).min() for _ in xrange(10) ] ) / 10.0
94 print >>sys.stderr, "%2d%% (%4d): error = %5f alpha = %5f"%(100 * i/num_iter,i, err, m._learning_rate.last)
95 if err < 0.08:
96 print >>sys.stderr, "Current error is good enough."
97 break
98 if i > 0 and err > 1.5 and False:
99 print >>sys.stderr, "This error sucks, reset."
100 m.reset()
102 m.learn( random.choice(total) )
104 #im = m.neuron_heatmap()
105 #im.show()
107 orig_m = m._metric
108 for i in xrange(0):
109 m._metric = ComponentEuclideanMetric(i)
110 im2 = m.neuron_heatmap()
111 im2.show()
112 m._metric = orig_m
115 winner_neurons= [ m.winner(input_vector) for input_vector in input_vectors ]
116 #print winner_neurons[0]/16, winner_neurons[0]%16
117 ### Get rid of overlapping labels in the plot by merging names of players represented by the same neuron
118 vecx = []
119 vecy = []
120 for neuron in winner_neurons:
121 vecx.append(neuron/size)
122 vecy.append(neuron%size)
123 trip = zip( vecx, vecy, players)
124 trip.sort()
125 uniq=[]
126 last = list(trip[0])
127 for next in trip[1:]:
128 if next[0] == last[0] and next[1] == last[1]:
129 last[2] += "\\n" + next[2]
130 else:
131 uniq.append(last)
132 last = list(next)
134 f = open(filename, 'w')
135 print >>sys.stderr, "Saving output_vectors to file:", filename
137 for x,y,name in uniq:
138 name_to_print = '_'.join(name.split())
139 print >>f, name_to_print, x, y
141 f.close()
143 print >> sys.stderr, "\nNow plot that in Gnuplot by:"
144 print >> sys.stderr, 'set xrange[0:%d] ; set yrange[0:%d]'%(size,size)
145 print >> sys.stderr, 'set xtics 1 ; set ytics 1'
146 print >> sys.stderr, 'set grid ; set size square'
147 print >> sys.stderr, 'plot "%s" using 2:3:1 with labels font "arial,10" point lt 10 pt 5 left'%(filename,)