UCT: Play pass only when we are ready to score
[pachi.git] / uct / tree.c
blob70979874513e78a46edb16a17ffe4aa46edd0037
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "engine.h"
10 #include "move.h"
11 #include "playout.h"
12 #include "uct/tree.h"
15 static struct tree_node *
16 tree_init_node(coord_t coord)
18 struct tree_node *n = calloc(1, sizeof(*n));
19 n->coord = coord;
20 return n;
23 struct tree *
24 tree_init(struct board *board)
26 struct tree *t = calloc(1, sizeof(*t));
27 /* The root PASS move is only virtual, we never play it. */
28 t->root = tree_init_node(pass);
29 t->board = board;
30 return t;
34 static void
35 tree_done_node(struct tree_node *n)
37 struct tree_node *ni = n->children;
38 while (ni) {
39 struct tree_node *nj = ni->sibling;
40 tree_done_node(ni);
41 ni = nj;
43 free(n);
46 void
47 tree_done(struct tree *t)
49 tree_done_node(t->root);
50 free(t);
54 static void
55 tree_node_dump(struct tree *tree, struct tree_node *node, int l)
57 for (int i = 0; i < l; i++) fputc(' ', stderr);
58 fprintf(stderr, "[%s] %f (%d/%d playouts)\n", coord2sstr(node->coord, tree->board), node->value, node->wins, node->playouts);
60 /* Print nodes sorted by #playouts. */
62 struct tree_node *nbox[1000]; int nboxl = 0;
63 for (struct tree_node *ni = node->children; ni; ni = ni->sibling)
64 if (ni->playouts > 200)
65 nbox[nboxl++] = ni;
67 while (true) {
68 int best = -1;
69 for (int i = 0; i < nboxl; i++)
70 if (nbox[i] && (best < 0 || nbox[i]->playouts > nbox[best]->playouts))
71 best = i;
72 if (best < 0)
73 break;
74 tree_node_dump(tree, nbox[best], l + 1);
75 nbox[best] = NULL;
79 void
80 tree_dump(struct tree *tree)
82 tree_node_dump(tree, tree->root, 0);
86 void
87 tree_expand_node(struct tree *t, struct tree_node *node, struct board *b)
89 assert(!node->children);
91 struct tree_node *ni = tree_init_node(pass);
92 ni->parent = node; node->children = ni;
94 /* The loop excludes the offboard margin. */
95 for (int i = 1; i < t->board->size; i++) {
96 for (int j = 1; j < t->board->size; j++) {
97 coord_t c = coord_xy_otf(i, j, t->board);
98 if (board_at(b, c) != S_NONE)
99 continue;
100 struct tree_node *nj = tree_init_node(coord_xy_otf(i, j, t->board));
101 nj->parent = node; ni->sibling = nj; ni = nj;
106 static void
107 tree_unlink_node(struct tree_node *node)
109 struct tree_node *ni = node->parent;
110 if (ni->children == node) {
111 ni->children = node->sibling;
112 } else {
113 ni = ni->children;
114 while (ni->sibling != node)
115 ni = ni->sibling;
116 ni->sibling = node->sibling;
120 void
121 tree_delete_node(struct tree_node *node)
123 tree_unlink_node(node);
124 assert(!node->children);
125 free(node);
128 void
129 tree_promote_node(struct tree *tree, struct tree_node *node)
131 assert(node->parent == tree->root);
132 tree_unlink_node(node);
133 tree_done_node(tree->root);
134 tree->root = node;
135 node->parent = NULL;
138 bool
139 tree_leaf_node(struct tree_node *node)
141 return !(node->children);
144 struct tree_node *
145 tree_best_child(struct tree_node *node, struct board *b, enum stone color)
147 struct tree_node *nbest = NULL;
148 for (struct tree_node *ni = node->children; ni; ni = ni->sibling)
149 // we compare playouts and choose the best-explored
150 // child; comparing values is more brittle
151 if (!nbest || ni->playouts > nbest->playouts) {
152 /* Play pass only if we can afford scoring */
153 if (is_pass(ni->coord)) {
154 float score = board_fast_score(b);
155 if (color == S_BLACK)
156 score = -score;
157 //fprintf(stderr, "%d score %f\n", b->last_move.color, score);
158 if (score <= 0)
159 continue;
161 nbest = ni;
163 return nbest;
167 struct tree_node *
168 tree_uct_descend(struct tree *tree, struct tree_node *node, int parity, bool allow_pass)
170 float xpl = log(node->playouts) * tree->explore_p;
172 struct tree_node *nbest = node->children;
173 float best_urgency = -9999;
174 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
175 /* Do not consider passing early. */
176 if (likely(!allow_pass) && unlikely(is_pass(ni->coord)))
177 continue;
178 #ifdef UCB1_TUNED
179 float xpl_loc = (ni->value - ni->value * ni->value);
180 if (parity < 0) xpl_loc = 1 - xpl_loc;
181 xpl_loc += sqrt(xpl / ni->playouts);
182 if (xpl_loc > 1.0/4) xpl_loc = 1.0/4;
183 float urgency = ni->value * parity + sqrt(xpl * xpl_loc / ni->playouts);
184 #else
185 float urgency = ni->value * parity + sqrt(xpl / ni->playouts);
186 #endif
187 if (urgency > best_urgency) {
188 best_urgency = urgency;
189 nbest = ni;
192 return nbest;
195 void
196 tree_uct_update(struct tree_node *node, int result)
198 for (; node; node = node->parent) {
199 node->playouts++;
200 node->wins += result;
201 node->value = (float)node->wins / node->playouts;