12 #include "tactics/goals.h"
13 #include "tactics/util.h"
14 #include "uct/internal.h"
16 #include "uct/policy/generic.h"
18 /* This implements the UCB1 policy with an extra AMAF heuristics. */
20 struct ucb1_policy_amaf
{
21 /* This is what the Modification of UCT with Patterns in Monte Carlo Go
22 * paper calls 'p'. Original UCB has this on 2, but this seems to
23 * produce way too wide searches; reduce this to get deeper and
24 * narrower readouts - try 0.2. */
26 /* In distributed mode, encourage different slaves to work on different
27 * parts of the tree by adding virtual wins to different nodes. */
30 int vwin_min_playouts
;
31 /* First Play Urgency - if set to less than infinity (the MoGo paper
32 * above reports 1.0 as the best), new branches are explored only
33 * if none of the existing ones has higher urgency than fpu. */
35 unsigned int equiv_rave
;
37 /* Give more weight to moves played earlier. */
39 /* Give 0 or negative rave bonus to ko threats before taking the ko.
40 1=normal bonus, 0=no bonus, -1=invert rave bonus, -2=double penalty... */
42 /* Coefficient of local tree values embedded in RAVE. */
43 floating_t ltree_rave
;
44 /* Coefficient of criticality embedded in RAVE. */
46 int crit_min_playouts
;
47 floating_t crit_plthres_coef
;
52 /* Coefficient of tactical rating embedded in RAVE. */
53 floating_t libmap_rave
;
57 static inline floating_t
fast_sqrt(unsigned int x
)
59 static const floating_t table
[] = {
60 0, 1, 1.41421356237309504880, 1.73205080756887729352,
61 2.00000000000000000000, 2.23606797749978969640,
62 2.44948974278317809819, 2.64575131106459059050,
63 2.82842712474619009760, 3.00000000000000000000,
64 3.16227766016837933199, 3.31662479035539984911,
65 3.46410161513775458705, 3.60555127546398929311,
66 3.74165738677394138558, 3.87298334620741688517,
67 4.00000000000000000000, 4.12310562561766054982,
68 4.24264068711928514640, 4.35889894354067355223,
69 4.47213595499957939281, 4.58257569495584000658,
70 4.69041575982342955456, 4.79583152331271954159,
71 4.89897948556635619639, 5.00000000000000000000,
72 5.09901951359278483002, 5.19615242270663188058,
73 5.29150262212918118100, 5.38516480713450403125,
74 5.47722557505166113456, 5.56776436283002192211,
75 5.65685424949238019520, 5.74456264653802865985,
76 5.83095189484530047087, 5.91607978309961604256,
77 6.00000000000000000000, 6.08276253029821968899,
78 6.16441400296897645025, 6.24499799839839820584,
79 6.32455532033675866399, 6.40312423743284868648,
80 6.48074069840786023096, 6.55743852430200065234,
81 6.63324958071079969822, 6.70820393249936908922,
82 6.78232998312526813906, 6.85565460040104412493,
83 6.92820323027550917410, 7.00000000000000000000,
84 7.07106781186547524400, 7.14142842854284999799,
85 7.21110255092797858623, 7.28010988928051827109,
86 7.34846922834953429459, 7.41619848709566294871,
87 7.48331477354788277116, 7.54983443527074969723,
88 7.61577310586390828566, 7.68114574786860817576,
89 7.74596669241483377035, 7.81024967590665439412,
90 7.87400787401181101968, 7.93725393319377177150,
92 if (x
< sizeof(table
) / sizeof(*table
)) {
99 #define URAVE_DEBUG if (0)
100 static inline floating_t
101 ucb1rave_evaluate(struct uct_policy
*p
, struct tree
*tree
, struct uct_descent
*descent
, int parity
)
103 struct ucb1_policy_amaf
*b
= p
->data
;
104 struct tree_node
*node
= descent
->node
;
105 struct tree_node
*lnode
= descent
->lnode
;
107 struct move_stats n
= node
->u
, r
= node
->amaf
;
108 if (p
->uct
->amaf_prior
) {
109 stats_merge(&r
, &node
->prior
);
111 stats_merge(&n
, &node
->prior
);
114 /* Local tree heuristics. */
115 assert(!lnode
|| lnode
->parent
);
116 if (p
->uct
->local_tree
&& b
->ltree_rave
> 0 && lnode
117 && (p
->uct
->local_tree_rootchoose
|| lnode
->parent
->parent
)) {
118 struct move_stats l
= lnode
->u
;
119 l
.playouts
= ((floating_t
) l
.playouts
) * b
->ltree_rave
/ LTREE_PLAYOUTS_MULTIPLIER
;
120 URAVE_DEBUG
fprintf(stderr
, "[ltree] adding [%s] %f%%%d to [%s] RAVE %f%%%d\n",
121 coord2sstr(node_coord(lnode
), tree
->board
), l
.value
, l
.playouts
,
122 coord2sstr(node_coord(node
), tree
->board
), r
.value
, r
.playouts
);
126 /* Criticality heuristics. */
127 if (b
->crit_rave
> 0 && (b
->crit_plthres_coef
> 0
128 ? node
->u
.playouts
> tree
->root
->u
.playouts
* b
->crit_plthres_coef
129 : node
->u
.playouts
> b
->crit_min_playouts
)) {
130 floating_t crit
= tree_node_criticality(tree
, node
);
131 if (b
->crit_negative
|| crit
> 0) {
132 floating_t val
= 1.0f
;
133 if (b
->crit_negflip
&& crit
< 0) {
137 struct move_stats c
= {
138 .value
= tree_node_get_value(tree
, parity
, val
),
139 .playouts
= crit
* r
.playouts
* b
->crit_rave
141 URAVE_DEBUG
fprintf(stderr
, "[crit] adding %f%%%d to [%s] RAVE %f%%%d\n",
143 coord2sstr(node_coord(node
), tree
->board
), r
.value
, r
.playouts
);
148 /* Tactical rating (liberty map) heuristics. */
149 if (b
->libmap_rave
> 0 && tree
->board
->libmap
) {
150 /* We look at tactical rating of a move relative to
152 /* XXX: We should rather record hashes pertaining this move
153 * in the tree. We entirely miss counter-atari information. */
154 enum stone color
= tree_node_color(tree
, node
);
155 struct move m
= { .coord
= node
->coord
, .color
= color
};
156 struct move_stats l
= libmap_board_move_stats(descent
->board
->libmap
, descent
->board
, m
);
157 if (l
.playouts
> 0) {
158 l
.value
= tree_node_get_value(tree
, parity
, l
.value
);
159 l
.playouts
*= b
->libmap_rave
;
161 URAVE_DEBUG
fprintf(stderr
, "[libmap] adding %f%%%d to [%s %s] RAVE %f%%%d\n",
162 l
.value
, l
.playouts
, stone2str(color
),
163 coord2sstr(node
->coord
, descent
->board
), r
.value
, r
.playouts
);
169 floating_t value
= 0;
172 /* At the beginning, beta is at 1 and RAVE is used.
173 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
175 if (b
->sylvain_rave
) {
176 beta
= (floating_t
) r
.playouts
/ (r
.playouts
+ n
.playouts
177 + (floating_t
) n
.playouts
* r
.playouts
/ b
->equiv_rave
);
179 /* XXX: This can be cached in descend; but we don't use this by default. */
180 beta
= sqrt(b
->equiv_rave
/ (3 * node
->parent
->u
.playouts
+ b
->equiv_rave
));
183 value
= beta
* r
.value
+ (1.f
- beta
) * n
.value
;
184 URAVE_DEBUG
fprintf(stderr
, "\t%s value = %f * %f + (1 - %f) * %f (prior %f)\n",
185 coord2sstr(node_coord(node
), tree
->board
), beta
, r
.value
, beta
, n
.value
, node
->prior
.value
);
188 URAVE_DEBUG
fprintf(stderr
, "\t%s value = %f (prior %f)\n",
189 coord2sstr(node_coord(node
), tree
->board
), n
.value
, node
->prior
.value
);
191 } else if (r
.playouts
) {
193 URAVE_DEBUG
fprintf(stderr
, "\t%s value = rave %f (prior %f)\n",
194 coord2sstr(node_coord(node
), tree
->board
), r
.value
, node
->prior
.value
);
196 descent
->value
.playouts
= r
.playouts
+ n
.playouts
;
197 descent
->value
.value
= value
;
198 return tree_node_get_value(tree
, parity
, value
);
202 ucb1rave_descend(struct uct_policy
*p
, struct tree
*tree
, struct uct_descent
*descent
, int parity
, bool allow_pass
)
204 struct ucb1_policy_amaf
*b
= p
->data
;
205 floating_t nconf
= 1.f
;
206 if (b
->explore_p
> 0)
207 nconf
= sqrt(log(descent
->node
->u
.playouts
+ descent
->node
->prior
.playouts
));
208 struct uct
*u
= p
->uct
;
210 if (u
->max_slaves
> 0 && u
->slave_index
>= 0)
211 vwin
= descent
->node
== tree
->root
? b
->root_virtual_win
: b
->virtual_win
;
214 uctd_try_node_children(tree
, descent
, allow_pass
, parity
, u
->tenuki_d
, di
, urgency
) {
215 struct tree_node
*ni
= di
.node
;
216 urgency
= ucb1rave_evaluate(p
, tree
, &di
, parity
);
218 /* In distributed mode, encourage different slaves to work on different
219 * parts of the tree. We rely on the fact that children (if they exist)
220 * are the same and in the same order in all slaves. */
221 if (vwin
> 0 && ni
->u
.playouts
> b
->vwin_min_playouts
&& (child
- u
->slave_index
) % u
->max_slaves
== 0)
222 urgency
+= vwin
/ (ni
->u
.playouts
+ vwin
);
224 if (ni
->u
.playouts
> 0 && b
->explore_p
> 0) {
225 urgency
+= b
->explore_p
* nconf
/ fast_sqrt(ni
->u
.playouts
);
227 } else if (ni
->u
.playouts
+ ni
->amaf
.playouts
+ ni
->prior
.playouts
== 0) {
228 /* assert(!u->even_eqex); */
231 } uctd_set_best_child(di
, urgency
);
233 uctd_get_best_child(descent
);
237 /* Return the length of the current ko (number of moves up to to the last ko capture),
238 * 0 if the sequence is empty or doesn't start with a ko capture.
240 * W plays a ko threat
241 * B answers ko threat
242 * W re-captures the ko <- return 4
243 * B plays a ko threat
244 * W connects the ko */
245 static inline int ko_length(bool *ko_capture_map
, int map_length
)
247 if (map_length
<= 0 || !ko_capture_map
[0]) return 0;
249 while (length
+ 2 < map_length
&& ko_capture_map
[length
+ 2]) length
+= 3;
254 ucb1amaf_update(struct uct_policy
*p
, struct tree
*tree
, struct tree_node
*node
,
255 enum stone node_color
, enum stone player_color
,
256 struct playout_amafmap
*map
, struct board
*final_board
,
259 struct ucb1_policy_amaf
*b
= p
->data
;
260 enum stone winner_color
= result
> 0.5 ? S_BLACK
: S_WHITE
;
262 /* Record of the random playout - for each intersection coord,
263 * first_move[coord] is the index map->game of the first move
264 * at this coordinate, or INT_MAX if the move was not played.
265 * The parity gives the color of this move.
267 int first_map
[board_size2(final_board
)+1];
268 int *first_move
= &first_map
[1]; // +1 for pass
271 struct board bb
; bb
.size
= 9+2;
272 for (struct tree_node
*ni
= node
; ni
; ni
= ni
->parent
)
273 fprintf(stderr
, "%s ", coord2sstr(node_coord(ni
), &bb
));
274 fprintf(stderr
, "[color %d] update result %d (color %d)\n",
275 node_color
, result
, player_color
);
278 /* Initialize first_move */
279 for (int i
= pass
; i
< board_size2(final_board
); i
++) first_move
[i
] = INT_MAX
;
281 assert(map
->gamelen
> 0);
282 for (move
= map
->gamelen
- 1; move
>= map
->game_baselen
; move
--)
283 first_move
[map
->game
[move
]] = move
;
286 if (!b
->crit_amaf
&& !is_pass(node_coord(node
))) {
287 stats_add_result(&node
->winner_owner
, board_local_value(b
->crit_lvalue
, final_board
, node_coord(node
), winner_color
), 1);
288 stats_add_result(&node
->black_owner
, board_local_value(b
->crit_lvalue
, final_board
, node_coord(node
), S_BLACK
), 1);
290 stats_add_result(&node
->u
, result
, 1);
292 bool *ko_capture_map
= &map
->is_ko_capture
[move
+1];
293 int max_threat_dist
= b
->threat_rave
<= 0 ? ko_length(ko_capture_map
, map
->gamelen
- (move
+1)) : -1;
295 /* This loop ignores symmetry considerations, but they should
296 * matter only at a point when AMAF doesn't help much. */
297 assert(map
->game_baselen
>= 0);
298 for (struct tree_node
*ni
= node
->children
; ni
; ni
= ni
->sibling
) {
299 if (is_pass(node_coord(ni
))) continue;
301 /* Use the child move only if it was first played by the same color. */
302 int first
= first_move
[node_coord(ni
)];
303 if (first
== INT_MAX
) continue;
304 assert(first
> move
&& first
< map
->gamelen
);
305 int distance
= first
- (move
+ 1);
306 if (distance
& 1) continue;
309 floating_t res
= result
;
311 /* Don't give amaf bonus to a ko threat before taking the ko.
312 * http://www.grappa.univ-lille3.fr/~coulom/Aja_PhD_Thesis.pdf
314 if (distance
<= max_threat_dist
&& distance
% 6 == 4) {
315 weight
= - b
->threat_rave
;
317 } else if (b
->distance_rave
!= 0) {
318 /* Give more weight to moves played earlier */
319 weight
+= b
->distance_rave
* (map
->gamelen
- first
) / (map
->gamelen
- move
);
321 stats_add_result(&ni
->amaf
, res
, weight
);
324 stats_add_result(&ni
->winner_owner
, board_local_value(b
->crit_lvalue
, final_board
, node_coord(ni
), winner_color
), 1);
325 stats_add_result(&ni
->black_owner
, board_local_value(b
->crit_lvalue
, final_board
, node_coord(ni
), S_BLACK
), 1);
328 struct board bb
; bb
.size
= 9+2;
329 fprintf(stderr
, "* %s<%"PRIhash
"> -> %s<%"PRIhash
"> [%d/%f => %d/%f]\n",
330 coord2sstr(node_coord(node
), &bb
), node
->hash
,
331 coord2sstr(node_coord(ni
), &bb
), ni
->hash
,
332 player_color
, result
, move
, res
);
336 assert(move
>= 0 && map
->game
[move
] == node_coord(node
) && first_move
[node_coord(node
)] > move
);
337 first_move
[node_coord(node
)] = move
;
346 policy_ucb1amaf_init(struct uct
*u
, char *arg
, struct board
*board
)
348 struct uct_policy
*p
= calloc2(1, sizeof(*p
));
349 struct ucb1_policy_amaf
*b
= calloc2(1, sizeof(*b
));
352 p
->choose
= uctp_generic_choose
;
353 p
->winner
= uctp_generic_winner
;
354 p
->evaluate
= ucb1rave_evaluate
;
355 p
->descend
= ucb1rave_descend
;
356 p
->update
= ucb1amaf_update
;
357 p
->wants_amaf
= true;
360 b
->equiv_rave
= board_large(board
) ? 4000 : 3000;
362 b
->sylvain_rave
= true;
363 b
->distance_rave
= 3;
365 b
->ltree_rave
= 0.75f
;
368 b
->crit_min_playouts
= 2000;
369 b
->crit_negative
= 1;
373 b
->root_virtual_win
= 30;
374 b
->vwin_min_playouts
= 1000;
377 char *optspec
, *next
= arg
;
380 next
+= strcspn(next
, ":");
381 if (*next
) { *next
++ = 0; } else { *next
= 0; }
383 char *optname
= optspec
;
384 char *optval
= strchr(optspec
, '=');
385 if (optval
) *optval
++ = 0;
387 if (!strcasecmp(optname
, "explore_p")) {
388 b
->explore_p
= atof(optval
);
389 } else if (!strcasecmp(optname
, "fpu") && optval
) {
390 b
->fpu
= atof(optval
);
391 } else if (!strcasecmp(optname
, "equiv_rave") && optval
) {
392 b
->equiv_rave
= atof(optval
);
393 } else if (!strcasecmp(optname
, "sylvain_rave")) {
394 b
->sylvain_rave
= !optval
|| *optval
== '1';
395 } else if (!strcasecmp(optname
, "distance_rave") && optval
) {
396 b
->distance_rave
= atoi(optval
);
397 } else if (!strcasecmp(optname
, "threat_rave") && optval
) {
398 b
->threat_rave
= atoi(optval
);
399 } else if (!strcasecmp(optname
, "ltree_rave") && optval
) {
400 b
->ltree_rave
= atof(optval
);
401 } else if (!strcasecmp(optname
, "crit_rave") && optval
) {
402 b
->crit_rave
= atof(optval
);
403 } else if (!strcasecmp(optname
, "crit_min_playouts") && optval
) {
404 b
->crit_min_playouts
= atoi(optval
);
405 } else if (!strcasecmp(optname
, "crit_plthres_coef") && optval
) {
406 b
->crit_plthres_coef
= atof(optval
);
407 } else if (!strcasecmp(optname
, "crit_negative")) {
408 b
->crit_negative
= !optval
|| *optval
== '1';
409 } else if (!strcasecmp(optname
, "crit_negflip")) {
410 b
->crit_negflip
= !optval
|| *optval
== '1';
411 } else if (!strcasecmp(optname
, "crit_amaf")) {
412 b
->crit_amaf
= !optval
|| *optval
== '1';
413 } else if (!strcasecmp(optname
, "crit_lvalue")) {
414 b
->crit_lvalue
= !optval
|| *optval
== '1';
415 } else if (!strcasecmp(optname
, "libmap_rave") && optval
) {
416 b
->libmap_rave
= atof(optval
);
417 } else if (!strcasecmp(optname
, "virtual_win") && optval
) {
418 b
->virtual_win
= atoi(optval
);
419 } else if (!strcasecmp(optname
, "root_virtual_win") && optval
) {
420 b
->root_virtual_win
= atoi(optval
);
421 } else if (!strcasecmp(optname
, "vwin_min_playouts") && optval
) {
422 b
->vwin_min_playouts
= atoi(optval
);
424 fprintf(stderr
, "ucb1amaf: Invalid policy argument %s or missing value\n",