UCB1AMAF: Set eqex to 6, must be even number\!
[pachi.git] / uct / policy / ucb1amaf.c
blob739b45d7959aa254e1fa332834fe90e9a1019b9b
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "move.h"
10 #include "random.h"
11 #include "uct/internal.h"
12 #include "uct/tree.h"
14 /* This implements the UCB1 policy with an extra AMAF heuristics. */
16 struct ucb1_policy_amaf {
17 /* This is what the Modification of UCT with Patterns in Monte Carlo Go
18 * paper calls 'p'. Original UCB has this on 2, but this seems to
19 * produce way too wide searches; reduce this to get deeper and
20 * narrower readouts - try 0.2. */
21 float explore_p;
22 /* First Play Urgency - if set to less than infinity (the MoGo paper
23 * above reports 1.0 as the best), new branches are explored only
24 * if none of the existing ones has higher urgency than fpu. */
25 float fpu;
26 /* Equivalent experience for prior knowledge. MoGo paper recommends
27 * 50 playouts per source. */
28 int eqex, even_eqex, gp_eqex, policy_eqex;
29 int urg_randoma, urg_randomm;
30 float explore_p_rave;
31 int equiv_rave;
32 bool rave_prior, both_colors;
36 struct tree_node *ucb1_choose(struct uct_policy *p, struct tree_node *node, struct board *b, enum stone color);
38 struct tree_node *ucb1_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass);
40 void ucb1_prior(struct uct_policy *p, struct tree *tree, struct tree_node *node, struct board *b, enum stone color, int parity);
43 /* Original RAVE function */
44 struct tree_node *
45 ucb1orave_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass)
47 /* We want to count in the prior stats here after all. Otherwise,
48 * nodes with positive prior will get explored _LESS_ since the
49 * urgency will be always higher; even with normal FPU because
50 * of the explore coefficient. */
52 struct ucb1_policy_amaf *b = p->data;
53 float xpl = log(node->u.playouts + node->prior.playouts) * b->explore_p;
54 float xpl_rave = log(node->amaf.playouts + (b->rave_prior ? node->prior.playouts : 0)) * b->explore_p_rave;
55 float beta = sqrt((float)b->equiv_rave / (3 * (node->u.playouts + node->prior.playouts) + b->equiv_rave));
57 struct tree_node *nbest = node->children;
58 float best_urgency = -9999;
59 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
60 /* Do not consider passing early. */
61 if (likely(!allow_pass) && unlikely(is_pass(ni->coord)))
62 continue;
63 int amaf_wins = ni->amaf.wins + (b->rave_prior ? ni->prior.wins : 0);
64 int amaf_playouts = ni->amaf.playouts + (b->rave_prior ? ni->prior.playouts : 0);
65 int uct_playouts = ni->u.playouts + ni->prior.playouts;
66 ni->amaf.value = (float)amaf_wins / amaf_playouts;
67 ni->prior.value = (float)ni->prior.wins / ni->prior.playouts;
68 float uctp = (parity > 0 ? ni->u.value : 1 - ni->u.value) + sqrt(xpl / uct_playouts);
69 float ravep = (parity > 0 ? ni->amaf.value : 1 - ni->amaf.value) + sqrt(xpl_rave / amaf_playouts);
70 float urgency = ni->u.playouts ? beta * ravep + (1 - beta) * uctp : b->fpu;
71 // fprintf(stderr, "uctp %f (uct %d/%d) ravep %f (xpl %f amaf %d/%d) beta %f => %f\n", uctp, ni->u.wins, ni->u.playouts, ravep, xpl_rave, amaf_wins, amaf_playouts, beta, urgency);
72 if (b->urg_randoma)
73 urgency += (float)(fast_random(b->urg_randoma) - b->urg_randoma / 2) / 1000;
74 if (b->urg_randomm)
75 urgency *= (float)(fast_random(b->urg_randomm) + 5) / b->urg_randomm;
76 if (urgency > best_urgency) {
77 best_urgency = urgency;
78 nbest = ni;
81 return nbest;
84 float fast_sqrt(int x)
86 static const float table[] = {
89 1.41421356237309504880,
90 1.73205080756887729352,
91 2.00000000000000000000,
92 #if 0
93 2.23606797749978969640,
94 2.44948974278317809819,
95 2.64575131106459059050,
96 2.82842712474619009760,
97 3.00000000000000000000,
98 3.16227766016837933199,
99 3.31662479035539984911,
100 3.46410161513775458705,
101 3.60555127546398929311,
102 3.74165738677394138558,
103 3.87298334620741688517,
104 4.00000000000000000000,
105 4.12310562561766054982,
106 4.24264068711928514640,
107 4.35889894354067355223,
108 4.47213595499957939281,
109 4.58257569495584000658,
110 4.69041575982342955456,
111 4.79583152331271954159,
112 4.89897948556635619639,
113 5.00000000000000000000,
114 5.09901951359278483002,
115 5.19615242270663188058,
116 5.29150262212918118100,
117 5.38516480713450403125,
118 5.47722557505166113456,
119 5.56776436283002192211,
120 5.65685424949238019520,
121 5.74456264653802865985,
122 5.83095189484530047087,
123 5.91607978309961604256,
124 6.00000000000000000000,
125 6.08276253029821968899,
126 6.16441400296897645025,
127 6.24499799839839820584,
128 6.32455532033675866399,
129 6.40312423743284868648,
130 6.48074069840786023096,
131 6.55743852430200065234,
132 6.63324958071079969822,
133 6.70820393249936908922,
134 6.78232998312526813906,
135 6.85565460040104412493,
136 6.92820323027550917410,
137 7.00000000000000000000,
138 7.07106781186547524400,
139 7.14142842854284999799,
140 7.21110255092797858623,
141 7.28010988928051827109,
142 7.34846922834953429459,
143 7.41619848709566294871,
144 7.48331477354788277116,
145 7.54983443527074969723,
146 7.61577310586390828566,
147 7.68114574786860817576,
148 7.74596669241483377035,
149 7.81024967590665439412,
150 7.87400787401181101968,
151 7.93725393319377177150,
152 #endif
154 //printf("sqrt %d\n", x);
155 if (x < sizeof(table) / sizeof(*table)) {
156 return table[x];
157 } else {
158 return sqrt(x);
159 #if 0
160 int y = 0;
161 int base = 1 << (sizeof(int) * 8 - 2);
162 if ((x & 0xFFFF0000) == 0) base >>= 16;
163 if ((x & 0xFF00FF00) == 0) base >>= 8;
164 if ((x & 0xF0F0F0F0) == 0) base >>= 4;
165 if ((x & 0xCCCCCCCC) == 0) base >>= 2;
166 // "base" starts at the highest power of four <= the argument.
168 while (base > 0) {
169 if (x >= y + base) {
170 x -= y + base;
171 y += base << 1;
173 y >>= 1;
174 base >>= 2;
176 printf("sqrt %d = %d\n", x, y);
177 return y;
178 #endif
182 /* Sylvain RAVE function */
183 struct tree_node *
184 ucb1srave_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass)
186 struct ucb1_policy_amaf *b = p->data;
187 float rave_coef = 1.0f / b->equiv_rave;
188 float conf = 1.f;
189 if (b->explore_p > 0 || b->explore_p_rave > 0)
190 conf = sqrt(log(node->u.playouts + node->prior.playouts));
192 // XXX: Stack overflow danger on big boards?
193 struct tree_node *nbest[512] = { node->children }; int nbests = 1;
194 float best_urgency = -9999;
196 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
197 /* Do not consider passing early. */
198 if (likely(!allow_pass) && unlikely(is_pass(ni->coord)))
199 continue;
201 /* TODO: Exploration? */
203 int ngames = ni->u.playouts;
204 int nwins = ni->u.wins;
205 int rgames = ni->amaf.playouts;
206 int rwins = ni->amaf.wins;
207 if (b->rave_prior) {
208 rgames += ni->prior.playouts;
209 rwins += ni->prior.wins;
210 } else {
211 ngames += ni->prior.playouts;
212 nwins += ni->prior.wins;
214 if (parity < 0) {
215 nwins = ngames - nwins;
216 rwins = rgames - rwins;
218 float nval = 0, rval = 0;
219 if (ngames) {
220 nval = (float) nwins / ngames;
221 if (b->explore_p > 0)
222 nval += b->explore_p * conf / fast_sqrt(ngames);
224 if (rgames) {
225 rval = (float) rwins / rgames;
226 if (b->explore_p_rave > 0)
227 rval += b->explore_p_rave * conf / fast_sqrt(rgames);
230 /* XXX: We later compare urgency with best_urgency; this can
231 * be difficult given that urgency can be in register with
232 * higher precision than best_urgency, thus even though
233 * the numbers are in fact the same, urgency will be
234 * slightly higher (or lower). Thus, we declare urgency
235 * as volatile, attempting to force the compiler to keep
236 * everything as a float. Ideally, we should do some random
237 * __FLT_EPSILON__ magic instead. */
238 volatile float urgency;
239 if (ngames) {
240 if (rgames) {
241 /* At the beginning, beta is at 1 and RAVE is used.
242 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
243 float beta = (float) rgames / (rgames + ngames + rave_coef * ngames * rgames);
244 #if 0
245 //if (node->coord == 7*11+4) // D7
246 fprintf(stderr, "[beta %f = %d / (%d + %d + %f)]\n",
247 beta, rgames, rgames, ngames, rave_coef * ngames * rgames);
248 #endif
249 urgency = beta * rval + (1 - beta) * nval;
250 } else {
251 urgency = nval;
253 } else if (rgames) {
254 urgency = rval;
255 } else {
256 assert(!b->even_eqex);
257 urgency = b->fpu;
260 #if 0
261 struct board bb; bb.size = 11;
262 //if (node->coord == 7*11+4) // D7
263 fprintf(stderr, "%s<%lld>-%s<%lld> urgency %f (r %d / %d, n %d / %d)\n",
264 coord2sstr(ni->parent->coord, &bb), ni->parent->hash,
265 coord2sstr(ni->coord, &bb), ni->hash, urgency,
266 rwins, rgames, nwins, ngames);
267 #endif
268 if (b->urg_randoma)
269 urgency += (float)(fast_random(b->urg_randoma) - b->urg_randoma / 2) / 1000;
270 if (b->urg_randomm)
271 urgency *= (float)(fast_random(b->urg_randomm) + 5) / b->urg_randomm;
273 if (urgency > best_urgency) {
274 best_urgency = urgency; nbests = 0;
276 if (urgency >= best_urgency) {
277 /* We want to always choose something else than a pass
278 * in case of a tie. pass causes degenerative behaviour. */
279 if (nbests == 1 && is_pass(nbest[0]->coord)) {
280 nbests--;
282 nbest[nbests++] = ni;
285 #if 0
286 struct board bb; bb.size = 11;
287 fprintf(stderr, "[%s %d: ", coord2sstr(node->coord, &bb), nbests);
288 for (int zz = 0; zz < nbests; zz++)
289 fprintf(stderr, "%s", coord2sstr(nbest[zz]->coord, &bb));
290 fprintf(stderr, "]\n");
291 #endif
292 return nbest[fast_random(nbests)];
295 static void
296 update_node(struct uct_policy *p, struct tree_node *node, int result)
298 node->u.playouts++;
299 node->u.wins += result;
300 tree_update_node_value(node);
302 static void
303 update_node_amaf(struct uct_policy *p, struct tree_node *node, int result)
305 node->amaf.playouts++;
306 node->amaf.wins += result;
307 tree_update_node_value(node);
310 void
311 ucb1amaf_update(struct uct_policy *p, struct tree *tree, struct tree_node *node, enum stone node_color, enum stone player_color, struct playout_amafmap *map, int result)
313 struct ucb1_policy_amaf *b = p->data;
314 enum stone child_color = stone_other(node_color);
316 #if 0
317 struct board bb; bb.size = 9+2;
318 for (struct tree_node *ni = node; ni; ni = ni->parent)
319 fprintf(stderr, "%s ", coord2sstr(ni->coord, &bb));
320 fprintf(stderr, "[color %d] update result %d (color %d)\n",
321 node_color, result, player_color);
322 #endif
324 while (node) {
325 if (node->parent == NULL)
326 assert(tree->root_color == stone_other(child_color));
328 if (p->descend != ucb1_descend)
329 node->hints |= NODE_HINT_NOAMAF; /* Rave, different update function */
330 update_node(p, node, result);
331 if (amaf_nakade(map->map[node->coord]))
332 amaf_op(map->map[node->coord], -);
334 /* This loop ignores symmetry considerations, but they should
335 * matter only at a point when AMAF doesn't help much. */
336 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
337 assert(map->map[ni->coord] != S_OFFBOARD);
338 if (map->map[ni->coord] == S_NONE
339 || amaf_nakade(map->map[ni->coord]))
340 continue;
342 int nres = result;
343 if (map->map[ni->coord] != child_color) {
344 if (!b->both_colors)
345 continue;
346 nres = !nres;
348 /* For child_color != player_color, we still want
349 * to record the result unmodified; in that case,
350 * we will correctly negate them at the descend phase. */
352 if (p->descend != ucb1_descend)
353 ni->hints |= NODE_HINT_NOAMAF; /* Rave, different update function */
354 update_node_amaf(p, ni, nres);
356 #if 0
357 fprintf(stderr, "* %s<%lld> -> %s<%lld> [%d %d => %d/%d]\n", coord2sstr(node->coord, &bb), node->hash, coord2sstr(ni->coord, &bb), ni->hash, player_color, child_color, result);
358 #endif
361 node = node->parent; child_color = stone_other(child_color);
366 struct uct_policy *
367 policy_ucb1amaf_init(struct uct *u, char *arg)
369 struct uct_policy *p = calloc(1, sizeof(*p));
370 struct ucb1_policy_amaf *b = calloc(1, sizeof(*b));
371 p->uct = u;
372 p->data = b;
373 p->descend = ucb1srave_descend;
374 p->choose = ucb1_choose;
375 p->update = ucb1amaf_update;
376 p->wants_amaf = true;
378 // RAVE: 0.2vs0: 40% (+-7.3) 0.1vs0: 54.7% (+-3.5)
379 b->explore_p = 0.1;
380 b->explore_p_rave = -1;
381 b->equiv_rave = 3000;
382 b->fpu = INFINITY;
383 // gp: 14 vs 0: 44% (+-3.5)
384 b->gp_eqex = 0;
385 b->even_eqex = b->policy_eqex = -1;
386 b->eqex = 6; /* Even number! */
387 b->rave_prior = true;
389 if (arg) {
390 char *optspec, *next = arg;
391 while (*next) {
392 optspec = next;
393 next += strcspn(next, ":");
394 if (*next) { *next++ = 0; } else { *next = 0; }
396 char *optname = optspec;
397 char *optval = strchr(optspec, '=');
398 if (optval) *optval++ = 0;
400 if (!strcasecmp(optname, "explore_p")) {
401 b->explore_p = atof(optval);
402 } else if (!strcasecmp(optname, "prior")) {
403 if (optval)
404 b->eqex = atoi(optval);
405 } else if (!strcasecmp(optname, "prior_even") && optval) {
406 b->even_eqex = atoi(optval);
407 } else if (!strcasecmp(optname, "prior_gp") && optval) {
408 b->gp_eqex = atoi(optval);
409 } else if (!strcasecmp(optname, "prior_policy") && optval) {
410 b->policy_eqex = atoi(optval);
411 } else if (!strcasecmp(optname, "fpu") && optval) {
412 b->fpu = atof(optval);
413 } else if (!strcasecmp(optname, "urg_randoma") && optval) {
414 b->urg_randoma = atoi(optval);
415 } else if (!strcasecmp(optname, "urg_randomm") && optval) {
416 b->urg_randomm = atoi(optval);
417 } else if (!strcasecmp(optname, "rave")) {
418 if (optval && *optval == '0')
419 p->descend = ucb1_descend;
420 else if (optval && *optval == 'o')
421 p->descend = ucb1orave_descend;
422 else if (optval && *optval == 's')
423 p->descend = ucb1srave_descend;
424 } else if (!strcasecmp(optname, "explore_p_rave") && optval) {
425 b->explore_p_rave = atof(optval);
426 } else if (!strcasecmp(optname, "equiv_rave") && optval) {
427 b->equiv_rave = atof(optval);
428 } else if (!strcasecmp(optname, "rave_prior") && optval) {
429 // 46% (+-3.5)
430 b->rave_prior = atoi(optval);
431 } else if (!strcasecmp(optname, "both_colors")) {
432 b->both_colors = true;
433 } else {
434 fprintf(stderr, "ucb1: Invalid policy argument %s or missing value\n", optname);
439 if (b->eqex) p->prior = ucb1_prior;
440 if (b->even_eqex < 0) b->even_eqex = b->eqex;
441 if (b->gp_eqex < 0) b->gp_eqex = b->eqex;
442 if (b->policy_eqex < 0) b->policy_eqex = b->eqex;
443 if (b->explore_p_rave < 0) b->explore_p_rave = b->explore_p;
445 return p;