1 #!/usr/bin/python
2 import sys
3 from gostyle import *
4 from math import sqrt
5 import random
6 import numpy
10 try:
11 import kohonen
12 from kohonen import DistanceMetric
13 except:
14 print >>sys.stderr, "Could not locate kohonen.py library. Visit http://code.google.com/p/python-kohonen/ to obtain it manually."
15 raise
17 class ComponentEuclideanMetric(DistanceMetric):
18 '''Implements the euclidean distance (L-2 norm).'''
19 def __init__(self, k=None):
20 self.k=k
21 def __call__(self, x, y):
22 d = x - y
23 if self.k:
24 return numpy.sqrt((d * d)[k])
25 else:
26 return numpy.sqrt(numpy.sum(d * d, axis=-1))
28 if __name__ == '__main__':
29 main_pat_filename = Data.main_pat_filename
30 filename = 'koh_all.data'
31 num_features = 300
32 size=16
33 players_ignore = [ 'Honinbo Shusaku', 'Kuwahara Shusaku', 'Yasuda Shusaku', 'Go Seigen', 'Cho Tae-hyeon','Rui Naiwei']
34 players_all = Data.players_all
35 players = [ p for p in players_all if p not in players_ignore ]
36 ### Object creating input vector when called
37 print >>sys.stderr, "Creating input vector generator from main pat file:", main_pat_filename
38 i = InputVectorGenerator(main_pat_filename, num_features)
40 # Create list of input vectors
41 input_vectors = []
42 for name in players:
43 input_vectors += [i(Data.pat_files_folder + name)]
45 if len(input_vectors) == 0:
46 print >>sys.stderr, "No vectors."
47 sys.exit()
49 ### PCA example usage
50 # Change this to False, if you do not want to use PCA
51 use_pca = True
52 if use_pca:
53 # Create PCA object, trained on input_vectors
54 print >>sys.stderr, "Running PCA."
55 pca = PCA(input_vectors, reduce=True)
56 # Perform a PCA on input vectors
57 input_vectors = pca.process_list_of_vectors(input_vectors)
58 dim = len(input_vectors[0])
60 m = kohonen.Map(kohonen.Parameters(dimension=dim,
61 shape=(size,size),
62 learning_rate=kohonen.ExponentialTimeseries(-5e-4, 0.5, 0.2),
63 noise_variance=None))
64 m.reset()
66 lc =[]
67 num_linc=2000
68 for i in xrange(num_linc):
69 if i % (num_linc/100) == 0:
70 print >>sys.stderr, "Generating training set: %d%%\r"%((100*i)/num_linc),
72 num = random.randint(2, 20)
73 coefs = get_random_norm_coefs(num)
74 vecs = [ random.choice(input_vectors) for _ in xrange(num) ]
75 lc.append(linear_combination(vecs, coefs))
77 print >>sys.stderr
79 input_vectors = [numpy.array(vec) for vec in input_vectors]
80 input_vectors_lc = [numpy.array(vec) for vec in lc]
81 total = input_vectors + input_vectors_lc
82 print len(total)
84 print >>sys.stderr, "Training Kohonen net."
85 num_iter = 2000
86 for i in xrange(num_iter):
87 if i % (num_iter/10) == 0:
88 err = sum( [ m.distances(random.choice(total)).min() for _ in xrange(10) ] ) / 10.0
89 print >>sys.stderr, "%2d%% (%4d): error = %5f alpha = %5f"%(100 * i/num_iter,i, err, m._learning_rate.last)
90 if err < 0.2:
91 print >>sys.stderr, "Current error is good enough."
92 break
93 if i > 0 and err > 1.5 and False:
94 print >>sys.stderr, "This error sucks, reset."
95 m.reset()
97 m.learn( random.choice(total) )
99 #im = m.neuron_heatmap()
100 #im.show()
102 orig_m = m._metric
103 m._metric = ComponentEuclideanMetric(0)
104 im2 = m.distance_heatmap(input_vectors[0])
105 im2.show()
106 m._metric = orig_m
109 winner_neurons= [ m.winner(input_vector) for input_vector in input_vectors ]
110 #print winner_neurons[0]/16, winner_neurons[0]%16
111 ### Get rid of overlapping labels in the plot by merging names of players represented by the same neuron
112 vecx = []
113 vecy = []
114 for neuron in winner_neurons:
115 vecx.append(neuron/size)
116 vecy.append(neuron%size)
117 trip = zip( vecx, vecy, players)
118 trip.sort()
119 uniq=[]
120 last = list(trip[0])
121 for next in trip[1:]:
122 if next[0] == last[0] and next[1] == last[1]:
123 last[2] += "\\n" + next[2]
124 else:
125 uniq.append(last)
126 last = list(next)
128 f = open(filename, 'w')
129 print >>sys.stderr, "Saving output_vectors to file:", filename
131 for x,y,name in uniq:
132 name_to_print = '_'.join(name.split())
133 print >>f, name_to_print, x, y
135 f.close()
137 print >> sys.stderr, "\nNow plot that in Gnuplot by:"
138 print >> sys.stderr, 'set xrange[0:%d] ; set yrange[0:%d]'%(size,size)
139 print >> sys.stderr, 'set xtics 1 ; set ytics 1'
140 print >> sys.stderr, 'set grid ; set size square'
141 print >> sys.stderr, 'plot "%s" using 2:3:1 with labels font "arial,10" point lt 10 pt 5 left'%(filename,)