Merge branch 'master' of git://repo.or.cz/pachi
[pachi.git] / uct / uct.c
blob1224b3dae095fa0667983b6ab11e1221e5ad8450
1 #include <assert.h>
2 #include <pthread.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #define DEBUG
9 #include "debug.h"
10 #include "board.h"
11 #include "move.h"
12 #include "playout.h"
13 #include "playout/moggy.h"
14 #include "playout/old.h"
15 #include "playout/light.h"
16 #include "random.h"
17 #include "uct/internal.h"
18 #include "uct/tree.h"
19 #include "uct/uct.h"
21 struct uct_policy *policy_ucb1_init(struct uct *u, char *arg);
22 struct uct_policy *policy_ucb1tuned_init(struct uct *u, char *arg);
23 struct uct_policy *policy_ucb1amaf_init(struct uct *u, char *arg);
26 #define MC_GAMES 40000
27 #define MC_GAMELEN 400
30 static void
31 progress_status(struct uct *u, struct tree *t, enum stone color, int playouts)
33 if (!UDEBUGL(0))
34 return;
36 /* Best move */
37 struct tree_node *best = u->policy->choose(u->policy, t->root, t->board, color);
38 if (!best) {
39 fprintf(stderr, "... No moves left\n");
40 return;
42 fprintf(stderr, "[%d] ", playouts);
43 fprintf(stderr, "best %f ", best->u.value);
45 /* Max depth */
46 fprintf(stderr, "deepest % 2d ", t->max_depth - t->root->depth);
48 /* Best sequence */
49 fprintf(stderr, "| seq ");
50 for (int depth = 0; depth < 6; depth++) {
51 if (best && best->u.playouts >= 25) {
52 fprintf(stderr, "%3s ", coord2sstr(best->coord, t->board));
53 best = u->policy->choose(u->policy, best, t->board, color);
54 } else {
55 fprintf(stderr, " ");
59 /* Best candidates */
60 fprintf(stderr, "| can ");
61 int cans = 4;
62 struct tree_node *can[cans];
63 memset(can, 0, sizeof(can));
64 best = t->root->children;
65 while (best) {
66 int c = 0;
67 while ((!can[c] || best->u.playouts > can[c]->u.playouts) && ++c < cans);
68 for (int d = 0; d < c; d++) can[d] = can[d + 1];
69 if (c > 0) can[c - 1] = best;
70 best = best->sibling;
72 while (--cans >= 0) {
73 if (can[cans]) {
74 fprintf(stderr, "%3s(%.3f) ", coord2sstr(can[cans]->coord, t->board), can[cans]->u.value);
75 } else {
76 fprintf(stderr, " ");
80 fprintf(stderr, "\n");
84 static int
85 uct_playout(struct uct *u, struct board *b, enum stone color, struct tree *t)
87 struct board b2;
88 board_copy(&b2, b);
90 struct playout_amafmap *amaf = NULL;
91 if (u->policy->wants_amaf) {
92 amaf = calloc(1, sizeof(*amaf));
93 amaf->map = calloc(board_size2(&b2) + 1, sizeof(*amaf->map));
94 amaf->map++; // -1 is pass
97 /* Walk the tree until we find a leaf, then expand it and do
98 * a random playout. */
99 struct tree_node *n = t->root;
100 enum stone orig_color = color;
101 int result;
102 int pass_limit = (board_size(&b2) - 2) * (board_size(&b2) - 2) / 2;
103 int passes = is_pass(b->last_move.coord);
104 if (UDEBUGL(8))
105 fprintf(stderr, "--- UCT walk with color %d\n", color);
106 for (; pass; color = stone_other(color)) {
107 if (tree_leaf_node(n)) {
108 if (n->u.playouts >= u->expand_p)
109 tree_expand_node(t, n, &b2, color, u->radar_d, u->policy, (color == orig_color ? 1 : -1));
111 result = play_random_game(&b2, color, u->gamelen, u->playout_amaf ? amaf : NULL, u->playout);
112 if (orig_color != color && result >= 0)
113 result = !result;
114 if (UDEBUGL(7))
115 fprintf(stderr, "[%d..%d] %s random playout result %d\n", orig_color, color, coord2sstr(n->coord, t->board), result);
117 /* Reset color to the @n color. */
118 color = stone_other(color);
119 break;
122 n = u->policy->descend(u->policy, t, n, (color == orig_color ? 1 : -1), pass_limit);
123 assert(n == t->root || n->parent);
124 if (UDEBUGL(7))
125 fprintf(stderr, "-- UCT sent us to [%s] %f\n", coord2sstr(n->coord, t->board), n->u.value);
126 if (amaf && n->coord >= -1)
127 amaf->map[n->coord] = color;
128 struct move m = { n->coord, color };
129 int res = board_play(&b2, &m);
131 if (res < 0 || (!is_pass(m.coord) && !group_at(&b2, m.coord)) /* suicide */
132 || b2.superko_violation) {
133 if (UDEBUGL(3)) {
134 for (struct tree_node *ni = n; ni; ni = ni->parent)
135 fprintf(stderr, "%s ", coord2sstr(ni->coord, t->board));
136 fprintf(stderr, "deleting invalid %s node %d,%d res %d group %d spk %d\n",
137 stone2str(color), coord_x(n->coord,b), coord_y(n->coord,b),
138 res, group_at(&b2, m.coord), b2.superko_violation);
140 tree_delete_node(t, n);
141 result = -1;
142 goto end;
145 if (is_pass(n->coord)) {
146 passes++;
147 if (passes >= 2) {
148 float score = board_official_score(&b2);
149 result = (orig_color == S_BLACK) ? score < 0 : score > 0;
150 //if (UDEBUGL(5))
151 fprintf(stderr, "[%d..%d] %s p-p scoring playout result %d (W %f)\n", orig_color, color, coord2sstr(n->coord, t->board), result, score);
152 if (UDEBUGL(6))
153 board_print(&b2, stderr);
154 break;
156 } else {
157 passes = 0;
161 assert(n == t->root || n->parent);
162 if (result >= 0)
163 u->policy->update(u->policy, t, n, color, amaf, result);
165 end:
166 if (amaf) {
167 free(amaf->map - 1);
168 free(amaf);
170 board_done_noalloc(&b2);
171 return result;
174 static void
175 prepare_move(struct engine *e, struct board *b, enum stone color, coord_t promote)
177 struct uct *u = e->data;
179 if (!b->moves && u->t) {
180 /* Stale state from last game */
181 tree_done(u->t);
182 u->t = NULL;
185 if (!u->t) {
186 u->t = tree_init(b, color);
187 //board_print(b, stderr);
188 tree_load(u->t, b, color);
191 /* XXX: We hope that the opponent didn't suddenly play
192 * several moves in the row. */
193 if (!is_resign(promote) && !tree_promote_at(u->t, b, promote)) {
194 fprintf(stderr, "CANNOT FIND NODE TO PROMOTE!\n");
195 /* Reset tree */
196 tree_done(u->t);
197 u->t = tree_init(b, color);
201 static int
202 uct_playouts(struct uct *u, struct board *b, enum stone color, struct tree *t)
204 int i, games = u->games - (t->root->u.playouts / 1.5);
205 for (i = 0; i < games; i++) {
206 int result = uct_playout(u, b, color, t);
207 if (result < 0) {
208 /* Tree descent has hit invalid move. */
209 continue;
212 if (i > 0 && !(i % 10000)) {
213 progress_status(u, t, color, i);
216 if (i > 0 && !(i % 500)) {
217 struct tree_node *best = u->policy->choose(u->policy, t->root, b, color);
218 if (best && best->u.playouts >= 1000 && best->u.value >= u->loss_threshold)
219 break;
223 progress_status(u, t, color, i);
224 if (UDEBUGL(3))
225 tree_dump(t, u->dumpthres);
226 return i;
229 struct spawn_ctx {
230 struct uct *u;
231 struct board *b;
232 enum stone color;
233 struct tree *t;
234 unsigned long seed;
235 int games;
238 static void *
239 spawn_helper(void *ctx_)
241 struct spawn_ctx *ctx = ctx_;
242 fast_srandom(ctx->seed);
243 ctx->games = uct_playouts(ctx->u, ctx->b, ctx->color, ctx->t);
244 return ctx;
247 static void
248 uct_notify_play(struct engine *e, struct board *b, struct move *m)
250 prepare_move(e, b, stone_other(m->color), m->coord);
253 static coord_t *
254 uct_genmove(struct engine *e, struct board *b, enum stone color)
256 struct uct *u = e->data;
258 /* Seed the tree. */
259 prepare_move(e, b, color, resign);
261 int played_games = 0;
262 if (!u->threads) {
263 played_games = uct_playouts(u, b, color, u->t);
264 } else {
265 pthread_t threads[u->threads];
266 for (int ti = 0; ti < u->threads; ti++) {
267 struct spawn_ctx *ctx = malloc(sizeof(*ctx));
268 ctx->u = u; ctx->b = b; ctx->color = color;
269 ctx->t = tree_copy(u->t);
270 ctx->seed = fast_random(65536) + ti;
271 pthread_create(&threads[ti], NULL, spawn_helper, ctx);
272 if (UDEBUGL(2))
273 fprintf(stderr, "Spawned thread %d\n", ti);
275 for (int ti = 0; ti < u->threads; ti++) {
276 struct spawn_ctx *ctx;
277 pthread_join(threads[ti], (void **) &ctx);
278 played_games += ctx->games;
279 tree_merge(u->t, ctx->t);
280 tree_done(ctx->t);
281 free(ctx);
282 if (UDEBUGL(2))
283 fprintf(stderr, "Joined thread %d\n", ti);
287 if (UDEBUGL(2))
288 tree_dump(u->t, u->dumpthres);
290 struct tree_node *best = u->policy->choose(u->policy, u->t->root, b, color);
291 if (!best) {
292 tree_done(u->t); u->t = NULL;
293 return coord_copy(pass);
295 if (UDEBUGL(0))
296 fprintf(stderr, "*** WINNER is %s (%d,%d) with score %1.4f (%d/%d:%d games)\n", coord2sstr(best->coord, b), coord_x(best->coord, b), coord_y(best->coord, b), best->u.value, best->u.playouts, u->t->root->u.playouts, played_games);
297 if (best->u.value < u->resign_ratio && !is_pass(best->coord)) {
298 tree_done(u->t); u->t = NULL;
299 return coord_copy(resign);
301 tree_promote_node(u->t, best);
302 return coord_copy(best->coord);
305 bool
306 uct_genbook(struct engine *e, struct board *b, enum stone color)
308 struct uct *u = e->data;
309 u->t = tree_init(b, color);
310 tree_load(u->t, b, color);
312 int i;
313 for (i = 0; i < u->games; i++) {
314 int result = uct_playout(u, b, color, u->t);
315 if (result < 0) {
316 /* Tree descent has hit invalid move. */
317 continue;
320 if (i > 0 && !(i % 10000)) {
321 progress_status(u, u->t, color, i);
324 progress_status(u, u->t, color, i);
326 tree_save(u->t, b, u->games / 100);
328 tree_done(u->t);
330 return true;
333 void
334 uct_dumpbook(struct engine *e, struct board *b, enum stone color)
336 struct uct *u = e->data;
337 u->t = tree_init(b, color);
338 tree_load(u->t, b, color);
339 tree_dump(u->t, 0);
340 tree_done(u->t);
344 struct uct *
345 uct_state_init(char *arg)
347 struct uct *u = calloc(1, sizeof(struct uct));
349 u->debug_level = 1;
350 u->games = MC_GAMES;
351 u->gamelen = MC_GAMELEN;
352 u->expand_p = 2;
353 u->dumpthres = 500;
355 if (arg) {
356 char *optspec, *next = arg;
357 while (*next) {
358 optspec = next;
359 next += strcspn(next, ",");
360 if (*next) { *next++ = 0; } else { *next = 0; }
362 char *optname = optspec;
363 char *optval = strchr(optspec, '=');
364 if (optval) *optval++ = 0;
366 if (!strcasecmp(optname, "debug")) {
367 if (optval)
368 u->debug_level = atoi(optval);
369 else
370 u->debug_level++;
371 } else if (!strcasecmp(optname, "games") && optval) {
372 u->games = atoi(optval);
373 } else if (!strcasecmp(optname, "gamelen") && optval) {
374 u->gamelen = atoi(optval);
375 } else if (!strcasecmp(optname, "expand_p") && optval) {
376 u->expand_p = atoi(optval);
377 } else if (!strcasecmp(optname, "radar_d") && optval) {
378 /* For 19x19, it is good idea to set this to 3. */
379 u->radar_d = atoi(optval);
380 } else if (!strcasecmp(optname, "dumpthres") && optval) {
381 u->dumpthres = atoi(optval);
382 } else if (!strcasecmp(optname, "playout_amaf")) {
383 /* Whether to include random playout moves in
384 * AMAF as well. (Otherwise, only tree moves
385 * are included in AMAF. Of course makes sense
386 * only in connection with an AMAF policy.) */
387 u->playout_amaf = true;
388 } else if (!strcasecmp(optname, "policy") && optval) {
389 char *policyarg = strchr(optval, ':');
390 if (policyarg)
391 *policyarg++ = 0;
392 if (!strcasecmp(optval, "ucb1")) {
393 u->policy = policy_ucb1_init(u, policyarg);
394 } else if (!strcasecmp(optval, "ucb1tuned")) {
395 u->policy = policy_ucb1tuned_init(u, policyarg);
396 } else if (!strcasecmp(optval, "ucb1amaf")) {
397 u->policy = policy_ucb1amaf_init(u, policyarg);
398 } else {
399 fprintf(stderr, "UCT: Invalid tree policy %s\n", optval);
401 } else if (!strcasecmp(optname, "playout") && optval) {
402 char *playoutarg = strchr(optval, ':');
403 if (playoutarg)
404 *playoutarg++ = 0;
405 if (!strcasecmp(optval, "old")) {
406 u->playout = playout_old_init(playoutarg);
407 } else if (!strcasecmp(optval, "moggy")) {
408 u->playout = playout_moggy_init(playoutarg);
409 } else if (!strcasecmp(optval, "light")) {
410 u->playout = playout_light_init(playoutarg);
411 } else {
412 fprintf(stderr, "UCT: Invalid playout policy %s\n", optval);
414 } else if (!strcasecmp(optname, "threads") && optval) {
415 u->threads = atoi(optval);
416 } else {
417 fprintf(stderr, "uct: Invalid engine argument %s or missing value\n", optname);
422 u->resign_ratio = 0.2; /* Resign when most games are lost. */
423 u->loss_threshold = 0.95; /* Stop reading if after at least 500 playouts this is best value. */
424 if (!u->policy)
425 u->policy = policy_ucb1_init(u, NULL);
427 if (!u->playout)
428 u->playout = playout_moggy_init(NULL);
429 u->playout->debug_level = u->debug_level;
431 return u;
435 struct engine *
436 engine_uct_init(char *arg)
438 struct uct *u = uct_state_init(arg);
439 struct engine *e = calloc(1, sizeof(struct engine));
440 e->name = "UCT Engine";
441 e->comment = "I'm playing UCT. When we both pass, I will consider all the stones on the board alive. If you are reading this, write 'yes'. Please bear with me at the game end, I need to fill the whole board; if you help me, we will both be happier. Filling the board will not lose points (NZ rules).";
442 e->genmove = uct_genmove;
443 e->notify_play = uct_notify_play;
444 e->data = u;
446 return e;