UCT: Do not load opening book unless the game just began and we're black
[pachi/json.git] / uct / uct.c
blob4f9861bfabaef3fcfa7b99e53843d4b6825c6c6a
1 #include <assert.h>
2 #include <pthread.h>
3 #include <signal.h>
4 #include <stdio.h>
5 #include <stdlib.h>
6 #include <string.h>
8 #define DEBUG
10 #include "debug.h"
11 #include "board.h"
12 #include "move.h"
13 #include "playout.h"
14 #include "playout/moggy.h"
15 #include "playout/light.h"
16 #include "random.h"
17 #include "uct/internal.h"
18 #include "uct/tree.h"
19 #include "uct/uct.h"
21 struct uct_policy *policy_ucb1_init(struct uct *u, char *arg);
22 struct uct_policy *policy_ucb1tuned_init(struct uct *u, char *arg);
23 struct uct_policy *policy_ucb1amaf_init(struct uct *u, char *arg);
26 #define MC_GAMES 80000
27 #define MC_GAMELEN 400
30 static void
31 progress_status(struct uct *u, struct tree *t, enum stone color, int playouts)
33 if (!UDEBUGL(0))
34 return;
36 /* Best move */
37 struct tree_node *best = u->policy->choose(u->policy, t->root, t->board, color);
38 if (!best) {
39 fprintf(stderr, "... No moves left\n");
40 return;
42 fprintf(stderr, "[%d] ", playouts);
43 fprintf(stderr, "best %f ", best->u.value);
45 /* Max depth */
46 fprintf(stderr, "deepest % 2d ", t->max_depth - t->root->depth);
48 /* Best sequence */
49 fprintf(stderr, "| seq ");
50 for (int depth = 0; depth < 6; depth++) {
51 if (best && best->u.playouts >= 25) {
52 fprintf(stderr, "%3s ", coord2sstr(best->coord, t->board));
53 best = u->policy->choose(u->policy, best, t->board, color);
54 } else {
55 fprintf(stderr, " ");
59 /* Best candidates */
60 fprintf(stderr, "| can ");
61 int cans = 4;
62 struct tree_node *can[cans];
63 memset(can, 0, sizeof(can));
64 best = t->root->children;
65 while (best) {
66 int c = 0;
67 while ((!can[c] || best->u.playouts > can[c]->u.playouts) && ++c < cans);
68 for (int d = 0; d < c; d++) can[d] = can[d + 1];
69 if (c > 0) can[c - 1] = best;
70 best = best->sibling;
72 while (--cans >= 0) {
73 if (can[cans]) {
74 fprintf(stderr, "%3s(%.3f) ", coord2sstr(can[cans]->coord, t->board), can[cans]->u.value);
75 } else {
76 fprintf(stderr, " ");
80 fprintf(stderr, "\n");
84 static int
85 uct_leaf_node(struct uct *u, struct board *b, enum stone player_color,
86 struct playout_amafmap *amaf,
87 struct tree *t, struct tree_node *n, enum stone node_color,
88 char *spaces)
90 enum stone next_color = stone_other(node_color);
91 if (n->u.playouts >= u->expand_p) {
92 // fprintf(stderr, "expanding %s (%p ^-%p)\n", coord2sstr(n->coord, b), n, n->parent);
93 tree_expand_node(t, n, b, next_color, u->radar_d, u->policy,
94 (next_color == player_color ? 1 : -1));
96 if (UDEBUGL(7))
97 fprintf(stderr, "%s*-- UCT playout #%d start [%s] %f\n",
98 spaces, n->u.playouts, coord2sstr(n->coord, t->board), n->u.value);
100 int result = play_random_game(b, next_color, u->gamelen, u->playout_amaf ? amaf : NULL, u->playout);
101 if (player_color != next_color && result >= 0)
102 result = !result;
103 if (UDEBUGL(7))
104 fprintf(stderr, "%s -- [%d..%d] %s random playout result %d\n",
105 spaces, player_color, next_color, coord2sstr(n->coord, t->board), result);
107 return result;
110 static int
111 uct_playout(struct uct *u, struct board *b, enum stone player_color, struct tree *t)
113 struct board b2;
114 board_copy(&b2, b);
116 struct playout_amafmap *amaf = NULL;
117 if (u->policy->wants_amaf) {
118 amaf = calloc(1, sizeof(*amaf));
119 amaf->map = calloc(board_size2(&b2) + 1, sizeof(*amaf->map));
120 amaf->map++; // -1 is pass
123 /* Walk the tree until we find a leaf, then expand it and do
124 * a random playout. */
125 struct tree_node *n = t->root;
126 enum stone node_color = stone_other(player_color);
127 assert(node_color == t->root_color);
129 int result;
130 int pass_limit = (board_size(&b2) - 2) * (board_size(&b2) - 2) / 2;
131 int passes = is_pass(b->last_move.coord);
133 /* debug */
134 int depth = 0;
135 static char spaces[] = "\0 ";
136 /* /debug */
137 if (UDEBUGL(8))
138 fprintf(stderr, "--- UCT walk with color %d\n", player_color);
140 while (!tree_leaf_node(n) && passes < 2) {
141 spaces[depth++] = ' '; spaces[depth] = 0;
143 /* Parity is chosen already according to the child color, since
144 * it is applied to children. */
145 node_color = stone_other(node_color);
146 n = u->policy->descend(u->policy, t, n, (node_color == player_color ? 1 : -1), pass_limit);
148 assert(n == t->root || n->parent);
149 if (UDEBUGL(7))
150 fprintf(stderr, "%s+-- UCT sent us to [%s:%d] %f\n",
151 spaces, coord2sstr(n->coord, t->board), n->coord, n->u.value);
153 if (amaf && n->coord >= -1 && !is_pass(n->coord)) {
154 if (amaf->map[n->coord] == S_NONE) {
155 amaf->map[n->coord] = node_color;
156 } else {
157 amaf_op(amaf->map[n->coord], +);
161 struct move m = { n->coord, node_color };
162 int res = board_play(&b2, &m);
164 if (res < 0 || (!is_pass(m.coord) && !group_at(&b2, m.coord)) /* suicide */
165 || b2.superko_violation) {
166 if (UDEBUGL(3)) {
167 for (struct tree_node *ni = n; ni; ni = ni->parent)
168 fprintf(stderr, "%s ", coord2sstr(ni->coord, t->board));
169 fprintf(stderr, "deleting invalid %s node %d,%d res %d group %d spk %d\n",
170 stone2str(node_color), coord_x(n->coord,b), coord_y(n->coord,b),
171 res, group_at(&b2, m.coord), b2.superko_violation);
173 tree_delete_node(t, n);
174 result = -1;
175 goto end;
178 if (is_pass(n->coord))
179 passes++;
180 else
181 passes = 0;
184 if (passes >= 2) {
185 float score = board_official_score(&b2);
186 result = (player_color == S_BLACK) ? score < 0 : score > 0;
188 if (UDEBUGL(5))
189 fprintf(stderr, "[%d..%d] %s p-p scoring playout result %d (W %f)\n",
190 player_color, node_color, coord2sstr(n->coord, t->board), result, score);
191 if (UDEBUGL(6))
192 board_print(&b2, stderr);
194 } else { assert(tree_leaf_node(n));
195 result = uct_leaf_node(u, &b2, player_color, amaf, t, n, node_color, spaces);
198 assert(n == t->root || n->parent);
199 if (result >= 0)
200 u->policy->update(u->policy, t, n, node_color, player_color, amaf, result);
202 end:
203 if (amaf) {
204 free(amaf->map - 1);
205 free(amaf);
207 board_done_noalloc(&b2);
208 return result;
211 static void
212 prepare_move(struct engine *e, struct board *b, enum stone color, coord_t promote)
214 struct uct *u = e->data;
216 if ((!b->moves || color != stone_other(u->t->root_color)) && u->t) {
217 /* Stale state from last game */
218 tree_done(u->t);
219 u->t = NULL;
222 if (!u->t) {
223 u->t = tree_init(b, color);
224 if (u->force_seed)
225 fast_srandom(u->force_seed);
226 if (UDEBUGL(0))
227 fprintf(stderr, "Fresh board with random seed %lu\n", fast_getseed());
228 //board_print(b, stderr);
229 if (!u->no_book && !b->moves && color == S_BLACK)
230 tree_load(u->t, b, color);
233 /* XXX: We hope that the opponent didn't suddenly play
234 * several moves in the row. */
235 if (!is_resign(promote) && !tree_promote_at(u->t, b, promote)) {
236 if (UDEBUGL(2))
237 fprintf(stderr, "<cannot find node to promote>\n");
238 /* Reset tree */
239 tree_done(u->t);
240 u->t = tree_init(b, color);
244 /* Set in main thread in case the playouts should stop. */
245 static volatile sig_atomic_t halt = 0;
247 static int
248 uct_playouts(struct uct *u, struct board *b, enum stone color, struct tree *t)
250 int i, games = u->games;
251 if (t->root->children)
252 games -= t->root->u.playouts / 1.5;
253 /* else this is highly read-out but dead-end branch of opening book;
254 * we need to start from scratch; XXX: Maybe actually base the readout
255 * count based on number of playouts of best node? */
256 for (i = 0; i < games; i++) {
257 int result = uct_playout(u, b, color, t);
258 if (result < 0) {
259 /* Tree descent has hit invalid move. */
260 continue;
263 if (i > 0 && !(i % 10000)) {
264 progress_status(u, t, color, i);
267 if (i > 0 && !(i % 500)) {
268 struct tree_node *best = u->policy->choose(u->policy, t->root, b, color);
269 if (best && ((best->u.playouts >= 5000 && best->u.value >= u->loss_threshold)
270 || (best->u.playouts >= 500 && best->u.value >= 0.95)))
271 break;
274 if (halt) {
275 if (UDEBUGL(2))
276 fprintf(stderr, "<halting early, %d games skipped>\n", games - i);
277 break;
281 progress_status(u, t, color, i);
282 if (UDEBUGL(3))
283 tree_dump(t, u->dumpthres);
284 return i;
287 static pthread_mutex_t finish_mutex = PTHREAD_MUTEX_INITIALIZER;
288 static pthread_cond_t finish_cond = PTHREAD_COND_INITIALIZER;
289 static volatile int finish_thread;
290 static pthread_mutex_t finish_serializer = PTHREAD_MUTEX_INITIALIZER;
292 struct spawn_ctx {
293 int tid;
294 struct uct *u;
295 struct board *b;
296 enum stone color;
297 struct tree *t;
298 unsigned long seed;
299 int games;
302 static void *
303 spawn_helper(void *ctx_)
305 struct spawn_ctx *ctx = ctx_;
306 /* Setup */
307 fast_srandom(ctx->seed);
308 /* Run */
309 ctx->games = uct_playouts(ctx->u, ctx->b, ctx->color, ctx->t);
310 /* Finish */
311 pthread_mutex_lock(&finish_serializer);
312 pthread_mutex_lock(&finish_mutex);
313 finish_thread = ctx->tid;
314 pthread_cond_signal(&finish_cond);
315 pthread_mutex_unlock(&finish_mutex);
316 return ctx;
319 static void
320 uct_notify_play(struct engine *e, struct board *b, struct move *m)
322 prepare_move(e, b, stone_other(m->color), m->coord);
325 static coord_t *
326 uct_genmove(struct engine *e, struct board *b, enum stone color)
328 struct uct *u = e->data;
330 /* Seed the tree. */
331 prepare_move(e, b, color, resign);
333 int played_games = 0;
334 if (!u->threads) {
335 played_games = uct_playouts(u, b, color, u->t);
336 } else {
337 pthread_t threads[u->threads];
338 int joined = 0;
339 halt = 0;
340 pthread_mutex_lock(&finish_mutex);
341 /* Spawn threads... */
342 for (int ti = 0; ti < u->threads; ti++) {
343 struct spawn_ctx *ctx = malloc(sizeof(*ctx));
344 ctx->u = u; ctx->b = b; ctx->color = color;
345 ctx->t = tree_copy(u->t); ctx->tid = ti;
346 ctx->seed = fast_random(65536) + ti;
347 pthread_create(&threads[ti], NULL, spawn_helper, ctx);
348 if (UDEBUGL(2))
349 fprintf(stderr, "Spawned thread %d\n", ti);
351 /* ...and collect them back: */
352 while (joined < u->threads) {
353 /* Wait for some thread to finish... */
354 pthread_cond_wait(&finish_cond, &finish_mutex);
355 /* ...and gather its remnants. */
356 struct spawn_ctx *ctx;
357 pthread_join(threads[finish_thread], (void **) &ctx);
358 played_games += ctx->games;
359 joined++;
360 tree_merge(u->t, ctx->t);
361 tree_done(ctx->t);
362 free(ctx);
363 if (UDEBUGL(2))
364 fprintf(stderr, "Joined thread %d\n", finish_thread);
365 /* Do not get stalled by slow threads. */
366 if (joined >= u->threads / 2)
367 halt = 1;
368 pthread_mutex_unlock(&finish_serializer);
370 pthread_mutex_unlock(&finish_mutex);
373 if (UDEBUGL(2))
374 tree_dump(u->t, u->dumpthres);
376 struct tree_node *best = u->policy->choose(u->policy, u->t->root, b, color);
377 if (!best) {
378 tree_done(u->t); u->t = NULL;
379 return coord_copy(pass);
381 if (UDEBUGL(0))
382 progress_status(u, u->t, color, played_games);
383 if (UDEBUGL(1))
384 fprintf(stderr, "*** WINNER is %s (%d,%d) with score %1.4f (%d/%d:%d games)\n", coord2sstr(best->coord, b), coord_x(best->coord, b), coord_y(best->coord, b), best->u.value, best->u.playouts, u->t->root->u.playouts, played_games);
385 if (best->u.value < u->resign_ratio && !is_pass(best->coord)) {
386 tree_done(u->t); u->t = NULL;
387 return coord_copy(resign);
389 tree_promote_node(u->t, best);
390 return coord_copy(best->coord);
393 bool
394 uct_genbook(struct engine *e, struct board *b, enum stone color)
396 struct uct *u = e->data;
397 u->t = tree_init(b, color);
398 tree_load(u->t, b, color);
400 int i;
401 for (i = 0; i < u->games; i++) {
402 int result = uct_playout(u, b, color, u->t);
403 if (result < 0) {
404 /* Tree descent has hit invalid move. */
405 continue;
408 if (i > 0 && !(i % 10000)) {
409 progress_status(u, u->t, color, i);
412 progress_status(u, u->t, color, i);
414 tree_save(u->t, b, u->games / 100);
416 tree_done(u->t);
418 return true;
421 void
422 uct_dumpbook(struct engine *e, struct board *b, enum stone color)
424 struct uct *u = e->data;
425 u->t = tree_init(b, color);
426 tree_load(u->t, b, color);
427 tree_dump(u->t, 0);
428 tree_done(u->t);
432 struct uct *
433 uct_state_init(char *arg)
435 struct uct *u = calloc(1, sizeof(struct uct));
437 u->debug_level = 1;
438 u->games = MC_GAMES;
439 u->gamelen = MC_GAMELEN;
440 u->expand_p = 2;
441 u->dumpthres = 1000;
442 u->playout_amaf = false;
444 if (arg) {
445 char *optspec, *next = arg;
446 while (*next) {
447 optspec = next;
448 next += strcspn(next, ",");
449 if (*next) { *next++ = 0; } else { *next = 0; }
451 char *optname = optspec;
452 char *optval = strchr(optspec, '=');
453 if (optval) *optval++ = 0;
455 if (!strcasecmp(optname, "debug")) {
456 if (optval)
457 u->debug_level = atoi(optval);
458 else
459 u->debug_level++;
460 } else if (!strcasecmp(optname, "games") && optval) {
461 u->games = atoi(optval);
462 } else if (!strcasecmp(optname, "gamelen") && optval) {
463 u->gamelen = atoi(optval);
464 } else if (!strcasecmp(optname, "expand_p") && optval) {
465 u->expand_p = atoi(optval);
466 } else if (!strcasecmp(optname, "radar_d") && optval) {
467 /* For 19x19, it is good idea to set this to 3. */
468 u->radar_d = atoi(optval);
469 } else if (!strcasecmp(optname, "dumpthres") && optval) {
470 u->dumpthres = atoi(optval);
471 } else if (!strcasecmp(optname, "playout_amaf")) {
472 /* Whether to include random playout moves in
473 * AMAF as well. (Otherwise, only tree moves
474 * are included in AMAF. Of course makes sense
475 * only in connection with an AMAF policy.) */
476 /* with-without: 55.5% (+-4.1) */
477 if (optval && *optval == '0')
478 u->playout_amaf = false;
479 else
480 u->playout_amaf = true;
481 } else if (!strcasecmp(optname, "policy") && optval) {
482 char *policyarg = strchr(optval, ':');
483 if (policyarg)
484 *policyarg++ = 0;
485 if (!strcasecmp(optval, "ucb1")) {
486 u->policy = policy_ucb1_init(u, policyarg);
487 } else if (!strcasecmp(optval, "ucb1tuned")) {
488 u->policy = policy_ucb1tuned_init(u, policyarg);
489 } else if (!strcasecmp(optval, "ucb1amaf")) {
490 u->policy = policy_ucb1amaf_init(u, policyarg);
491 } else {
492 fprintf(stderr, "UCT: Invalid tree policy %s\n", optval);
494 } else if (!strcasecmp(optname, "playout") && optval) {
495 char *playoutarg = strchr(optval, ':');
496 if (playoutarg)
497 *playoutarg++ = 0;
498 if (!strcasecmp(optval, "moggy")) {
499 u->playout = playout_moggy_init(playoutarg);
500 } else if (!strcasecmp(optval, "light")) {
501 u->playout = playout_light_init(playoutarg);
502 } else {
503 fprintf(stderr, "UCT: Invalid playout policy %s\n", optval);
505 } else if (!strcasecmp(optname, "threads") && optval) {
506 u->threads = atoi(optval);
507 } else if (!strcasecmp(optname, "force_seed") && optval) {
508 u->force_seed = atoi(optval);
509 } else if (!strcasecmp(optname, "no_book")) {
510 u->no_book = true;
511 } else {
512 fprintf(stderr, "uct: Invalid engine argument %s or missing value\n", optname);
517 u->resign_ratio = 0.2; /* Resign when most games are lost. */
518 u->loss_threshold = 0.85; /* Stop reading if after at least 5000 playouts this is best value. */
519 if (!u->policy)
520 u->policy = policy_ucb1amaf_init(u, NULL);
522 if (!u->playout)
523 u->playout = playout_moggy_init(NULL);
524 u->playout->debug_level = u->debug_level;
526 return u;
530 struct engine *
531 engine_uct_init(char *arg)
533 struct uct *u = uct_state_init(arg);
534 struct engine *e = calloc(1, sizeof(struct engine));
535 e->name = "UCT Engine";
536 e->comment = "I'm playing UCT. When we both pass, I will consider all the stones on the board alive. If you are reading this, write 'yes'. Please bear with me at the game end, I need to fill the whole board; if you help me, we will both be happier. Filling the board will not lose points (NZ rules).";
537 e->genmove = uct_genmove;
538 e->notify_play = uct_notify_play;
539 e->data = u;
541 return e;