t-unit: display stats at the end
[pachi.git] / dcnn.cpp
blob8919f82e49d910f72533ecd87928ae9eeba8c6a0
1 #define DEBUG
2 #include <assert.h>
3 #include <ctype.h>
4 #include <math.h>
5 #include <stdarg.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <string.h>
9 #include <unistd.h>
10 #include <sys/stat.h>
12 #define CPU_ONLY 1
13 #include <caffe/caffe.hpp>
14 using namespace caffe;
16 extern "C" {
17 #include "debug.h"
18 #include "board.h"
19 #include "dcnn.h"
22 static shared_ptr<Net<float> > net;
24 bool
25 using_dcnn(struct board *b)
27 return (real_board_size(b) == 19 && net);
30 /* Make caffe quiet */
31 void
32 dcnn_quiet_caffe(int argc, char *argv[])
34 if (DEBUGL(7) || getenv("GLOG_minloglevel"))
35 return;
37 setenv("GLOG_minloglevel", "2", 1);
38 execvp(argv[0], argv); /* Sucks that we have to do this */
41 void
42 dcnn_init()
44 if (net)
45 return;
47 struct stat s;
48 const char *model_file = "golast19.prototxt";
49 const char *trained_file = "golast.trained";
50 if (stat(model_file, &s) != 0 || stat(trained_file, &s) != 0) {
51 if (DEBUGL(1))
52 fprintf(stderr, "No dcnn files found, will not use dcnn code.\n");
53 return;
56 Caffe::set_mode(Caffe::CPU);
58 /* Load the network. */
59 net.reset(new Net<float>(model_file, TEST));
60 net->CopyTrainedLayersFrom(trained_file);
62 if (DEBUGL(1))
63 fprintf(stderr, "Initialized dcnn.\n");
66 void
67 dcnn_get_moves(struct board *b, enum stone color, float result[])
69 assert(real_board_size(b) == 19);
71 int size = 19;
72 int dsize = 13 * size * size;
73 float *data = new float[dsize];
74 for (int i = 0; i < dsize; i++)
75 data[i] = 0.0;
77 for (int j = 0; j < size; j++) {
78 for(int k = 0; k < size; k++) {
79 int p = size * j + k;
80 coord_t c = coord_xy(b, j+1, k+1);
81 group_t g = group_at(b, c);
82 enum stone bc = board_at(b, c);
83 int libs = board_group_info(b, g).libs - 1;
84 if (libs > 3) libs = 3;
85 if (bc == S_NONE)
86 data[8*size*size + p] = 1.0;
87 else if (bc == color)
88 data[(0+libs)*size*size + p] = 1.0;
89 else if (bc == stone_other(color))
90 data[(4+libs)*size*size + p] = 1.0;
92 if (c == b->last_move.coord)
93 data[9*size*size + p] = 1.0;
94 else if (c == b->last_move2.coord)
95 data[10*size*size + p] = 1.0;
96 else if (c == b->last_move3.coord)
97 data[11*size*size + p] = 1.0;
98 else if (c == b->last_move4.coord)
99 data[12*size*size + p] = 1.0;
104 Blob<float> *blob = new Blob<float>(1,13,size,size);
105 blob->set_cpu_data(data);
106 vector<Blob<float>*> bottom;
107 bottom.push_back(blob);
108 assert(net);
109 const vector<Blob<float>*>& rr = net->Forward(bottom);
111 for (int i = 0; i < size * size; i++) {
112 result[i] = rr[0]->cpu_data()[i];
113 if (result[i] < 0.00001)
114 result[i] = 0.00001;
116 delete[] data;
117 delete blob;
121 } /* extern "C" */