Keep only pure libmap stuff in libmap.*, move goal tracking to tactics/goals.*
[pachi.git] / uct / policy / ucb1amaf.c
blob94ac9c0a5eb010b05261b8b909a85b32374bf038
1 #include <assert.h>
2 #include <limits.h>
3 #include <math.h>
4 #include <stdio.h>
5 #include <stdlib.h>
6 #include <string.h>
8 #include "board.h"
9 #include "debug.h"
10 #include "move.h"
11 #include "random.h"
12 #include "tactics/goals.h"
13 #include "tactics/util.h"
14 #include "uct/internal.h"
15 #include "uct/tree.h"
16 #include "uct/policy/generic.h"
18 /* This implements the UCB1 policy with an extra AMAF heuristics. */
20 struct ucb1_policy_amaf {
21 /* This is what the Modification of UCT with Patterns in Monte Carlo Go
22 * paper calls 'p'. Original UCB has this on 2, but this seems to
23 * produce way too wide searches; reduce this to get deeper and
24 * narrower readouts - try 0.2. */
25 floating_t explore_p;
26 /* In distributed mode, encourage different slaves to work on different
27 * parts of the tree by adding virtual wins to different nodes. */
28 int virtual_win;
29 int root_virtual_win;
30 int vwin_min_playouts;
31 /* First Play Urgency - if set to less than infinity (the MoGo paper
32 * above reports 1.0 as the best), new branches are explored only
33 * if none of the existing ones has higher urgency than fpu. */
34 floating_t fpu;
35 unsigned int equiv_rave;
36 bool sylvain_rave;
37 /* Give more weight to moves played earlier. */
38 int distance_rave;
39 /* Give 0 or negative rave bonus to ko threats before taking the ko.
40 1=normal bonus, 0=no bonus, -1=invert rave bonus, -2=double penalty... */
41 int threat_rave;
42 /* Coefficient of local tree values embedded in RAVE. */
43 floating_t ltree_rave;
44 /* Coefficient of criticality embedded in RAVE. */
45 floating_t crit_rave;
46 int crit_min_playouts;
47 floating_t crit_plthres_coef;
48 bool crit_negative;
49 bool crit_negflip;
50 bool crit_amaf;
51 bool crit_lvalue;
52 /* Coefficient of tactical rating embedded in RAVE. */
53 floating_t libmap_rave;
57 static inline floating_t fast_sqrt(unsigned int x)
59 static const floating_t table[] = {
60 0, 1, 1.41421356237309504880, 1.73205080756887729352,
61 2.00000000000000000000, 2.23606797749978969640,
62 2.44948974278317809819, 2.64575131106459059050,
63 2.82842712474619009760, 3.00000000000000000000,
64 3.16227766016837933199, 3.31662479035539984911,
65 3.46410161513775458705, 3.60555127546398929311,
66 3.74165738677394138558, 3.87298334620741688517,
67 4.00000000000000000000, 4.12310562561766054982,
68 4.24264068711928514640, 4.35889894354067355223,
69 4.47213595499957939281, 4.58257569495584000658,
70 4.69041575982342955456, 4.79583152331271954159,
71 4.89897948556635619639, 5.00000000000000000000,
72 5.09901951359278483002, 5.19615242270663188058,
73 5.29150262212918118100, 5.38516480713450403125,
74 5.47722557505166113456, 5.56776436283002192211,
75 5.65685424949238019520, 5.74456264653802865985,
76 5.83095189484530047087, 5.91607978309961604256,
77 6.00000000000000000000, 6.08276253029821968899,
78 6.16441400296897645025, 6.24499799839839820584,
79 6.32455532033675866399, 6.40312423743284868648,
80 6.48074069840786023096, 6.55743852430200065234,
81 6.63324958071079969822, 6.70820393249936908922,
82 6.78232998312526813906, 6.85565460040104412493,
83 6.92820323027550917410, 7.00000000000000000000,
84 7.07106781186547524400, 7.14142842854284999799,
85 7.21110255092797858623, 7.28010988928051827109,
86 7.34846922834953429459, 7.41619848709566294871,
87 7.48331477354788277116, 7.54983443527074969723,
88 7.61577310586390828566, 7.68114574786860817576,
89 7.74596669241483377035, 7.81024967590665439412,
90 7.87400787401181101968, 7.93725393319377177150,
92 if (x < sizeof(table) / sizeof(*table)) {
93 return table[x];
94 } else {
95 return sqrt(x);
99 #define URAVE_DEBUG if (0)
100 static inline floating_t
101 ucb1rave_evaluate(struct uct_policy *p, struct tree *tree, struct uct_descent *descent, int parity)
103 struct ucb1_policy_amaf *b = p->data;
104 struct tree_node *node = descent->node;
105 struct tree_node *lnode = descent->lnode;
107 struct move_stats n = node->u, r = node->amaf;
108 if (p->uct->amaf_prior) {
109 stats_merge(&r, &node->prior);
110 } else {
111 stats_merge(&n, &node->prior);
114 /* Local tree heuristics. */
115 assert(!lnode || lnode->parent);
116 if (p->uct->local_tree && b->ltree_rave > 0 && lnode
117 && (p->uct->local_tree_rootchoose || lnode->parent->parent)) {
118 struct move_stats l = lnode->u;
119 l.playouts = ((floating_t) l.playouts) * b->ltree_rave / LTREE_PLAYOUTS_MULTIPLIER;
120 URAVE_DEBUG fprintf(stderr, "[ltree] adding [%s] %f%%%d to [%s] RAVE %f%%%d\n",
121 coord2sstr(node_coord(lnode), tree->board), l.value, l.playouts,
122 coord2sstr(node_coord(node), tree->board), r.value, r.playouts);
123 stats_merge(&r, &l);
126 /* Criticality heuristics. */
127 if (b->crit_rave > 0 && (b->crit_plthres_coef > 0
128 ? node->u.playouts > tree->root->u.playouts * b->crit_plthres_coef
129 : node->u.playouts > b->crit_min_playouts)) {
130 floating_t crit = tree_node_criticality(tree, node);
131 if (b->crit_negative || crit > 0) {
132 floating_t val = 1.0f;
133 if (b->crit_negflip && crit < 0) {
134 val = 0;
135 crit = -crit;
137 struct move_stats c = {
138 .value = tree_node_get_value(tree, parity, val),
139 .playouts = crit * r.playouts * b->crit_rave
141 URAVE_DEBUG fprintf(stderr, "[crit] adding %f%%%d to [%s] RAVE %f%%%d\n",
142 c.value, c.playouts,
143 coord2sstr(node_coord(node), tree->board), r.value, r.playouts);
144 stats_merge(&r, &c);
148 /* Tactical rating (liberty map) heuristics. */
149 if (b->libmap_rave > 0 && tree->board->libmap) {
150 /* We look at tactical rating of a move relative to
151 * all neighbors. */
152 /* XXX: We should rather record hashes pertaining this move
153 * in the tree. We entirely miss counter-atari information. */
154 enum stone color = tree_node_color(tree, node);
155 struct move m = { .coord = node->coord, .color = color };
156 struct move_stats l = libmap_board_move_stats(descent->board->libmap, descent->board, m);
157 if (l.playouts > 0) {
158 l.value = tree_node_get_value(tree, parity, l.value);
159 l.playouts *= b->libmap_rave;
161 URAVE_DEBUG fprintf(stderr, "[libmap] adding %f%%%d to [%s %s] RAVE %f%%%d\n",
162 l.value, l.playouts, stone2str(color),
163 coord2sstr(node->coord, descent->board), r.value, r.playouts);
164 stats_merge(&r, &l);
169 floating_t value = 0;
170 if (n.playouts) {
171 if (r.playouts) {
172 /* At the beginning, beta is at 1 and RAVE is used.
173 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
174 floating_t beta;
175 if (b->sylvain_rave) {
176 beta = (floating_t) r.playouts / (r.playouts + n.playouts
177 + (floating_t) n.playouts * r.playouts / b->equiv_rave);
178 } else {
179 /* XXX: This can be cached in descend; but we don't use this by default. */
180 beta = sqrt(b->equiv_rave / (3 * node->parent->u.playouts + b->equiv_rave));
183 value = beta * r.value + (1.f - beta) * n.value;
184 URAVE_DEBUG fprintf(stderr, "\t%s value = %f * %f + (1 - %f) * %f (prior %f)\n",
185 coord2sstr(node_coord(node), tree->board), beta, r.value, beta, n.value, node->prior.value);
186 } else {
187 value = n.value;
188 URAVE_DEBUG fprintf(stderr, "\t%s value = %f (prior %f)\n",
189 coord2sstr(node_coord(node), tree->board), n.value, node->prior.value);
191 } else if (r.playouts) {
192 value = r.value;
193 URAVE_DEBUG fprintf(stderr, "\t%s value = rave %f (prior %f)\n",
194 coord2sstr(node_coord(node), tree->board), r.value, node->prior.value);
196 descent->value.playouts = r.playouts + n.playouts;
197 descent->value.value = value;
198 return tree_node_get_value(tree, parity, value);
201 void
202 ucb1rave_descend(struct uct_policy *p, struct tree *tree, struct uct_descent *descent, int parity, bool allow_pass)
204 struct ucb1_policy_amaf *b = p->data;
205 floating_t nconf = 1.f;
206 if (b->explore_p > 0)
207 nconf = sqrt(log(descent->node->u.playouts + descent->node->prior.playouts));
208 struct uct *u = p->uct;
209 int vwin = 0;
210 if (u->max_slaves > 0 && u->slave_index >= 0)
211 vwin = descent->node == tree->root ? b->root_virtual_win : b->virtual_win;
212 int child = 0;
214 uctd_try_node_children(tree, descent, allow_pass, parity, u->tenuki_d, di, urgency) {
215 struct tree_node *ni = di.node;
216 urgency = ucb1rave_evaluate(p, tree, &di, parity);
218 /* In distributed mode, encourage different slaves to work on different
219 * parts of the tree. We rely on the fact that children (if they exist)
220 * are the same and in the same order in all slaves. */
221 if (vwin > 0 && ni->u.playouts > b->vwin_min_playouts && (child - u->slave_index) % u->max_slaves == 0)
222 urgency += vwin / (ni->u.playouts + vwin);
224 if (ni->u.playouts > 0 && b->explore_p > 0) {
225 urgency += b->explore_p * nconf / fast_sqrt(ni->u.playouts);
227 } else if (ni->u.playouts + ni->amaf.playouts + ni->prior.playouts == 0) {
228 /* assert(!u->even_eqex); */
229 urgency = b->fpu;
231 } uctd_set_best_child(di, urgency);
233 uctd_get_best_child(descent);
237 /* Return the length of the current ko (number of moves up to to the last ko capture),
238 * 0 if the sequence is empty or doesn't start with a ko capture.
239 * B captures a ko
240 * W plays a ko threat
241 * B answers ko threat
242 * W re-captures the ko <- return 4
243 * B plays a ko threat
244 * W connects the ko */
245 static inline int ko_length(bool *ko_capture_map, int map_length)
247 if (map_length <= 0 || !ko_capture_map[0]) return 0;
248 int length = 1;
249 while (length + 2 < map_length && ko_capture_map[length + 2]) length += 3;
250 return length;
253 void
254 ucb1amaf_update(struct uct_policy *p, struct tree *tree, struct tree_node *node,
255 enum stone node_color, enum stone player_color,
256 struct playout_amafmap *map, struct board *final_board,
257 floating_t result)
259 struct ucb1_policy_amaf *b = p->data;
260 enum stone winner_color = result > 0.5 ? S_BLACK : S_WHITE;
262 /* Record of the random playout - for each intersection coord,
263 * first_move[coord] is the index map->game of the first move
264 * at this coordinate, or INT_MAX if the move was not played.
265 * The parity gives the color of this move.
267 int first_map[board_size2(final_board)+1];
268 int *first_move = &first_map[1]; // +1 for pass
270 #if 0
271 struct board bb; bb.size = 9+2;
272 for (struct tree_node *ni = node; ni; ni = ni->parent)
273 fprintf(stderr, "%s ", coord2sstr(node_coord(ni), &bb));
274 fprintf(stderr, "[color %d] update result %d (color %d)\n",
275 node_color, result, player_color);
276 #endif
278 /* Initialize first_move */
279 for (int i = pass; i < board_size2(final_board); i++) first_move[i] = INT_MAX;
280 int move;
281 assert(map->gamelen > 0);
282 for (move = map->gamelen - 1; move >= map->game_baselen; move--)
283 first_move[map->game[move]] = move;
285 while (node) {
286 if (!b->crit_amaf && !is_pass(node_coord(node))) {
287 stats_add_result(&node->winner_owner, board_local_value(b->crit_lvalue, final_board, node_coord(node), winner_color), 1);
288 stats_add_result(&node->black_owner, board_local_value(b->crit_lvalue, final_board, node_coord(node), S_BLACK), 1);
290 stats_add_result(&node->u, result, 1);
292 bool *ko_capture_map = &map->is_ko_capture[move+1];
293 int max_threat_dist = b->threat_rave <= 0 ? ko_length(ko_capture_map, map->gamelen - (move+1)) : -1;
295 /* This loop ignores symmetry considerations, but they should
296 * matter only at a point when AMAF doesn't help much. */
297 assert(map->game_baselen >= 0);
298 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
299 if (is_pass(node_coord(ni))) continue;
301 /* Use the child move only if it was first played by the same color. */
302 int first = first_move[node_coord(ni)];
303 if (first == INT_MAX) continue;
304 assert(first > move && first < map->gamelen);
305 int distance = first - (move + 1);
306 if (distance & 1) continue;
308 int weight = 1;
309 floating_t res = result;
311 /* Don't give amaf bonus to a ko threat before taking the ko.
312 * http://www.grappa.univ-lille3.fr/~coulom/Aja_PhD_Thesis.pdf
314 if (distance <= max_threat_dist && distance % 6 == 4) {
315 weight = - b->threat_rave;
316 res = 1.0 - res;
317 } else if (b->distance_rave != 0) {
318 /* Give more weight to moves played earlier */
319 weight += b->distance_rave * (map->gamelen - first) / (map->gamelen - move);
321 stats_add_result(&ni->amaf, res, weight);
323 if (b->crit_amaf) {
324 stats_add_result(&ni->winner_owner, board_local_value(b->crit_lvalue, final_board, node_coord(ni), winner_color), 1);
325 stats_add_result(&ni->black_owner, board_local_value(b->crit_lvalue, final_board, node_coord(ni), S_BLACK), 1);
327 #if 0
328 struct board bb; bb.size = 9+2;
329 fprintf(stderr, "* %s<%"PRIhash"> -> %s<%"PRIhash"> [%d/%f => %d/%f]\n",
330 coord2sstr(node_coord(node), &bb), node->hash,
331 coord2sstr(node_coord(ni), &bb), ni->hash,
332 player_color, result, move, res);
333 #endif
335 if (node->parent) {
336 assert(move >= 0 && map->game[move] == node_coord(node) && first_move[node_coord(node)] > move);
337 first_move[node_coord(node)] = move;
338 move--;
340 node = node->parent;
345 struct uct_policy *
346 policy_ucb1amaf_init(struct uct *u, char *arg, struct board *board)
348 struct uct_policy *p = calloc2(1, sizeof(*p));
349 struct ucb1_policy_amaf *b = calloc2(1, sizeof(*b));
350 p->uct = u;
351 p->data = b;
352 p->choose = uctp_generic_choose;
353 p->winner = uctp_generic_winner;
354 p->evaluate = ucb1rave_evaluate;
355 p->descend = ucb1rave_descend;
356 p->update = ucb1amaf_update;
357 p->wants_amaf = true;
359 b->explore_p = 0;
360 b->equiv_rave = board_large(board) ? 4000 : 3000;
361 b->fpu = INFINITY;
362 b->sylvain_rave = true;
363 b->distance_rave = 3;
364 b->threat_rave = 0;
365 b->ltree_rave = 0.75f;
367 b->crit_rave = 1.1f;
368 b->crit_min_playouts = 2000;
369 b->crit_negative = 1;
370 b->crit_amaf = 0;
372 b->virtual_win = 5;
373 b->root_virtual_win = 30;
374 b->vwin_min_playouts = 1000;
376 if (arg) {
377 char *optspec, *next = arg;
378 while (*next) {
379 optspec = next;
380 next += strcspn(next, ":");
381 if (*next) { *next++ = 0; } else { *next = 0; }
383 char *optname = optspec;
384 char *optval = strchr(optspec, '=');
385 if (optval) *optval++ = 0;
387 if (!strcasecmp(optname, "explore_p")) {
388 b->explore_p = atof(optval);
389 } else if (!strcasecmp(optname, "fpu") && optval) {
390 b->fpu = atof(optval);
391 } else if (!strcasecmp(optname, "equiv_rave") && optval) {
392 b->equiv_rave = atof(optval);
393 } else if (!strcasecmp(optname, "sylvain_rave")) {
394 b->sylvain_rave = !optval || *optval == '1';
395 } else if (!strcasecmp(optname, "distance_rave") && optval) {
396 b->distance_rave = atoi(optval);
397 } else if (!strcasecmp(optname, "threat_rave") && optval) {
398 b->threat_rave = atoi(optval);
399 } else if (!strcasecmp(optname, "ltree_rave") && optval) {
400 b->ltree_rave = atof(optval);
401 } else if (!strcasecmp(optname, "crit_rave") && optval) {
402 b->crit_rave = atof(optval);
403 } else if (!strcasecmp(optname, "crit_min_playouts") && optval) {
404 b->crit_min_playouts = atoi(optval);
405 } else if (!strcasecmp(optname, "crit_plthres_coef") && optval) {
406 b->crit_plthres_coef = atof(optval);
407 } else if (!strcasecmp(optname, "crit_negative")) {
408 b->crit_negative = !optval || *optval == '1';
409 } else if (!strcasecmp(optname, "crit_negflip")) {
410 b->crit_negflip = !optval || *optval == '1';
411 } else if (!strcasecmp(optname, "crit_amaf")) {
412 b->crit_amaf = !optval || *optval == '1';
413 } else if (!strcasecmp(optname, "crit_lvalue")) {
414 b->crit_lvalue = !optval || *optval == '1';
415 } else if (!strcasecmp(optname, "libmap_rave") && optval) {
416 b->libmap_rave = atof(optval);
417 } else if (!strcasecmp(optname, "virtual_win") && optval) {
418 b->virtual_win = atoi(optval);
419 } else if (!strcasecmp(optname, "root_virtual_win") && optval) {
420 b->root_virtual_win = atoi(optval);
421 } else if (!strcasecmp(optname, "vwin_min_playouts") && optval) {
422 b->vwin_min_playouts = atoi(optval);
423 } else {
424 fprintf(stderr, "ucb1amaf: Invalid policy argument %s or missing value\n",
425 optname);
426 exit(1);
431 return p;