UCT: Add tree_node_get_{wins,value}() parity-conscient accessors
[pachi.git] / uct / policy / ucb1amaf.c
blob8b0098777c63afb867d4d9b9777cf29ba5855f86
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "move.h"
10 #include "random.h"
11 #include "uct/internal.h"
12 #include "uct/tree.h"
14 /* This implements the UCB1 policy with an extra AMAF heuristics. */
16 struct ucb1_policy_amaf {
17 /* This is what the Modification of UCT with Patterns in Monte Carlo Go
18 * paper calls 'p'. Original UCB has this on 2, but this seems to
19 * produce way too wide searches; reduce this to get deeper and
20 * narrower readouts - try 0.2. */
21 float explore_p;
22 /* First Play Urgency - if set to less than infinity (the MoGo paper
23 * above reports 1.0 as the best), new branches are explored only
24 * if none of the existing ones has higher urgency than fpu. */
25 float fpu;
26 /* Equivalent experience for prior knowledge. MoGo paper recommends
27 * 50 playouts per source. */
28 int eqex, even_eqex, gp_eqex, policy_eqex;
29 int urg_randoma, urg_randomm;
30 float explore_p_rave;
31 int equiv_rave;
32 bool rave_prior, both_colors;
36 struct tree_node *ucb1_choose(struct uct_policy *p, struct tree_node *node, struct board *b, enum stone color);
38 struct tree_node *ucb1_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass);
40 void ucb1_prior(struct uct_policy *p, struct tree *tree, struct tree_node *node, struct board *b, enum stone color, int parity);
43 /* Original RAVE function */
44 struct tree_node *
45 ucb1orave_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass)
47 /* We want to count in the prior stats here after all. Otherwise,
48 * nodes with positive prior will get explored _LESS_ since the
49 * urgency will be always higher; even with normal FPU because
50 * of the explore coefficient. */
52 struct ucb1_policy_amaf *b = p->data;
53 float xpl = log(node->u.playouts + node->prior.playouts) * b->explore_p;
54 float xpl_rave = log(node->amaf.playouts + (b->rave_prior ? node->prior.playouts : 0)) * b->explore_p_rave;
55 float beta = sqrt((float)b->equiv_rave / (3 * (node->u.playouts + node->prior.playouts) + b->equiv_rave));
57 struct tree_node *nbest = node->children;
58 float best_urgency = -9999;
59 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
60 /* Do not consider passing early. */
61 if (likely(!allow_pass) && unlikely(is_pass(ni->coord)))
62 continue;
63 int amaf_wins = ni->amaf.wins + (b->rave_prior ? ni->prior.wins : 0);
64 int amaf_playouts = ni->amaf.playouts + (b->rave_prior ? ni->prior.playouts : 0);
65 int uct_playouts = ni->u.playouts + ni->prior.playouts;
66 ni->amaf.value = (float)amaf_wins / amaf_playouts;
67 ni->prior.value = (float)ni->prior.wins / ni->prior.playouts;
68 float uctp = tree_node_get_value(tree, ni, u, parity) + sqrt(xpl / uct_playouts);
69 float ravep = tree_node_get_value(tree, ni, amaf, parity) + sqrt(xpl_rave / amaf_playouts);
70 float urgency = ni->u.playouts ? beta * ravep + (1 - beta) * uctp : b->fpu;
71 // fprintf(stderr, "uctp %f (uct %d/%d) ravep %f (xpl %f amaf %d/%d) beta %f => %f\n", uctp, ni->u.wins, ni->u.playouts, ravep, xpl_rave, amaf_wins, amaf_playouts, beta, urgency);
72 if (b->urg_randoma)
73 urgency += (float)(fast_random(b->urg_randoma) - b->urg_randoma / 2) / 1000;
74 if (b->urg_randomm)
75 urgency *= (float)(fast_random(b->urg_randomm) + 5) / b->urg_randomm;
76 if (urgency > best_urgency) {
77 best_urgency = urgency;
78 nbest = ni;
81 return nbest;
84 float fast_sqrt(int x)
86 static const float table[] = {
89 1.41421356237309504880,
90 1.73205080756887729352,
91 2.00000000000000000000,
92 #if 0
93 2.23606797749978969640,
94 2.44948974278317809819,
95 2.64575131106459059050,
96 2.82842712474619009760,
97 3.00000000000000000000,
98 3.16227766016837933199,
99 3.31662479035539984911,
100 3.46410161513775458705,
101 3.60555127546398929311,
102 3.74165738677394138558,
103 3.87298334620741688517,
104 4.00000000000000000000,
105 4.12310562561766054982,
106 4.24264068711928514640,
107 4.35889894354067355223,
108 4.47213595499957939281,
109 4.58257569495584000658,
110 4.69041575982342955456,
111 4.79583152331271954159,
112 4.89897948556635619639,
113 5.00000000000000000000,
114 5.09901951359278483002,
115 5.19615242270663188058,
116 5.29150262212918118100,
117 5.38516480713450403125,
118 5.47722557505166113456,
119 5.56776436283002192211,
120 5.65685424949238019520,
121 5.74456264653802865985,
122 5.83095189484530047087,
123 5.91607978309961604256,
124 6.00000000000000000000,
125 6.08276253029821968899,
126 6.16441400296897645025,
127 6.24499799839839820584,
128 6.32455532033675866399,
129 6.40312423743284868648,
130 6.48074069840786023096,
131 6.55743852430200065234,
132 6.63324958071079969822,
133 6.70820393249936908922,
134 6.78232998312526813906,
135 6.85565460040104412493,
136 6.92820323027550917410,
137 7.00000000000000000000,
138 7.07106781186547524400,
139 7.14142842854284999799,
140 7.21110255092797858623,
141 7.28010988928051827109,
142 7.34846922834953429459,
143 7.41619848709566294871,
144 7.48331477354788277116,
145 7.54983443527074969723,
146 7.61577310586390828566,
147 7.68114574786860817576,
148 7.74596669241483377035,
149 7.81024967590665439412,
150 7.87400787401181101968,
151 7.93725393319377177150,
152 #endif
154 //printf("sqrt %d\n", x);
155 if (x < sizeof(table) / sizeof(*table)) {
156 return table[x];
157 } else {
158 return sqrt(x);
159 #if 0
160 int y = 0;
161 int base = 1 << (sizeof(int) * 8 - 2);
162 if ((x & 0xFFFF0000) == 0) base >>= 16;
163 if ((x & 0xFF00FF00) == 0) base >>= 8;
164 if ((x & 0xF0F0F0F0) == 0) base >>= 4;
165 if ((x & 0xCCCCCCCC) == 0) base >>= 2;
166 // "base" starts at the highest power of four <= the argument.
168 while (base > 0) {
169 if (x >= y + base) {
170 x -= y + base;
171 y += base << 1;
173 y >>= 1;
174 base >>= 2;
176 printf("sqrt %d = %d\n", x, y);
177 return y;
178 #endif
182 /* Sylvain RAVE function */
183 struct tree_node *
184 ucb1srave_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass)
186 struct ucb1_policy_amaf *b = p->data;
187 float rave_coef = 1.0f / b->equiv_rave;
188 float nconf = 1.f, rconf = 1.f;
189 if (b->explore_p > 0)
190 nconf = sqrt(log(node->u.playouts + node->prior.playouts));
191 if (b->explore_p_rave > 0 && node->amaf.playouts)
192 rconf = sqrt(log(node->amaf.playouts + node->prior.playouts));
194 // XXX: Stack overflow danger on big boards?
195 struct tree_node *nbest[512] = { node->children }; int nbests = 1;
196 float best_urgency = -9999;
198 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
199 /* Do not consider passing early. */
200 if (likely(!allow_pass) && unlikely(is_pass(ni->coord)))
201 continue;
203 /* TODO: Exploration? */
205 int ngames = ni->u.playouts;
206 int nwins = ni->u.wins;
207 int rgames = ni->amaf.playouts;
208 int rwins = ni->amaf.wins;
209 if (b->rave_prior) {
210 rgames += ni->prior.playouts;
211 rwins += ni->prior.wins;
212 } else {
213 ngames += ni->prior.playouts;
214 nwins += ni->prior.wins;
216 if (parity < 0) {
217 nwins = ngames - nwins;
218 rwins = rgames - rwins;
220 float nval = 0, rval = 0;
221 if (ngames) {
222 nval = (float) nwins / ngames;
223 if (b->explore_p > 0)
224 nval += b->explore_p * nconf / fast_sqrt(ngames);
226 if (rgames) {
227 rval = (float) rwins / rgames;
228 if (b->explore_p_rave > 0 && !is_pass(ni->coord))
229 rval += b->explore_p_rave * rconf / fast_sqrt(rgames);
232 /* XXX: We later compare urgency with best_urgency; this can
233 * be difficult given that urgency can be in register with
234 * higher precision than best_urgency, thus even though
235 * the numbers are in fact the same, urgency will be
236 * slightly higher (or lower). Thus, we declare urgency
237 * as volatile, attempting to force the compiler to keep
238 * everything as a float. Ideally, we should do some random
239 * __FLT_EPSILON__ magic instead. */
240 volatile float urgency;
241 if (ngames) {
242 if (rgames) {
243 /* At the beginning, beta is at 1 and RAVE is used.
244 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
245 float beta = (float) rgames / (rgames + ngames + rave_coef * ngames * rgames);
246 #if 0
247 //if (node->coord == 7*11+4) // D7
248 fprintf(stderr, "[beta %f = %d / (%d + %d + %f)]\n",
249 beta, rgames, rgames, ngames, rave_coef * ngames * rgames);
250 #endif
251 urgency = beta * rval + (1 - beta) * nval;
252 } else {
253 urgency = nval;
255 } else if (rgames) {
256 urgency = rval;
257 } else {
258 assert(!b->even_eqex);
259 urgency = b->fpu;
262 #if 0
263 struct board bb; bb.size = 11;
264 //if (node->coord == 7*11+4) // D7
265 fprintf(stderr, "%s<%lld>-%s<%lld> urgency %f (r %d / %d, n %d / %d)\n",
266 coord2sstr(ni->parent->coord, &bb), ni->parent->hash,
267 coord2sstr(ni->coord, &bb), ni->hash, urgency,
268 rwins, rgames, nwins, ngames);
269 #endif
270 if (b->urg_randoma)
271 urgency += (float)(fast_random(b->urg_randoma) - b->urg_randoma / 2) / 1000;
272 if (b->urg_randomm)
273 urgency *= (float)(fast_random(b->urg_randomm) + 5) / b->urg_randomm;
275 if (urgency > best_urgency) {
276 best_urgency = urgency; nbests = 0;
278 if (urgency >= best_urgency) {
279 /* We want to always choose something else than a pass
280 * in case of a tie. pass causes degenerative behaviour. */
281 if (nbests == 1 && is_pass(nbest[0]->coord)) {
282 nbests--;
284 nbest[nbests++] = ni;
287 #if 0
288 struct board bb; bb.size = 11;
289 fprintf(stderr, "[%s %d: ", coord2sstr(node->coord, &bb), nbests);
290 for (int zz = 0; zz < nbests; zz++)
291 fprintf(stderr, "%s", coord2sstr(nbest[zz]->coord, &bb));
292 fprintf(stderr, "]\n");
293 #endif
294 return nbest[fast_random(nbests)];
297 static void
298 update_node(struct uct_policy *p, struct tree_node *node, int result)
300 node->u.playouts++;
301 node->u.wins += result;
302 tree_update_node_value(node);
304 static void
305 update_node_amaf(struct uct_policy *p, struct tree_node *node, int result)
307 node->amaf.playouts++;
308 node->amaf.wins += result;
309 tree_update_node_value(node);
312 void
313 ucb1amaf_update(struct uct_policy *p, struct tree *tree, struct tree_node *node, enum stone node_color, enum stone player_color, struct playout_amafmap *map, int result)
315 struct ucb1_policy_amaf *b = p->data;
316 enum stone child_color = stone_other(node_color);
318 #if 0
319 struct board bb; bb.size = 9+2;
320 for (struct tree_node *ni = node; ni; ni = ni->parent)
321 fprintf(stderr, "%s ", coord2sstr(ni->coord, &bb));
322 fprintf(stderr, "[color %d] update result %d (color %d)\n",
323 node_color, result, player_color);
324 #endif
326 while (node) {
327 if (node->parent == NULL)
328 assert(tree->root_color == stone_other(child_color));
330 if (p->descend != ucb1_descend)
331 node->hints |= NODE_HINT_NOAMAF; /* Rave, different update function */
332 update_node(p, node, result);
333 if (amaf_nakade(map->map[node->coord]))
334 amaf_op(map->map[node->coord], -);
336 /* This loop ignores symmetry considerations, but they should
337 * matter only at a point when AMAF doesn't help much. */
338 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
339 assert(map->map[ni->coord] != S_OFFBOARD);
340 if (map->map[ni->coord] == S_NONE)
341 continue;
342 assert(map->game_baselen >= 0);
343 enum stone amaf_color = map->map[ni->coord];
344 if (amaf_nakade(map->map[ni->coord])) {
345 /* We don't care to implement both_colors
346 * properly since it sucks anyway. */
347 int i;
348 for (i = map->game_baselen; i < map->gamelen; i++)
349 if (map->game[i].coord == ni->coord
350 && map->game[i].color == child_color)
351 break;
352 if (i == map->gamelen)
353 continue;
354 amaf_color = child_color;
357 int nres = result;
358 if (amaf_color != child_color) {
359 if (!b->both_colors)
360 continue;
361 nres = !nres;
363 /* For child_color != player_color, we still want
364 * to record the result unmodified; in that case,
365 * we will correctly negate them at the descend phase. */
367 if (p->descend != ucb1_descend)
368 ni->hints |= NODE_HINT_NOAMAF; /* Rave, different update function */
369 update_node_amaf(p, ni, nres);
371 #if 0
372 fprintf(stderr, "* %s<%lld> -> %s<%lld> [%d %d => %d/%d]\n", coord2sstr(node->coord, &bb), node->hash, coord2sstr(ni->coord, &bb), ni->hash, player_color, child_color, result);
373 #endif
376 if (!is_pass(node->coord)) {
377 map->game_baselen--;
379 node = node->parent; child_color = stone_other(child_color);
384 struct uct_policy *
385 policy_ucb1amaf_init(struct uct *u, char *arg)
387 struct uct_policy *p = calloc(1, sizeof(*p));
388 struct ucb1_policy_amaf *b = calloc(1, sizeof(*b));
389 p->uct = u;
390 p->data = b;
391 p->descend = ucb1srave_descend;
392 p->choose = ucb1_choose;
393 p->update = ucb1amaf_update;
394 p->wants_amaf = true;
396 // RAVE: 0.2vs0: 40% (+-7.3) 0.1vs0: 54.7% (+-3.5)
397 b->explore_p = 0.1;
398 b->explore_p_rave = -1;
399 b->equiv_rave = 3000;
400 b->fpu = INFINITY;
401 // gp: 14 vs 0: 44% (+-3.5)
402 b->gp_eqex = 0;
403 b->even_eqex = b->policy_eqex = -1;
404 b->eqex = 6; /* Even number! */
405 b->rave_prior = true;
407 if (arg) {
408 char *optspec, *next = arg;
409 while (*next) {
410 optspec = next;
411 next += strcspn(next, ":");
412 if (*next) { *next++ = 0; } else { *next = 0; }
414 char *optname = optspec;
415 char *optval = strchr(optspec, '=');
416 if (optval) *optval++ = 0;
418 if (!strcasecmp(optname, "explore_p")) {
419 b->explore_p = atof(optval);
420 } else if (!strcasecmp(optname, "prior")) {
421 if (optval)
422 b->eqex = atoi(optval);
423 } else if (!strcasecmp(optname, "prior_even") && optval) {
424 b->even_eqex = atoi(optval);
425 } else if (!strcasecmp(optname, "prior_gp") && optval) {
426 b->gp_eqex = atoi(optval);
427 } else if (!strcasecmp(optname, "prior_policy") && optval) {
428 b->policy_eqex = atoi(optval);
429 } else if (!strcasecmp(optname, "fpu") && optval) {
430 b->fpu = atof(optval);
431 } else if (!strcasecmp(optname, "urg_randoma") && optval) {
432 b->urg_randoma = atoi(optval);
433 } else if (!strcasecmp(optname, "urg_randomm") && optval) {
434 b->urg_randomm = atoi(optval);
435 } else if (!strcasecmp(optname, "rave")) {
436 if (optval && *optval == '0')
437 p->descend = ucb1_descend;
438 else if (optval && *optval == 'o')
439 p->descend = ucb1orave_descend;
440 else if (optval && *optval == 's')
441 p->descend = ucb1srave_descend;
442 } else if (!strcasecmp(optname, "explore_p_rave") && optval) {
443 b->explore_p_rave = atof(optval);
444 } else if (!strcasecmp(optname, "equiv_rave") && optval) {
445 b->equiv_rave = atof(optval);
446 } else if (!strcasecmp(optname, "rave_prior") && optval) {
447 // 46% (+-3.5)
448 b->rave_prior = atoi(optval);
449 } else if (!strcasecmp(optname, "both_colors")) {
450 b->both_colors = true;
451 } else {
452 fprintf(stderr, "ucb1: Invalid policy argument %s or missing value\n", optname);
457 if (b->eqex) p->prior = ucb1_prior;
458 if (b->even_eqex < 0) b->even_eqex = b->eqex;
459 if (b->gp_eqex < 0) b->gp_eqex = b->eqex;
460 if (b->policy_eqex < 0) b->policy_eqex = b->eqex;
461 if (b->explore_p_rave < 0) b->explore_p_rave = b->explore_p;
463 return p;