TESTS: Few historical 19x19 results
[pachi.git] / uct / uct.c
blob941a523d03511eeaef2d8c228deb1cf1475eda26
1 #include <assert.h>
2 #include <pthread.h>
3 #include <signal.h>
4 #include <stdio.h>
5 #include <stdlib.h>
6 #include <string.h>
8 #define DEBUG
10 #include "debug.h"
11 #include "board.h"
12 #include "move.h"
13 #include "playout.h"
14 #include "playout/moggy.h"
15 #include "playout/light.h"
16 #include "random.h"
17 #include "tactics.h"
18 #include "uct/internal.h"
19 #include "uct/prior.h"
20 #include "uct/tree.h"
21 #include "uct/uct.h"
23 struct uct_policy *policy_ucb1_init(struct uct *u, char *arg);
24 struct uct_policy *policy_ucb1amaf_init(struct uct *u, char *arg);
27 #define MC_GAMES 80000
28 #define MC_GAMELEN MAX_GAMELEN
31 static bool
32 can_pass(struct board *b, enum stone color)
34 float score = board_official_score(b);
35 if (color == S_BLACK)
36 score = -score;
37 //fprintf(stderr, "%d score %f\n", color, score);
38 return (score > 0);
41 static float
42 get_extra_komi(struct uct *u, struct board *b, enum stone player_color)
44 float extra_komi = board_effective_handicap(b) * (u->dynkomi - b->moves) / u->dynkomi;
45 if (player_color == S_WHITE)
46 extra_komi *= -1;
47 return extra_komi;
50 static void
51 progress_status(struct uct *u, struct tree *t, enum stone color, int playouts)
53 if (!UDEBUGL(0))
54 return;
56 /* Best move */
57 struct tree_node *best = u->policy->choose(u->policy, t->root, t->board, color);
58 if (!best) {
59 fprintf(stderr, "... No moves left\n");
60 return;
62 fprintf(stderr, "[%d] ", playouts);
63 fprintf(stderr, "best %f ", tree_node_get_value(t, best, u, 1));
65 /* Max depth */
66 fprintf(stderr, "deepest % 2d ", t->max_depth - t->root->depth);
68 /* Best sequence */
69 fprintf(stderr, "| seq ");
70 for (int depth = 0; depth < 6; depth++) {
71 if (best && best->u.playouts >= 25) {
72 fprintf(stderr, "%3s ", coord2sstr(best->coord, t->board));
73 best = u->policy->choose(u->policy, best, t->board, color);
74 } else {
75 fprintf(stderr, " ");
79 /* Best candidates */
80 fprintf(stderr, "| can ");
81 int cans = 4;
82 struct tree_node *can[cans];
83 memset(can, 0, sizeof(can));
84 best = t->root->children;
85 while (best) {
86 int c = 0;
87 while ((!can[c] || best->u.playouts > can[c]->u.playouts) && ++c < cans);
88 for (int d = 0; d < c; d++) can[d] = can[d + 1];
89 if (c > 0) can[c - 1] = best;
90 best = best->sibling;
92 while (--cans >= 0) {
93 if (can[cans]) {
94 fprintf(stderr, "%3s(%.3f) ",
95 coord2sstr(can[cans]->coord, t->board),
96 tree_node_get_value(t, can[cans], u, 1));
97 } else {
98 fprintf(stderr, " ");
102 fprintf(stderr, "\n");
106 static int
107 uct_leaf_node(struct uct *u, struct board *b, enum stone player_color,
108 struct playout_amafmap *amaf,
109 struct tree *t, struct tree_node *n, enum stone node_color,
110 char *spaces)
112 enum stone next_color = stone_other(node_color);
113 int parity = (next_color == player_color ? 1 : -1);
114 if (n->u.playouts >= u->expand_p) {
115 // fprintf(stderr, "expanding %s (%p ^-%p)\n", coord2sstr(n->coord, b), n, n->parent);
116 tree_expand_node(t, n, b, next_color, u->radar_d, u, parity);
118 if (UDEBUGL(7))
119 fprintf(stderr, "%s*-- UCT playout #%d start [%s] %f\n",
120 spaces, n->u.playouts, coord2sstr(n->coord, t->board),
121 tree_node_get_value(t, n, u, parity));
123 int result = play_random_game(b, next_color, u->gamelen, u->playout_amaf ? amaf : NULL, NULL, u->playout);
124 if (next_color == S_WHITE) {
125 /* We need the result from black's perspective. */
126 result = - result;
128 if (UDEBUGL(7))
129 fprintf(stderr, "%s -- [%d..%d] %s random playout result %d\n",
130 spaces, player_color, next_color, coord2sstr(n->coord, t->board), result);
132 return result;
135 static int
136 uct_playout(struct uct *u, struct board *b, enum stone player_color, struct tree *t)
138 struct board b2;
139 board_copy(&b2, b);
141 struct playout_amafmap *amaf = NULL;
142 if (u->policy->wants_amaf) {
143 amaf = calloc(1, sizeof(*amaf));
144 amaf->map = calloc(board_size2(&b2) + 1, sizeof(*amaf->map));
145 amaf->map++; // -1 is pass
148 /* Walk the tree until we find a leaf, then expand it and do
149 * a random playout. */
150 struct tree_node *n = t->root;
151 enum stone node_color = stone_other(player_color);
152 assert(node_color == t->root_color);
154 int result;
155 int pass_limit = (board_size(&b2) - 2) * (board_size(&b2) - 2) / 2;
156 int passes = is_pass(b->last_move.coord) && b->moves > 0;
158 /* debug */
159 int depth = 0;
160 static char spaces[] = "\0 ";
161 /* /debug */
162 if (UDEBUGL(8))
163 fprintf(stderr, "--- UCT walk with color %d\n", player_color);
165 while (!tree_leaf_node(n) && passes < 2) {
166 spaces[depth++] = ' '; spaces[depth] = 0;
168 /* Parity is chosen already according to the child color, since
169 * it is applied to children. */
170 node_color = stone_other(node_color);
171 int parity = (node_color == player_color ? 1 : -1);
172 n = u->policy->descend(u->policy, t, n, parity, pass_limit);
174 assert(n == t->root || n->parent);
175 if (UDEBUGL(7))
176 fprintf(stderr, "%s+-- UCT sent us to [%s:%d] %f\n",
177 spaces, coord2sstr(n->coord, t->board), n->coord,
178 tree_node_get_value(t, n, u, parity));
180 assert(n->coord >= -1);
181 if (amaf && !is_pass(n->coord)) {
182 if (amaf->map[n->coord] == S_NONE || amaf->map[n->coord] == node_color) {
183 amaf->map[n->coord] = node_color;
184 } else { // XXX: Respect amaf->record_nakade
185 amaf_op(amaf->map[n->coord], +);
187 amaf->game[amaf->gamelen].coord = n->coord;
188 amaf->game[amaf->gamelen].color = node_color;
189 amaf->gamelen++;
190 assert(amaf->gamelen < sizeof(amaf->game) / sizeof(amaf->game[0]));
193 struct move m = { n->coord, node_color };
194 int res = board_play(&b2, &m);
196 if (res < 0 || (!is_pass(m.coord) && !group_at(&b2, m.coord)) /* suicide */
197 || b2.superko_violation) {
198 if (UDEBUGL(3)) {
199 for (struct tree_node *ni = n; ni; ni = ni->parent)
200 fprintf(stderr, "%s<%lld> ", coord2sstr(ni->coord, t->board), ni->hash);
201 fprintf(stderr, "deleting invalid %s node %d,%d res %d group %d spk %d\n",
202 stone2str(node_color), coord_x(n->coord,b), coord_y(n->coord,b),
203 res, group_at(&b2, m.coord), b2.superko_violation);
205 tree_delete_node(t, n);
206 result = 0;
207 goto end;
210 if (is_pass(n->coord))
211 passes++;
212 else
213 passes = 0;
216 if (amaf) {
217 amaf->game_baselen = amaf->gamelen;
218 amaf->record_nakade = u->playout_amaf_nakade;
221 if (u->dynkomi > b2.moves)
222 b2.komi += get_extra_komi(u, &b2, player_color);
224 if (passes >= 2) {
225 float score = board_official_score(&b2);
226 /* Result from black's perspective (no matter who
227 * the player; black's perspective is always
228 * what the tree stores. */
229 result = - (score * 2);
231 if (UDEBUGL(5))
232 fprintf(stderr, "[%d..%d] %s p-p scoring playout result %d (W %f)\n",
233 player_color, node_color, coord2sstr(n->coord, t->board), result, score);
234 if (UDEBUGL(6))
235 board_print(&b2, stderr);
237 } else { assert(tree_leaf_node(n));
238 result = uct_leaf_node(u, &b2, player_color, amaf, t, n, node_color, spaces);
241 if (amaf && u->playout_amaf_cutoff) {
242 int cutoff = amaf->game_baselen;
243 cutoff += (amaf->gamelen - amaf->game_baselen) * u->playout_amaf_cutoff / 100;
244 /* Now, reconstruct the amaf map. */
245 memset(amaf->map, 0, board_size2(&b2) * sizeof(*amaf->map));
246 for (int i = 0; i < cutoff; i++) {
247 coord_t coord = amaf->game[i].coord;
248 enum stone color = amaf->game[i].color;
249 if (amaf->map[coord] == S_NONE || amaf->map[coord] == color) {
250 amaf->map[coord] = color;
251 /* Nakade always recorded for in-tree part */
252 } else if (amaf->record_nakade || i <= amaf->game_baselen) {
253 amaf_op(amaf->map[n->coord], +);
258 assert(n == t->root || n->parent);
259 if (result != 0) {
260 float rval = result > 0;
261 if (u->val_scale) {
262 float sval = (float) abs(result) / u->val_points;
263 sval = sval > 1 ? 1 : sval;
264 if (result < 0) sval = 1 - sval;
265 rval = (1 - u->val_scale) * rval + u->val_scale * sval;
266 // fprintf(stderr, "score %d => sval %f, rval %f\n", result, sval, rval);
268 u->policy->update(u->policy, t, n, node_color, player_color, amaf, rval);
271 end:
272 if (amaf) {
273 free(amaf->map - 1);
274 free(amaf);
276 board_done_noalloc(&b2);
277 return result;
280 static void
281 prepare_move(struct engine *e, struct board *b, enum stone color, coord_t promote)
283 struct uct *u = e->data;
285 if (u->t && (!b->moves || color != stone_other(u->t->root_color))) {
286 /* Stale state from last game */
287 tree_done(u->t);
288 u->t = NULL;
291 if (!u->t) {
292 u->t = tree_init(b, color);
293 if (u->force_seed)
294 fast_srandom(u->force_seed);
295 if (UDEBUGL(0))
296 fprintf(stderr, "Fresh board with random seed %lu\n", fast_getseed());
297 //board_print(b, stderr);
298 if (!u->no_book && b->moves < 2)
299 tree_load(u->t, b);
302 /* XXX: We hope that the opponent didn't suddenly play
303 * several moves in the row. */
304 if (!is_resign(promote) && !tree_promote_at(u->t, b, promote)) {
305 if (UDEBUGL(2))
306 fprintf(stderr, "<cannot find node to promote>\n");
307 /* Reset tree */
308 tree_done(u->t);
309 u->t = tree_init(b, color);
312 if (u->dynkomi)
313 u->t->extra_komi = get_extra_komi(u, b, color);
316 /* Set in main thread in case the playouts should stop. */
317 static volatile sig_atomic_t halt = 0;
319 static int
320 uct_playouts(struct uct *u, struct board *b, enum stone color, struct tree *t)
322 int i, games = u->games;
323 if (t->root->children)
324 games -= t->root->u.playouts / 1.5;
325 /* else this is highly read-out but dead-end branch of opening book;
326 * we need to start from scratch; XXX: Maybe actually base the readout
327 * count based on number of playouts of best node? */
328 for (i = 0; i < games; i++) {
329 int result = uct_playout(u, b, color, t);
330 if (result == 0) {
331 /* Tree descent has hit invalid move. */
332 continue;
335 if (i > 0 && !(i % 10000)) {
336 progress_status(u, t, color, i);
339 if (i > 0 && !(i % 500)) {
340 struct tree_node *best = u->policy->choose(u->policy, t->root, b, color);
341 if (best && ((best->u.playouts >= 5000 && tree_node_get_value(t, best, u, 1) >= u->loss_threshold)
342 || (best->u.playouts >= 500 && tree_node_get_value(t, best, u, 1) >= 0.95)))
343 break;
346 if (halt) {
347 if (UDEBUGL(2))
348 fprintf(stderr, "<halting early, %d games skipped>\n", games - i);
349 break;
353 progress_status(u, t, color, i);
354 if (UDEBUGL(3))
355 tree_dump(t, u->dumpthres);
356 return i;
359 static pthread_mutex_t finish_mutex = PTHREAD_MUTEX_INITIALIZER;
360 static pthread_cond_t finish_cond = PTHREAD_COND_INITIALIZER;
361 static volatile int finish_thread;
362 static pthread_mutex_t finish_serializer = PTHREAD_MUTEX_INITIALIZER;
364 struct spawn_ctx {
365 int tid;
366 struct uct *u;
367 struct board *b;
368 enum stone color;
369 struct tree *t;
370 unsigned long seed;
371 int games;
374 static void *
375 spawn_helper(void *ctx_)
377 struct spawn_ctx *ctx = ctx_;
378 /* Setup */
379 fast_srandom(ctx->seed);
380 /* Run */
381 ctx->games = uct_playouts(ctx->u, ctx->b, ctx->color, ctx->t);
382 /* Finish */
383 pthread_mutex_lock(&finish_serializer);
384 pthread_mutex_lock(&finish_mutex);
385 finish_thread = ctx->tid;
386 pthread_cond_signal(&finish_cond);
387 pthread_mutex_unlock(&finish_mutex);
388 return ctx;
391 static void
392 uct_notify_play(struct engine *e, struct board *b, struct move *m)
394 prepare_move(e, b, m->color, m->coord);
397 static coord_t *
398 uct_genmove(struct engine *e, struct board *b, enum stone color)
400 struct uct *u = e->data;
402 /* Seed the tree. */
403 prepare_move(e, b, color, resign);
405 if (b->superko_violation) {
406 fprintf(stderr, "!!! WARNING: SUPERKO VIOLATION OCCURED BEFORE THIS MOVE\n");
407 fprintf(stderr, "Maybe you play with situational instead of positional superko?\n");
408 fprintf(stderr, "I'm going to ignore the violation, but note that I may miss\n");
409 fprintf(stderr, "some moves valid under this ruleset because of this.\n");
410 b->superko_violation = false;
413 /* If the opponent just passes and we win counting, just
414 * pass as well. */
415 if (b->moves > 1 && is_pass(b->last_move.coord) && can_pass(b, color))
416 return coord_copy(pass);
418 int played_games = 0;
419 if (!u->threads) {
420 played_games = uct_playouts(u, b, color, u->t);
421 } else {
422 pthread_t threads[u->threads];
423 int joined = 0;
424 halt = 0;
425 pthread_mutex_lock(&finish_mutex);
426 /* Spawn threads... */
427 for (int ti = 0; ti < u->threads; ti++) {
428 struct spawn_ctx *ctx = malloc(sizeof(*ctx));
429 ctx->u = u; ctx->b = b; ctx->color = color;
430 ctx->t = tree_copy(u->t); ctx->tid = ti;
431 ctx->seed = fast_random(65536) + ti;
432 pthread_create(&threads[ti], NULL, spawn_helper, ctx);
433 if (UDEBUGL(2))
434 fprintf(stderr, "Spawned thread %d\n", ti);
436 /* ...and collect them back: */
437 while (joined < u->threads) {
438 /* Wait for some thread to finish... */
439 pthread_cond_wait(&finish_cond, &finish_mutex);
440 /* ...and gather its remnants. */
441 struct spawn_ctx *ctx;
442 pthread_join(threads[finish_thread], (void **) &ctx);
443 played_games += ctx->games;
444 joined++;
445 tree_merge(u->t, ctx->t);
446 tree_done(ctx->t);
447 free(ctx);
448 if (UDEBUGL(2))
449 fprintf(stderr, "Joined thread %d\n", finish_thread);
450 /* Do not get stalled by slow threads. */
451 if (joined >= u->threads / 2)
452 halt = 1;
453 pthread_mutex_unlock(&finish_serializer);
455 pthread_mutex_unlock(&finish_mutex);
457 tree_normalize(u->t, u->threads);
460 if (UDEBUGL(2))
461 tree_dump(u->t, u->dumpthres);
463 struct tree_node *best = u->policy->choose(u->policy, u->t->root, b, color);
464 if (!best) {
465 tree_done(u->t); u->t = NULL;
466 return coord_copy(pass);
468 if (UDEBUGL(0))
469 progress_status(u, u->t, color, played_games);
470 if (UDEBUGL(1))
471 fprintf(stderr, "*** WINNER is %s (%d,%d) with score %1.4f (%d/%d:%d games)\n",
472 coord2sstr(best->coord, b), coord_x(best->coord, b), coord_y(best->coord, b),
473 tree_node_get_value(u->t, best, u, 1),
474 best->u.playouts, u->t->root->u.playouts, played_games);
475 if (tree_node_get_value(u->t, best, u, 1) < u->resign_ratio && !is_pass(best->coord)) {
476 tree_done(u->t); u->t = NULL;
477 return coord_copy(resign);
479 tree_promote_node(u->t, best);
480 return coord_copy(best->coord);
483 bool
484 uct_genbook(struct engine *e, struct board *b, enum stone color)
486 struct uct *u = e->data;
487 u->t = tree_init(b, color);
488 tree_load(u->t, b);
490 int i;
491 for (i = 0; i < u->games; i++) {
492 int result = uct_playout(u, b, color, u->t);
493 if (result == 0) {
494 /* Tree descent has hit invalid move. */
495 continue;
498 if (i > 0 && !(i % 10000)) {
499 progress_status(u, u->t, color, i);
502 progress_status(u, u->t, color, i);
504 tree_save(u->t, b, u->games / 100);
506 tree_done(u->t);
508 return true;
511 void
512 uct_dumpbook(struct engine *e, struct board *b, enum stone color)
514 struct uct *u = e->data;
515 u->t = tree_init(b, color);
516 tree_load(u->t, b);
517 tree_dump(u->t, 0);
518 tree_done(u->t);
522 struct uct *
523 uct_state_init(char *arg)
525 struct uct *u = calloc(1, sizeof(struct uct));
527 u->debug_level = 1;
528 u->games = MC_GAMES;
529 u->gamelen = MC_GAMELEN;
530 u->expand_p = 2;
531 u->dumpthres = 1000;
532 u->playout_amaf = true;
533 u->playout_amaf_nakade = false;
534 u->amaf_prior = true;
536 if (arg) {
537 char *optspec, *next = arg;
538 while (*next) {
539 optspec = next;
540 next += strcspn(next, ",");
541 if (*next) { *next++ = 0; } else { *next = 0; }
543 char *optname = optspec;
544 char *optval = strchr(optspec, '=');
545 if (optval) *optval++ = 0;
547 if (!strcasecmp(optname, "debug")) {
548 if (optval)
549 u->debug_level = atoi(optval);
550 else
551 u->debug_level++;
552 } else if (!strcasecmp(optname, "games") && optval) {
553 u->games = atoi(optval);
554 } else if (!strcasecmp(optname, "gamelen") && optval) {
555 u->gamelen = atoi(optval);
556 } else if (!strcasecmp(optname, "expand_p") && optval) {
557 u->expand_p = atoi(optval);
558 } else if (!strcasecmp(optname, "radar_d") && optval) {
559 /* For 19x19, it is good idea to set this to 3. */
560 u->radar_d = atoi(optval);
561 } else if (!strcasecmp(optname, "dumpthres") && optval) {
562 u->dumpthres = atoi(optval);
563 } else if (!strcasecmp(optname, "playout_amaf")) {
564 /* Whether to include random playout moves in
565 * AMAF as well. (Otherwise, only tree moves
566 * are included in AMAF. Of course makes sense
567 * only in connection with an AMAF policy.) */
568 /* with-without: 55.5% (+-4.1) */
569 if (optval && *optval == '0')
570 u->playout_amaf = false;
571 else
572 u->playout_amaf = true;
573 } else if (!strcasecmp(optname, "playout_amaf_nakade")) {
574 /* Whether to include nakade moves from playouts
575 * in the AMAF statistics; this tends to nullify
576 * the playout_amaf effect by adding too much
577 * noise. */
578 if (optval && *optval == '0')
579 u->playout_amaf_nakade = false;
580 else
581 u->playout_amaf_nakade = true;
582 } else if (!strcasecmp(optname, "playout_amaf_cutoff") && optval) {
583 /* Keep only first N% of playout stage AMAF
584 * information. */
585 u->playout_amaf_cutoff = atoi(optval);
586 } else if (!strcasecmp(optname, "policy") && optval) {
587 char *policyarg = strchr(optval, ':');
588 if (policyarg)
589 *policyarg++ = 0;
590 if (!strcasecmp(optval, "ucb1")) {
591 u->policy = policy_ucb1_init(u, policyarg);
592 } else if (!strcasecmp(optval, "ucb1amaf")) {
593 u->policy = policy_ucb1amaf_init(u, policyarg);
594 } else {
595 fprintf(stderr, "UCT: Invalid tree policy %s\n", optval);
597 } else if (!strcasecmp(optname, "playout") && optval) {
598 char *playoutarg = strchr(optval, ':');
599 if (playoutarg)
600 *playoutarg++ = 0;
601 if (!strcasecmp(optval, "moggy")) {
602 u->playout = playout_moggy_init(playoutarg);
603 } else if (!strcasecmp(optval, "light")) {
604 u->playout = playout_light_init(playoutarg);
605 } else {
606 fprintf(stderr, "UCT: Invalid playout policy %s\n", optval);
608 } else if (!strcasecmp(optname, "prior") && optval) {
609 u->prior = uct_prior_init(optval);
610 } else if (!strcasecmp(optname, "amaf_prior") && optval) {
611 u->amaf_prior = atoi(optval);
612 } else if (!strcasecmp(optname, "threads") && optval) {
613 u->threads = atoi(optval);
614 } else if (!strcasecmp(optname, "force_seed") && optval) {
615 u->force_seed = atoi(optval);
616 } else if (!strcasecmp(optname, "no_book")) {
617 u->no_book = true;
618 } else if (!strcasecmp(optname, "dynkomi")) {
619 /* Dynamic komi in handicap game; linearly
620 * decreases to basic settings until move
621 * #optval. */
622 u->dynkomi = optval ? atoi(optval) : 150;
623 } else if (!strcasecmp(optname, "val_scale") && optval) {
624 /* How much of the game result value should be
625 * influenced by win size. */
626 u->val_scale = atof(optval);
627 } else if (!strcasecmp(optname, "val_points") && optval) {
628 /* Maximum size of win to be scaled into game
629 * result value. */
630 u->val_points = atoi(optval) * 2; // result values are doubled
631 } else {
632 fprintf(stderr, "uct: Invalid engine argument %s or missing value\n", optname);
633 exit(1);
638 u->resign_ratio = 0.2; /* Resign when most games are lost. */
639 u->loss_threshold = 0.85; /* Stop reading if after at least 5000 playouts this is best value. */
640 if (!u->policy)
641 u->policy = policy_ucb1amaf_init(u, NULL);
643 if (!u->prior)
644 u->prior = uct_prior_init(NULL);
646 if (!u->playout)
647 u->playout = playout_moggy_init(NULL);
648 u->playout->debug_level = u->debug_level;
650 return u;
654 struct engine *
655 engine_uct_init(char *arg)
657 struct uct *u = uct_state_init(arg);
658 struct engine *e = calloc(1, sizeof(struct engine));
659 e->name = "UCT Engine";
660 e->comment = "I'm playing UCT. When we both pass, I will consider all the stones on the board alive. If you are reading this, write 'yes'. Please capture all dead stones before passing; it will not cost you points (area scoring is used).";
661 e->genmove = uct_genmove;
662 e->notify_play = uct_notify_play;
663 e->data = u;
665 return e;