Merge pull request #28 from lemonsqueeze/dcnn
[pachi.git] / uct / prior.c
blob7261d380754432da20223c5e0291e4aec6e2a172
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #define DEBUG
8 #include "board.h"
9 #include "debug.h"
10 #include "joseki/base.h"
11 #include "move.h"
12 #include "random.h"
13 #include "dcnn.h"
14 #include "tactics/ladder.h"
15 #include "tactics/util.h"
16 #include "uct/internal.h"
17 #include "uct/plugins.h"
18 #include "uct/prior.h"
19 #include "uct/tree.h"
21 /* Applying heuristic values to the tree nodes, skewing the reading in
22 * most interesting directions. */
24 /* TODO: Introduce foreach_fpoint() to iterate only over non-occupied
25 * positions. */
27 struct uct_prior {
28 /* Equivalent experience for prior knowledge. MoGo paper recommends
29 * 50 playouts per source; in practice, esp. with RAVE, about 6
30 * playouts per source seems best. */
31 int eqex;
32 int even_eqex, policy_eqex, b19_eqex, eye_eqex, ko_eqex, plugin_eqex, joseki_eqex, pattern_eqex;
33 int dcnn_eqex;
34 int cfgdn; int *cfgd_eqex;
35 bool prune_ladders;
38 void
39 uct_prior_even(struct uct *u, struct tree_node *node, struct prior_map *map)
41 /* Q_{even} */
42 /* This may be dubious for normal UCB1 but is essential for
43 * reading stability of RAVE, it appears. */
44 add_prior_value(map, pass, 0.5, u->prior->even_eqex);
45 foreach_free_point(map->b) {
46 if (!map->consider[c])
47 continue;
48 add_prior_value(map, c, 0.5, u->prior->even_eqex);
49 } foreach_free_point_end;
52 void
53 uct_prior_eye(struct uct *u, struct tree_node *node, struct prior_map *map)
55 /* Discourage playing into our own eyes. However, we cannot
56 * completely prohibit it:
57 * #######
58 * ...XX.#
59 * XOOOXX#
60 * X.OOOO#
61 * .XXXX.# */
62 foreach_free_point(map->b) {
63 if (!map->consider[c])
64 continue;
65 if (!board_is_one_point_eye(map->b, c, map->to_play))
66 continue;
67 add_prior_value(map, c, 0, u->prior->eye_eqex);
68 } foreach_free_point_end;
71 #ifdef DCNN
73 #define DCNN_BEST_N 5
75 static void
76 find_dcnn_best_moves(struct prior_map *map, float *r, coord_t *best, float *best_r)
78 struct board *b = map->b;
80 for (int i = 0; i < DCNN_BEST_N; i++)
81 best[i] = pass;
83 foreach_free_point(b) {
84 if (!map->consider[c])
85 continue;
87 int k = (coord_x(c, b) - 1) * 19 + (coord_y(c, b) - 1);
88 for (int i = 0; i < DCNN_BEST_N; i++)
89 if (r[k] > best_r[i]) {
90 for (int j = DCNN_BEST_N - 1; j > i; j--) { // shift
91 best_r[j] = best_r[j - 1];
92 best[j] = best[j - 1];
94 best_r[i] = r[k];
95 best[i] = c;
96 break;
98 } foreach_free_point_end;
101 static void
102 print_dcnn_best_moves(struct tree_node *node, struct prior_map *map,
103 coord_t *best, float *best_r)
105 fprintf(stderr, "dcnn best: [ ");
106 for (int i = 0; i < DCNN_BEST_N; i++)
107 fprintf(stderr, "%s ", coord2sstr(best[i], map->b));
108 fprintf(stderr, "] ");
110 fprintf(stderr, "[ ");
111 for (int i = 0; i < DCNN_BEST_N; i++)
112 fprintf(stderr, "%.2f ", best_r[i]);
113 fprintf(stderr, "]\n");
116 static void
117 uct_prior_dcnn(struct uct *u, struct tree_node *node, struct prior_map *map)
119 float r[19 * 19];
120 float best_r[DCNN_BEST_N] = { 0.0, };
121 coord_t best_moves[DCNN_BEST_N];
122 dcnn_get_moves(map->b, map->to_play, r);
123 find_dcnn_best_moves(map, r, best_moves, best_r);
124 print_dcnn_best_moves(node, map, best_moves, best_r);
126 foreach_free_point(map->b) {
127 if (!map->consider[c])
128 continue;
130 int i = coord_x(c, map->b) - 1;
131 int j = coord_y(c, map->b) - 1;
132 assert(i >= 0 && i < 19);
133 assert(j >= 0 && j < 19);
134 float val = r[i * 19 + j];
135 if (isnan(val) || val < 0.001)
136 continue;
137 assert(val >= 0.0 && val <= 1.0);
138 add_prior_value(map, c, 1, sqrt(val) * u->prior->dcnn_eqex);
139 } foreach_free_point_end;
142 #else
143 #define uct_prior_dcnn(u, node, map)
144 #endif /* DCNN */
147 void
148 uct_prior_ko(struct uct *u, struct tree_node *node, struct prior_map *map)
150 /* Favor fighting ko, if we took it le 10 moves ago. */
151 coord_t ko = map->b->last_ko.coord;
152 if (is_pass(ko) || map->b->moves - map->b->last_ko_age > 10 || !map->consider[ko])
153 return;
154 // fprintf(stderr, "prior ko-fight @ %s %s\n", stone2str(map->to_play), coord2sstr(ko, map->b));
155 add_prior_value(map, ko, 1, u->prior->ko_eqex);
158 void
159 uct_prior_b19(struct uct *u, struct tree_node *node, struct prior_map *map)
161 /* Q_{b19} */
162 /* Specific hints for 19x19 board - priors for certain edge distances. */
163 foreach_free_point(map->b) {
164 if (!map->consider[c])
165 continue;
166 int d = coord_edge_distance(c, map->b);
167 if (d != 0 && d != 2)
168 continue;
169 /* The bonus applies only with no stones in immediate
170 * vincinity. */
171 if (board_stone_radar(map->b, c, 2))
172 continue;
173 /* First line: 0 */
174 /* Third line: 1 */
175 add_prior_value(map, c, d == 2, u->prior->b19_eqex);
176 } foreach_free_point_end;
179 void
180 uct_prior_playout(struct uct *u, struct tree_node *node, struct prior_map *map)
182 /* Q_{playout-policy} */
183 if (u->playout->assess)
184 u->playout->assess(u->playout, map, u->prior->policy_eqex);
187 void
188 uct_prior_cfgd(struct uct *u, struct tree_node *node, struct prior_map *map)
190 /* Q_{common_fate_graph_distance} */
191 /* Give bonus to moves local to the last move, where "local" means
192 * local in terms of groups, not just manhattan distance. */
193 if (is_pass(map->b->last_move.coord) || is_resign(map->b->last_move.coord))
194 return;
196 foreach_free_point(map->b) {
197 if (!map->consider[c])
198 continue;
199 if (map->distances[c] > u->prior->cfgdn)
200 continue;
201 assert(map->distances[c] != 0);
202 int bonus = u->prior->cfgd_eqex[map->distances[c]];
203 add_prior_value(map, c, 1, bonus);
204 } foreach_free_point_end;
207 void
208 uct_prior_joseki(struct uct *u, struct tree_node *node, struct prior_map *map)
210 /* Q_{joseki} */
211 if (!u->jdict)
212 return;
213 for (int i = 0; i < 4; i++) {
214 hash_t h = map->b->qhash[i] & joseki_hash_mask;
215 coord_t *cc = u->jdict->patterns[h].moves[map->to_play - 1];
216 if (!cc) continue;
217 for (; !is_pass(*cc); cc++) {
218 if (coord_quadrant(*cc, map->b) != i)
219 continue;
220 add_prior_value(map, *cc, 1.0, u->prior->joseki_eqex);
225 void
226 uct_prior_pattern(struct uct *u, struct tree_node *node, struct prior_map *map)
228 /* Q_{pattern} */
229 if (!u->pat.pd)
230 return;
232 struct board *b = map->b;
233 struct pattern pats[b->flen];
234 floating_t probs[b->flen];
235 pattern_rate_moves(&u->pat, b, map->to_play, pats, probs);
236 if (UDEBUGL(5)) {
237 fprintf(stderr, "Pattern prior at node %s\n", coord2sstr(node->coord, b));
238 board_print(b, stderr);
241 for (int f = 0; f < b->flen; f++) {
242 if (isnan(probs[f]) || probs[f] < 0.001)
243 continue;
244 assert(!is_pass(b->f[f]));
245 if (UDEBUGL(5)) {
246 char s[256]; pattern2str(s, &pats[f]);
247 fprintf(stderr, "\t%s: %.3f %s\n", coord2sstr(b->f[f], b), probs[f], s);
249 add_prior_value(map, b->f[f], 1.0, sqrt(probs[f]) * u->prior->pattern_eqex);
253 void
254 uct_prior(struct uct *u, struct tree_node *node, struct prior_map *map)
256 if (u->prior->prune_ladders && !board_playing_ko_threat(map->b)) {
257 foreach_free_point(map->b) {
258 if (!map->consider[c])
259 continue;
260 group_t atari_neighbor = board_get_atari_neighbor(map->b, c, map->to_play);
261 if (atari_neighbor && is_ladder(map->b, c, atari_neighbor, true)) {
262 if (UDEBUGL(5))
263 fprintf(stderr, "Pruning ladder move %s\n", coord2sstr(c, map->b));
264 map->consider[c] = false;
266 } foreach_free_point_end;
269 if (u->prior->even_eqex)
270 uct_prior_even(u, node, map);
271 if (u->prior->eye_eqex)
272 uct_prior_eye(u, node, map);
273 if (u->prior->ko_eqex)
274 uct_prior_ko(u, node, map);
275 if (u->prior->b19_eqex)
276 uct_prior_b19(u, node, map);
278 if (!node->parent) // Use dcnn for root priors
279 if (u->prior->dcnn_eqex)
280 uct_prior_dcnn(u, node, map);
282 if (u->prior->policy_eqex)
283 uct_prior_playout(u, node, map);
284 if (u->prior->cfgd_eqex)
285 uct_prior_cfgd(u, node, map);
286 if (u->prior->joseki_eqex)
287 uct_prior_joseki(u, node, map);
288 if (u->prior->pattern_eqex)
289 uct_prior_pattern(u, node, map);
290 if (u->prior->plugin_eqex)
291 plugin_prior(u->plugins, node, map, u->prior->plugin_eqex);
294 struct uct_prior *
295 uct_prior_init(char *arg, struct board *b, struct uct *u)
297 struct uct_prior *p = calloc2(1, sizeof(struct uct_prior));
299 p->even_eqex = p->policy_eqex = p->b19_eqex = p->eye_eqex = p->ko_eqex = p->plugin_eqex = -100;
300 /* FIXME: Optimal pattern_eqex is about -1000 with small playout counts
301 * but only -400 on a cluster. We need a better way to set the default
302 * here. */
303 p->pattern_eqex = -800;
304 /* Best value for dcnn_eqex so far seems to be 1300 with ~88% winrate
305 * against regular pachi. Below 1200 is bad (50% winrate and worse), more
306 * gives diminishing returns (1500 -> 78%, 2000 -> 70% ...) */
307 p->dcnn_eqex = 1300;
308 p->joseki_eqex = -200;
309 p->cfgdn = -1;
311 /* Even number! */
312 p->eqex = board_large(b) ? 20 : 14;
314 p->prune_ladders = true;
316 if (arg) {
317 char *optspec, *next = arg;
318 while (*next) {
319 optspec = next;
320 next += strcspn(next, ":");
321 if (*next) { *next++ = 0; } else { *next = 0; }
323 char *optname = optspec;
324 char *optval = strchr(optspec, '=');
325 if (optval) *optval++ = 0;
327 if (!strcasecmp(optname, "eqex") && optval) {
328 p->eqex = atoi(optval);
330 /* In the following settings, you can use negative
331 * numbers to give the hundredths of default eqex.
332 * E.g. -100 is default eqex, -50 is half of the
333 * default eqex, -200 is double the default eqex. */
334 } else if (!strcasecmp(optname, "even") && optval) {
335 p->even_eqex = atoi(optval);
336 } else if (!strcasecmp(optname, "policy") && optval) {
337 p->policy_eqex = atoi(optval);
338 } else if (!strcasecmp(optname, "b19") && optval) {
339 p->b19_eqex = atoi(optval);
340 } else if (!strcasecmp(optname, "cfgd") && optval) {
341 /* cfgd=3%40%20%20 - 3 levels; immediate libs
342 * of last move => 40 wins, their neighbors
343 * 20 wins, 2nd-level neighbors 20 wins;
344 * neighbors are group-transitive. */
345 p->cfgdn = atoi(optval); optval += strcspn(optval, "%");
346 p->cfgd_eqex = calloc2(p->cfgdn + 1, sizeof(*p->cfgd_eqex));
347 p->cfgd_eqex[0] = 0;
348 int i;
349 for (i = 1; *optval; i++, optval += strcspn(optval, "%")) {
350 optval++;
351 p->cfgd_eqex[i] = atoi(optval);
353 if (i != p->cfgdn + 1) {
354 fprintf(stderr, "uct: Missing prior cfdn level %d/%d\n", i, p->cfgdn);
355 exit(1);
358 } else if (!strcasecmp(optname, "joseki") && optval) {
359 p->joseki_eqex = atoi(optval);
360 } else if (!strcasecmp(optname, "eye") && optval) {
361 p->eye_eqex = atoi(optval);
362 } else if (!strcasecmp(optname, "ko") && optval) {
363 p->ko_eqex = atoi(optval);
364 } else if (!strcasecmp(optname, "pattern") && optval) {
365 /* Pattern-based prior eqex. */
366 /* Note that this prior is still going to be
367 * used only if you have downloaded or
368 * generated the pattern files! */
369 p->pattern_eqex = atoi(optval);
370 } else if (!strcasecmp(optname, "plugin") && optval) {
371 /* Unlike others, this is just a *recommendation*. */
372 p->plugin_eqex = atoi(optval);
373 } else if (!strcasecmp(optname, "prune_ladders")) {
374 p->prune_ladders = !optval || atoi(optval);
375 #ifdef DCNN
376 } else if (!strcasecmp(optname, "dcnn") && optval) {
377 p->dcnn_eqex = atoi(optval);
378 #endif
379 } else {
380 fprintf(stderr, "uct: Invalid prior argument %s or missing value\n", optname);
381 exit(1);
386 if (p->even_eqex < 0) p->even_eqex = p->eqex * -p->even_eqex / 100;
387 if (p->policy_eqex < 0) p->policy_eqex = p->eqex * -p->policy_eqex / 100;
388 if (p->b19_eqex < 0) p->b19_eqex = p->eqex * -p->b19_eqex / 100;
389 if (p->eye_eqex < 0) p->eye_eqex = p->eqex * -p->eye_eqex / 100;
390 if (p->ko_eqex < 0) p->ko_eqex = p->eqex * -p->ko_eqex / 100;
391 if (p->joseki_eqex < 0) p->joseki_eqex = p->eqex * -p->joseki_eqex / 100;
392 if (p->pattern_eqex < 0) p->pattern_eqex = p->eqex * -p->pattern_eqex / 100;
393 if (p->plugin_eqex < 0) p->plugin_eqex = p->eqex * -p->plugin_eqex / 100;
394 if (p->dcnn_eqex < 0) p->dcnn_eqex = p->eqex * -p->dcnn_eqex / 100;
396 if (!using_dcnn(b))
397 p->dcnn_eqex = 0;
399 if (p->cfgdn < 0) {
400 static int large_bonuses[] = { 0, 55, 50, 15 };
401 static int small_bonuses[] = { 0, 45, 40, 15 };
402 p->cfgdn = 3;
403 p->cfgd_eqex = calloc2(p->cfgdn + 1, sizeof(*p->cfgd_eqex));
404 memcpy(p->cfgd_eqex, board_large(b) ? large_bonuses : small_bonuses, sizeof(large_bonuses));
406 if (p->cfgdn > TREE_NODE_D_MAX) {
407 fprintf(stderr, "uct: CFG distances only up to %d available\n", TREE_NODE_D_MAX);
408 exit(1);
411 if (p->pattern_eqex)
412 u->want_pat = true;
414 return p;
417 void
418 uct_prior_done(struct uct_prior *p)
420 assert(p->cfgd_eqex);
421 free(p->cfgd_eqex);
422 free(p);