UCT: Add force_seed to force playout of given game with certain seed
[pachi.git] / uct / tree.c
blob13c0ea6a6439f5153e7e0501019995f8e09c8ba8
1 #include <assert.h>
2 #include <math.h>
3 #include <stddef.h>
4 #include <stdint.h>
5 #include <stdio.h>
6 #include <stdlib.h>
7 #include <string.h>
9 #include "board.h"
10 #include "debug.h"
11 #include "engine.h"
12 #include "move.h"
13 #include "playout.h"
14 #include "uct/internal.h"
15 #include "uct/tree.h"
18 static struct tree_node *
19 tree_init_node(struct tree *t, coord_t coord, int depth)
21 struct tree_node *n = calloc(1, sizeof(*n));
22 n->coord = coord;
23 n->depth = depth;
24 static long c = 1000000;
25 n->hash = c++;
26 if (depth > t->max_depth)
27 t->max_depth = depth;
28 return n;
31 struct tree *
32 tree_init(struct board *board, enum stone color)
34 struct tree *t = calloc(1, sizeof(*t));
35 t->board = board;
36 /* The root PASS move is only virtual, we never play it. */
37 t->root = tree_init_node(t, pass, 0);
38 t->root_symmetry = board->symmetry;
39 return t;
43 static void
44 tree_done_node(struct tree *t, struct tree_node *n)
46 struct tree_node *ni = n->children;
47 while (ni) {
48 struct tree_node *nj = ni->sibling;
49 tree_done_node(t, ni);
50 ni = nj;
52 free(n);
55 void
56 tree_done(struct tree *t)
58 tree_done_node(t, t->root);
59 free(t);
63 static void
64 tree_node_dump(struct tree *tree, struct tree_node *node, int l, int thres)
66 for (int i = 0; i < l; i++) fputc(' ', stderr);
67 int children = 0;
68 for (struct tree_node *ni = node->children; ni; ni = ni->sibling)
69 children++;
70 fprintf(stderr, "[%s] %f (%d/%d playouts [prior %d/%d amaf %d/%d]; hints %x; %d children) <%lld>\n", coord2sstr(node->coord, tree->board), node->u.value, node->u.wins, node->u.playouts, node->prior.wins, node->prior.playouts, node->amaf.wins, node->amaf.playouts, node->hints, children, node->hash);
72 /* Print nodes sorted by #playouts. */
74 struct tree_node *nbox[1000]; int nboxl = 0;
75 for (struct tree_node *ni = node->children; ni; ni = ni->sibling)
76 if (ni->u.playouts > thres)
77 nbox[nboxl++] = ni;
79 while (true) {
80 int best = -1;
81 for (int i = 0; i < nboxl; i++)
82 if (nbox[i] && (best < 0 || nbox[i]->u.playouts > nbox[best]->u.playouts))
83 best = i;
84 if (best < 0)
85 break;
86 tree_node_dump(tree, nbox[best], l + 1, /* node->u.value < 0.1 ? 0 : */ thres);
87 nbox[best] = NULL;
91 void
92 tree_dump(struct tree *tree, int thres)
94 if (thres && tree->root->u.playouts / thres > 100) {
95 /* Be a bit sensible about this; the opening book can create
96 * huge dumps at first. */
97 thres = tree->root->u.playouts / 100 * (thres < 1000 ? 1 : thres / 1000);
99 tree_node_dump(tree, tree->root, 0, thres);
103 static char *
104 tree_book_name(struct board *b)
106 static char buf[256];
107 if (b->handicap > 0) {
108 sprintf(buf, "uctbook-%d-%02.01f-h%d.pachitree", b->size - 2, b->komi, b->handicap);
109 } else {
110 sprintf(buf, "uctbook-%d-%02.01f.pachitree", b->size - 2, b->komi);
112 return buf;
115 static void
116 tree_node_save(FILE *f, struct tree_node *node, int thres)
118 fputc(1, f);
119 fwrite(((void *) node) + offsetof(struct tree_node, depth),
120 sizeof(struct tree_node) - offsetof(struct tree_node, depth),
121 1, f);
123 if (node->u.playouts >= thres)
124 for (struct tree_node *ni = node->children; ni; ni = ni->sibling)
125 tree_node_save(f, ni, thres);
127 fputc(0, f);
130 void
131 tree_save(struct tree *tree, struct board *b, int thres)
133 char *filename = tree_book_name(b);
134 FILE *f = fopen(filename, "wb");
135 if (!f) {
136 perror("fopen");
137 return;
139 tree_node_save(f, tree->root, thres);
140 fputc(0, f);
141 fclose(f);
145 void
146 tree_node_load(FILE *f, struct tree_node *node, int *num, bool invert)
148 (*num)++;
150 fread(((void *) node) + offsetof(struct tree_node, depth),
151 sizeof(struct tree_node) - offsetof(struct tree_node, depth),
152 1, f);
154 /* Keep values in sane scale, otherwise we start overflowing.
155 * We may go slow here but we must be careful about not getting
156 * too huge integers.*/
157 #define MAX_PLAYOUTS 10000000
158 if (node->u.playouts > MAX_PLAYOUTS) {
159 int over = node->u.playouts - MAX_PLAYOUTS;
160 node->u.wins -= ((double) node->u.wins / node->u.playouts) * over;
161 node->u.playouts = MAX_PLAYOUTS;
163 if (node->amaf.playouts > MAX_PLAYOUTS) {
164 int over = node->amaf.playouts - MAX_PLAYOUTS;
165 node->amaf.wins -= ((double) node->amaf.wins / node->amaf.playouts) * over;
166 node->amaf.playouts = MAX_PLAYOUTS;
169 if (invert) {
170 node->u.wins = node->u.playouts - node->u.wins;
171 node->u.value = 1 - node->u.value;
172 node->amaf.wins = node->amaf.playouts - node->amaf.wins;
173 node->amaf.value = 1 - node->amaf.value;
174 node->prior.wins = node->prior.playouts - node->prior.wins;
175 node->prior.value = 1 - node->prior.value;
178 struct tree_node *ni = NULL, *ni_prev = NULL;
179 while (fgetc(f)) {
180 ni_prev = ni; ni = calloc(1, sizeof(*ni));
181 if (!node->children)
182 node->children = ni;
183 else
184 ni_prev->sibling = ni;
185 ni->parent = node;
186 tree_node_load(f, ni, num, invert);
190 void
191 tree_load(struct tree *tree, struct board *b, enum stone color)
193 char *filename = tree_book_name(b);
194 FILE *f = fopen(filename, "rb");
195 if (!f)
196 return;
198 fprintf(stderr, "Loading opening book %s...\n", filename);
200 int num = 0;
201 if (fgetc(f))
202 tree_node_load(f, tree->root, &num, color != S_BLACK);
203 fprintf(stderr, "Loaded %d nodes.\n", num);
205 fclose(f);
209 static struct tree_node *
210 tree_node_copy(struct tree_node *node)
212 struct tree_node *n2 = malloc(sizeof(*n2));
213 *n2 = *node;
214 if (!node->children)
215 return n2;
216 struct tree_node *ni = node->children;
217 struct tree_node *ni2 = tree_node_copy(ni);
218 n2->children = ni2; ni2->parent = n2;
219 while ((ni = ni->sibling)) {
220 ni2->sibling = tree_node_copy(ni);
221 ni2 = ni2->sibling; ni2->parent = n2;
223 return n2;
226 struct tree *
227 tree_copy(struct tree *tree)
229 struct tree *t2 = malloc(sizeof(*t2));
230 *t2 = *tree;
231 t2->root = tree_node_copy(tree->root);
232 return t2;
236 static void
237 tree_node_merge(struct tree_node *dest, struct tree_node *src)
239 dest->hints |= src->hints;
241 /* Merge the children, both are coord-sorted lists. */
242 struct tree_node *di = dest->children, *dip = NULL;
243 struct tree_node *si = src->children, *sip = NULL;
244 while (di && si) {
245 if (di->coord != si->coord) {
246 /* src has some extra items or misses di */
247 struct tree_node *si2 = si->sibling;
248 while (si2 && di->coord != si2->coord) {
249 si2 = si2->sibling;
251 if (!si2)
252 goto next_di; /* src misses di, move on */
253 /* chain the extra [si,si2) items before di */
254 if (dip)
255 dip->sibling = si;
256 else
257 dest->children = si;
258 while (si->sibling != si2) {
259 si->parent = dest;
260 si = si->sibling;
262 si->sibling = di;
263 si = si2;
264 if (sip)
265 sip->sibling = si;
266 else
267 src->children = si;
269 /* Matching nodes - recurse... */
270 tree_node_merge(di, si);
271 /* ...and move on. */
272 sip = si; si = si->sibling;
273 next_di:
274 dip = di; di = di->sibling;
276 if (si) {
277 if (dip)
278 dip->sibling = si;
279 else
280 dest->children = si;
281 while (si) {
282 si->parent = dest;
283 si = si->sibling;
285 if (sip)
286 sip->sibling = NULL;
287 else
288 src->children = NULL;
291 /* In case of prior playouts, we do not want to accumulate them
292 * over merges - they remain static after setup. However, different
293 * trees may have different priors non-deterministically. We just
294 * take the average. */
295 if (dest->prior.playouts != src->prior.playouts
296 || dest->prior.wins != src->prior.wins) {
297 dest->prior.playouts = (dest->prior.playouts + src->prior.playouts) / 2;
298 dest->prior.wins = (dest->prior.wins + src->prior.wins) / 2;
299 if (dest->prior.playouts)
300 dest->prior.value = dest->prior.wins / dest->prior.playouts;
303 dest->amaf.playouts += src->amaf.playouts;
304 dest->amaf.wins += src->amaf.wins;
305 if (dest->amaf.playouts)
306 dest->amaf.value = dest->amaf.wins / dest->amaf.playouts;
308 dest->u.playouts += src->u.playouts;
309 dest->u.wins += src->u.wins;
310 if (dest->prior.playouts + dest->amaf.playouts + dest->u.playouts)
311 tree_update_node_value(dest);
314 /* Merge two trees built upon the same board. Note that the operation is
315 * destructive on src. */
316 void
317 tree_merge(struct tree *dest, struct tree *src)
319 if (src->max_depth > dest->max_depth)
320 dest->max_depth = src->max_depth;
321 tree_node_merge(dest->root, src->root);
325 /* Tree symmetry: When possible, we will localize the tree to a single part
326 * of the board in tree_expand_node() and possibly flip along symmetry axes
327 * to another part of the board in tree_promote_at(). We follow b->symmetry
328 * guidelines here. */
331 void
332 tree_expand_node(struct tree *t, struct tree_node *node, struct board *b, enum stone color, int radar, struct uct_policy *policy, int parity)
334 struct tree_node *ni = tree_init_node(t, pass, node->depth + 1);
335 ni->parent = node; node->children = ni;
336 if (policy->prior)
337 policy->prior(policy, t, ni, b, color, parity);
339 /* The loop considers only the symmetry playground. */
340 if (UDEBUGL(6)) {
341 fprintf(stderr, "expanding %s within [%d,%d],[%d,%d] %d-%d\n",
342 coord2sstr(node->coord, b),
343 b->symmetry.x1, b->symmetry.y1,
344 b->symmetry.x2, b->symmetry.y2,
345 b->symmetry.type, b->symmetry.d);
347 for (int i = b->symmetry.x1; i <= b->symmetry.x2; i++) {
348 for (int j = b->symmetry.y1; j <= b->symmetry.y2; j++) {
349 if (b->symmetry.d) {
350 int x = b->symmetry.type == SYM_DIAG_DOWN ? board_size(b) - 1 - i : i;
351 if (x > j) {
352 if (UDEBUGL(7))
353 fprintf(stderr, "drop %d,%d\n", i, j);
354 continue;
358 coord_t c = coord_xy_otf(i, j, t->board);
359 if (board_at(b, c) != S_NONE)
360 continue;
361 assert(c != node->coord); // I have spotted "C3 C3" in some sequence...
362 /* This looks very useful on large boards - weeds out huge amount of crufty moves. */
363 if (b->hash /* not empty board */ && radar && !board_stone_radar(b, c, radar))
364 continue;
366 struct tree_node *nj = tree_init_node(t, c, node->depth + 1);
367 nj->parent = node; ni->sibling = nj; ni = nj;
369 if (policy->prior)
370 policy->prior(policy, t, ni, b, color, parity);
376 static coord_t
377 flip_coord(struct board *b, coord_t c,
378 bool flip_horiz, bool flip_vert, int flip_diag)
380 int x = coord_x(c, b), y = coord_y(c, b);
381 if (flip_diag) {
382 int z = x; x = y; y = z;
384 if (flip_horiz) {
385 x = board_size(b) - 1 - x;
387 if (flip_vert) {
388 y = board_size(b) - 1 - y;
390 return coord_xy_otf(x, y, b);
393 static void
394 tree_fix_node_symmetry(struct board *b, struct tree_node *node,
395 bool flip_horiz, bool flip_vert, int flip_diag)
397 node->coord = flip_coord(b, node->coord, flip_horiz, flip_vert, flip_diag);
399 for (struct tree_node *ni = node->children; ni; ni = ni->sibling)
400 tree_fix_node_symmetry(b, ni, flip_horiz, flip_vert, flip_diag);
403 static void
404 tree_fix_symmetry(struct tree *tree, struct board *b, coord_t c)
406 struct board_symmetry *s = &tree->root_symmetry;
407 int cx = coord_x(c, b), cy = coord_y(c, b);
409 /* playground X->h->v->d normalization
410 * :::.. .d...
411 * .::.. v....
412 * ..:.. .....
413 * ..... h...X
414 * ..... ..... */
415 bool flip_horiz = cx < s->x1 || cx > s->x2;
416 bool flip_vert = cy < s->y1 || cy > s->y2;
418 bool flip_diag = 0;
419 if (s->d) {
420 bool dir = (s->type == SYM_DIAG_DOWN);
421 int x = dir ^ flip_horiz ^ flip_vert ? board_size(b) - 1 - cx : cx;
422 if (flip_vert ? x < cy : x > cy) {
423 flip_diag = 1;
427 if (UDEBUGL(4)) {
428 fprintf(stderr, "%s will flip %d %d %d -> %s, sym %d (%d) -> %d (%d)\n",
429 coord2sstr(c, b), flip_horiz, flip_vert, flip_diag,
430 coord2sstr(flip_coord(b, c, flip_horiz, flip_vert, flip_diag), b),
431 s->type, s->d, b->symmetry.type, b->symmetry.d);
433 tree_fix_node_symmetry(b, tree->root, flip_horiz, flip_vert, flip_diag);
437 static void
438 tree_unlink_node(struct tree_node *node)
440 struct tree_node *ni = node->parent;
441 if (ni->children == node) {
442 ni->children = node->sibling;
443 } else {
444 ni = ni->children;
445 while (ni->sibling != node)
446 ni = ni->sibling;
447 ni->sibling = node->sibling;
451 void
452 tree_delete_node(struct tree *tree, struct tree_node *node)
454 tree_unlink_node(node);
455 tree_done_node(tree, node);
458 void
459 tree_promote_node(struct tree *tree, struct tree_node *node)
461 assert(node->parent == tree->root);
462 tree_unlink_node(node);
463 tree_done_node(tree, tree->root);
464 tree->root = node;
465 board_symmetry_update(tree->board, &tree->root_symmetry, node->coord);
466 node->parent = NULL;
469 bool
470 tree_promote_at(struct tree *tree, struct board *b, coord_t c)
472 tree_fix_symmetry(tree, b, c);
474 for (struct tree_node *ni = tree->root->children; ni; ni = ni->sibling) {
475 if (ni->coord == c) {
476 tree_promote_node(tree, ni);
477 return true;
480 return false;
483 bool
484 tree_leaf_node(struct tree_node *node)
486 return !(node->children);
489 void
490 tree_update_node_value(struct tree_node *node)
492 bool noamaf = node->hints & NODE_HINT_NOAMAF;
493 node->u.value = (float)(node->u.wins + node->prior.wins + (!noamaf ? node->amaf.wins : 0))
494 / (node->u.playouts + node->prior.playouts + (!noamaf ? node->amaf.playouts : 0));
495 #if 0
496 { struct board b2; board_size(&b2) = 9+2;
497 fprintf(stderr, "%s->%s %d/%d %d/%d %f\n", node->parent ? coord2sstr(node->parent->coord, &b2) : NULL, coord2sstr(node->coord, &b2), node->u.wins, node->u.playouts, node->prior.wins, node->prior.playouts, node->u.value); }
498 #endif