gostyle_old, commit old
[gostyle.git] / bayes.pl
blob31d0dae4eb5d91f4d97ed9b45ed579cf92676a1f
1 #!/usr/bin/perl
2 use warnings;
3 use strict;
4 use POSIX;
6 my $dim = $ARGV[0]; # id of characteristic; eg. 0=territory
7 my $deg = $ARGV[1]; # number of class bins
9 our @PCA=qw(PCA1 PCA2 PCA3 PCA4 PCA5 PCA6 PCA7 PCA8 PCA9 PCA10);
10 our (%input, %output);
11 open P, "pca.data" or die "$! - run pca.py?";
12 while (<P>) {
13 chomp;
14 /^(\D+)\s+(\d+)\s+(.*)$/;
15 my $k = $1; $k =~ y/_/ /;
16 $input{$k}->{$PCA[$2-1]} = $3;
18 close P;
19 open P, "python -c 'import data_about_players; data_about_players.questionare_average_raw(data_about_players.Data.questionare_list)'|" or die "$!";
20 while (<P>) {
21 chomp;
22 @_ = split(/,\s*/);
23 my $n = $_[$dim + 2];
24 next unless $input{$_[0]};
25 $output{$_[0]} = POSIX::floor($deg*($n-1)/10);
27 close P;
30 # k-fold cross-validation
31 my $n = keys %input;
32 my @p = keys %input;
33 my $k = 5;
34 my @folds;
35 foreach my $f (0..$k-1) {
36 foreach (1..$n/$k) {
37 push @{$folds[$f]}, splice(@p, rand(scalar @p), 1);
42 my $mse = 0;
44 foreach my $fold (@folds) {
45 #print "--- Fold\n";
47 use AI::NaiveBayes1;
48 my $nb = AI::NaiveBayes1->new;
50 $nb->set_real(@PCA);
52 foreach my $p (keys %input) {
53 next if grep { $p eq $_ } @$fold;
54 # do not add instances with singluar values within this fold
55 next if ((scalar grep { my $pp = $_; (not grep { $pp eq $_ } @$fold) and $output{$p} eq $output{$pp} } keys %output) < 2);
56 $nb->add_instance(attributes=>$input{$p}, label=>'div='.$output{$p});
59 $nb->train;
60 #print $nb->print_model;
62 foreach my $pl (@$fold) {
63 my $p = $nb->predict(attributes=>$input{$pl});
64 my @sp = sort {$p->{$b} <=> $p->{$a}} keys %$p;
65 my $top = $sp[0]; $top =~ s/div=//;
66 # my $top = rand($deg); - for comparing with random classifier
67 # calculate squared error wrt. the original [1,10] scale
68 my $se = (($output{$pl} - $top) * (10 / $deg)) ** 2;
69 $mse += $se;
70 #print join ',', $pl, $se, $output{$pl}, $top, $p->{'div='.$top}, "\n";
74 print $mse/(@folds*$k),"\n";