13 #include <caffe/caffe.hpp>
14 using namespace caffe
;
22 static shared_ptr
<Net
<float> > net
;
25 using_dcnn(struct board
*b
)
27 return (real_board_size(b
) == 19 && net
);
30 /* Make caffe quiet */
32 dcnn_quiet_caffe(int argc
, char *argv
[])
34 if (DEBUGL(7) || getenv("GLOG_minloglevel"))
37 setenv("GLOG_minloglevel", "2", 1);
38 execvp(argv
[0], argv
); /* Sucks that we have to do this */
48 const char *model_file
= "golast19.prototxt";
49 const char *trained_file
= "golast.trained";
50 if (stat(model_file
, &s
) != 0 || stat(trained_file
, &s
) != 0) {
52 fprintf(stderr
, "No dcnn files found, will not use dcnn code.\n");
56 Caffe::set_mode(Caffe::CPU
);
58 /* Load the network. */
59 net
.reset(new Net
<float>(model_file
, TEST
));
60 net
->CopyTrainedLayersFrom(trained_file
);
63 fprintf(stderr
, "Initialized dcnn.\n");
67 dcnn_get_moves(struct board
*b
, enum stone color
, float result
[])
69 assert(real_board_size(b
) == 19);
72 int dsize
= 13 * size
* size
;
73 float *data
= new float[dsize
];
74 for (int i
= 0; i
< dsize
; i
++)
77 for (int j
= 0; j
< size
; j
++) {
78 for(int k
= 0; k
< size
; k
++) {
80 coord_t c
= coord_xy(b
, j
+1, k
+1);
81 group_t g
= group_at(b
, c
);
82 enum stone bc
= board_at(b
, c
);
83 int libs
= board_group_info(b
, g
).libs
- 1;
84 if (libs
> 3) libs
= 3;
86 data
[8*size
*size
+ p
] = 1.0;
88 data
[(0+libs
)*size
*size
+ p
] = 1.0;
89 else if (bc
== stone_other(color
))
90 data
[(4+libs
)*size
*size
+ p
] = 1.0;
92 if (c
== b
->last_move
.coord
)
93 data
[9*size
*size
+ p
] = 1.0;
94 else if (c
== b
->last_move2
.coord
)
95 data
[10*size
*size
+ p
] = 1.0;
96 else if (c
== b
->last_move3
.coord
)
97 data
[11*size
*size
+ p
] = 1.0;
98 else if (c
== b
->last_move4
.coord
)
99 data
[12*size
*size
+ p
] = 1.0;
104 Blob
<float> *blob
= new Blob
<float>(1,13,size
,size
);
105 blob
->set_cpu_data(data
);
106 vector
<Blob
<float>*> bottom
;
107 bottom
.push_back(blob
);
109 const vector
<Blob
<float>*>& rr
= net
->Forward(bottom
);
111 for (int i
= 0; i
< size
* size
; i
++) {
112 result
[i
] = rr
[0]->cpu_data()[i
];
113 if (result
[i
] < 0.00001)