Now prints last_urgency to help debugging. Seems that the problematic moves in 11071...
[pachi/pachi-r6144.git] / uct / policy / ucb1amaf.c
blobd4f63e004fee2a9afc1b94e2809c63a37c479805
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "move.h"
10 #include "random.h"
11 #include "uct/internal.h"
12 #include "uct/tree.h"
13 #include "uct/policy/generic.h"
15 /* This implements the UCB1 policy with an extra AMAF heuristics. */
17 struct ucb1_policy_amaf {
18 unsigned int equiv_rave;
19 bool check_nakade;
20 bool sylvain_rave;
21 /* Coefficient of local tree values embedded in RAVE. */
22 floating_t ltree_rave;
23 /* Coefficient of criticality embedded in RAVE. */
24 floating_t crit_rave;
25 int crit_min_playouts;
26 bool crit_negative;
27 bool crit_amaf;
31 static inline floating_t fast_sqrt(unsigned int x)
33 static const floating_t table[] = {
34 0, 1, 1.41421356237309504880, 1.73205080756887729352,
35 2.00000000000000000000, 2.23606797749978969640,
36 2.44948974278317809819, 2.64575131106459059050,
37 2.82842712474619009760, 3.00000000000000000000,
38 3.16227766016837933199, 3.31662479035539984911,
39 3.46410161513775458705, 3.60555127546398929311,
40 3.74165738677394138558, 3.87298334620741688517,
41 4.00000000000000000000, 4.12310562561766054982,
42 4.24264068711928514640, 4.35889894354067355223,
43 4.47213595499957939281, 4.58257569495584000658,
44 4.69041575982342955456, 4.79583152331271954159,
45 4.89897948556635619639, 5.00000000000000000000,
46 5.09901951359278483002, 5.19615242270663188058,
47 5.29150262212918118100, 5.38516480713450403125,
48 5.47722557505166113456, 5.56776436283002192211,
49 5.65685424949238019520, 5.74456264653802865985,
50 5.83095189484530047087, 5.91607978309961604256,
51 6.00000000000000000000, 6.08276253029821968899,
52 6.16441400296897645025, 6.24499799839839820584,
53 6.32455532033675866399, 6.40312423743284868648,
54 6.48074069840786023096, 6.55743852430200065234,
55 6.63324958071079969822, 6.70820393249936908922,
56 6.78232998312526813906, 6.85565460040104412493,
57 6.92820323027550917410, 7.00000000000000000000,
58 7.07106781186547524400, 7.14142842854284999799,
59 7.21110255092797858623, 7.28010988928051827109,
60 7.34846922834953429459, 7.41619848709566294871,
61 7.48331477354788277116, 7.54983443527074969723,
62 7.61577310586390828566, 7.68114574786860817576,
63 7.74596669241483377035, 7.81024967590665439412,
64 7.87400787401181101968, 7.93725393319377177150,
66 if (x < sizeof(table) / sizeof(*table)) {
67 return table[x];
68 } else {
69 return sqrt(x);
73 #define LTREE_DEBUG if (0)
74 static floating_t inline
75 ucb1rave_evaluate(struct uct_policy *p, struct tree *tree, struct uct_descent *descent, int parity)
77 struct ucb1_policy_amaf *b = p->data;
78 struct tree_node *node = descent->node;
79 struct tree_node *lnode = descent->lnode;
81 struct move_stats n = node->u, r = node->amaf;
82 if (p->uct->amaf_prior) {
83 stats_merge(&r, &node->prior);
84 } else {
85 stats_merge(&n, &node->prior);
88 /* Local tree heuristics. */
89 if (p->uct->local_tree && b->ltree_rave > 0 && lnode) {
90 struct move_stats l = lnode->u;
91 l.playouts = ((floating_t) l.playouts) * b->ltree_rave / LTREE_PLAYOUTS_MULTIPLIER;
92 LTREE_DEBUG fprintf(stderr, "[ltree] adding [%s] %f%%%d to [%s] RAVE %f%%%d\n",
93 coord2sstr(lnode->coord, tree->board), l.value, l.playouts,
94 coord2sstr(node->coord, tree->board), r.value, r.playouts);
95 stats_merge(&r, &l);
98 /* Criticality heuristics. */
99 if (b->crit_rave > 0 && node->u.playouts > b->crit_min_playouts) {
100 floating_t crit = tree_node_criticality(tree, node);
101 if (b->crit_negative || crit > 0) {
102 struct move_stats c = {
103 .value = tree_node_get_value(tree, parity, 1.0f),
104 .playouts = crit * r.playouts * b->crit_rave
106 LTREE_DEBUG fprintf(stderr, "[crit] adding %f%%%d to [%s] RAVE %f%%%d\n",
107 c.value, c.playouts,
108 coord2sstr(node->coord, tree->board), r.value, r.playouts);
109 stats_merge(&r, &c);
114 floating_t value = 0;
115 if (n.playouts) {
116 if (r.playouts) {
117 /* At the beginning, beta is at 1 and RAVE is used.
118 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
119 floating_t beta;
120 if (b->sylvain_rave) {
121 beta = (floating_t) r.playouts / (r.playouts + n.playouts
122 + (floating_t) n.playouts * r.playouts / b->equiv_rave);
123 } else {
124 /* XXX: This can be cached in descend; but we don't use this by default. */
125 beta = sqrt(b->equiv_rave / (3 * node->parent->u.playouts + b->equiv_rave));
128 value = beta * r.value + (1.f - beta) * n.value;
129 } else {
130 value = n.value;
132 } else if (r.playouts) {
133 value = r.value;
135 descent->value.playouts = r.playouts + n.playouts;
136 descent->value.value = value;
137 return tree_node_get_value(tree, parity, value);
140 void
141 ucb1rave_descend(struct uct_policy *p, struct tree *tree, struct uct_descent *descent, int parity, bool allow_pass)
143 /* struct ucb1_policy_amaf *b = p->data; */
144 floating_t explore_p = (fast_random(3 + descent->node->depth) == 0) ? 2.0f : 0.0f;
145 floating_t nconf = 1.f;
146 if (explore_p > 0) {
147 int playouts = descent->node->u.playouts + descent->node->prior.playouts;
148 if (playouts < 1) playouts = 1;
149 nconf = sqrt(log(playouts));
150 if (nconf < 1.0f) nconf = 1.0f;
153 uctd_try_node_children(tree, descent, allow_pass, parity, p->uct->tenuki_d, di, urgency) {
154 struct tree_node *ni = di.node;
155 urgency = ucb1rave_evaluate(p, tree, &di, parity);
156 ni->last_urgency = tree_node_get_value(tree, parity, urgency); // convert it back to a black-oriented value
158 if (explore_p > 0) {
159 /* It is probably safe to include passes, although passes will still be generated without the extra urgency. */
160 /* Infinite first-play urgency will somehow break things... */
161 floating_t coef = (ni->u.playouts > 0) ? 1.0f / fast_sqrt(ni->u.playouts) : 2.0f;
162 urgency += explore_p * nconf * coef;
164 /* fprintf(stderr, "[%s] urgency=%0.3f\n", coord2sstr(ni->coord, tree->board), urgency); */
165 } uctd_set_best_child(di, urgency);
167 uctd_get_best_child(descent);
171 void
172 ucb1amaf_update(struct uct_policy *p, struct tree *tree, struct tree_node *node,
173 enum stone node_color, enum stone player_color,
174 struct playout_amafmap *map, struct board *final_board,
175 floating_t result)
177 struct ucb1_policy_amaf *b = p->data;
178 enum stone winner_color = result > 0.5 ? S_BLACK : S_WHITE;
179 enum stone child_color = stone_other(node_color);
181 #if 0
182 struct board bb; bb.size = 9+2;
183 for (struct tree_node *ni = node; ni; ni = ni->parent)
184 fprintf(stderr, "%s ", coord2sstr(ni->coord, &bb));
185 fprintf(stderr, "[color %d] update result %d (color %d)\n",
186 node_color, result, player_color);
187 #endif
189 while (node) {
190 if (node->parent == NULL)
191 assert(tree->root_color == stone_other(child_color));
193 if (!b->crit_amaf && !is_pass(node->coord)) {
194 stats_add_result(&node->winner_owner, board_at(final_board, node->coord) == winner_color ? 1.0 : 0.0, 1);
195 stats_add_result(&node->black_owner, board_at(final_board, node->coord) == S_BLACK ? 1.0 : 0.0, 1);
197 stats_add_result(&node->u, result, 1);
198 if (amaf_nakade(map->map[node->coord]))
199 amaf_op(map->map[node->coord], -);
201 /* This loop ignores symmetry considerations, but they should
202 * matter only at a point when AMAF doesn't help much. */
203 assert(map->game_baselen >= 0);
204 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
205 enum stone amaf_color = map->map[ni->coord];
206 assert(amaf_color != S_OFFBOARD);
207 if (amaf_color == S_NONE)
208 continue;
209 if (amaf_nakade(map->map[ni->coord])) {
210 if (!b->check_nakade)
211 continue;
212 unsigned int i;
213 for (i = map->game_baselen; i < map->gamelen; i++)
214 if (map->game[i].coord == ni->coord
215 && map->game[i].color == child_color)
216 break;
217 if (i == map->gamelen)
218 continue;
219 amaf_color = child_color;
222 floating_t nres = result;
223 if (amaf_color != child_color) {
224 continue;
226 /* For child_color != player_color, we still want
227 * to record the result unmodified; in that case,
228 * we will correctly negate them at the descend phase. */
230 if (b->crit_amaf && !is_pass(node->coord)) {
231 stats_add_result(&ni->winner_owner, board_at(final_board, ni->coord) == winner_color ? 1.0 : 0.0, 1);
232 stats_add_result(&ni->black_owner, board_at(final_board, ni->coord) == S_BLACK ? 1.0 : 0.0, 1);
234 stats_add_result(&ni->amaf, nres, 1);
236 #if 0
237 struct board bb; bb.size = 9+2;
238 fprintf(stderr, "* %s<%"PRIhash"> -> %s<%"PRIhash"> [%d/%f => %d/%f]\n",
239 coord2sstr(node->coord, &bb), node->hash,
240 coord2sstr(ni->coord, &bb), ni->hash,
241 player_color, result, child_color, nres);
242 #endif
245 if (!is_pass(node->coord)) {
246 map->game_baselen--;
248 node = node->parent; child_color = stone_other(child_color);
253 struct uct_policy *
254 policy_ucb1amaf_init(struct uct *u, char *arg)
256 struct uct_policy *p = calloc2(1, sizeof(*p));
257 struct ucb1_policy_amaf *b = calloc2(1, sizeof(*b));
258 p->uct = u;
259 p->data = b;
260 p->choose = uctp_generic_choose;
261 p->winner = uctp_generic_winner;
262 p->evaluate = ucb1rave_evaluate;
263 p->descend = ucb1rave_descend;
264 p->update = ucb1amaf_update;
265 p->wants_amaf = true;
267 b->equiv_rave = 3000;
268 b->check_nakade = true;
269 b->sylvain_rave = true;
270 b->ltree_rave = 0.75f;
272 b->crit_rave = 1.0f;
273 b->crit_min_playouts = 2000;
274 b->crit_negative = 1;
275 b->crit_amaf = 0;
277 if (arg) {
278 char *optspec, *next = arg;
279 while (*next) {
280 optspec = next;
281 next += strcspn(next, ":");
282 if (*next) { *next++ = 0; } else { *next = 0; }
284 char *optname = optspec;
285 char *optval = strchr(optspec, '=');
286 if (optval) *optval++ = 0;
288 if (!strcasecmp(optname, "equiv_rave") && optval) {
289 b->equiv_rave = atof(optval);
290 } else if (!strcasecmp(optname, "sylvain_rave")) {
291 b->sylvain_rave = !optval || *optval == '1';
292 } else if (!strcasecmp(optname, "check_nakade")) {
293 b->check_nakade = !optval || *optval == '1';
294 } else if (!strcasecmp(optname, "ltree_rave") && optval) {
295 b->ltree_rave = atof(optval);
296 } else if (!strcasecmp(optname, "crit_rave") && optval) {
297 b->crit_rave = atof(optval);
298 } else if (!strcasecmp(optname, "crit_min_playouts") && optval) {
299 b->crit_min_playouts = atoi(optval);
300 } else if (!strcasecmp(optname, "crit_negative")) {
301 b->crit_negative = !optval || *optval == '1';
302 } else if (!strcasecmp(optname, "crit_amaf")) {
303 b->crit_amaf = !optval || *optval == '1';
304 } else {
305 fprintf(stderr, "ucb1amaf: Invalid policy argument %s or missing value\n",
306 optname);
307 exit(1);
312 return p;