Moggy assess: Support 2lib prior; half the recomended games
[pachi.git] / uct / policy / ucb1amaf.c
blob602c50d20f77def05c71cfaef8b5e900c298d7e4
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "move.h"
10 #include "random.h"
11 #include "uct/internal.h"
12 #include "uct/tree.h"
14 /* This implements the UCB1 policy with an extra AMAF heuristics. */
16 struct ucb1_policy_amaf {
17 /* This is what the Modification of UCT with Patterns in Monte Carlo Go
18 * paper calls 'p'. Original UCB has this on 2, but this seems to
19 * produce way too wide searches; reduce this to get deeper and
20 * narrower readouts - try 0.2. */
21 float explore_p;
22 /* First Play Urgency - if set to less than infinity (the MoGo paper
23 * above reports 1.0 as the best), new branches are explored only
24 * if none of the existing ones has higher urgency than fpu. */
25 float fpu;
26 int urg_randoma, urg_randomm;
27 float explore_p_rave;
28 int equiv_rave;
29 bool both_colors;
30 bool check_nakade;
31 bool sylvain_rave;
35 struct tree_node *ucb1_choose(struct uct_policy *p, struct tree_node *node, struct board *b, enum stone color);
37 struct tree_node *ucb1_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass);
40 static inline float fast_sqrt(int x)
42 static const float table[] = {
43 0, 1, 1.41421356237309504880, 1.73205080756887729352,
44 2.00000000000000000000, 2.23606797749978969640,
45 2.44948974278317809819, 2.64575131106459059050,
46 2.82842712474619009760, 3.00000000000000000000,
47 3.16227766016837933199, 3.31662479035539984911,
48 3.46410161513775458705, 3.60555127546398929311,
49 3.74165738677394138558, 3.87298334620741688517,
50 4.00000000000000000000, 4.12310562561766054982,
51 4.24264068711928514640, 4.35889894354067355223,
52 4.47213595499957939281, 4.58257569495584000658,
53 4.69041575982342955456, 4.79583152331271954159,
54 4.89897948556635619639, 5.00000000000000000000,
55 5.09901951359278483002, 5.19615242270663188058,
56 5.29150262212918118100, 5.38516480713450403125,
57 5.47722557505166113456, 5.56776436283002192211,
58 5.65685424949238019520, 5.74456264653802865985,
59 5.83095189484530047087, 5.91607978309961604256,
60 6.00000000000000000000, 6.08276253029821968899,
61 6.16441400296897645025, 6.24499799839839820584,
62 6.32455532033675866399, 6.40312423743284868648,
63 6.48074069840786023096, 6.55743852430200065234,
64 6.63324958071079969822, 6.70820393249936908922,
65 6.78232998312526813906, 6.85565460040104412493,
66 6.92820323027550917410, 7.00000000000000000000,
67 7.07106781186547524400, 7.14142842854284999799,
68 7.21110255092797858623, 7.28010988928051827109,
69 7.34846922834953429459, 7.41619848709566294871,
70 7.48331477354788277116, 7.54983443527074969723,
71 7.61577310586390828566, 7.68114574786860817576,
72 7.74596669241483377035, 7.81024967590665439412,
73 7.87400787401181101968, 7.93725393319377177150,
75 //printf("sqrt %d\n", x);
76 if (x < sizeof(table) / sizeof(*table)) {
77 return table[x];
78 } else {
79 return sqrt(x);
83 struct tree_node *
84 ucb1rave_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass)
86 struct ucb1_policy_amaf *b = p->data;
87 float beta = 0;
88 float nconf = 1.f, rconf = 1.f;
89 if (b->explore_p > 0)
90 nconf = sqrt(log(node->u.playouts + node->prior.playouts));
91 if (b->explore_p_rave > 0 && node->amaf.playouts)
92 rconf = sqrt(log(node->amaf.playouts + node->prior.playouts));
94 if (!b->sylvain_rave)
95 beta = sqrt(b->equiv_rave / (3 * node->u.playouts + b->equiv_rave));
97 // XXX: Stack overflow danger on big boards?
98 struct tree_node *nbest[512] = { node->children }; int nbests = 1;
99 float best_urgency = -9999;
101 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
102 /* Do not consider passing early. */
103 if (likely(!allow_pass) && unlikely(is_pass(ni->coord)))
104 continue;
106 /* TODO: Exploration? */
108 int ngames = ni->u.playouts;
109 int nwins = ni->u.wins;
110 int rgames = ni->amaf.playouts;
111 int rwins = ni->amaf.wins;
112 if (p->uct->amaf_prior) {
113 rgames += ni->prior.playouts;
114 rwins += ni->prior.wins;
115 } else {
116 ngames += ni->prior.playouts;
117 nwins += ni->prior.wins;
119 if (tree_parity(tree, parity) < 0) {
120 nwins = ngames - nwins;
121 rwins = rgames - rwins;
123 float nval = 0, rval = 0;
124 if (ngames) {
125 nval = (float) nwins / ngames;
126 if (b->explore_p > 0)
127 nval += b->explore_p * nconf / fast_sqrt(ngames);
129 if (rgames) {
130 rval = (float) rwins / rgames;
131 if (b->explore_p_rave > 0 && !is_pass(ni->coord))
132 rval += b->explore_p_rave * rconf / fast_sqrt(rgames);
135 float urgency;
136 if (ngames) {
137 if (rgames) {
138 /* At the beginning, beta is at 1 and RAVE is used.
139 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
140 if (b->sylvain_rave)
141 beta = (float) rgames / (rgames + ngames + ngames * rgames / b->equiv_rave);
142 #if 0
143 //if (node->coord == 7*11+4) // D7
144 fprintf(stderr, "[beta %f = %d / (%d + %d + %f)]\n",
145 beta, rgames, rgames, ngames, ngames * rgames / b->equiv_rave);
146 #endif
147 urgency = beta * rval + (1.f - beta) * nval;
148 } else {
149 urgency = nval;
151 } else if (rgames) {
152 urgency = rval;
153 } else {
154 /* assert(!u->even_eqex); */
155 urgency = b->fpu;
158 #if 0
159 struct board bb; bb.size = 11;
160 //if (node->coord == 7*11+4) // D7
161 fprintf(stderr, "%s<%lld>-%s<%lld> urgency %f (r %d / %d + e = %f, n %d / %d + e = %f)\n",
162 coord2sstr(ni->parent->coord, &bb), ni->parent->hash,
163 coord2sstr(ni->coord, &bb), ni->hash, urgency,
164 rwins, rgames, rval, nwins, ngames, nval);
165 #endif
166 if (b->urg_randoma)
167 urgency += (float)(fast_random(b->urg_randoma) - b->urg_randoma / 2) / 1000;
168 if (b->urg_randomm)
169 urgency *= (float)(fast_random(b->urg_randomm) + 5) / b->urg_randomm;
171 if (urgency - best_urgency > __FLT_EPSILON__) { // urgency > best_urgency
172 best_urgency = urgency; nbests = 0;
174 if (urgency - best_urgency > -__FLT_EPSILON__) { // urgency >= best_urgency
175 /* We want to always choose something else than a pass
176 * in case of a tie. pass causes degenerative behaviour. */
177 if (nbests == 1 && is_pass(nbest[0]->coord)) {
178 nbests--;
180 nbest[nbests++] = ni;
183 #if 0
184 struct board bb; bb.size = 11;
185 fprintf(stderr, "RESULT [%s %d: ", coord2sstr(node->coord, &bb), nbests);
186 for (int zz = 0; zz < nbests; zz++)
187 fprintf(stderr, "%s", coord2sstr(nbest[zz]->coord, &bb));
188 fprintf(stderr, "]\n");
189 #endif
190 return nbest[fast_random(nbests)];
193 static void
194 update_node(struct uct_policy *p, struct tree_node *node, int result)
196 node->u.playouts++;
197 node->u.wins += result;
198 tree_update_node_value(node, p->uct->amaf_prior);
201 static void
202 update_node_amaf(struct uct_policy *p, struct tree_node *node, int result)
204 node->amaf.playouts++;
205 node->amaf.wins += result;
206 tree_update_node_rvalue(node, p->uct->amaf_prior);
209 void
210 ucb1amaf_update(struct uct_policy *p, struct tree *tree, struct tree_node *node, enum stone node_color, enum stone player_color, struct playout_amafmap *map, int result)
212 struct ucb1_policy_amaf *b = p->data;
213 enum stone child_color = stone_other(node_color);
215 #if 0
216 struct board bb; bb.size = 9+2;
217 for (struct tree_node *ni = node; ni; ni = ni->parent)
218 fprintf(stderr, "%s ", coord2sstr(ni->coord, &bb));
219 fprintf(stderr, "[color %d] update result %d (color %d)\n",
220 node_color, result, player_color);
221 #endif
223 while (node) {
224 if (node->parent == NULL)
225 assert(tree->root_color == stone_other(child_color));
227 update_node(p, node, result);
228 if (amaf_nakade(map->map[node->coord]))
229 amaf_op(map->map[node->coord], -);
231 /* This loop ignores symmetry considerations, but they should
232 * matter only at a point when AMAF doesn't help much. */
233 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
234 assert(map->map[ni->coord] != S_OFFBOARD);
235 if (map->map[ni->coord] == S_NONE)
236 continue;
237 assert(map->game_baselen >= 0);
238 enum stone amaf_color = map->map[ni->coord];
239 if (amaf_nakade(map->map[ni->coord])) {
240 if (!b->check_nakade)
241 continue;
242 /* We don't care to implement both_colors
243 * properly since it sucks anyway. */
244 int i;
245 for (i = map->game_baselen; i < map->gamelen; i++)
246 if (map->game[i].coord == ni->coord
247 && map->game[i].color == child_color)
248 break;
249 if (i == map->gamelen)
250 continue;
251 amaf_color = child_color;
254 int nres = result;
255 if (amaf_color != child_color) {
256 if (!b->both_colors)
257 continue;
258 nres = !nres;
260 /* For child_color != player_color, we still want
261 * to record the result unmodified; in that case,
262 * we will correctly negate them at the descend phase. */
264 update_node_amaf(p, ni, nres);
266 #if 0
267 fprintf(stderr, "* %s<%lld> -> %s<%lld> [%d %d => %d/%d]\n", coord2sstr(node->coord, &bb), node->hash, coord2sstr(ni->coord, &bb), ni->hash, player_color, child_color, result);
268 #endif
271 if (!is_pass(node->coord)) {
272 map->game_baselen--;
274 node = node->parent; child_color = stone_other(child_color);
279 struct uct_policy *
280 policy_ucb1amaf_init(struct uct *u, char *arg)
282 struct uct_policy *p = calloc(1, sizeof(*p));
283 struct ucb1_policy_amaf *b = calloc(1, sizeof(*b));
284 p->uct = u;
285 p->data = b;
286 p->descend = ucb1rave_descend;
287 p->choose = ucb1_choose;
288 p->update = ucb1amaf_update;
289 p->wants_amaf = true;
291 // RAVE: 0.2vs0: 40% (+-7.3) 0.1vs0: 54.7% (+-3.5)
292 b->explore_p = 0.1;
293 b->explore_p_rave = 0.01;
294 b->equiv_rave = 3000;
295 b->fpu = INFINITY;
296 b->check_nakade = true;
297 b->sylvain_rave = true;
299 if (arg) {
300 char *optspec, *next = arg;
301 while (*next) {
302 optspec = next;
303 next += strcspn(next, ":");
304 if (*next) { *next++ = 0; } else { *next = 0; }
306 char *optname = optspec;
307 char *optval = strchr(optspec, '=');
308 if (optval) *optval++ = 0;
310 if (!strcasecmp(optname, "explore_p")) {
311 b->explore_p = atof(optval);
312 } else if (!strcasecmp(optname, "fpu") && optval) {
313 b->fpu = atof(optval);
314 } else if (!strcasecmp(optname, "urg_randoma") && optval) {
315 b->urg_randoma = atoi(optval);
316 } else if (!strcasecmp(optname, "urg_randomm") && optval) {
317 b->urg_randomm = atoi(optval);
318 } else if (!strcasecmp(optname, "explore_p_rave") && optval) {
319 b->explore_p_rave = atof(optval);
320 } else if (!strcasecmp(optname, "equiv_rave") && optval) {
321 b->equiv_rave = atof(optval);
322 } else if (!strcasecmp(optname, "both_colors")) {
323 b->both_colors = true;
324 } else if (!strcasecmp(optname, "sylvain_rave")) {
325 b->sylvain_rave = !optval || *optval == '1';
326 } else if (!strcasecmp(optname, "check_nakade")) {
327 b->check_nakade = !optval || *optval == '1';
328 } else {
329 fprintf(stderr, "ucb1: Invalid policy argument %s or missing value\n", optname);
334 if (b->explore_p_rave < 0) b->explore_p_rave = b->explore_p;
336 return p;