11 #include "uct/internal.h"
13 #include "uct/policy/generic.h"
15 /* This implements the UCB1 policy with an extra AMAF heuristics. */
17 struct ucb1_policy_amaf
{
18 unsigned int equiv_rave
;
21 /* Coefficient of local tree values embedded in RAVE. */
22 floating_t ltree_rave
;
23 /* Coefficient of criticality embedded in RAVE. */
25 int crit_min_playouts
;
31 static inline floating_t
fast_sqrt(unsigned int x
)
33 static const floating_t table
[] = {
34 0, 1, 1.41421356237309504880, 1.73205080756887729352,
35 2.00000000000000000000, 2.23606797749978969640,
36 2.44948974278317809819, 2.64575131106459059050,
37 2.82842712474619009760, 3.00000000000000000000,
38 3.16227766016837933199, 3.31662479035539984911,
39 3.46410161513775458705, 3.60555127546398929311,
40 3.74165738677394138558, 3.87298334620741688517,
41 4.00000000000000000000, 4.12310562561766054982,
42 4.24264068711928514640, 4.35889894354067355223,
43 4.47213595499957939281, 4.58257569495584000658,
44 4.69041575982342955456, 4.79583152331271954159,
45 4.89897948556635619639, 5.00000000000000000000,
46 5.09901951359278483002, 5.19615242270663188058,
47 5.29150262212918118100, 5.38516480713450403125,
48 5.47722557505166113456, 5.56776436283002192211,
49 5.65685424949238019520, 5.74456264653802865985,
50 5.83095189484530047087, 5.91607978309961604256,
51 6.00000000000000000000, 6.08276253029821968899,
52 6.16441400296897645025, 6.24499799839839820584,
53 6.32455532033675866399, 6.40312423743284868648,
54 6.48074069840786023096, 6.55743852430200065234,
55 6.63324958071079969822, 6.70820393249936908922,
56 6.78232998312526813906, 6.85565460040104412493,
57 6.92820323027550917410, 7.00000000000000000000,
58 7.07106781186547524400, 7.14142842854284999799,
59 7.21110255092797858623, 7.28010988928051827109,
60 7.34846922834953429459, 7.41619848709566294871,
61 7.48331477354788277116, 7.54983443527074969723,
62 7.61577310586390828566, 7.68114574786860817576,
63 7.74596669241483377035, 7.81024967590665439412,
64 7.87400787401181101968, 7.93725393319377177150,
66 if (x
< sizeof(table
) / sizeof(*table
)) {
73 #define LTREE_DEBUG if (0)
74 static floating_t
inline
75 ucb1rave_evaluate(struct uct_policy
*p
, struct tree
*tree
, struct uct_descent
*descent
, int parity
)
77 struct ucb1_policy_amaf
*b
= p
->data
;
78 struct tree_node
*node
= descent
->node
;
79 struct tree_node
*lnode
= descent
->lnode
;
81 struct move_stats n
= node
->u
, r
= node
->amaf
;
82 if (p
->uct
->amaf_prior
) {
83 stats_merge(&r
, &node
->prior
);
85 stats_merge(&n
, &node
->prior
);
88 /* Local tree heuristics. */
89 if (p
->uct
->local_tree
&& b
->ltree_rave
> 0 && lnode
) {
90 struct move_stats l
= lnode
->u
;
91 l
.playouts
= ((floating_t
) l
.playouts
) * b
->ltree_rave
/ LTREE_PLAYOUTS_MULTIPLIER
;
92 LTREE_DEBUG
fprintf(stderr
, "[ltree] adding [%s] %f%%%d to [%s] RAVE %f%%%d\n",
93 coord2sstr(lnode
->coord
, tree
->board
), l
.value
, l
.playouts
,
94 coord2sstr(node
->coord
, tree
->board
), r
.value
, r
.playouts
);
98 /* Criticality heuristics. */
99 if (b
->crit_rave
> 0 && node
->u
.playouts
> b
->crit_min_playouts
) {
100 floating_t crit
= tree_node_criticality(tree
, node
);
101 if (b
->crit_negative
|| crit
> 0) {
102 struct move_stats c
= {
103 .value
= tree_node_get_value(tree
, parity
, 1.0f
),
104 .playouts
= crit
* r
.playouts
* b
->crit_rave
106 LTREE_DEBUG
fprintf(stderr
, "[crit] adding %f%%%d to [%s] RAVE %f%%%d\n",
108 coord2sstr(node
->coord
, tree
->board
), r
.value
, r
.playouts
);
114 floating_t value
= 0;
117 /* At the beginning, beta is at 1 and RAVE is used.
118 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
120 if (b
->sylvain_rave
) {
121 beta
= (floating_t
) r
.playouts
/ (r
.playouts
+ n
.playouts
122 + (floating_t
) n
.playouts
* r
.playouts
/ b
->equiv_rave
);
124 /* XXX: This can be cached in descend; but we don't use this by default. */
125 beta
= sqrt(b
->equiv_rave
/ (3 * node
->parent
->u
.playouts
+ b
->equiv_rave
));
128 value
= beta
* r
.value
+ (1.f
- beta
) * n
.value
;
132 } else if (r
.playouts
) {
135 descent
->value
.playouts
= r
.playouts
+ n
.playouts
;
136 descent
->value
.value
= value
;
137 return tree_node_get_value(tree
, parity
, value
);
141 ucb1rave_descend(struct uct_policy
*p
, struct tree
*tree
, struct uct_descent
*descent
, int parity
, bool allow_pass
)
143 /* struct ucb1_policy_amaf *b = p->data; */
144 floating_t explore_p
= (fast_random(3 + descent
->node
->depth
) == 0) ? 2.0f
: 0.0f
;
145 floating_t nconf
= 1.f
;
147 int playouts
= descent
->node
->u
.playouts
+ descent
->node
->prior
.playouts
;
148 if (playouts
< 1) playouts
= 1;
149 nconf
= sqrt(log(playouts
));
150 if (nconf
< 1.0f
) nconf
= 1.0f
;
153 uctd_try_node_children(tree
, descent
, allow_pass
, parity
, p
->uct
->tenuki_d
, di
, urgency
) {
154 struct tree_node
*ni
= di
.node
;
155 urgency
= ucb1rave_evaluate(p
, tree
, &di
, parity
);
156 ni
->last_urgency
= tree_node_get_value(tree
, parity
, urgency
); // convert it back to a black-oriented value
159 /* It is probably safe to include passes, although passes will still be generated without the extra urgency. */
160 /* Infinite first-play urgency will somehow break things... */
161 floating_t coef
= (ni
->u
.playouts
> 0) ? 1.0f
/ fast_sqrt(ni
->u
.playouts
) : 2.0f
;
162 urgency
+= explore_p
* nconf
* coef
;
164 /* fprintf(stderr, "[%s] urgency=%0.3f\n", coord2sstr(ni->coord, tree->board), urgency); */
165 } uctd_set_best_child(di
, urgency
);
167 uctd_get_best_child(descent
);
172 ucb1amaf_update(struct uct_policy
*p
, struct tree
*tree
, struct tree_node
*node
,
173 enum stone node_color
, enum stone player_color
,
174 struct playout_amafmap
*map
, struct board
*final_board
,
177 struct ucb1_policy_amaf
*b
= p
->data
;
178 enum stone winner_color
= result
> 0.5 ? S_BLACK
: S_WHITE
;
179 enum stone child_color
= stone_other(node_color
);
182 struct board bb
; bb
.size
= 9+2;
183 for (struct tree_node
*ni
= node
; ni
; ni
= ni
->parent
)
184 fprintf(stderr
, "%s ", coord2sstr(ni
->coord
, &bb
));
185 fprintf(stderr
, "[color %d] update result %d (color %d)\n",
186 node_color
, result
, player_color
);
190 if (node
->parent
== NULL
)
191 assert(tree
->root_color
== stone_other(child_color
));
193 if (!b
->crit_amaf
&& !is_pass(node
->coord
)) {
194 stats_add_result(&node
->winner_owner
, board_at(final_board
, node
->coord
) == winner_color
? 1.0 : 0.0, 1);
195 stats_add_result(&node
->black_owner
, board_at(final_board
, node
->coord
) == S_BLACK
? 1.0 : 0.0, 1);
197 stats_add_result(&node
->u
, result
, 1);
198 if (amaf_nakade(map
->map
[node
->coord
]))
199 amaf_op(map
->map
[node
->coord
], -);
201 /* This loop ignores symmetry considerations, but they should
202 * matter only at a point when AMAF doesn't help much. */
203 assert(map
->game_baselen
>= 0);
204 for (struct tree_node
*ni
= node
->children
; ni
; ni
= ni
->sibling
) {
205 enum stone amaf_color
= map
->map
[ni
->coord
];
206 assert(amaf_color
!= S_OFFBOARD
);
207 if (amaf_color
== S_NONE
)
209 if (amaf_nakade(map
->map
[ni
->coord
])) {
210 if (!b
->check_nakade
)
213 for (i
= map
->game_baselen
; i
< map
->gamelen
; i
++)
214 if (map
->game
[i
].coord
== ni
->coord
215 && map
->game
[i
].color
== child_color
)
217 if (i
== map
->gamelen
)
219 amaf_color
= child_color
;
222 floating_t nres
= result
;
223 if (amaf_color
!= child_color
) {
226 /* For child_color != player_color, we still want
227 * to record the result unmodified; in that case,
228 * we will correctly negate them at the descend phase. */
230 if (b
->crit_amaf
&& !is_pass(node
->coord
)) {
231 stats_add_result(&ni
->winner_owner
, board_at(final_board
, ni
->coord
) == winner_color
? 1.0 : 0.0, 1);
232 stats_add_result(&ni
->black_owner
, board_at(final_board
, ni
->coord
) == S_BLACK
? 1.0 : 0.0, 1);
234 stats_add_result(&ni
->amaf
, nres
, 1);
237 struct board bb
; bb
.size
= 9+2;
238 fprintf(stderr
, "* %s<%"PRIhash
"> -> %s<%"PRIhash
"> [%d/%f => %d/%f]\n",
239 coord2sstr(node
->coord
, &bb
), node
->hash
,
240 coord2sstr(ni
->coord
, &bb
), ni
->hash
,
241 player_color
, result
, child_color
, nres
);
245 if (!is_pass(node
->coord
)) {
248 node
= node
->parent
; child_color
= stone_other(child_color
);
254 policy_ucb1amaf_init(struct uct
*u
, char *arg
)
256 struct uct_policy
*p
= calloc2(1, sizeof(*p
));
257 struct ucb1_policy_amaf
*b
= calloc2(1, sizeof(*b
));
260 p
->choose
= uctp_generic_choose
;
261 p
->winner
= uctp_generic_winner
;
262 p
->evaluate
= ucb1rave_evaluate
;
263 p
->descend
= ucb1rave_descend
;
264 p
->update
= ucb1amaf_update
;
265 p
->wants_amaf
= true;
267 b
->equiv_rave
= 3000;
268 b
->check_nakade
= true;
269 b
->sylvain_rave
= true;
270 b
->ltree_rave
= 0.75f
;
273 b
->crit_min_playouts
= 2000;
274 b
->crit_negative
= 1;
278 char *optspec
, *next
= arg
;
281 next
+= strcspn(next
, ":");
282 if (*next
) { *next
++ = 0; } else { *next
= 0; }
284 char *optname
= optspec
;
285 char *optval
= strchr(optspec
, '=');
286 if (optval
) *optval
++ = 0;
288 if (!strcasecmp(optname
, "equiv_rave") && optval
) {
289 b
->equiv_rave
= atof(optval
);
290 } else if (!strcasecmp(optname
, "sylvain_rave")) {
291 b
->sylvain_rave
= !optval
|| *optval
== '1';
292 } else if (!strcasecmp(optname
, "check_nakade")) {
293 b
->check_nakade
= !optval
|| *optval
== '1';
294 } else if (!strcasecmp(optname
, "ltree_rave") && optval
) {
295 b
->ltree_rave
= atof(optval
);
296 } else if (!strcasecmp(optname
, "crit_rave") && optval
) {
297 b
->crit_rave
= atof(optval
);
298 } else if (!strcasecmp(optname
, "crit_min_playouts") && optval
) {
299 b
->crit_min_playouts
= atoi(optval
);
300 } else if (!strcasecmp(optname
, "crit_negative")) {
301 b
->crit_negative
= !optval
|| *optval
== '1';
302 } else if (!strcasecmp(optname
, "crit_amaf")) {
303 b
->crit_amaf
= !optval
|| *optval
== '1';
305 fprintf(stderr
, "ucb1amaf: Invalid policy argument %s or missing value\n",