UCT Prior: Add ko prior, by default 0
[pachi.git] / uct / uct.c
blob21650d372881450da15dfd8a676b91a409272034
1 #include <assert.h>
2 #include <pthread.h>
3 #include <signal.h>
4 #include <stdio.h>
5 #include <stdlib.h>
6 #include <string.h>
8 #define DEBUG
10 #include "debug.h"
11 #include "board.h"
12 #include "move.h"
13 #include "playout.h"
14 #include "playout/moggy.h"
15 #include "playout/light.h"
16 #include "random.h"
17 #include "uct/internal.h"
18 #include "uct/prior.h"
19 #include "uct/tree.h"
20 #include "uct/uct.h"
22 struct uct_policy *policy_ucb1_init(struct uct *u, char *arg);
23 struct uct_policy *policy_ucb1amaf_init(struct uct *u, char *arg);
26 #define MC_GAMES 80000
27 #define MC_GAMELEN MAX_GAMELEN
30 static bool
31 can_pass(struct board *b, enum stone color)
33 float score = board_official_score(b);
34 if (color == S_BLACK)
35 score = -score;
36 //fprintf(stderr, "%d score %f\n", color, score);
37 return (score > 0);
40 static void
41 progress_status(struct uct *u, struct tree *t, enum stone color, int playouts)
43 if (!UDEBUGL(0))
44 return;
46 /* Best move */
47 struct tree_node *best = u->policy->choose(u->policy, t->root, t->board, color);
48 if (!best) {
49 fprintf(stderr, "... No moves left\n");
50 return;
52 fprintf(stderr, "[%d] ", playouts);
53 fprintf(stderr, "best %f ", tree_node_get_value(t, best, u, 1));
55 /* Max depth */
56 fprintf(stderr, "deepest % 2d ", t->max_depth - t->root->depth);
58 /* Best sequence */
59 fprintf(stderr, "| seq ");
60 for (int depth = 0; depth < 6; depth++) {
61 if (best && best->u.playouts >= 25) {
62 fprintf(stderr, "%3s ", coord2sstr(best->coord, t->board));
63 best = u->policy->choose(u->policy, best, t->board, color);
64 } else {
65 fprintf(stderr, " ");
69 /* Best candidates */
70 fprintf(stderr, "| can ");
71 int cans = 4;
72 struct tree_node *can[cans];
73 memset(can, 0, sizeof(can));
74 best = t->root->children;
75 while (best) {
76 int c = 0;
77 while ((!can[c] || best->u.playouts > can[c]->u.playouts) && ++c < cans);
78 for (int d = 0; d < c; d++) can[d] = can[d + 1];
79 if (c > 0) can[c - 1] = best;
80 best = best->sibling;
82 while (--cans >= 0) {
83 if (can[cans]) {
84 fprintf(stderr, "%3s(%.3f) ",
85 coord2sstr(can[cans]->coord, t->board),
86 tree_node_get_value(t, can[cans], u, 1));
87 } else {
88 fprintf(stderr, " ");
92 fprintf(stderr, "\n");
96 static int
97 uct_leaf_node(struct uct *u, struct board *b, enum stone player_color,
98 struct playout_amafmap *amaf,
99 struct tree *t, struct tree_node *n, enum stone node_color,
100 char *spaces)
102 enum stone next_color = stone_other(node_color);
103 int parity = (next_color == player_color ? 1 : -1);
104 if (n->u.playouts >= u->expand_p) {
105 // fprintf(stderr, "expanding %s (%p ^-%p)\n", coord2sstr(n->coord, b), n, n->parent);
106 tree_expand_node(t, n, b, next_color, u->radar_d, u, parity);
108 if (UDEBUGL(7))
109 fprintf(stderr, "%s*-- UCT playout #%d start [%s] %f\n",
110 spaces, n->u.playouts, coord2sstr(n->coord, t->board),
111 tree_node_get_value(t, n, u, parity));
113 int result = play_random_game(b, next_color, u->gamelen, u->playout_amaf ? amaf : NULL, NULL, u->playout);
114 if (next_color == S_WHITE && result >= 0) {
115 /* We need the result from black's perspective. */
116 result = !result;
118 if (UDEBUGL(7))
119 fprintf(stderr, "%s -- [%d..%d] %s random playout result %d\n",
120 spaces, player_color, next_color, coord2sstr(n->coord, t->board), result);
122 return result;
125 static int
126 uct_playout(struct uct *u, struct board *b, enum stone player_color, struct tree *t)
128 struct board b2;
129 board_copy(&b2, b);
131 struct playout_amafmap *amaf = NULL;
132 if (u->policy->wants_amaf) {
133 amaf = calloc(1, sizeof(*amaf));
134 amaf->map = calloc(board_size2(&b2) + 1, sizeof(*amaf->map));
135 amaf->map++; // -1 is pass
138 /* Walk the tree until we find a leaf, then expand it and do
139 * a random playout. */
140 struct tree_node *n = t->root;
141 enum stone node_color = stone_other(player_color);
142 assert(node_color == t->root_color);
144 int result;
145 int pass_limit = (board_size(&b2) - 2) * (board_size(&b2) - 2) / 2;
146 int passes = is_pass(b->last_move.coord);
148 /* debug */
149 int depth = 0;
150 static char spaces[] = "\0 ";
151 /* /debug */
152 if (UDEBUGL(8))
153 fprintf(stderr, "--- UCT walk with color %d\n", player_color);
155 while (!tree_leaf_node(n) && passes < 2) {
156 spaces[depth++] = ' '; spaces[depth] = 0;
158 /* Parity is chosen already according to the child color, since
159 * it is applied to children. */
160 node_color = stone_other(node_color);
161 int parity = (node_color == player_color ? 1 : -1);
162 n = u->policy->descend(u->policy, t, n, parity, pass_limit);
164 assert(n == t->root || n->parent);
165 if (UDEBUGL(7))
166 fprintf(stderr, "%s+-- UCT sent us to [%s:%d] %f\n",
167 spaces, coord2sstr(n->coord, t->board), n->coord,
168 tree_node_get_value(t, n, u, parity));
170 assert(n->coord >= -1);
171 if (amaf && !is_pass(n->coord)) {
172 if (amaf->map[n->coord] == S_NONE || amaf->map[n->coord] == node_color) {
173 amaf->map[n->coord] = node_color;
174 } else { // XXX: Respect amaf->record_nakade
175 amaf_op(amaf->map[n->coord], +);
177 amaf->game[amaf->gamelen].coord = n->coord;
178 amaf->game[amaf->gamelen].color = node_color;
179 amaf->gamelen++;
180 assert(amaf->gamelen < sizeof(amaf->game) / sizeof(amaf->game[0]));
183 struct move m = { n->coord, node_color };
184 int res = board_play(&b2, &m);
186 if (res < 0 || (!is_pass(m.coord) && !group_at(&b2, m.coord)) /* suicide */
187 || b2.superko_violation) {
188 if (UDEBUGL(3)) {
189 for (struct tree_node *ni = n; ni; ni = ni->parent)
190 fprintf(stderr, "%s<%lld> ", coord2sstr(ni->coord, t->board), ni->hash);
191 fprintf(stderr, "deleting invalid %s node %d,%d res %d group %d spk %d\n",
192 stone2str(node_color), coord_x(n->coord,b), coord_y(n->coord,b),
193 res, group_at(&b2, m.coord), b2.superko_violation);
195 tree_delete_node(t, n);
196 result = -1;
197 goto end;
200 if (is_pass(n->coord))
201 passes++;
202 else
203 passes = 0;
206 if (amaf) {
207 amaf->game_baselen = amaf->gamelen;
208 amaf->record_nakade = u->playout_amaf_nakade;
211 if (passes >= 2) {
212 float score = board_official_score(&b2);
213 /* Result from black's perspective (no matter who
214 * the player; black's perspective is always
215 * what the tree stores. */
216 result = score < 0;
218 if (UDEBUGL(5))
219 fprintf(stderr, "[%d..%d] %s p-p scoring playout result %d (W %f)\n",
220 player_color, node_color, coord2sstr(n->coord, t->board), result, score);
221 if (UDEBUGL(6))
222 board_print(&b2, stderr);
224 } else { assert(tree_leaf_node(n));
225 result = uct_leaf_node(u, &b2, player_color, amaf, t, n, node_color, spaces);
228 if (amaf && u->playout_amaf_cutoff) {
229 int cutoff = amaf->game_baselen;
230 cutoff += (amaf->gamelen - amaf->game_baselen) * u->playout_amaf_cutoff / 100;
231 /* Now, reconstruct the amaf map. */
232 memset(amaf->map, 0, board_size2(&b2) * sizeof(*amaf->map));
233 for (int i = 0; i < cutoff; i++) {
234 coord_t coord = amaf->game[i].coord;
235 enum stone color = amaf->game[i].color;
236 if (amaf->map[coord] == S_NONE || amaf->map[coord] == color) {
237 amaf->map[coord] = color;
238 /* Nakade always recorded for in-tree part */
239 } else if (amaf->record_nakade || i <= amaf->game_baselen) {
240 amaf_op(amaf->map[n->coord], +);
245 assert(n == t->root || n->parent);
246 if (result >= 0)
247 u->policy->update(u->policy, t, n, node_color, player_color, amaf, result);
249 end:
250 if (amaf) {
251 free(amaf->map - 1);
252 free(amaf);
254 board_done_noalloc(&b2);
255 return result;
258 static void
259 prepare_move(struct engine *e, struct board *b, enum stone color, coord_t promote)
261 struct uct *u = e->data;
263 if (u->t && (!b->moves || color != stone_other(u->t->root_color))) {
264 /* Stale state from last game */
265 tree_done(u->t);
266 u->t = NULL;
269 if (!u->t) {
270 u->t = tree_init(b, color);
271 if (u->force_seed)
272 fast_srandom(u->force_seed);
273 if (UDEBUGL(0))
274 fprintf(stderr, "Fresh board with random seed %lu\n", fast_getseed());
275 //board_print(b, stderr);
276 if (!u->no_book && b->moves < 2)
277 tree_load(u->t, b);
280 /* XXX: We hope that the opponent didn't suddenly play
281 * several moves in the row. */
282 if (!is_resign(promote) && !tree_promote_at(u->t, b, promote)) {
283 if (UDEBUGL(2))
284 fprintf(stderr, "<cannot find node to promote>\n");
285 /* Reset tree */
286 tree_done(u->t);
287 u->t = tree_init(b, color);
291 /* Set in main thread in case the playouts should stop. */
292 static volatile sig_atomic_t halt = 0;
294 static int
295 uct_playouts(struct uct *u, struct board *b, enum stone color, struct tree *t)
297 int i, games = u->games;
298 if (t->root->children)
299 games -= t->root->u.playouts / 1.5;
300 /* else this is highly read-out but dead-end branch of opening book;
301 * we need to start from scratch; XXX: Maybe actually base the readout
302 * count based on number of playouts of best node? */
303 for (i = 0; i < games; i++) {
304 int result = uct_playout(u, b, color, t);
305 if (result < 0) {
306 /* Tree descent has hit invalid move. */
307 continue;
310 if (i > 0 && !(i % 10000)) {
311 progress_status(u, t, color, i);
314 if (i > 0 && !(i % 500)) {
315 struct tree_node *best = u->policy->choose(u->policy, t->root, b, color);
316 if (best && ((best->u.playouts >= 5000 && tree_node_get_value(t, best, u, 1) >= u->loss_threshold)
317 || (best->u.playouts >= 500 && tree_node_get_value(t, best, u, 1) >= 0.95)))
318 break;
321 if (halt) {
322 if (UDEBUGL(2))
323 fprintf(stderr, "<halting early, %d games skipped>\n", games - i);
324 break;
328 progress_status(u, t, color, i);
329 if (UDEBUGL(3))
330 tree_dump(t, u->dumpthres);
331 return i;
334 static pthread_mutex_t finish_mutex = PTHREAD_MUTEX_INITIALIZER;
335 static pthread_cond_t finish_cond = PTHREAD_COND_INITIALIZER;
336 static volatile int finish_thread;
337 static pthread_mutex_t finish_serializer = PTHREAD_MUTEX_INITIALIZER;
339 struct spawn_ctx {
340 int tid;
341 struct uct *u;
342 struct board *b;
343 enum stone color;
344 struct tree *t;
345 unsigned long seed;
346 int games;
349 static void *
350 spawn_helper(void *ctx_)
352 struct spawn_ctx *ctx = ctx_;
353 /* Setup */
354 fast_srandom(ctx->seed);
355 /* Run */
356 ctx->games = uct_playouts(ctx->u, ctx->b, ctx->color, ctx->t);
357 /* Finish */
358 pthread_mutex_lock(&finish_serializer);
359 pthread_mutex_lock(&finish_mutex);
360 finish_thread = ctx->tid;
361 pthread_cond_signal(&finish_cond);
362 pthread_mutex_unlock(&finish_mutex);
363 return ctx;
366 static void
367 uct_notify_play(struct engine *e, struct board *b, struct move *m)
369 prepare_move(e, b, m->color, m->coord);
372 static coord_t *
373 uct_genmove(struct engine *e, struct board *b, enum stone color)
375 struct uct *u = e->data;
377 /* Seed the tree. */
378 prepare_move(e, b, color, resign);
380 if (b->superko_violation) {
381 fprintf(stderr, "!!! WARNING: SUPERKO VIOLATION OCCURED BEFORE THIS MOVE\n");
382 fprintf(stderr, "Maybe you play with situational instead of positional superko?\n");
383 fprintf(stderr, "I'm going to ignore the violation, but note that I may miss\n");
384 fprintf(stderr, "some moves valid under this ruleset because of this.\n");
385 b->superko_violation = false;
388 /* If the opponent just passes and we win counting, just
389 * pass as well. */
390 if (b->moves > 1 && is_pass(b->last_move.coord) && can_pass(b, color))
391 return coord_copy(pass);
393 int played_games = 0;
394 if (!u->threads) {
395 played_games = uct_playouts(u, b, color, u->t);
396 } else {
397 pthread_t threads[u->threads];
398 int joined = 0;
399 halt = 0;
400 pthread_mutex_lock(&finish_mutex);
401 /* Spawn threads... */
402 for (int ti = 0; ti < u->threads; ti++) {
403 struct spawn_ctx *ctx = malloc(sizeof(*ctx));
404 ctx->u = u; ctx->b = b; ctx->color = color;
405 ctx->t = tree_copy(u->t); ctx->tid = ti;
406 ctx->seed = fast_random(65536) + ti;
407 pthread_create(&threads[ti], NULL, spawn_helper, ctx);
408 if (UDEBUGL(2))
409 fprintf(stderr, "Spawned thread %d\n", ti);
411 /* ...and collect them back: */
412 while (joined < u->threads) {
413 /* Wait for some thread to finish... */
414 pthread_cond_wait(&finish_cond, &finish_mutex);
415 /* ...and gather its remnants. */
416 struct spawn_ctx *ctx;
417 pthread_join(threads[finish_thread], (void **) &ctx);
418 played_games += ctx->games;
419 joined++;
420 tree_merge(u->t, ctx->t);
421 tree_done(ctx->t);
422 free(ctx);
423 if (UDEBUGL(2))
424 fprintf(stderr, "Joined thread %d\n", finish_thread);
425 /* Do not get stalled by slow threads. */
426 if (joined >= u->threads / 2)
427 halt = 1;
428 pthread_mutex_unlock(&finish_serializer);
430 pthread_mutex_unlock(&finish_mutex);
432 tree_normalize(u->t, u->threads);
435 if (UDEBUGL(2))
436 tree_dump(u->t, u->dumpthres);
438 struct tree_node *best = u->policy->choose(u->policy, u->t->root, b, color);
439 if (!best) {
440 tree_done(u->t); u->t = NULL;
441 return coord_copy(pass);
443 if (UDEBUGL(0))
444 progress_status(u, u->t, color, played_games);
445 if (UDEBUGL(1))
446 fprintf(stderr, "*** WINNER is %s (%d,%d) with score %1.4f (%d/%d:%d games)\n",
447 coord2sstr(best->coord, b), coord_x(best->coord, b), coord_y(best->coord, b),
448 tree_node_get_value(u->t, best, u, 1),
449 best->u.playouts, u->t->root->u.playouts, played_games);
450 if (tree_node_get_value(u->t, best, u, 1) < u->resign_ratio && !is_pass(best->coord)) {
451 tree_done(u->t); u->t = NULL;
452 return coord_copy(resign);
454 tree_promote_node(u->t, best);
455 return coord_copy(best->coord);
458 bool
459 uct_genbook(struct engine *e, struct board *b, enum stone color)
461 struct uct *u = e->data;
462 u->t = tree_init(b, color);
463 tree_load(u->t, b);
465 int i;
466 for (i = 0; i < u->games; i++) {
467 int result = uct_playout(u, b, color, u->t);
468 if (result < 0) {
469 /* Tree descent has hit invalid move. */
470 continue;
473 if (i > 0 && !(i % 10000)) {
474 progress_status(u, u->t, color, i);
477 progress_status(u, u->t, color, i);
479 tree_save(u->t, b, u->games / 100);
481 tree_done(u->t);
483 return true;
486 void
487 uct_dumpbook(struct engine *e, struct board *b, enum stone color)
489 struct uct *u = e->data;
490 u->t = tree_init(b, color);
491 tree_load(u->t, b);
492 tree_dump(u->t, 0);
493 tree_done(u->t);
497 struct uct *
498 uct_state_init(char *arg)
500 struct uct *u = calloc(1, sizeof(struct uct));
502 u->debug_level = 1;
503 u->games = MC_GAMES;
504 u->gamelen = MC_GAMELEN;
505 u->expand_p = 2;
506 u->dumpthres = 1000;
507 u->playout_amaf = true;
508 u->playout_amaf_nakade = false;
509 u->amaf_prior = true;
511 // gp: 14 vs 0: 44% (+-3.5)
512 u->gp_eqex = u->ko_eqex = 0;
513 u->even_eqex = u->policy_eqex = u->b19_eqex = u->cfgd_eqex = u->eye_eqex = -1;
514 u->eqex = 6; /* Even number! */
516 if (arg) {
517 char *optspec, *next = arg;
518 while (*next) {
519 optspec = next;
520 next += strcspn(next, ",");
521 if (*next) { *next++ = 0; } else { *next = 0; }
523 char *optname = optspec;
524 char *optval = strchr(optspec, '=');
525 if (optval) *optval++ = 0;
527 if (!strcasecmp(optname, "debug")) {
528 if (optval)
529 u->debug_level = atoi(optval);
530 else
531 u->debug_level++;
532 } else if (!strcasecmp(optname, "games") && optval) {
533 u->games = atoi(optval);
534 } else if (!strcasecmp(optname, "gamelen") && optval) {
535 u->gamelen = atoi(optval);
536 } else if (!strcasecmp(optname, "expand_p") && optval) {
537 u->expand_p = atoi(optval);
538 } else if (!strcasecmp(optname, "radar_d") && optval) {
539 /* For 19x19, it is good idea to set this to 3. */
540 u->radar_d = atoi(optval);
541 } else if (!strcasecmp(optname, "dumpthres") && optval) {
542 u->dumpthres = atoi(optval);
543 } else if (!strcasecmp(optname, "playout_amaf")) {
544 /* Whether to include random playout moves in
545 * AMAF as well. (Otherwise, only tree moves
546 * are included in AMAF. Of course makes sense
547 * only in connection with an AMAF policy.) */
548 /* with-without: 55.5% (+-4.1) */
549 if (optval && *optval == '0')
550 u->playout_amaf = false;
551 else
552 u->playout_amaf = true;
553 } else if (!strcasecmp(optname, "playout_amaf_nakade")) {
554 /* Whether to include nakade moves from playouts
555 * in the AMAF statistics; this tends to nullify
556 * the playout_amaf effect by adding too much
557 * noise. */
558 if (optval && *optval == '0')
559 u->playout_amaf_nakade = false;
560 else
561 u->playout_amaf_nakade = true;
562 } else if (!strcasecmp(optname, "playout_amaf_cutoff") && optval) {
563 /* Keep only first N% of playout stage AMAF
564 * information. */
565 u->playout_amaf_cutoff = atoi(optval);
566 } else if (!strcasecmp(optname, "policy") && optval) {
567 char *policyarg = strchr(optval, ':');
568 if (policyarg)
569 *policyarg++ = 0;
570 if (!strcasecmp(optval, "ucb1")) {
571 u->policy = policy_ucb1_init(u, policyarg);
572 } else if (!strcasecmp(optval, "ucb1amaf")) {
573 u->policy = policy_ucb1amaf_init(u, policyarg);
574 } else {
575 fprintf(stderr, "UCT: Invalid tree policy %s\n", optval);
577 } else if (!strcasecmp(optname, "playout") && optval) {
578 char *playoutarg = strchr(optval, ':');
579 if (playoutarg)
580 *playoutarg++ = 0;
581 if (!strcasecmp(optval, "moggy")) {
582 u->playout = playout_moggy_init(playoutarg);
583 } else if (!strcasecmp(optval, "light")) {
584 u->playout = playout_light_init(playoutarg);
585 } else {
586 fprintf(stderr, "UCT: Invalid playout policy %s\n", optval);
588 } else if (!strcasecmp(optname, "prior") && optval) {
589 u->eqex = atoi(optval);
590 } else if (!strcasecmp(optname, "prior_even") && optval) {
591 u->even_eqex = atoi(optval);
592 } else if (!strcasecmp(optname, "prior_gp") && optval) {
593 u->gp_eqex = atoi(optval);
594 } else if (!strcasecmp(optname, "prior_policy") && optval) {
595 u->policy_eqex = atoi(optval);
596 } else if (!strcasecmp(optname, "prior_b19") && optval) {
597 u->b19_eqex = atoi(optval);
598 } else if (!strcasecmp(optname, "prior_cfgd") && optval) {
599 u->cfgd_eqex = atoi(optval);
600 } else if (!strcasecmp(optname, "prior_eye") && optval) {
601 u->eye_eqex = atoi(optval);
602 } else if (!strcasecmp(optname, "prior_ko") && optval) {
603 u->ko_eqex = atoi(optval);
604 } else if (!strcasecmp(optname, "amaf_prior") && optval) {
605 u->amaf_prior = atoi(optval);
606 } else if (!strcasecmp(optname, "threads") && optval) {
607 u->threads = atoi(optval);
608 } else if (!strcasecmp(optname, "force_seed") && optval) {
609 u->force_seed = atoi(optval);
610 } else if (!strcasecmp(optname, "no_book")) {
611 u->no_book = true;
612 } else {
613 fprintf(stderr, "uct: Invalid engine argument %s or missing value\n", optname);
614 exit(1);
619 if (u->even_eqex < 0) u->even_eqex = u->eqex;
620 if (u->gp_eqex < 0) u->gp_eqex = u->eqex;
621 if (u->policy_eqex < 0) u->policy_eqex = u->eqex;
622 if (u->b19_eqex < 0) u->b19_eqex = u->eqex;
623 if (u->cfgd_eqex < 0) u->cfgd_eqex = u->eqex;
625 u->resign_ratio = 0.2; /* Resign when most games are lost. */
626 u->loss_threshold = 0.85; /* Stop reading if after at least 5000 playouts this is best value. */
627 if (!u->policy)
628 u->policy = policy_ucb1amaf_init(u, NULL);
630 if (!u->playout)
631 u->playout = playout_moggy_init(NULL);
632 u->playout->debug_level = u->debug_level;
634 return u;
638 struct engine *
639 engine_uct_init(char *arg)
641 struct uct *u = uct_state_init(arg);
642 struct engine *e = calloc(1, sizeof(struct engine));
643 e->name = "UCT Engine";
644 e->comment = "I'm playing UCT. When we both pass, I will consider all the stones on the board alive. If you are reading this, write 'yes'. Please capture all dead stones before passing; it will not cost you points (area scoring is used).";
645 e->genmove = uct_genmove;
646 e->notify_play = uct_notify_play;
647 e->data = u;
649 return e;