PCA: Print also projection information on the output
[gostyle.git] / knn.py
blob8c2e923bbe4ffe9a722d1b8ff8fd3813dd93dbe6
1 #!/usr/bin/python
2 import sys
3 from gostyle import *
4 from math import sqrt
6 from data_about_players import Data
8 class KNNOutputVectorGenerator(VectorGenerator):
9 """ k-NearestNeighbour output vector generator."""
10 def __init__(self, ref_dict, k=2):
11 """
12 ref_dict is a dictionary of refence input/output vectors.
13 e.g. ref_dict= { (1.0,2.0):(9.0,16.0,21.0)
14 """
15 self.ref_dict = ref_dict
16 self.k = k
17 def __call__(self, player_vector):
18 distance=[]
19 for ref_vec in ref_dict.keys():
20 distance.append((self.distance(ref_vec, player_vector), ref_vec))
21 distance.sort()
23 #print "DBG :"
24 #print distance
26 ref_output_vecs = [ self.ref_dict[b] for a,b in distance[:self.k] ]
27 coefs = [ self.weight_fc(a) for a,b in distance[:self.k] ]
29 return linear_combination(ref_output_vecs, coefs)
30 def weight_fc(self, distance):
31 return 0.2 ** distance
32 def distance(self, vec1, vec2):
33 if len(vec1) != len(vec2):
34 raise RuntimeError("Dimensions of vectors mismatch.")
35 return sqrt(sum([ (float(a) - float(b))**2 for a,b in zip(vec1,vec2)]))
38 if __name__ == '__main__':
39 main_pat_filename = Data.main_pat_filename
40 filename_play_other = 'knn_other.data'
41 filename_play_ref = 'knn_ref.data'
42 filename_play_ref_orig = 'knn_ref_orig.data'
43 num_features = 300
44 players_all = Data.players_all
45 players_ref = Data.player_vector.keys()
46 players_other = [ x for x in players_all if x not in players_ref ]
48 ### Object creating input vector when called
49 print >>sys.stderr, "Creating input vector generator from main pat file:", main_pat_filename
50 i = InputVectorGenerator(main_pat_filename, num_features)
52 # Create list of input vectors
53 input_vectors_ref = []
54 for name in players_ref:
55 input_vectors_ref += [i(Data.pat_files_folder + name)]
56 input_vectors_other = []
57 for name in players_other:
58 input_vectors_other += [i(Data.pat_files_folder + name)]
60 if len(input_vectors_ref) == 0:
61 print >>sys.stderr, "No reference vectors."
62 sys.exit()
63 if len(input_vectors_other) == 0:
64 print >>sys.stderr, "No vectors to process."
65 sys.exit()
67 ### PCA example usage
68 # Change this to False, if you do not want to use PCA
69 use_pca = False
70 if use_pca:
71 # Create PCA object, trained on input_vectors
72 print >>sys.stderr, "Running PCA."
73 pca = PCA(input_vectors_ref + input_vectors_other, reduce=True)
74 # Perform a PCA on input vectors
75 input_vectors_ref = pca.process_list_of_vectors(input_vectors_ref)
76 input_vectors_other = pca.process_list_of_vectors(input_vectors_other)
77 # Creates a Composed object that first generates an input vector
78 # and then performs a PCA analysis on it.
79 i = Compose(i, pca)
81 ### Object creating output vector when called;
82 ref_dict = {}
83 for name, input_vector in zip(players_ref, input_vectors_ref):
84 ref_dict[tuple(input_vector)] = Data.player_vector[name]
86 oknn = KNNOutputVectorGenerator(ref_dict, k=5)
89 # Create list of output vectors using weighted kNN algorithm approximating output_vector
90 output_vectors_other = [ oknn(input_vector) for input_vector in input_vectors_other ]
91 output_vectors_ref = [ oknn(input_vector) for input_vector in input_vectors_ref ]
93 def print_me( names, vecs, where):
94 if len(names) != len(vecs):
95 raise RuntimeError("Dimensions of vectors mismatch.")
97 f = open(where, 'w')
98 print >>sys.stderr, "Saving output_vectors to file:", where
100 for i in xrange(len(names)):
101 name_to_print = '_'.join(names[i].split())
102 print_vector([name_to_print] + list(vecs[i]), f)
104 f.close()
106 print_me(players_ref, output_vectors_ref, filename_play_ref)
107 print_me(players_other, output_vectors_other, filename_play_other)
109 f = open(filename_play_ref_orig, 'w')
110 print >>sys.stderr, "Saving output_vectors to file:", filename_play_ref_orig
111 for name, vec in Data.player_vector.items():
112 name_to_print = '_'.join(name.split())
113 print_vector([name_to_print]+vec, f)
115 f.close()