UCT: Move prior computation out of tree policy
[pachi.git] / uct / policy / ucb1amaf.c
blob989a1fa832e0649655bb87635b00c2f40d74ff25
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "move.h"
10 #include "random.h"
11 #include "uct/internal.h"
12 #include "uct/tree.h"
14 /* This implements the UCB1 policy with an extra AMAF heuristics. */
16 struct ucb1_policy_amaf {
17 /* This is what the Modification of UCT with Patterns in Monte Carlo Go
18 * paper calls 'p'. Original UCB has this on 2, but this seems to
19 * produce way too wide searches; reduce this to get deeper and
20 * narrower readouts - try 0.2. */
21 float explore_p;
22 /* First Play Urgency - if set to less than infinity (the MoGo paper
23 * above reports 1.0 as the best), new branches are explored only
24 * if none of the existing ones has higher urgency than fpu. */
25 float fpu;
26 int urg_randoma, urg_randomm;
27 float explore_p_rave;
28 int equiv_rave;
29 bool rave_prior, both_colors;
30 bool check_nakade;
34 struct tree_node *ucb1_choose(struct uct_policy *p, struct tree_node *node, struct board *b, enum stone color);
36 struct tree_node *ucb1_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass);
39 /* Original RAVE function */
40 struct tree_node *
41 ucb1orave_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass)
43 /* We want to count in the prior stats here after all. Otherwise,
44 * nodes with positive prior will get explored _LESS_ since the
45 * urgency will be always higher; even with normal FPU because
46 * of the explore coefficient. */
48 struct ucb1_policy_amaf *b = p->data;
49 float xpl = log(node->u.playouts + node->prior.playouts) * b->explore_p;
50 float xpl_rave = log(node->amaf.playouts + (b->rave_prior ? node->prior.playouts : 0)) * b->explore_p_rave;
51 float beta = sqrt((float)b->equiv_rave / (3 * (node->u.playouts + node->prior.playouts) + b->equiv_rave));
53 struct tree_node *nbest = node->children;
54 float best_urgency = -9999;
55 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
56 /* Do not consider passing early. */
57 if (likely(!allow_pass) && unlikely(is_pass(ni->coord)))
58 continue;
59 int amaf_wins = ni->amaf.wins + (b->rave_prior ? ni->prior.wins : 0);
60 int amaf_playouts = ni->amaf.playouts + (b->rave_prior ? ni->prior.playouts : 0);
61 int uct_playouts = ni->u.playouts + ni->prior.playouts;
62 ni->amaf.value = (float)amaf_wins / amaf_playouts;
63 ni->prior.value = (float)ni->prior.wins / ni->prior.playouts;
64 float uctp = tree_node_get_value(tree, ni, u, parity) + sqrt(xpl / uct_playouts);
65 float ravep = tree_node_get_value(tree, ni, amaf, parity) + sqrt(xpl_rave / amaf_playouts);
66 float urgency = ni->u.playouts ? beta * ravep + (1 - beta) * uctp : b->fpu;
67 // fprintf(stderr, "uctp %f (uct %d/%d) ravep %f (xpl %f amaf %d/%d) beta %f => %f\n", uctp, ni->u.wins, ni->u.playouts, ravep, xpl_rave, amaf_wins, amaf_playouts, beta, urgency);
68 if (b->urg_randoma)
69 urgency += (float)(fast_random(b->urg_randoma) - b->urg_randoma / 2) / 1000;
70 if (b->urg_randomm)
71 urgency *= (float)(fast_random(b->urg_randomm) + 5) / b->urg_randomm;
72 if (urgency > best_urgency) {
73 best_urgency = urgency;
74 nbest = ni;
77 return nbest;
80 float fast_sqrt(int x)
82 static const float table[] = {
85 1.41421356237309504880,
86 1.73205080756887729352,
87 2.00000000000000000000,
88 #if 0
89 2.23606797749978969640,
90 2.44948974278317809819,
91 2.64575131106459059050,
92 2.82842712474619009760,
93 3.00000000000000000000,
94 3.16227766016837933199,
95 3.31662479035539984911,
96 3.46410161513775458705,
97 3.60555127546398929311,
98 3.74165738677394138558,
99 3.87298334620741688517,
100 4.00000000000000000000,
101 4.12310562561766054982,
102 4.24264068711928514640,
103 4.35889894354067355223,
104 4.47213595499957939281,
105 4.58257569495584000658,
106 4.69041575982342955456,
107 4.79583152331271954159,
108 4.89897948556635619639,
109 5.00000000000000000000,
110 5.09901951359278483002,
111 5.19615242270663188058,
112 5.29150262212918118100,
113 5.38516480713450403125,
114 5.47722557505166113456,
115 5.56776436283002192211,
116 5.65685424949238019520,
117 5.74456264653802865985,
118 5.83095189484530047087,
119 5.91607978309961604256,
120 6.00000000000000000000,
121 6.08276253029821968899,
122 6.16441400296897645025,
123 6.24499799839839820584,
124 6.32455532033675866399,
125 6.40312423743284868648,
126 6.48074069840786023096,
127 6.55743852430200065234,
128 6.63324958071079969822,
129 6.70820393249936908922,
130 6.78232998312526813906,
131 6.85565460040104412493,
132 6.92820323027550917410,
133 7.00000000000000000000,
134 7.07106781186547524400,
135 7.14142842854284999799,
136 7.21110255092797858623,
137 7.28010988928051827109,
138 7.34846922834953429459,
139 7.41619848709566294871,
140 7.48331477354788277116,
141 7.54983443527074969723,
142 7.61577310586390828566,
143 7.68114574786860817576,
144 7.74596669241483377035,
145 7.81024967590665439412,
146 7.87400787401181101968,
147 7.93725393319377177150,
148 #endif
150 //printf("sqrt %d\n", x);
151 if (x < sizeof(table) / sizeof(*table)) {
152 return table[x];
153 } else {
154 return sqrt(x);
155 #if 0
156 int y = 0;
157 int base = 1 << (sizeof(int) * 8 - 2);
158 if ((x & 0xFFFF0000) == 0) base >>= 16;
159 if ((x & 0xFF00FF00) == 0) base >>= 8;
160 if ((x & 0xF0F0F0F0) == 0) base >>= 4;
161 if ((x & 0xCCCCCCCC) == 0) base >>= 2;
162 // "base" starts at the highest power of four <= the argument.
164 while (base > 0) {
165 if (x >= y + base) {
166 x -= y + base;
167 y += base << 1;
169 y >>= 1;
170 base >>= 2;
172 printf("sqrt %d = %d\n", x, y);
173 return y;
174 #endif
178 /* Sylvain RAVE function */
179 struct tree_node *
180 ucb1srave_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass)
182 struct ucb1_policy_amaf *b = p->data;
183 float rave_coef = 1.0f / b->equiv_rave;
184 float nconf = 1.f, rconf = 1.f;
185 if (b->explore_p > 0)
186 nconf = sqrt(log(node->u.playouts + node->prior.playouts));
187 if (b->explore_p_rave > 0 && node->amaf.playouts)
188 rconf = sqrt(log(node->amaf.playouts + node->prior.playouts));
190 // XXX: Stack overflow danger on big boards?
191 struct tree_node *nbest[512] = { node->children }; int nbests = 1;
192 float best_urgency = -9999;
194 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
195 /* Do not consider passing early. */
196 if (likely(!allow_pass) && unlikely(is_pass(ni->coord)))
197 continue;
199 /* TODO: Exploration? */
201 int ngames = ni->u.playouts;
202 int nwins = ni->u.wins;
203 int rgames = ni->amaf.playouts;
204 int rwins = ni->amaf.wins;
205 if (b->rave_prior) {
206 rgames += ni->prior.playouts;
207 rwins += ni->prior.wins;
208 } else {
209 ngames += ni->prior.playouts;
210 nwins += ni->prior.wins;
212 if (tree_parity(tree, parity) < 0) {
213 nwins = ngames - nwins;
214 rwins = rgames - rwins;
216 float nval = 0, rval = 0;
217 if (ngames) {
218 nval = (float) nwins / ngames;
219 if (b->explore_p > 0)
220 nval += b->explore_p * nconf / fast_sqrt(ngames);
222 if (rgames) {
223 rval = (float) rwins / rgames;
224 if (b->explore_p_rave > 0 && !is_pass(ni->coord))
225 rval += b->explore_p_rave * rconf / fast_sqrt(rgames);
228 /* XXX: We later compare urgency with best_urgency; this can
229 * be difficult given that urgency can be in register with
230 * higher precision than best_urgency, thus even though
231 * the numbers are in fact the same, urgency will be
232 * slightly higher (or lower). Thus, we declare urgency
233 * as volatile, attempting to force the compiler to keep
234 * everything as a float. Ideally, we should do some random
235 * __FLT_EPSILON__ magic instead. */
236 volatile float urgency;
237 if (ngames) {
238 if (rgames) {
239 /* At the beginning, beta is at 1 and RAVE is used.
240 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
241 float beta = (float) rgames / (rgames + ngames + rave_coef * ngames * rgames);
242 #if 0
243 //if (node->coord == 7*11+4) // D7
244 fprintf(stderr, "[beta %f = %d / (%d + %d + %f)]\n",
245 beta, rgames, rgames, ngames, rave_coef * ngames * rgames);
246 #endif
247 urgency = beta * rval + (1 - beta) * nval;
248 } else {
249 urgency = nval;
251 } else if (rgames) {
252 urgency = rval;
253 } else {
254 /* assert(!u->even_eqex); */
255 urgency = b->fpu;
258 #if 0
259 struct board bb; bb.size = 11;
260 //if (node->coord == 7*11+4) // D7
261 fprintf(stderr, "%s<%lld>-%s<%lld> urgency %f (r %d / %d, n %d / %d)\n",
262 coord2sstr(ni->parent->coord, &bb), ni->parent->hash,
263 coord2sstr(ni->coord, &bb), ni->hash, urgency,
264 rwins, rgames, nwins, ngames);
265 #endif
266 if (b->urg_randoma)
267 urgency += (float)(fast_random(b->urg_randoma) - b->urg_randoma / 2) / 1000;
268 if (b->urg_randomm)
269 urgency *= (float)(fast_random(b->urg_randomm) + 5) / b->urg_randomm;
271 if (urgency > best_urgency) {
272 best_urgency = urgency; nbests = 0;
274 if (urgency >= best_urgency) {
275 /* We want to always choose something else than a pass
276 * in case of a tie. pass causes degenerative behaviour. */
277 if (nbests == 1 && is_pass(nbest[0]->coord)) {
278 nbests--;
280 nbest[nbests++] = ni;
283 #if 0
284 struct board bb; bb.size = 11;
285 fprintf(stderr, "[%s %d: ", coord2sstr(node->coord, &bb), nbests);
286 for (int zz = 0; zz < nbests; zz++)
287 fprintf(stderr, "%s", coord2sstr(nbest[zz]->coord, &bb));
288 fprintf(stderr, "]\n");
289 #endif
290 return nbest[fast_random(nbests)];
293 static void
294 update_node(struct uct_policy *p, struct tree_node *node, int result)
296 node->u.playouts++;
297 node->u.wins += result;
298 tree_update_node_value(node);
300 static void
301 update_node_amaf(struct uct_policy *p, struct tree_node *node, int result)
303 node->amaf.playouts++;
304 node->amaf.wins += result;
305 tree_update_node_value(node);
308 void
309 ucb1amaf_update(struct uct_policy *p, struct tree *tree, struct tree_node *node, enum stone node_color, enum stone player_color, struct playout_amafmap *map, int result)
311 struct ucb1_policy_amaf *b = p->data;
312 enum stone child_color = stone_other(node_color);
314 #if 0
315 struct board bb; bb.size = 9+2;
316 for (struct tree_node *ni = node; ni; ni = ni->parent)
317 fprintf(stderr, "%s ", coord2sstr(ni->coord, &bb));
318 fprintf(stderr, "[color %d] update result %d (color %d)\n",
319 node_color, result, player_color);
320 #endif
322 while (node) {
323 if (node->parent == NULL)
324 assert(tree->root_color == stone_other(child_color));
326 if (p->descend != ucb1_descend)
327 node->hints |= NODE_HINT_NOAMAF; /* Rave, different update function */
328 update_node(p, node, result);
329 if (amaf_nakade(map->map[node->coord]))
330 amaf_op(map->map[node->coord], -);
332 /* This loop ignores symmetry considerations, but they should
333 * matter only at a point when AMAF doesn't help much. */
334 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
335 assert(map->map[ni->coord] != S_OFFBOARD);
336 if (map->map[ni->coord] == S_NONE)
337 continue;
338 assert(map->game_baselen >= 0);
339 enum stone amaf_color = map->map[ni->coord];
340 if (amaf_nakade(map->map[ni->coord])) {
341 if (!b->check_nakade)
342 continue;
343 /* We don't care to implement both_colors
344 * properly since it sucks anyway. */
345 int i;
346 for (i = map->game_baselen; i < map->gamelen; i++)
347 if (map->game[i].coord == ni->coord
348 && map->game[i].color == child_color)
349 break;
350 if (i == map->gamelen)
351 continue;
352 amaf_color = child_color;
355 int nres = result;
356 if (amaf_color != child_color) {
357 if (!b->both_colors)
358 continue;
359 nres = !nres;
361 /* For child_color != player_color, we still want
362 * to record the result unmodified; in that case,
363 * we will correctly negate them at the descend phase. */
365 if (p->descend != ucb1_descend)
366 ni->hints |= NODE_HINT_NOAMAF; /* Rave, different update function */
367 update_node_amaf(p, ni, nres);
369 #if 0
370 fprintf(stderr, "* %s<%lld> -> %s<%lld> [%d %d => %d/%d]\n", coord2sstr(node->coord, &bb), node->hash, coord2sstr(ni->coord, &bb), ni->hash, player_color, child_color, result);
371 #endif
374 if (!is_pass(node->coord)) {
375 map->game_baselen--;
377 node = node->parent; child_color = stone_other(child_color);
382 struct uct_policy *
383 policy_ucb1amaf_init(struct uct *u, char *arg)
385 struct uct_policy *p = calloc(1, sizeof(*p));
386 struct ucb1_policy_amaf *b = calloc(1, sizeof(*b));
387 p->uct = u;
388 p->data = b;
389 p->descend = ucb1srave_descend;
390 p->choose = ucb1_choose;
391 p->update = ucb1amaf_update;
392 p->wants_amaf = true;
394 // RAVE: 0.2vs0: 40% (+-7.3) 0.1vs0: 54.7% (+-3.5)
395 b->explore_p = 0.1;
396 b->explore_p_rave = 0.01;
397 b->equiv_rave = 3000;
398 b->fpu = INFINITY;
399 b->rave_prior = true;
400 b->check_nakade = true;
402 if (arg) {
403 char *optspec, *next = arg;
404 while (*next) {
405 optspec = next;
406 next += strcspn(next, ":");
407 if (*next) { *next++ = 0; } else { *next = 0; }
409 char *optname = optspec;
410 char *optval = strchr(optspec, '=');
411 if (optval) *optval++ = 0;
413 if (!strcasecmp(optname, "explore_p")) {
414 b->explore_p = atof(optval);
415 } else if (!strcasecmp(optname, "fpu") && optval) {
416 b->fpu = atof(optval);
417 } else if (!strcasecmp(optname, "urg_randoma") && optval) {
418 b->urg_randoma = atoi(optval);
419 } else if (!strcasecmp(optname, "urg_randomm") && optval) {
420 b->urg_randomm = atoi(optval);
421 } else if (!strcasecmp(optname, "rave")) {
422 if (optval && *optval == '0')
423 p->descend = ucb1_descend;
424 else if (optval && *optval == 'o')
425 p->descend = ucb1orave_descend;
426 else if (optval && *optval == 's')
427 p->descend = ucb1srave_descend;
428 } else if (!strcasecmp(optname, "explore_p_rave") && optval) {
429 b->explore_p_rave = atof(optval);
430 } else if (!strcasecmp(optname, "equiv_rave") && optval) {
431 b->equiv_rave = atof(optval);
432 } else if (!strcasecmp(optname, "rave_prior") && optval) {
433 // 46% (+-3.5)
434 b->rave_prior = atoi(optval);
435 } else if (!strcasecmp(optname, "both_colors")) {
436 b->both_colors = true;
437 } else if (!strcasecmp(optname, "check_nakade")) {
438 b->check_nakade = !optval || *optval == '1';
439 } else {
440 fprintf(stderr, "ucb1: Invalid policy argument %s or missing value\n", optname);
445 if (b->explore_p_rave < 0) b->explore_p_rave = b->explore_p;
447 return p;