gostyle_old, commit old
[gostyle.git] / cross_val.py
blob4a8f7e1d38ddf4a7a65f01a3aa8113f7cc557456
1 import random
2 import copy
4 def CrossValidation(indices, k):
5 ''' k-fold cross validation.
6 Yields pairs (training set, validation set) '''
8 assert k > 0
9 assert k <= len(indices)
11 num_per_group_def = len(indices) / k
12 num_rest = len(indices) % k
14 # group_nums is a list of number of elements
15 group_nums = [ num_per_group_def ] * k
17 index = 0
18 while num_rest > 0:
19 group_nums[index] += 1
20 num_rest -= 1
21 index = (index + 1) % k
23 for val_index in range(len(group_nums)):
24 train_left_size = sum( [ group_nums[x] for x in xrange(val_index) ] )
25 val_size = group_nums[val_index]
26 train_right_size = sum( [ group_nums[x] for x in xrange(val_index + 1, len(group_nums)) ] )
28 train_set = []
29 validation_set = []
31 train_set += indices[ : train_left_size]
32 index = train_left_size
33 validation_set += indices[index : index + val_size]
34 index += val_size
35 train_set += indices[index : ]
37 yield ( train_set, validation_set )
39 def Shuffled(cross_val):
40 def cross_val_shuff(indices, k):
41 indices_copy = copy.copy(indices)
42 random.shuffle(indices_copy)
43 return cross_val(indices_copy, k)
44 return cross_val_shuff
47 if __name__ == '__main__':
48 for tr, val in CrossValidation(range(24), 5):
49 print "train: ",len(tr), tr
50 print "valid: ",len(val), val
51 print
52 for tr, val in Shuffled(CrossValidation)(range(24), 5):
53 print "train: ",len(tr), tr
54 print "valid: ",len(val), val