Merge branch 'master' of ssh://repo.or.cz/srv/git/gostyle
[gostyle.git] / knn_cross.py
blob68210a79463619c8c79b254dbadeb8466c95fdeb
1 #!/usr/bin/python
2 import sys
3 from gostyle import *
4 from math import sqrt
5 import numpy
7 from data_about_players import Data
9 from knn import KNNOutputVectorGenerator
11 if __name__ == '__main__':
12 main_pat_filename = Data.main_pat_filename
13 filename_play_other = 'knn_other.data'
14 filename_play_ref = 'knn_ref.data'
15 filename_play_ref_orig = 'knn_ref_orig.data'
16 num_features = 300
17 #k = 5
18 player_vector = Data.questionare_total
19 players_ignore = [ "Yi Ch'ang-ho 2004-" ]#, "Fujisawa Hideyuki","Yuki Satoshi", "Otake Hideo", "Yi Ch'ang-ho 2005+","Takao Shinji","Hane Naoki","Kobayashi Koichi" ]
20 players_all = [ p for p in player_vector.keys() if p not in players_ignore ]
22 ### Object creating input vector when called
23 print "Creating input vector generator from main pat file:", main_pat_filename
24 print
25 i = InputVectorGenerator(main_pat_filename, num_features)
27 # Create list of input vectors
28 input_vectors = []
29 for name in players_all:
30 input_vectors += [i(Data.pat_files_folder + name)]
32 #print '"%s"'%(players_all[2],)
33 #print input_vectors[2]
35 if len(input_vectors) == 0:
36 print >>sys.stderr, "No reference vectors."
37 sys.exit()
39 ### PCA example usage
40 # Change this to False, if you do not want to use PCA
41 use_pca = False
42 if use_pca:
43 # Create PCA object, trained on input_vectors
44 print >>sys.stderr, "Running PCA."
45 pca = PCA(input_vectors, reduce=True)
46 # Perform a PCA on input vectors
47 input_vectors = pca.process_list_of_vectors(input_vectors)
48 # Creates a Composed object that first generates an input vector
49 # and then performs a PCA analysis on it.
50 i = Compose(i, pca)
52 ### n/4-fold cross validation
53 #bounds = random.sample(range(1,len(players_all)), len(players_all) / 10 )
54 bounds=[]
55 for x in range(1,len(players_all)/4):
56 bounds += [4*x for _ in [1] if 4*x < len(players_all)]
57 if not bounds:
58 print >>sys.stderr, "Pop too small."
59 sys.exit()
60 bounds.sort()
62 errs=[]
63 sentinel=len(players_all)
64 prev=0
65 for b in bounds+[sentinel]:
66 validation_set = range(prev, b)
67 reference_set = range(0,prev) + range(b,sentinel)
68 print "Reference set :",
69 for pr in range(0, prev):
70 print "R",
71 for pr in validation_set:
72 print "_",
73 for pr in range(b, sentinel):
74 print "R",
75 print
76 prev = b
78 ### Object creating output vector when called;
79 ref_dict = {}
80 for index in reference_set:
81 ref_dict[tuple(input_vectors[index])] = player_vector[players_all[index]]
83 oknn = KNNOutputVectorGenerator(ref_dict, k=5, weight_param=0.799)
85 # Create list of output vectors using weighted kNN algorithm approximating output_vector
86 def rand_vect(k):
87 return list(10*numpy.random.random(k))
88 output_vectors = [ oknn(input_vectors[index]) for index in validation_set ]
89 #output_vectors = [ rand_vect(4) for index in validation_set ]
90 desired_vectors = [ player_vector[players_all[index]] for index in validation_set ]
92 for vec_set,text in [(output_vectors, "Output: "), (desired_vectors, "Desired:")]:
93 print text,
94 for o in vec_set:
95 for x in o:
96 print "%02.3f"%(x,),
97 print "; ",
98 print
100 for o,d in zip(output_vectors, desired_vectors):
101 err = 0.0
102 for x,y in zip(o,d):
103 err += (1.0*x-1.0*y)**2
104 errs += [err]
106 #for e in errs[-4:]:
107 # print "%2.3f"%(e,),
108 print
111 print "Total square err: %2.3f"%( sum(errs),)
112 print "Mean square err: " + u"%2.3f \u00B1 %2.3f "%(numpy.array(errs).mean(),sqrt(numpy.array(errs).var()))
113 print
114 print "Players sorted by mean square error:"
115 p = zip(errs, players_all)
116 p.sort()
117 for err, name in p:
118 print "%2.3f %s"%(err,name)
119 #print "%s"%(name,)
120 sys.exit()