tex: Strength Estimation Analysis introduction
[gostyle.git] / knn.py
blob2390fc6420e13624306af79c1cccf7a60ef58a07
1 #!/usr/bin/python
2 import sys
3 from gostyle import *
4 from math import sqrt
6 from data_about_players import Data
8 class KNNOutputVectorGenerator(VectorGenerator):
9 """ k-NearestNeighbour output vector generator."""
10 def __init__(self, ref_dict, k=5, weight_param=0.8, dist_mult=10):
11 """
12 ref_dict is a dictionary of refence input/output vectors.
13 e.g. ref_dict= { (1.0,2.0):(9.0,16.0,21.0)
14 """
15 self.ref_dict = ref_dict
16 self.k = k
17 self.weigth_param = weight_param
18 self.dist_mult = dist_mult
19 def __call__(self, player_vector):
20 distance=[]
21 for ref_vec in self.ref_dict.keys():
22 distance.append((self.distance(ref_vec, player_vector), ref_vec))
23 distance.sort()
25 #for p,v in distance:
26 # print "%2.3f"%(float(p),),
27 #print
28 ref_output_vecs = [ self.ref_dict[b] for a,b in distance[:self.k] ]
29 coefs = [ self.weight_fc(a) for a,b in distance[:self.k] ]
31 return linear_combination(ref_output_vecs, coefs)
32 def weight_fc(self, distance):
33 return self.weigth_param ** (distance)
34 def distance(self, vec1, vec2):
35 if len(vec1) != len(vec2):
36 raise RuntimeError("Dimensions of vectors mismatch.")
37 ### the 10* multiplicative constant is empirically determined for correct scaling
38 return self.dist_mult * sqrt(sum([ (float(a) - float(b))**2 for a,b in zip(vec1,vec2)]))
41 if __name__ == '__main__':
42 main_pat_filename = Data.main_pat_filename
43 filename_play_other = 'knn_other.data'
44 filename_play_ref = 'knn_ref.data'
45 filename_play_ref_orig = 'knn_ref_orig.data'
46 num_features = 300
47 k = 4
48 player_vector = Data.questionare_total
49 players_ignore = [ "Yi Ch'ang-ho 2004-", "Yi Ch'ang-ho"] #,"Takao Shinji","Hane Naoki","Kobayashi Koichi" ]
51 players_all = [ p for p in Data.players_all if p not in players_ignore ]
52 players_ref = [ p for p in player_vector if p not in players_ignore ]
53 players_other = [ x for x in players_all if x not in players_ref ]
55 ### Object creating input vector when called
56 print >>sys.stderr, "Creating input vector generator from main pat file:", main_pat_filename
57 i = InputVectorGenerator(main_pat_filename, num_features)
59 # Create list of input vectors
60 input_vectors_ref = []
61 for name in players_ref:
62 input_vectors_ref += [i(Data.pat_files_folder + name)]
63 input_vectors_other = []
64 for name in players_other:
65 input_vectors_other += [i(Data.pat_files_folder + name)]
67 if len(input_vectors_ref) == 0:
68 print >>sys.stderr, "No reference vectors."
69 sys.exit()
70 if len(input_vectors_other) == 0:
71 print >>sys.stderr, "No vectors to process."
72 sys.exit()
74 ### PCA example usage
75 # Change this to False, if you do not want to use PCA
76 use_pca = False
77 if use_pca:
78 # Create PCA object, trained on input_vectors
79 print >>sys.stderr, "Running PCA."
80 pca = PCA(input_vectors_ref + input_vectors_other, reduce=True)
81 # Perform a PCA on input vectors
82 input_vectors_ref = pca.process_list_of_vectors(input_vectors_ref)
83 input_vectors_other = pca.process_list_of_vectors(input_vectors_other)
84 # Creates a Composed object that first generates an input vector
85 # and then performs a PCA analysis on it.
86 i = Compose(i, pca)
88 ### Object creating output vector when called;
89 ref_dict = {}
90 for name, input_vector in zip(players_ref, input_vectors_ref):
91 ref_dict[tuple(input_vector)] = player_vector[name]
93 oknn = KNNOutputVectorGenerator(ref_dict, k=k)
95 # Create list of output vectors using weighted kNN algorithm approximating output_vector
96 output_vectors_other = [ oknn(input_vector) for input_vector in input_vectors_other ]
97 output_vectors_ref = [ oknn(input_vector) for input_vector in input_vectors_ref ]
99 def print_me(names, vecs, where):
100 if len(names) != len(vecs):
101 raise RuntimeError("Dimensions of vectors mismatch.")
103 print >>sys.stderr, "Saving output_vectors to file:", where
104 f = open(where, 'w')
105 for i in xrange(len(names)):
106 name_to_print = '_'.join(names[i].split())
107 print_vector([name_to_print] + list(vecs[i]), f)
108 f.close()
110 print_me(players_ref, [player_vector[name] for name in players_ref], filename_play_ref_orig)
111 print_me(players_ref, output_vectors_ref, filename_play_ref)
112 print_me(players_other, output_vectors_other, filename_play_other)
114 print >> sys.stderr, "\nNow plot that in Gnuplot by:"
115 #print >> sys.stderr, 'set xrange[0:%d] ; set yrange[0:%d]'%(size,size)
116 print >> sys.stderr, 'set xtics 1 ; set ytics 1'
117 print >> sys.stderr, 'set grid ; set size square'
118 print >> sys.stderr, 'plot "%s" using 2:3:1 with labels font "arial,11" point lt 10 pt 4 left, "%s" using 2:3:1 with labels font "arial,11" point lt 12 pt 4 left'%(filename_play_other, filename_play_ref)