UCT: Add compile-time UCB1_TUNED support
[pachi.git] / uct / tree.c
blobdb8e00ee02995ff4aff3d40eb06a8378306f6c2a
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "engine.h"
10 #include "move.h"
11 #include "playout.h"
12 #include "uct/tree.h"
15 static struct tree_node *
16 tree_init_node(coord_t coord)
18 struct tree_node *n = calloc(1, sizeof(*n));
19 n->coord = coord;
20 return n;
23 struct tree *
24 tree_init(struct board *board)
26 struct tree *t = calloc(1, sizeof(*t));
27 /* The root PASS move is only virtual, we never play it. */
28 t->root = tree_init_node(pass);
29 t->board = board;
30 return t;
34 static void
35 tree_done_node(struct tree_node *n)
37 struct tree_node *ni = n->children;
38 while (ni) {
39 struct tree_node *nj = ni->sibling;
40 tree_done_node(ni);
41 ni = nj;
43 free(n);
46 void
47 tree_done(struct tree *t)
49 tree_done_node(t->root);
50 free(t);
54 static void
55 tree_node_dump(struct tree *tree, struct tree_node *node, int l)
57 for (int i = 0; i < l; i++) fputc(' ', stderr);
58 fprintf(stderr, "[%s] %f (%d playouts)\n", coord2sstr(node->coord, tree->board), node->value, node->playouts);
60 /* Print nodes sorted by #playouts. */
62 struct tree_node *nbox[1000]; int nboxl = 0;
63 for (struct tree_node *ni = node->children; ni; ni = ni->sibling)
64 if (ni->playouts > 200)
65 nbox[nboxl++] = ni;
67 while (true) {
68 int best = -1;
69 for (int i = 0; i < nboxl; i++)
70 if (nbox[i] && (best < 0 || nbox[i]->playouts > nbox[best]->playouts))
71 best = i;
72 if (best < 0)
73 break;
74 tree_node_dump(tree, nbox[best], l + 1);
75 nbox[best] = NULL;
79 void
80 tree_dump(struct tree *tree)
82 tree_node_dump(tree, tree->root, 0);
86 void
87 tree_expand_node(struct tree *t, struct tree_node *node)
89 assert(!node->children);
91 struct tree_node *ni = tree_init_node(pass);
92 ni->parent = node; node->children = ni;
94 /* The loop excludes the offboard margin. */
95 for (int i = 1; i < t->board->size; i++) {
96 for (int j = 1; j < t->board->size; j++) {
97 struct tree_node *nj = tree_init_node(coord_xy_otf(i, j, t->board));
98 nj->parent = node; ni->sibling = nj; ni = nj;
103 void
104 tree_delete_node(struct tree_node *node)
106 /* Unlink */
107 struct tree_node *ni = node->parent;
108 if (ni->children == node) {
109 ni->children = node->sibling;
110 } else {
111 ni = ni->children;
112 while (ni->sibling != node)
113 ni = ni->sibling;
114 ni->sibling = node->sibling;
117 /* Free */
118 assert(!node->children);
119 free(node);
122 bool
123 tree_leaf_node(struct tree_node *node)
125 return !(node->children);
128 struct tree_node *
129 tree_best_child(struct tree_node *node)
131 struct tree_node *nbest = node->children;
132 for (struct tree_node *ni = node->children; ni; ni = ni->sibling)
133 // we compare playouts and choose the best-explored
134 // child; comparing values is more brittle
135 if (ni->playouts > nbest->playouts)
136 nbest = ni;
137 return nbest;
141 struct tree_node *
142 tree_uct_descend(struct tree *tree, struct tree_node *node, int parity)
144 float xpl = log(node->playouts) * tree->explore_p;
146 struct tree_node *nbest = node->children;
147 float best_urgency = -9999;
148 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
149 #ifdef UCB1_TUNED
150 float xpl_loc = (ni->value - ni->value * ni->value);
151 if (parity < 0) xpl_loc = 1 - xpl_loc;
152 xpl_loc += sqrt(xpl / ni->playouts);
153 if (xpl_loc > 1.0/4) xpl_loc = 1.0/4;
154 float urgency = ni->value * parity + sqrt(xpl * xpl_loc / ni->playouts);
155 #else
156 float urgency = ni->value * parity + sqrt(xpl / ni->playouts);
157 #endif
158 if (urgency > best_urgency) {
159 best_urgency = urgency;
160 nbest = ni;
163 return nbest;
166 void
167 tree_uct_update(struct tree_node *node, int result)
169 for (; node; node = node->parent) {
170 node->playouts++;
171 node->wins += result;
172 node->value = (float)node->wins / node->playouts;