Board+UCT: Symmetry Folding Support
[pachi/peepo.git] / uct / tree.c
bloba8e939e56ea1c0e82829067654a69c5189140a01
1 #include <assert.h>
2 #include <math.h>
3 #include <stddef.h>
4 #include <stdio.h>
5 #include <stdlib.h>
6 #include <string.h>
8 #include "board.h"
9 #include "debug.h"
10 #include "engine.h"
11 #include "move.h"
12 #include "playout.h"
13 #include "uct/internal.h"
14 #include "uct/tree.h"
17 static struct tree_node *
18 tree_init_node(struct tree *t, coord_t coord, int depth)
20 struct tree_node *n = calloc(1, sizeof(*n));
21 n->coord = coord;
22 n->depth = depth;
23 if (depth > t->max_depth)
24 t->max_depth = depth;
25 return n;
28 struct tree *
29 tree_init(struct board *board, enum stone color)
31 struct tree *t = calloc(1, sizeof(*t));
32 t->board = board;
33 /* The root PASS move is only virtual, we never play it. */
34 t->root = tree_init_node(t, pass, 0);
35 return t;
39 static void
40 tree_done_node(struct tree *t, struct tree_node *n)
42 struct tree_node *ni = n->children;
43 while (ni) {
44 struct tree_node *nj = ni->sibling;
45 tree_done_node(t, ni);
46 ni = nj;
48 free(n);
51 void
52 tree_done(struct tree *t)
54 tree_done_node(t, t->root);
55 free(t);
59 static void
60 tree_node_dump(struct tree *tree, struct tree_node *node, int l, int thres)
62 for (int i = 0; i < l; i++) fputc(' ', stderr);
63 fprintf(stderr, "[%s] %f (%d/%d playouts [prior %d/%d amaf %d/%d]; hints %x)\n", coord2sstr(node->coord, tree->board), node->u.value, node->u.wins, node->u.playouts, node->prior.wins, node->prior.playouts, node->amaf.wins, node->amaf.playouts, node->hints);
65 /* Print nodes sorted by #playouts. */
67 struct tree_node *nbox[1000]; int nboxl = 0;
68 for (struct tree_node *ni = node->children; ni; ni = ni->sibling)
69 if (ni->u.playouts > thres)
70 nbox[nboxl++] = ni;
72 while (true) {
73 int best = -1;
74 for (int i = 0; i < nboxl; i++)
75 if (nbox[i] && (best < 0 || nbox[i]->u.playouts > nbox[best]->u.playouts))
76 best = i;
77 if (best < 0)
78 break;
79 tree_node_dump(tree, nbox[best], l + 1, thres);
80 nbox[best] = NULL;
84 void
85 tree_dump(struct tree *tree, int thres)
87 tree_node_dump(tree, tree->root, 0, thres);
91 static char *
92 tree_book_name(struct board *b)
94 static char buf[256];
95 sprintf(buf, "uct-%d-%02.01f.pachibook", b->size - 2, b->komi);
96 return buf;
99 static void
100 tree_node_save(FILE *f, struct tree_node *node, int thres)
102 if (node->u.playouts < thres)
103 return;
105 fputc(1, f);
106 fwrite(((void *) node) + offsetof(struct tree_node, depth),
107 sizeof(struct tree_node) - offsetof(struct tree_node, depth),
108 1, f);
110 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
111 tree_node_save(f, ni, thres);
114 fputc(0, f);
117 void
118 tree_save(struct tree *tree, struct board *b, int thres)
120 char *filename = tree_book_name(b);
121 FILE *f = fopen(filename, "wb");
122 if (!f) {
123 perror("fopen");
124 return;
126 tree_node_save(f, tree->root, thres);
127 fputc(0, f);
128 fclose(f);
132 void
133 tree_node_load(FILE *f, struct tree_node *node, int *num)
135 (*num)++;
137 fread(((void *) node) + offsetof(struct tree_node, depth),
138 sizeof(struct tree_node) - offsetof(struct tree_node, depth),
139 1, f);
141 struct tree_node *ni = NULL, *ni_prev = NULL;
142 while (fgetc(f)) {
143 ni_prev = ni; ni = calloc(1, sizeof(*ni));
144 if (!node->children)
145 node->children = ni;
146 else
147 ni_prev->sibling = ni;
148 ni->parent = node;
149 tree_node_load(f, ni, num);
153 void
154 tree_load(struct tree *tree, struct board *b)
156 char *filename = tree_book_name(b);
157 FILE *f = fopen(filename, "rb");
158 if (!f)
159 return;
161 fprintf(stderr, "Loading opening book %s...\n", filename);
163 int num = 0;
164 if (fgetc(f))
165 tree_node_load(f, tree->root, &num);
166 fprintf(stderr, "Loaded %d nodes.\n", num);
168 fclose(f);
172 /* Tree symmetry: When possible, we will localize the tree to a single part
173 * of the board in tree_expand_node() and possibly flip along symmetry axes
174 * to another part of the board in tree_promote_at(). We follow b->symmetry
175 * guidelines here. */
178 void
179 tree_expand_node(struct tree *t, struct tree_node *node, struct board *b, enum stone color, int radar, struct uct_policy *policy, int parity)
181 struct tree_node *ni = tree_init_node(t, pass, node->depth + 1);
182 ni->parent = node; node->children = ni;
184 /* The loop considers only the symmetry playground. */
185 for (int i = b->symmetry.x1; i <= b->symmetry.x2; i++) {
186 for (int j = b->symmetry.y1; j <= b->symmetry.y2; j++) {
187 if (b->symmetry.d) {
188 int x = b->symmetry.type == SYM_DIAG_DOWN ? board_size(b) - i : i;
189 if (b->symmetry.d < 0 ? x < j : x > j)
190 continue;
193 coord_t c = coord_xy_otf(i, j, t->board);
194 if (board_at(b, c) != S_NONE)
195 continue;
196 /* This looks very useful on large boards - weeds out huge amount of crufty moves. */
197 if (b->hash /* not empty board */ && radar && !board_stone_radar(b, c, radar))
198 continue;
200 struct tree_node *nj = tree_init_node(t, c, node->depth + 1);
201 nj->parent = node; ni->sibling = nj; ni = nj;
203 if (policy->prior)
204 policy->prior(policy, t, ni, b, color, parity);
210 static void
211 tree_fix_node_symmetry(struct board *b, struct tree_node *node,
212 bool flip_horiz, bool flip_vert, int flip_diag)
214 int x = coord_x(node->coord, b), y = coord_y(node->coord, b);
215 if (flip_diag) {
216 int z = x;
217 x = flip_diag == 1 ? y : board_size(b) - y;
218 y = flip_diag == 1 ? z : board_size(b) - z;
220 if (flip_horiz) {
221 x = board_size(b) - x;
223 if (flip_vert) {
224 y = board_size(b) - y;
226 node->coord = coord_xy_otf(x, y, b);
228 for (struct tree_node *ni = node->children; ni; ni = ni->sibling)
229 tree_fix_node_symmetry(b, ni, flip_horiz, flip_vert, flip_diag);
232 static void
233 tree_fix_symmetry(struct tree *tree, struct board *b, coord_t c)
235 /* XXX: We hard-coded assume c is the same move that tree-root,
236 * just possibly flipped. */
238 if (c == tree->root->coord)
239 return;
241 int cx = coord_x(c, b), cy = coord_y(c, b);
242 int rx = coord_x(tree->root->coord, b), ry = coord_y(tree->root->coord, b);
244 /* playground X->h->v->d normalization
245 * :::.. .d...
246 * .::.. v....
247 * ..:.. .....
248 * ..... h...X
249 * ..... ..... */
250 bool flip_horiz = cy == ry;
251 bool flip_vert = cx == rx;
253 int nx = flip_horiz ? board_size(b) - rx : rx;
254 int ny = flip_vert ? board_size(b) - ry : ry;
256 int flip_diag = 0;
257 if (nx == cy && ny == cx) {
258 flip_diag = 1;
259 } else if (board_size(b) - nx == cy && ny == board_size(b) - cx) {
260 flip_diag = 2;
263 tree_fix_node_symmetry(b, tree->root, flip_horiz, flip_vert, flip_diag);
267 static void
268 tree_unlink_node(struct tree_node *node)
270 struct tree_node *ni = node->parent;
271 if (ni->children == node) {
272 ni->children = node->sibling;
273 } else {
274 ni = ni->children;
275 while (ni->sibling != node)
276 ni = ni->sibling;
277 ni->sibling = node->sibling;
281 void
282 tree_delete_node(struct tree *tree, struct tree_node *node)
284 tree_unlink_node(node);
285 tree_done_node(tree, node);
288 void
289 tree_promote_node(struct tree *tree, struct tree_node *node)
291 assert(node->parent == tree->root);
292 tree_unlink_node(node);
293 tree_done_node(tree, tree->root);
294 tree->root = node;
295 node->parent = NULL;
298 bool
299 tree_promote_at(struct tree *tree, struct board *b, coord_t c)
301 tree_fix_symmetry(tree, b, c);
303 for (struct tree_node *ni = tree->root->children; ni; ni = ni->sibling)
304 if (ni->coord == c) {
305 tree_promote_node(tree, ni);
306 return true;
308 return false;
311 bool
312 tree_leaf_node(struct tree_node *node)
314 return !(node->children);
317 void
318 tree_update_node_value(struct tree_node *node, bool add_amaf)
320 node->u.value = (float)(node->u.wins + node->prior.wins + (add_amaf ? node->amaf.wins : 0))
321 / (node->u.playouts + node->prior.playouts + (add_amaf ? node->amaf.playouts : 0));
322 #if 0
323 { struct board b2; board_size(&b2) = 9+2;
324 fprintf(stderr, "%s->%s %d/%d %d/%d %f\n", node->parent ? coord2sstr(node->parent->coord, &b2) : NULL, coord2sstr(node->coord, &b2), node->u.wins, node->u.playouts, node->prior.wins, node->prior.playouts, node->u.value); }
325 #endif