Criticality crit_negflip: Add support
[pachi/t.git] / uct / policy / ucb1amaf.c
blob86757fb73c0b6945a8928cd53850f648cf946cc2
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "move.h"
10 #include "random.h"
11 #include "tactics/util.h"
12 #include "uct/internal.h"
13 #include "uct/tree.h"
14 #include "uct/policy/generic.h"
16 /* This implements the UCB1 policy with an extra AMAF heuristics. */
18 struct ucb1_policy_amaf {
19 /* This is what the Modification of UCT with Patterns in Monte Carlo Go
20 * paper calls 'p'. Original UCB has this on 2, but this seems to
21 * produce way too wide searches; reduce this to get deeper and
22 * narrower readouts - try 0.2. */
23 floating_t explore_p;
24 /* First Play Urgency - if set to less than infinity (the MoGo paper
25 * above reports 1.0 as the best), new branches are explored only
26 * if none of the existing ones has higher urgency than fpu. */
27 floating_t fpu;
28 unsigned int equiv_rave;
29 bool check_nakade;
30 bool sylvain_rave;
31 /* Coefficient of local tree values embedded in RAVE. */
32 floating_t ltree_rave;
33 /* Coefficient of criticality embedded in RAVE. */
34 floating_t crit_rave;
35 int crit_min_playouts;
36 bool crit_negative;
37 bool crit_negflip;
38 bool crit_amaf;
39 bool crit_lvalue;
43 static inline floating_t fast_sqrt(unsigned int x)
45 static const floating_t table[] = {
46 0, 1, 1.41421356237309504880, 1.73205080756887729352,
47 2.00000000000000000000, 2.23606797749978969640,
48 2.44948974278317809819, 2.64575131106459059050,
49 2.82842712474619009760, 3.00000000000000000000,
50 3.16227766016837933199, 3.31662479035539984911,
51 3.46410161513775458705, 3.60555127546398929311,
52 3.74165738677394138558, 3.87298334620741688517,
53 4.00000000000000000000, 4.12310562561766054982,
54 4.24264068711928514640, 4.35889894354067355223,
55 4.47213595499957939281, 4.58257569495584000658,
56 4.69041575982342955456, 4.79583152331271954159,
57 4.89897948556635619639, 5.00000000000000000000,
58 5.09901951359278483002, 5.19615242270663188058,
59 5.29150262212918118100, 5.38516480713450403125,
60 5.47722557505166113456, 5.56776436283002192211,
61 5.65685424949238019520, 5.74456264653802865985,
62 5.83095189484530047087, 5.91607978309961604256,
63 6.00000000000000000000, 6.08276253029821968899,
64 6.16441400296897645025, 6.24499799839839820584,
65 6.32455532033675866399, 6.40312423743284868648,
66 6.48074069840786023096, 6.55743852430200065234,
67 6.63324958071079969822, 6.70820393249936908922,
68 6.78232998312526813906, 6.85565460040104412493,
69 6.92820323027550917410, 7.00000000000000000000,
70 7.07106781186547524400, 7.14142842854284999799,
71 7.21110255092797858623, 7.28010988928051827109,
72 7.34846922834953429459, 7.41619848709566294871,
73 7.48331477354788277116, 7.54983443527074969723,
74 7.61577310586390828566, 7.68114574786860817576,
75 7.74596669241483377035, 7.81024967590665439412,
76 7.87400787401181101968, 7.93725393319377177150,
78 if (x < sizeof(table) / sizeof(*table)) {
79 return table[x];
80 } else {
81 return sqrt(x);
85 #define URAVE_DEBUG if (0)
86 static floating_t inline
87 ucb1rave_evaluate(struct uct_policy *p, struct tree *tree, struct uct_descent *descent, int parity)
89 struct ucb1_policy_amaf *b = p->data;
90 struct tree_node *node = descent->node;
91 struct tree_node *lnode = descent->lnode;
93 struct move_stats n = node->u, r = node->amaf;
94 if (p->uct->amaf_prior) {
95 stats_merge(&r, &node->prior);
96 } else {
97 stats_merge(&n, &node->prior);
100 /* Local tree heuristics. */
101 assert(!lnode || lnode->parent);
102 if (p->uct->local_tree && b->ltree_rave > 0 && lnode
103 && (p->uct->local_tree_rootchoose || lnode->parent->parent)) {
104 struct move_stats l = lnode->u;
105 l.playouts = ((floating_t) l.playouts) * b->ltree_rave / LTREE_PLAYOUTS_MULTIPLIER;
106 URAVE_DEBUG fprintf(stderr, "[ltree] adding [%s] %f%%%d to [%s] RAVE %f%%%d\n",
107 coord2sstr(node_coord(lnode), tree->board), l.value, l.playouts,
108 coord2sstr(node_coord(node), tree->board), r.value, r.playouts);
109 stats_merge(&r, &l);
112 /* Criticality heuristics. */
113 if (b->crit_rave > 0 && node->u.playouts > b->crit_min_playouts) {
114 floating_t crit = tree_node_criticality(tree, node);
115 if (b->crit_negative || crit > 0) {
116 floating_t val = 1.0f;
117 if (b->crit_negflip && crit < 0) {
118 val = 0;
119 crit = -crit;
121 struct move_stats c = {
122 .value = tree_node_get_value(tree, parity, val),
123 .playouts = crit * r.playouts * b->crit_rave
125 URAVE_DEBUG fprintf(stderr, "[crit] adding %f%%%d to [%s] RAVE %f%%%d\n",
126 c.value, c.playouts,
127 coord2sstr(node_coord(node), tree->board), r.value, r.playouts);
128 stats_merge(&r, &c);
133 floating_t value = 0;
134 if (n.playouts) {
135 if (r.playouts) {
136 /* At the beginning, beta is at 1 and RAVE is used.
137 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
138 floating_t beta;
139 if (b->sylvain_rave) {
140 beta = (floating_t) r.playouts / (r.playouts + n.playouts
141 + (floating_t) n.playouts * r.playouts / b->equiv_rave);
142 } else {
143 /* XXX: This can be cached in descend; but we don't use this by default. */
144 beta = sqrt(b->equiv_rave / (3 * node->parent->u.playouts + b->equiv_rave));
147 value = beta * r.value + (1.f - beta) * n.value;
148 URAVE_DEBUG fprintf(stderr, "\t%s value = %f * %f + (1 - %f) * %f (prior %f)\n",
149 coord2sstr(node_coord(node), tree->board), beta, r.value, beta, n.value, node->prior.value);
150 } else {
151 value = n.value;
152 URAVE_DEBUG fprintf(stderr, "\t%s value = %f (prior %f)\n",
153 coord2sstr(node_coord(node), tree->board), n.value, node->prior.value);
155 } else if (r.playouts) {
156 value = r.value;
157 URAVE_DEBUG fprintf(stderr, "\t%s value = rave %f (prior %f)\n",
158 coord2sstr(node_coord(node), tree->board), r.value, node->prior.value);
160 descent->value.playouts = r.playouts + n.playouts;
161 descent->value.value = value;
162 return tree_node_get_value(tree, parity, value);
165 void
166 ucb1rave_descend(struct uct_policy *p, struct tree *tree, struct uct_descent *descent, int parity, bool allow_pass)
168 struct ucb1_policy_amaf *b = p->data;
169 floating_t nconf = 1.f;
170 if (b->explore_p > 0)
171 nconf = sqrt(log(descent->node->u.playouts + descent->node->prior.playouts));
173 uctd_try_node_children(tree, descent, allow_pass, parity, p->uct->tenuki_d, di, urgency) {
174 struct tree_node *ni = di.node;
175 urgency = ucb1rave_evaluate(p, tree, &di, parity);
177 if (ni->u.playouts > 0 && b->explore_p > 0) {
178 urgency += b->explore_p * nconf / fast_sqrt(ni->u.playouts);
180 } else if (ni->u.playouts + ni->amaf.playouts + ni->prior.playouts == 0) {
181 /* assert(!u->even_eqex); */
182 urgency = b->fpu;
184 } uctd_set_best_child(di, urgency);
186 uctd_get_best_child(descent);
190 void
191 ucb1amaf_update(struct uct_policy *p, struct tree *tree, struct tree_node *node,
192 enum stone node_color, enum stone player_color,
193 struct playout_amafmap *map, struct board *final_board,
194 floating_t result)
196 struct ucb1_policy_amaf *b = p->data;
197 enum stone winner_color = result > 0.5 ? S_BLACK : S_WHITE;
198 enum stone child_color = stone_other(node_color);
200 #if 0
201 struct board bb; bb.size = 9+2;
202 for (struct tree_node *ni = node; ni; ni = ni->parent)
203 fprintf(stderr, "%s ", coord2sstr(node_coord(ni), &bb));
204 fprintf(stderr, "[color %d] update result %d (color %d)\n",
205 node_color, result, player_color);
206 #endif
208 while (node) {
209 if (node->parent == NULL)
210 assert(tree->root_color == stone_other(child_color));
212 if (!b->crit_amaf && !is_pass(node_coord(node))) {
213 stats_add_result(&node->winner_owner, board_local_value(b->crit_lvalue, final_board, node_coord(node), winner_color), 1);
214 stats_add_result(&node->black_owner, board_local_value(b->crit_lvalue, final_board, node_coord(node), S_BLACK), 1);
216 stats_add_result(&node->u, result, 1);
217 if (!is_pass(node_coord(node)) && amaf_nakade(map->map[node_coord(node)]))
218 amaf_op(map->map[node_coord(node)], -);
220 /* This loop ignores symmetry considerations, but they should
221 * matter only at a point when AMAF doesn't help much. */
222 assert(map->game_baselen >= 0);
223 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
224 enum stone amaf_color = map->map[node_coord(ni)];
225 assert(amaf_color != S_OFFBOARD);
226 if (amaf_color == S_NONE)
227 continue;
228 if (amaf_nakade(map->map[node_coord(ni)])) {
229 if (!b->check_nakade)
230 continue;
231 unsigned int i;
232 for (i = map->game_baselen; i < map->gamelen; i++)
233 if (map->game[i].coord == node_coord(ni)
234 && map->game[i].color == child_color)
235 break;
236 if (i == map->gamelen)
237 continue;
238 amaf_color = child_color;
241 floating_t nres = result;
242 if (amaf_color != child_color) {
243 continue;
245 /* For child_color != player_color, we still want
246 * to record the result unmodified; in that case,
247 * we will correctly negate them at the descend phase. */
249 if (b->crit_amaf && !is_pass(node_coord(node))) {
250 stats_add_result(&ni->winner_owner, board_local_value(b->crit_lvalue, final_board, node_coord(ni), winner_color), 1);
251 stats_add_result(&ni->black_owner, board_local_value(b->crit_lvalue, final_board, node_coord(ni), S_BLACK), 1);
253 stats_add_result(&ni->amaf, nres, 1);
255 #if 0
256 struct board bb; bb.size = 9+2;
257 fprintf(stderr, "* %s<%"PRIhash"> -> %s<%"PRIhash"> [%d/%f => %d/%f]\n",
258 coord2sstr(node_coord(node), &bb), node->hash,
259 coord2sstr(node_coord(ni), &bb), ni->hash,
260 player_color, result, child_color, nres);
261 #endif
264 if (!is_pass(node_coord(node))) {
265 map->game_baselen--;
267 node = node->parent; child_color = stone_other(child_color);
272 struct uct_policy *
273 policy_ucb1amaf_init(struct uct *u, char *arg)
275 struct uct_policy *p = calloc2(1, sizeof(*p));
276 struct ucb1_policy_amaf *b = calloc2(1, sizeof(*b));
277 p->uct = u;
278 p->data = b;
279 p->choose = uctp_generic_choose;
280 p->winner = uctp_generic_winner;
281 p->evaluate = ucb1rave_evaluate;
282 p->descend = ucb1rave_descend;
283 p->update = ucb1amaf_update;
284 p->wants_amaf = true;
286 b->explore_p = 0; // 0.02 can be also good on 19x19 with prior=eqex=40
287 b->equiv_rave = 3000;
288 b->fpu = INFINITY;
289 b->check_nakade = true;
290 b->sylvain_rave = true;
291 b->ltree_rave = 0.75f;
293 b->crit_rave = 1.0f;
294 b->crit_min_playouts = 2000;
295 b->crit_negative = 1;
296 b->crit_amaf = 0;
298 if (arg) {
299 char *optspec, *next = arg;
300 while (*next) {
301 optspec = next;
302 next += strcspn(next, ":");
303 if (*next) { *next++ = 0; } else { *next = 0; }
305 char *optname = optspec;
306 char *optval = strchr(optspec, '=');
307 if (optval) *optval++ = 0;
309 if (!strcasecmp(optname, "explore_p")) {
310 b->explore_p = atof(optval);
311 } else if (!strcasecmp(optname, "fpu") && optval) {
312 b->fpu = atof(optval);
313 } else if (!strcasecmp(optname, "equiv_rave") && optval) {
314 b->equiv_rave = atof(optval);
315 } else if (!strcasecmp(optname, "sylvain_rave")) {
316 b->sylvain_rave = !optval || *optval == '1';
317 } else if (!strcasecmp(optname, "check_nakade")) {
318 b->check_nakade = !optval || *optval == '1';
319 } else if (!strcasecmp(optname, "ltree_rave") && optval) {
320 b->ltree_rave = atof(optval);
321 } else if (!strcasecmp(optname, "crit_rave") && optval) {
322 b->crit_rave = atof(optval);
323 } else if (!strcasecmp(optname, "crit_min_playouts") && optval) {
324 b->crit_min_playouts = atoi(optval);
325 } else if (!strcasecmp(optname, "crit_negative")) {
326 b->crit_negative = !optval || *optval == '1';
327 } else if (!strcasecmp(optname, "crit_negflip")) {
328 b->crit_negflip = !optval || *optval == '1';
329 } else if (!strcasecmp(optname, "crit_amaf")) {
330 b->crit_amaf = !optval || *optval == '1';
331 } else if (!strcasecmp(optname, "crit_lvalue")) {
332 b->crit_lvalue = !optval || *optval == '1';
333 } else {
334 fprintf(stderr, "ucb1amaf: Invalid policy argument %s or missing value\n",
335 optname);
336 exit(1);
341 return p;