merge_groups(): Use group_base() for connecting groups
[pachi/peepo.git] / uct / tree.c
blobdda2ec7c9d16638343ae60c182731a531f0debd3
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "engine.h"
10 #include "move.h"
11 #include "playout.h"
12 #include "uct/internal.h"
13 #include "uct/tree.h"
16 static struct tree_node *
17 tree_init_node(struct tree *t, coord_t coord, int depth)
19 struct tree_node *n = calloc(1, sizeof(*n));
20 n->coord = coord;
21 n->depth = depth;
22 if (depth > t->max_depth)
23 t->max_depth = depth;
24 return n;
27 struct tree *
28 tree_init(struct board *board, enum stone color)
30 struct tree *t = calloc(1, sizeof(*t));
31 t->board = board;
32 /* The root PASS move is only virtual, we never play it. */
33 t->root = tree_init_node(t, pass, 0);
34 return t;
38 static void
39 tree_done_node(struct tree *t, struct tree_node *n)
41 struct tree_node *ni = n->children;
42 while (ni) {
43 struct tree_node *nj = ni->sibling;
44 tree_done_node(t, ni);
45 ni = nj;
47 free(n);
50 void
51 tree_done(struct tree *t)
53 tree_done_node(t, t->root);
54 free(t);
58 static void
59 tree_node_dump(struct tree *tree, struct tree_node *node, int l, int thres)
61 for (int i = 0; i < l; i++) fputc(' ', stderr);
62 fprintf(stderr, "[%s] %f (%d/%d playouts [prior %d/%d amaf %d/%d]; hints %x)\n", coord2sstr(node->coord, tree->board), node->u.value, node->u.wins, node->u.playouts, node->prior.wins, node->prior.playouts, node->amaf.wins, node->amaf.playouts, node->hints);
64 /* Print nodes sorted by #playouts. */
66 struct tree_node *nbox[1000]; int nboxl = 0;
67 for (struct tree_node *ni = node->children; ni; ni = ni->sibling)
68 if (ni->u.playouts > thres)
69 nbox[nboxl++] = ni;
71 while (true) {
72 int best = -1;
73 for (int i = 0; i < nboxl; i++)
74 if (nbox[i] && (best < 0 || nbox[i]->u.playouts > nbox[best]->u.playouts))
75 best = i;
76 if (best < 0)
77 break;
78 tree_node_dump(tree, nbox[best], l + 1, thres);
79 nbox[best] = NULL;
83 void
84 tree_dump(struct tree *tree, int thres)
86 tree_node_dump(tree, tree->root, 0, thres);
90 void
91 tree_expand_node(struct tree *t, struct tree_node *node, struct board *b, enum stone color, int radar, struct uct_policy *policy, int parity)
93 struct tree_node *ni = tree_init_node(t, pass, node->depth + 1);
94 ni->parent = node; node->children = ni;
96 /* The loop excludes the offboard margin. */
97 for (int i = 1; i < board_size(t->board); i++) {
98 for (int j = 1; j < board_size(t->board); j++) {
99 coord_t c = coord_xy_otf(i, j, t->board);
100 if (board_at(b, c) != S_NONE)
101 continue;
102 /* This looks very useful on large boards - weeds out huge amount of crufty moves. */
103 if (b->hash /* not empty board */ && radar && !board_stone_radar(b, c, radar))
104 continue;
106 struct tree_node *nj = tree_init_node(t, c, node->depth + 1);
107 nj->parent = node; ni->sibling = nj; ni = nj;
109 if (policy->prior)
110 policy->prior(policy, t, ni, b, color, parity);
115 static void
116 tree_unlink_node(struct tree_node *node)
118 struct tree_node *ni = node->parent;
119 if (ni->children == node) {
120 ni->children = node->sibling;
121 } else {
122 ni = ni->children;
123 while (ni->sibling != node)
124 ni = ni->sibling;
125 ni->sibling = node->sibling;
129 void
130 tree_delete_node(struct tree *tree, struct tree_node *node)
132 tree_unlink_node(node);
133 tree_done_node(tree, node);
136 void
137 tree_promote_node(struct tree *tree, struct tree_node *node)
139 assert(node->parent == tree->root);
140 tree_unlink_node(node);
141 tree_done_node(tree, tree->root);
142 tree->root = node;
143 node->parent = NULL;
146 bool
147 tree_leaf_node(struct tree_node *node)
149 return !(node->children);
152 void
153 tree_update_node_value(struct tree_node *node, bool add_amaf)
155 node->u.value = (float)(node->u.wins + node->prior.wins + (add_amaf ? node->amaf.wins : 0))
156 / (node->u.playouts + node->prior.playouts + (add_amaf ? node->amaf.playouts : 0));
157 #if 0
158 { struct board b2; board_size(&b2) = 9+2;
159 fprintf(stderr, "%s->%s %d/%d %d/%d %f\n", node->parent ? coord2sstr(node->parent->coord, &b2) : NULL, coord2sstr(node->coord, &b2), node->u.wins, node->u.playouts, node->prior.wins, node->prior.playouts, node->u.value); }
160 #endif