UCB1AMAF: fast_sqrt() static inline
[pachi.git] / uct / policy / ucb1amaf.c
bloba71d26238238afea5c50398fd07dd1c49d71f8f5
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "move.h"
10 #include "random.h"
11 #include "uct/internal.h"
12 #include "uct/tree.h"
14 /* This implements the UCB1 policy with an extra AMAF heuristics. */
16 struct ucb1_policy_amaf {
17 /* This is what the Modification of UCT with Patterns in Monte Carlo Go
18 * paper calls 'p'. Original UCB has this on 2, but this seems to
19 * produce way too wide searches; reduce this to get deeper and
20 * narrower readouts - try 0.2. */
21 float explore_p;
22 /* First Play Urgency - if set to less than infinity (the MoGo paper
23 * above reports 1.0 as the best), new branches are explored only
24 * if none of the existing ones has higher urgency than fpu. */
25 float fpu;
26 int urg_randoma, urg_randomm;
27 float explore_p_rave;
28 int equiv_rave;
29 bool rave_prior, both_colors;
30 bool check_nakade;
34 struct tree_node *ucb1_choose(struct uct_policy *p, struct tree_node *node, struct board *b, enum stone color);
36 struct tree_node *ucb1_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass);
39 static inline float fast_sqrt(int x)
41 static const float table[] = {
44 1.41421356237309504880,
45 1.73205080756887729352,
46 2.00000000000000000000,
48 //printf("sqrt %d\n", x);
49 if (x < sizeof(table) / sizeof(*table)) {
50 return table[x];
51 } else {
52 return sqrt(x);
56 /* Sylvain RAVE function */
57 struct tree_node *
58 ucb1srave_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass)
60 struct ucb1_policy_amaf *b = p->data;
61 float rave_coef = 1.0f / b->equiv_rave;
62 float nconf = 1.f, rconf = 1.f;
63 if (b->explore_p > 0)
64 nconf = sqrt(log(node->u.playouts + node->prior.playouts));
65 if (b->explore_p_rave > 0 && node->amaf.playouts)
66 rconf = sqrt(log(node->amaf.playouts + node->prior.playouts));
68 // XXX: Stack overflow danger on big boards?
69 struct tree_node *nbest[512] = { node->children }; int nbests = 1;
70 float best_urgency = -9999;
72 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
73 /* Do not consider passing early. */
74 if (likely(!allow_pass) && unlikely(is_pass(ni->coord)))
75 continue;
77 /* TODO: Exploration? */
79 int ngames = ni->u.playouts;
80 int nwins = ni->u.wins;
81 int rgames = ni->amaf.playouts;
82 int rwins = ni->amaf.wins;
83 if (b->rave_prior) {
84 rgames += ni->prior.playouts;
85 rwins += ni->prior.wins;
86 } else {
87 ngames += ni->prior.playouts;
88 nwins += ni->prior.wins;
90 if (tree_parity(tree, parity) < 0) {
91 nwins = ngames - nwins;
92 rwins = rgames - rwins;
94 float nval = 0, rval = 0;
95 if (ngames) {
96 nval = (float) nwins / ngames;
97 if (b->explore_p > 0)
98 nval += b->explore_p * nconf / fast_sqrt(ngames);
100 if (rgames) {
101 rval = (float) rwins / rgames;
102 if (b->explore_p_rave > 0 && !is_pass(ni->coord))
103 rval += b->explore_p_rave * rconf / fast_sqrt(rgames);
106 /* XXX: We later compare urgency with best_urgency; this can
107 * be difficult given that urgency can be in register with
108 * higher precision than best_urgency, thus even though
109 * the numbers are in fact the same, urgency will be
110 * slightly higher (or lower). Thus, we declare urgency
111 * as volatile, attempting to force the compiler to keep
112 * everything as a float. Ideally, we should do some random
113 * __FLT_EPSILON__ magic instead. */
114 volatile float urgency;
115 if (ngames) {
116 if (rgames) {
117 /* At the beginning, beta is at 1 and RAVE is used.
118 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
119 float beta = (float) rgames / (rgames + ngames + rave_coef * ngames * rgames);
120 #if 0
121 //if (node->coord == 7*11+4) // D7
122 fprintf(stderr, "[beta %f = %d / (%d + %d + %f)]\n",
123 beta, rgames, rgames, ngames, rave_coef * ngames * rgames);
124 #endif
125 urgency = beta * rval + (1 - beta) * nval;
126 } else {
127 urgency = nval;
129 } else if (rgames) {
130 urgency = rval;
131 } else {
132 /* assert(!u->even_eqex); */
133 urgency = b->fpu;
136 #if 0
137 struct board bb; bb.size = 11;
138 //if (node->coord == 7*11+4) // D7
139 fprintf(stderr, "%s<%lld>-%s<%lld> urgency %f (r %d / %d, n %d / %d)\n",
140 coord2sstr(ni->parent->coord, &bb), ni->parent->hash,
141 coord2sstr(ni->coord, &bb), ni->hash, urgency,
142 rwins, rgames, nwins, ngames);
143 #endif
144 if (b->urg_randoma)
145 urgency += (float)(fast_random(b->urg_randoma) - b->urg_randoma / 2) / 1000;
146 if (b->urg_randomm)
147 urgency *= (float)(fast_random(b->urg_randomm) + 5) / b->urg_randomm;
149 if (urgency > best_urgency) {
150 best_urgency = urgency; nbests = 0;
152 if (urgency >= best_urgency) {
153 /* We want to always choose something else than a pass
154 * in case of a tie. pass causes degenerative behaviour. */
155 if (nbests == 1 && is_pass(nbest[0]->coord)) {
156 nbests--;
158 nbest[nbests++] = ni;
161 #if 0
162 struct board bb; bb.size = 11;
163 fprintf(stderr, "[%s %d: ", coord2sstr(node->coord, &bb), nbests);
164 for (int zz = 0; zz < nbests; zz++)
165 fprintf(stderr, "%s", coord2sstr(nbest[zz]->coord, &bb));
166 fprintf(stderr, "]\n");
167 #endif
168 return nbest[fast_random(nbests)];
171 static void
172 update_node(struct uct_policy *p, struct tree_node *node, int result)
174 node->u.playouts++;
175 node->u.wins += result;
176 tree_update_node_value(node);
178 static void
179 update_node_amaf(struct uct_policy *p, struct tree_node *node, int result)
181 node->amaf.playouts++;
182 node->amaf.wins += result;
183 tree_update_node_value(node);
186 void
187 ucb1amaf_update(struct uct_policy *p, struct tree *tree, struct tree_node *node, enum stone node_color, enum stone player_color, struct playout_amafmap *map, int result)
189 struct ucb1_policy_amaf *b = p->data;
190 enum stone child_color = stone_other(node_color);
192 #if 0
193 struct board bb; bb.size = 9+2;
194 for (struct tree_node *ni = node; ni; ni = ni->parent)
195 fprintf(stderr, "%s ", coord2sstr(ni->coord, &bb));
196 fprintf(stderr, "[color %d] update result %d (color %d)\n",
197 node_color, result, player_color);
198 #endif
200 while (node) {
201 if (node->parent == NULL)
202 assert(tree->root_color == stone_other(child_color));
204 if (p->descend != ucb1_descend)
205 node->hints |= NODE_HINT_NOAMAF; /* Rave, different update function */
206 update_node(p, node, result);
207 if (amaf_nakade(map->map[node->coord]))
208 amaf_op(map->map[node->coord], -);
210 /* This loop ignores symmetry considerations, but they should
211 * matter only at a point when AMAF doesn't help much. */
212 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
213 assert(map->map[ni->coord] != S_OFFBOARD);
214 if (map->map[ni->coord] == S_NONE)
215 continue;
216 assert(map->game_baselen >= 0);
217 enum stone amaf_color = map->map[ni->coord];
218 if (amaf_nakade(map->map[ni->coord])) {
219 if (!b->check_nakade)
220 continue;
221 /* We don't care to implement both_colors
222 * properly since it sucks anyway. */
223 int i;
224 for (i = map->game_baselen; i < map->gamelen; i++)
225 if (map->game[i].coord == ni->coord
226 && map->game[i].color == child_color)
227 break;
228 if (i == map->gamelen)
229 continue;
230 amaf_color = child_color;
233 int nres = result;
234 if (amaf_color != child_color) {
235 if (!b->both_colors)
236 continue;
237 nres = !nres;
239 /* For child_color != player_color, we still want
240 * to record the result unmodified; in that case,
241 * we will correctly negate them at the descend phase. */
243 if (p->descend != ucb1_descend)
244 ni->hints |= NODE_HINT_NOAMAF; /* Rave, different update function */
245 update_node_amaf(p, ni, nres);
247 #if 0
248 fprintf(stderr, "* %s<%lld> -> %s<%lld> [%d %d => %d/%d]\n", coord2sstr(node->coord, &bb), node->hash, coord2sstr(ni->coord, &bb), ni->hash, player_color, child_color, result);
249 #endif
252 if (!is_pass(node->coord)) {
253 map->game_baselen--;
255 node = node->parent; child_color = stone_other(child_color);
260 struct uct_policy *
261 policy_ucb1amaf_init(struct uct *u, char *arg)
263 struct uct_policy *p = calloc(1, sizeof(*p));
264 struct ucb1_policy_amaf *b = calloc(1, sizeof(*b));
265 p->uct = u;
266 p->data = b;
267 p->descend = ucb1srave_descend;
268 p->choose = ucb1_choose;
269 p->update = ucb1amaf_update;
270 p->wants_amaf = true;
272 // RAVE: 0.2vs0: 40% (+-7.3) 0.1vs0: 54.7% (+-3.5)
273 b->explore_p = 0.1;
274 b->explore_p_rave = 0.01;
275 b->equiv_rave = 3000;
276 b->fpu = INFINITY;
277 b->rave_prior = true;
278 b->check_nakade = true;
280 if (arg) {
281 char *optspec, *next = arg;
282 while (*next) {
283 optspec = next;
284 next += strcspn(next, ":");
285 if (*next) { *next++ = 0; } else { *next = 0; }
287 char *optname = optspec;
288 char *optval = strchr(optspec, '=');
289 if (optval) *optval++ = 0;
291 if (!strcasecmp(optname, "explore_p")) {
292 b->explore_p = atof(optval);
293 } else if (!strcasecmp(optname, "fpu") && optval) {
294 b->fpu = atof(optval);
295 } else if (!strcasecmp(optname, "urg_randoma") && optval) {
296 b->urg_randoma = atoi(optval);
297 } else if (!strcasecmp(optname, "urg_randomm") && optval) {
298 b->urg_randomm = atoi(optval);
299 } else if (!strcasecmp(optname, "rave")) {
300 if (optval && *optval == '0')
301 p->descend = ucb1_descend;
302 else if (optval && *optval == 's')
303 p->descend = ucb1srave_descend;
304 } else if (!strcasecmp(optname, "explore_p_rave") && optval) {
305 b->explore_p_rave = atof(optval);
306 } else if (!strcasecmp(optname, "equiv_rave") && optval) {
307 b->equiv_rave = atof(optval);
308 } else if (!strcasecmp(optname, "rave_prior") && optval) {
309 // 46% (+-3.5)
310 b->rave_prior = atoi(optval);
311 } else if (!strcasecmp(optname, "both_colors")) {
312 b->both_colors = true;
313 } else if (!strcasecmp(optname, "check_nakade")) {
314 b->check_nakade = !optval || *optval == '1';
315 } else {
316 fprintf(stderr, "ucb1: Invalid policy argument %s or missing value\n", optname);
321 if (b->explore_p_rave < 0) b->explore_p_rave = b->explore_p;
323 return p;