4 def CrossValidation(indices
, k
):
5 ''' k-fold cross validation.
6 Yields pairs (training set, validation set) '''
9 assert k
<= len(indices
)
11 num_per_group_def
= len(indices
) / k
12 num_rest
= len(indices
) % k
14 # group_nums is a list of number of elements
15 group_nums
= [ num_per_group_def
] * k
19 group_nums
[index
] += 1
21 index
= (index
+ 1) % k
23 for val_index
in range(len(group_nums
)):
24 train_left_size
= sum( [ group_nums
[x
] for x
in xrange(val_index
) ] )
25 val_size
= group_nums
[val_index
]
26 train_right_size
= sum( [ group_nums
[x
] for x
in xrange(val_index
+ 1, len(group_nums
)) ] )
31 train_set
+= indices
[ : train_left_size
]
32 index
= train_left_size
33 validation_set
+= indices
[index
: index
+ val_size
]
35 train_set
+= indices
[index
: ]
37 yield ( train_set
, validation_set
)
39 def Shuffled(cross_val
):
40 def cross_val_shuff(indices
, k
):
41 indices_copy
= copy
.copy(indices
)
42 random
.shuffle(indices_copy
)
43 return cross_val(indices_copy
, k
)
44 return cross_val_shuff
47 if __name__
== '__main__':
48 for tr
, val
in CrossValidation(range(24), 5):
49 print "train: ",len(tr
), tr
50 print "valid: ",len(val
), val
52 for tr
, val
in Shuffled(CrossValidation
)(range(24), 5):
53 print "train: ",len(tr
), tr
54 print "valid: ",len(val
), val