UCT Tree: Unique id for tree nodes for debugging
[pachi.git] / uct / policy / ucb1amaf.c
blob864e667f0bf4fc10a5c7f83628f8774fca34584a
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "move.h"
10 #include "random.h"
11 #include "uct/internal.h"
12 #include "uct/tree.h"
14 /* This implements the UCB1 policy with an extra AMAF heuristics. */
16 struct ucb1_policy_amaf {
17 /* This is what the Modification of UCT with Patterns in Monte Carlo Go
18 * paper calls 'p'. Original UCB has this on 2, but this seems to
19 * produce way too wide searches; reduce this to get deeper and
20 * narrower readouts - try 0.2. */
21 float explore_p;
22 /* First Play Urgency - if set to less than infinity (the MoGo paper
23 * above reports 1.0 as the best), new branches are explored only
24 * if none of the existing ones has higher urgency than fpu. */
25 float fpu;
26 /* Equivalent experience for prior knowledge. MoGo paper recommends
27 * 50 playouts per source. */
28 int eqex, even_eqex, gp_eqex, policy_eqex;
29 int urg_randoma, urg_randomm;
30 float explore_p_rave;
31 int equiv_rave;
32 bool rave_prior, both_colors;
36 struct tree_node *ucb1_choose(struct uct_policy *p, struct tree_node *node, struct board *b, enum stone color);
38 struct tree_node *ucb1_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass);
40 void ucb1_prior(struct uct_policy *p, struct tree *tree, struct tree_node *node, struct board *b, enum stone color, int parity);
43 /* Original RAVE function */
44 struct tree_node *
45 ucb1orave_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass)
47 /* We want to count in the prior stats here after all. Otherwise,
48 * nodes with positive prior will get explored _LESS_ since the
49 * urgency will be always higher; even with normal FPU because
50 * of the explore coefficient. */
52 struct ucb1_policy_amaf *b = p->data;
53 float xpl = log(node->u.playouts + node->prior.playouts) * b->explore_p;
54 float xpl_rave = log(node->amaf.playouts + (b->rave_prior ? node->prior.playouts : 0)) * b->explore_p_rave;
55 float beta = sqrt((float)b->equiv_rave / (3 * (node->u.playouts + node->prior.playouts) + b->equiv_rave));
57 struct tree_node *nbest = node->children;
58 float best_urgency = -9999;
59 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
60 /* Do not consider passing early. */
61 if (likely(!allow_pass) && unlikely(is_pass(ni->coord)))
62 continue;
63 int amaf_wins = ni->amaf.wins + (b->rave_prior ? ni->prior.wins : 0);
64 int amaf_playouts = ni->amaf.playouts + (b->rave_prior ? ni->prior.playouts : 0);
65 int uct_playouts = ni->u.playouts + ni->prior.playouts;
66 ni->amaf.value = (float)amaf_wins / amaf_playouts;
67 ni->prior.value = (float)ni->prior.wins / ni->prior.playouts;
68 float uctp = (parity > 0 ? ni->u.value : 1 - ni->u.value) + sqrt(xpl / uct_playouts);
69 float ravep = (parity > 0 ? ni->amaf.value : 1 - ni->amaf.value) + sqrt(xpl_rave / amaf_playouts);
70 float urgency = ni->u.playouts ? beta * ravep + (1 - beta) * uctp : b->fpu;
71 // fprintf(stderr, "uctp %f (uct %d/%d) ravep %f (xpl %f amaf %d/%d) beta %f => %f\n", uctp, ni->u.wins, ni->u.playouts, ravep, xpl_rave, amaf_wins, amaf_playouts, beta, urgency);
72 if (b->urg_randoma)
73 urgency += (float)(fast_random(b->urg_randoma) - b->urg_randoma / 2) / 1000;
74 if (b->urg_randomm)
75 urgency *= (float)(fast_random(b->urg_randomm) + 5) / b->urg_randomm;
76 if (urgency > best_urgency) {
77 best_urgency = urgency;
78 nbest = ni;
81 return nbest;
84 float fast_sqrt(int x)
86 static const float table[] = {
89 1.41421356237309504880,
90 1.73205080756887729352,
91 2.00000000000000000000,
92 #if 0
93 2.23606797749978969640,
94 2.44948974278317809819,
95 2.64575131106459059050,
96 2.82842712474619009760,
97 3.00000000000000000000,
98 3.16227766016837933199,
99 3.31662479035539984911,
100 3.46410161513775458705,
101 3.60555127546398929311,
102 3.74165738677394138558,
103 3.87298334620741688517,
104 4.00000000000000000000,
105 4.12310562561766054982,
106 4.24264068711928514640,
107 4.35889894354067355223,
108 4.47213595499957939281,
109 4.58257569495584000658,
110 4.69041575982342955456,
111 4.79583152331271954159,
112 4.89897948556635619639,
113 5.00000000000000000000,
114 5.09901951359278483002,
115 5.19615242270663188058,
116 5.29150262212918118100,
117 5.38516480713450403125,
118 5.47722557505166113456,
119 5.56776436283002192211,
120 5.65685424949238019520,
121 5.74456264653802865985,
122 5.83095189484530047087,
123 5.91607978309961604256,
124 6.00000000000000000000,
125 6.08276253029821968899,
126 6.16441400296897645025,
127 6.24499799839839820584,
128 6.32455532033675866399,
129 6.40312423743284868648,
130 6.48074069840786023096,
131 6.55743852430200065234,
132 6.63324958071079969822,
133 6.70820393249936908922,
134 6.78232998312526813906,
135 6.85565460040104412493,
136 6.92820323027550917410,
137 7.00000000000000000000,
138 7.07106781186547524400,
139 7.14142842854284999799,
140 7.21110255092797858623,
141 7.28010988928051827109,
142 7.34846922834953429459,
143 7.41619848709566294871,
144 7.48331477354788277116,
145 7.54983443527074969723,
146 7.61577310586390828566,
147 7.68114574786860817576,
148 7.74596669241483377035,
149 7.81024967590665439412,
150 7.87400787401181101968,
151 7.93725393319377177150,
152 #endif
154 //printf("sqrt %d\n", x);
155 if (x < sizeof(table) / sizeof(*table)) {
156 return table[x];
157 } else {
158 return sqrt(x);
159 #if 0
160 int y = 0;
161 int base = 1 << (sizeof(int) * 8 - 2);
162 if ((x & 0xFFFF0000) == 0) base >>= 16;
163 if ((x & 0xFF00FF00) == 0) base >>= 8;
164 if ((x & 0xF0F0F0F0) == 0) base >>= 4;
165 if ((x & 0xCCCCCCCC) == 0) base >>= 2;
166 // "base" starts at the highest power of four <= the argument.
168 while (base > 0) {
169 if (x >= y + base) {
170 x -= y + base;
171 y += base << 1;
173 y >>= 1;
174 base >>= 2;
176 printf("sqrt %d = %d\n", x, y);
177 return y;
178 #endif
182 /* Sylvain RAVE function */
183 struct tree_node *
184 ucb1srave_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass)
186 struct ucb1_policy_amaf *b = p->data;
187 float rave_coef = 1.0f / b->equiv_rave;
188 float conf = 1.f;
189 if (b->explore_p > 0 || b->explore_p_rave > 0)
190 conf = sqrt(log(node->u.playouts + node->prior.playouts));
192 struct tree_node *nbest = node->children;
193 float best_urgency = -9999;
194 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
195 /* Do not consider passing early. */
196 if (likely(!allow_pass) && unlikely(is_pass(ni->coord)))
197 continue;
199 /* TODO: Exploration? */
201 int ngames = ni->u.playouts;
202 int nwins = ni->u.wins;
203 int rgames = ni->amaf.playouts;
204 int rwins = ni->amaf.wins;
205 if (b->rave_prior) {
206 rgames += ni->prior.playouts;
207 rwins += ni->prior.wins;
208 } else {
209 ngames += ni->prior.playouts;
210 nwins += ni->prior.wins;
212 if (parity < 0) {
213 nwins = ngames - nwins;
214 rwins = rgames - rwins;
216 float nval = 0, rval = 0;
217 if (ngames) {
218 nval = (float) nwins / ngames;
219 if (b->explore_p > 0)
220 nval += b->explore_p * conf / fast_sqrt(ngames);
222 if (rgames) {
223 rval = (float) rwins / rgames;
224 if (b->explore_p_rave > 0)
225 rval += b->explore_p_rave * conf / fast_sqrt(rgames);
228 float urgency;
229 if (ngames) {
230 if (rgames) {
231 /* At the beginning, beta is at 1 and RAVE is used.
232 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
233 float beta = (float) rgames / (rgames + ngames + rave_coef * ngames * rgames);
234 #if 0
235 fprintf(stderr, "[beta %f = %d / (%d + %d + %f)]\n",
236 beta, rgames, rgames, ngames, rave_coef * ngames * rgames);
237 #endif
238 urgency = beta * rval + (1 - beta) * nval;
239 } else {
240 urgency = nval;
242 } else if (rgames) {
243 urgency = rval;
244 } else {
245 assert(!b->even_eqex);
246 urgency = parity < 0 ? 1 - b->fpu : b->fpu;
249 #if 0
250 struct board bb; bb.size = 11;
251 fprintf(stderr, "%s<%lld> urgency %f (r %d / %d, n %d / %d)\n",
252 coord2sstr(ni->coord, &bb), ni->hash, urgency, rwins, rgames, nwins, ngames);
253 #endif
254 if (b->urg_randoma)
255 urgency += (float)(fast_random(b->urg_randoma) - b->urg_randoma / 2) / 1000;
256 if (b->urg_randomm)
257 urgency *= (float)(fast_random(b->urg_randomm) + 5) / b->urg_randomm;
258 /* The >= is important since we will always choose something
259 * else than a pass in case of a tie. pass causes degenerative
260 * behaviour. */
261 if (urgency >= best_urgency) {
262 best_urgency = urgency;
263 nbest = ni;
266 return nbest;
269 static void
270 update_node(struct uct_policy *p, struct tree_node *node, int result)
272 node->u.playouts++;
273 node->u.wins += result;
274 tree_update_node_value(node);
276 static void
277 update_node_amaf(struct uct_policy *p, struct tree_node *node, int result)
279 node->amaf.playouts++;
280 node->amaf.wins += result;
281 tree_update_node_value(node);
284 void
285 ucb1amaf_update(struct uct_policy *p, struct tree *tree, struct tree_node *node, enum stone color, struct playout_amafmap *map, int result)
287 struct ucb1_policy_amaf *b = p->data;
289 color = stone_other(color); // We will look in CHILDREN of the node!
290 for (; node; node = node->parent, color = stone_other(color)) {
291 if (p->descend != ucb1_descend)
292 node->hints |= NODE_HINT_NOAMAF; /* Rave, different update function */
293 update_node(p, node, result);
294 if (amaf_nakade(map->map[node->coord]))
295 amaf_op(map->map[node->coord], -);
296 /* This loop ignores symmetry considerations, but they should
297 * matter only at a point when AMAF doesn't help much. */
298 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
299 assert(map->map[ni->coord] != S_OFFBOARD);
300 if (map->map[ni->coord] == S_NONE || amaf_nakade(map->map[ni->coord]))
301 continue;
303 #if 0
304 struct board bb; bb.size = 9+2;
305 fprintf(stderr, "%s<%lld> -> %s<%lld> [%d %d => %d]\n", coord2sstr(node->coord, &bb), node->hash, coord2sstr(ni->coord, &bb), ni->hash, map->map[ni->coord], color, result);
306 #endif
307 if (b->both_colors) {
308 update_node_amaf(p, ni, map->map[ni->coord] == color ? result : !result);
309 } else if (map->map[ni->coord] == color) {
310 update_node_amaf(p, ni, result);
317 struct uct_policy *
318 policy_ucb1amaf_init(struct uct *u, char *arg)
320 struct uct_policy *p = calloc(1, sizeof(*p));
321 struct ucb1_policy_amaf *b = calloc(1, sizeof(*b));
322 p->uct = u;
323 p->data = b;
324 p->descend = ucb1srave_descend;
325 p->choose = ucb1_choose;
326 p->update = ucb1amaf_update;
327 p->wants_amaf = true;
329 // RAVE: 0.2vs0: 40% (+-7.3) 0.1vs0: 54.7% (+-3.5)
330 b->explore_p = 0.1;
331 b->explore_p_rave = -1;
332 b->equiv_rave = 3000;
333 b->fpu = INFINITY;
334 b->even_eqex = 0;
335 b->gp_eqex = b->policy_eqex = -1;
336 b->eqex = 50;
338 if (arg) {
339 char *optspec, *next = arg;
340 while (*next) {
341 optspec = next;
342 next += strcspn(next, ":");
343 if (*next) { *next++ = 0; } else { *next = 0; }
345 char *optname = optspec;
346 char *optval = strchr(optspec, '=');
347 if (optval) *optval++ = 0;
349 if (!strcasecmp(optname, "explore_p")) {
350 b->explore_p = atof(optval);
351 } else if (!strcasecmp(optname, "prior")) {
352 if (optval)
353 b->eqex = atoi(optval);
354 } else if (!strcasecmp(optname, "prior_even") && optval) {
355 b->even_eqex = atoi(optval);
356 } else if (!strcasecmp(optname, "prior_gp") && optval) {
357 b->gp_eqex = atoi(optval);
358 } else if (!strcasecmp(optname, "prior_policy") && optval) {
359 b->policy_eqex = atoi(optval);
360 } else if (!strcasecmp(optname, "fpu") && optval) {
361 b->fpu = atof(optval);
362 } else if (!strcasecmp(optname, "urg_randoma") && optval) {
363 b->urg_randoma = atoi(optval);
364 } else if (!strcasecmp(optname, "urg_randomm") && optval) {
365 b->urg_randomm = atoi(optval);
366 } else if (!strcasecmp(optname, "rave")) {
367 if (optval && *optval == '0')
368 p->descend = ucb1_descend;
369 else if (optval && *optval == 'o')
370 p->descend = ucb1orave_descend;
371 else if (optval && *optval == 's')
372 p->descend = ucb1srave_descend;
373 } else if (!strcasecmp(optname, "explore_p_rave") && optval) {
374 b->explore_p_rave = atof(optval);
375 } else if (!strcasecmp(optname, "equiv_rave") && optval) {
376 b->equiv_rave = atof(optval);
377 } else if (!strcasecmp(optname, "rave_prior")) {
378 b->rave_prior = true;
379 } else if (!strcasecmp(optname, "both_colors")) {
380 b->both_colors = true;
381 } else {
382 fprintf(stderr, "ucb1: Invalid policy argument %s or missing value\n", optname);
387 if (b->eqex) p->prior = ucb1_prior;
388 if (b->even_eqex < 0) b->even_eqex = b->eqex;
389 if (b->gp_eqex < 0) b->gp_eqex = b->eqex;
390 if (b->policy_eqex < 0) b->policy_eqex = b->eqex;
391 if (b->explore_p_rave < 0) b->explore_p_rave = b->explore_p;
393 return p;