time_in_byoyomi(): Rewrite, smoother to read and more flexible now
[pachi.git] / uct / policy / ucb1amaf.c
blobf961382377aff0951b155b747db40fb9de80ee26
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "move.h"
10 #include "random.h"
11 #include "uct/internal.h"
12 #include "uct/tree.h"
13 #include "uct/policy/generic.h"
15 /* This implements the UCB1 policy with an extra AMAF heuristics. */
17 struct ucb1_policy_amaf {
18 /* This is what the Modification of UCT with Patterns in Monte Carlo Go
19 * paper calls 'p'. Original UCB has this on 2, but this seems to
20 * produce way too wide searches; reduce this to get deeper and
21 * narrower readouts - try 0.2. */
22 float explore_p;
23 /* First Play Urgency - if set to less than infinity (the MoGo paper
24 * above reports 1.0 as the best), new branches are explored only
25 * if none of the existing ones has higher urgency than fpu. */
26 float fpu;
27 int equiv_rave;
28 bool both_colors;
29 bool check_nakade;
30 bool sylvain_rave;
31 /* Coefficient of root values embedded in RAVE. */
32 float root_rave;
36 static inline float fast_sqrt(int x)
38 static const float table[] = {
39 0, 1, 1.41421356237309504880, 1.73205080756887729352,
40 2.00000000000000000000, 2.23606797749978969640,
41 2.44948974278317809819, 2.64575131106459059050,
42 2.82842712474619009760, 3.00000000000000000000,
43 3.16227766016837933199, 3.31662479035539984911,
44 3.46410161513775458705, 3.60555127546398929311,
45 3.74165738677394138558, 3.87298334620741688517,
46 4.00000000000000000000, 4.12310562561766054982,
47 4.24264068711928514640, 4.35889894354067355223,
48 4.47213595499957939281, 4.58257569495584000658,
49 4.69041575982342955456, 4.79583152331271954159,
50 4.89897948556635619639, 5.00000000000000000000,
51 5.09901951359278483002, 5.19615242270663188058,
52 5.29150262212918118100, 5.38516480713450403125,
53 5.47722557505166113456, 5.56776436283002192211,
54 5.65685424949238019520, 5.74456264653802865985,
55 5.83095189484530047087, 5.91607978309961604256,
56 6.00000000000000000000, 6.08276253029821968899,
57 6.16441400296897645025, 6.24499799839839820584,
58 6.32455532033675866399, 6.40312423743284868648,
59 6.48074069840786023096, 6.55743852430200065234,
60 6.63324958071079969822, 6.70820393249936908922,
61 6.78232998312526813906, 6.85565460040104412493,
62 6.92820323027550917410, 7.00000000000000000000,
63 7.07106781186547524400, 7.14142842854284999799,
64 7.21110255092797858623, 7.28010988928051827109,
65 7.34846922834953429459, 7.41619848709566294871,
66 7.48331477354788277116, 7.54983443527074969723,
67 7.61577310586390828566, 7.68114574786860817576,
68 7.74596669241483377035, 7.81024967590665439412,
69 7.87400787401181101968, 7.93725393319377177150,
71 if (x < sizeof(table) / sizeof(*table)) {
72 return table[x];
73 } else {
74 return sqrt(x);
78 static float inline
79 ucb1rave_evaluate(struct uct_policy *p, void **state, struct tree *tree, struct tree_node *node, int parity)
81 struct ucb1_policy_amaf *b = p->data;
83 struct move_stats n = node->u, r = node->amaf;
84 if (p->uct->amaf_prior) {
85 stats_merge(&r, &node->prior);
86 } else {
87 stats_merge(&n, &node->prior);
90 /* Root heuristics, if we aren't actually near the root. */
91 if (tree->chvals && b->root_rave > 0 && likely(!is_pass(node->coord))
92 && node->parent && node->parent->parent && node->parent->parent->parent) {
93 struct move_stats *rv = parity > 0 ? tree->chvals : tree->chchvals;
94 struct move_stats root = rv[node->coord];
95 root.playouts *= b->root_rave;
96 stats_merge(&r, &root);
99 if (tree_parity(tree, parity) < 0) {
100 stats_reverse_parity(&n);
101 stats_reverse_parity(&r);
104 float value = 0;
105 if (n.playouts) {
106 if (r.playouts) {
107 /* At the beginning, beta is at 1 and RAVE is used.
108 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
109 float beta;
110 if (b->sylvain_rave) {
111 beta = (float) r.playouts / (r.playouts + n.playouts
112 + (float) n.playouts * r.playouts / b->equiv_rave);
113 } else {
114 /* XXX: This can be cached in descend; but we don't use this by default. */
115 beta = sqrt(b->equiv_rave / (3 * node->parent->u.playouts + b->equiv_rave));
118 value = beta * r.value + (1.f - beta) * n.value;
119 } else {
120 value = n.value;
122 } else if (r.playouts) {
123 value = r.value;
125 return value;
128 struct tree_node *
129 ucb1rave_descend(struct uct_policy *p, void **state, struct tree *tree, struct tree_node *node, int parity, bool allow_pass)
131 struct ucb1_policy_amaf *b = p->data;
132 float nconf = 1.f;
133 if (b->explore_p > 0)
134 nconf = sqrt(log(node->u.playouts + node->prior.playouts));
136 uctd_try_node_children(node, allow_pass, ni, urgency) {
137 urgency = ucb1rave_evaluate(p, state, tree, ni, parity);
139 if (ni->u.playouts > 0 && b->explore_p > 0) {
140 urgency += b->explore_p * nconf / fast_sqrt(ni->u.playouts);
142 } else if (ni->u.playouts + ni->amaf.playouts + ni->prior.playouts == 0) {
143 /* assert(!u->even_eqex); */
144 urgency = b->fpu;
146 } uctd_set_best_child(ni, urgency);
148 return uctd_get_best_child();
152 void
153 ucb1amaf_update(struct uct_policy *p, struct tree *tree, struct tree_node *node,
154 enum stone node_color, enum stone player_color,
155 struct playout_amafmap *map, float result)
157 struct ucb1_policy_amaf *b = p->data;
158 enum stone child_color = stone_other(node_color);
160 #if 0
161 struct board bb; bb.size = 9+2;
162 for (struct tree_node *ni = node; ni; ni = ni->parent)
163 fprintf(stderr, "%s ", coord2sstr(ni->coord, &bb));
164 fprintf(stderr, "[color %d] update result %d (color %d)\n",
165 node_color, result, player_color);
166 #endif
168 while (node) {
169 if (node->parent == NULL)
170 assert(tree->root_color == stone_other(child_color));
172 stats_add_result(&node->u, result, 1);
173 if (amaf_nakade(map->map[node->coord]))
174 amaf_op(map->map[node->coord], -);
176 /* This loop ignores symmetry considerations, but they should
177 * matter only at a point when AMAF doesn't help much. */
178 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
179 assert(map->map[ni->coord] != S_OFFBOARD);
180 if (map->map[ni->coord] == S_NONE)
181 continue;
182 assert(map->game_baselen >= 0);
183 enum stone amaf_color = map->map[ni->coord];
184 if (amaf_nakade(map->map[ni->coord])) {
185 if (!b->check_nakade)
186 continue;
187 /* We don't care to implement both_colors
188 * properly since it sucks anyway. */
189 int i;
190 for (i = map->game_baselen; i < map->gamelen; i++)
191 if (map->game[i].coord == ni->coord
192 && map->game[i].color == child_color)
193 break;
194 if (i == map->gamelen)
195 continue;
196 amaf_color = child_color;
199 float nres = result;
200 if (amaf_color != child_color) {
201 if (!b->both_colors)
202 continue;
203 nres = 1 - nres;
205 /* For child_color != player_color, we still want
206 * to record the result unmodified; in that case,
207 * we will correctly negate them at the descend phase. */
209 stats_add_result(&ni->amaf, nres, 1);
211 #if 0
212 struct board bb; bb.size = 9+2;
213 fprintf(stderr, "* %s<%"PRIhash"> -> %s<%"PRIhash"> [%d/%f => %d/%f]\n",
214 coord2sstr(node->coord, &bb), node->hash,
215 coord2sstr(ni->coord, &bb), ni->hash,
216 player_color, result, child_color, nres);
217 #endif
220 if (!is_pass(node->coord)) {
221 map->game_baselen--;
223 node = node->parent; child_color = stone_other(child_color);
228 struct uct_policy *
229 policy_ucb1amaf_init(struct uct *u, char *arg)
231 struct uct_policy *p = calloc(1, sizeof(*p));
232 struct ucb1_policy_amaf *b = calloc(1, sizeof(*b));
233 p->uct = u;
234 p->data = b;
235 p->choose = uctp_generic_choose;
236 p->winner = uctp_generic_winner;
237 p->evaluate = ucb1rave_evaluate;
238 p->descend = ucb1rave_descend;
239 p->update = ucb1amaf_update;
240 p->wants_amaf = true;
242 b->explore_p = 0; // 0.02 can be also good on 19x19 with prior=eqex=40
243 b->equiv_rave = 3000;
244 b->fpu = INFINITY;
245 b->check_nakade = true;
246 b->sylvain_rave = true;
247 b->root_rave = 1.0f;
249 if (arg) {
250 char *optspec, *next = arg;
251 while (*next) {
252 optspec = next;
253 next += strcspn(next, ":");
254 if (*next) { *next++ = 0; } else { *next = 0; }
256 char *optname = optspec;
257 char *optval = strchr(optspec, '=');
258 if (optval) *optval++ = 0;
260 if (!strcasecmp(optname, "explore_p")) {
261 b->explore_p = atof(optval);
262 } else if (!strcasecmp(optname, "fpu") && optval) {
263 b->fpu = atof(optval);
264 } else if (!strcasecmp(optname, "equiv_rave") && optval) {
265 b->equiv_rave = atof(optval);
266 } else if (!strcasecmp(optname, "both_colors")) {
267 b->both_colors = true;
268 } else if (!strcasecmp(optname, "sylvain_rave")) {
269 b->sylvain_rave = !optval || *optval == '1';
270 } else if (!strcasecmp(optname, "check_nakade")) {
271 b->check_nakade = !optval || *optval == '1';
272 } else if (!strcasecmp(optname, "root_rave") && optval) {
273 b->root_rave = atof(optval);
274 } else {
275 fprintf(stderr, "ucb1amaf: Invalid policy argument %s or missing value\n",
276 optname);
277 exit(1);
282 return p;