11 #include "uct/internal.h"
14 /* This implements the UCB1 policy with an extra AMAF heuristics. */
16 struct ucb1_policy_amaf
{
17 /* This is what the Modification of UCT with Patterns in Monte Carlo Go
18 * paper calls 'p'. Original UCB has this on 2, but this seems to
19 * produce way too wide searches; reduce this to get deeper and
20 * narrower readouts - try 0.2. */
22 /* First Play Urgency - if set to less than infinity (the MoGo paper
23 * above reports 1.0 as the best), new branches are explored only
24 * if none of the existing ones has higher urgency than fpu. */
26 int urg_randoma
, urg_randomm
;
29 bool rave_prior
, both_colors
;
34 struct tree_node
*ucb1_choose(struct uct_policy
*p
, struct tree_node
*node
, struct board
*b
, enum stone color
);
36 struct tree_node
*ucb1_descend(struct uct_policy
*p
, struct tree
*tree
, struct tree_node
*node
, int parity
, bool allow_pass
);
39 static inline float fast_sqrt(int x
)
41 static const float table
[] = {
44 1.41421356237309504880,
45 1.73205080756887729352,
46 2.00000000000000000000,
48 //printf("sqrt %d\n", x);
49 if (x
< sizeof(table
) / sizeof(*table
)) {
56 /* Sylvain RAVE function */
58 ucb1srave_descend(struct uct_policy
*p
, struct tree
*tree
, struct tree_node
*node
, int parity
, bool allow_pass
)
60 struct ucb1_policy_amaf
*b
= p
->data
;
61 float rave_coef
= 1.0f
/ b
->equiv_rave
;
62 float nconf
= 1.f
, rconf
= 1.f
;
64 nconf
= sqrt(log(node
->u
.playouts
+ node
->prior
.playouts
));
65 if (b
->explore_p_rave
> 0 && node
->amaf
.playouts
)
66 rconf
= sqrt(log(node
->amaf
.playouts
+ node
->prior
.playouts
));
68 // XXX: Stack overflow danger on big boards?
69 struct tree_node
*nbest
[512] = { node
->children
}; int nbests
= 1;
70 float best_urgency
= -9999;
72 for (struct tree_node
*ni
= node
->children
; ni
; ni
= ni
->sibling
) {
73 /* Do not consider passing early. */
74 if (likely(!allow_pass
) && unlikely(is_pass(ni
->coord
)))
77 /* TODO: Exploration? */
79 int ngames
= ni
->u
.playouts
;
80 int nwins
= ni
->u
.wins
;
81 int rgames
= ni
->amaf
.playouts
;
82 int rwins
= ni
->amaf
.wins
;
84 rgames
+= ni
->prior
.playouts
;
85 rwins
+= ni
->prior
.wins
;
87 ngames
+= ni
->prior
.playouts
;
88 nwins
+= ni
->prior
.wins
;
90 if (tree_parity(tree
, parity
) < 0) {
91 nwins
= ngames
- nwins
;
92 rwins
= rgames
- rwins
;
94 float nval
= 0, rval
= 0;
96 nval
= (float) nwins
/ ngames
;
98 nval
+= b
->explore_p
* nconf
/ fast_sqrt(ngames
);
101 rval
= (float) rwins
/ rgames
;
102 if (b
->explore_p_rave
> 0 && !is_pass(ni
->coord
))
103 rval
+= b
->explore_p_rave
* rconf
/ fast_sqrt(rgames
);
106 /* XXX: We later compare urgency with best_urgency; this can
107 * be difficult given that urgency can be in register with
108 * higher precision than best_urgency, thus even though
109 * the numbers are in fact the same, urgency will be
110 * slightly higher (or lower). Thus, we declare urgency
111 * as volatile, attempting to force the compiler to keep
112 * everything as a float. Ideally, we should do some random
113 * __FLT_EPSILON__ magic instead. */
114 volatile float urgency
;
117 /* At the beginning, beta is at 1 and RAVE is used.
118 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
119 float beta
= (float) rgames
/ (rgames
+ ngames
+ rave_coef
* ngames
* rgames
);
121 //if (node->coord == 7*11+4) // D7
122 fprintf(stderr
, "[beta %f = %d / (%d + %d + %f)]\n",
123 beta
, rgames
, rgames
, ngames
, rave_coef
* ngames
* rgames
);
125 urgency
= beta
* rval
+ (1 - beta
) * nval
;
132 /* assert(!u->even_eqex); */
137 struct board bb
; bb
.size
= 11;
138 //if (node->coord == 7*11+4) // D7
139 fprintf(stderr
, "%s<%lld>-%s<%lld> urgency %f (r %d / %d, n %d / %d)\n",
140 coord2sstr(ni
->parent
->coord
, &bb
), ni
->parent
->hash
,
141 coord2sstr(ni
->coord
, &bb
), ni
->hash
, urgency
,
142 rwins
, rgames
, nwins
, ngames
);
145 urgency
+= (float)(fast_random(b
->urg_randoma
) - b
->urg_randoma
/ 2) / 1000;
147 urgency
*= (float)(fast_random(b
->urg_randomm
) + 5) / b
->urg_randomm
;
149 if (urgency
> best_urgency
) {
150 best_urgency
= urgency
; nbests
= 0;
152 if (urgency
>= best_urgency
) {
153 /* We want to always choose something else than a pass
154 * in case of a tie. pass causes degenerative behaviour. */
155 if (nbests
== 1 && is_pass(nbest
[0]->coord
)) {
158 nbest
[nbests
++] = ni
;
162 struct board bb
; bb
.size
= 11;
163 fprintf(stderr
, "[%s %d: ", coord2sstr(node
->coord
, &bb
), nbests
);
164 for (int zz
= 0; zz
< nbests
; zz
++)
165 fprintf(stderr
, "%s", coord2sstr(nbest
[zz
]->coord
, &bb
));
166 fprintf(stderr
, "]\n");
168 return nbest
[fast_random(nbests
)];
172 update_node(struct uct_policy
*p
, struct tree_node
*node
, int result
)
175 node
->u
.wins
+= result
;
176 tree_update_node_value(node
);
179 update_node_amaf(struct uct_policy
*p
, struct tree_node
*node
, int result
)
181 node
->amaf
.playouts
++;
182 node
->amaf
.wins
+= result
;
183 tree_update_node_value(node
);
187 ucb1amaf_update(struct uct_policy
*p
, struct tree
*tree
, struct tree_node
*node
, enum stone node_color
, enum stone player_color
, struct playout_amafmap
*map
, int result
)
189 struct ucb1_policy_amaf
*b
= p
->data
;
190 enum stone child_color
= stone_other(node_color
);
193 struct board bb
; bb
.size
= 9+2;
194 for (struct tree_node
*ni
= node
; ni
; ni
= ni
->parent
)
195 fprintf(stderr
, "%s ", coord2sstr(ni
->coord
, &bb
));
196 fprintf(stderr
, "[color %d] update result %d (color %d)\n",
197 node_color
, result
, player_color
);
201 if (node
->parent
== NULL
)
202 assert(tree
->root_color
== stone_other(child_color
));
204 if (p
->descend
!= ucb1_descend
)
205 node
->hints
|= NODE_HINT_NOAMAF
; /* Rave, different update function */
206 update_node(p
, node
, result
);
207 if (amaf_nakade(map
->map
[node
->coord
]))
208 amaf_op(map
->map
[node
->coord
], -);
210 /* This loop ignores symmetry considerations, but they should
211 * matter only at a point when AMAF doesn't help much. */
212 for (struct tree_node
*ni
= node
->children
; ni
; ni
= ni
->sibling
) {
213 assert(map
->map
[ni
->coord
] != S_OFFBOARD
);
214 if (map
->map
[ni
->coord
] == S_NONE
)
216 assert(map
->game_baselen
>= 0);
217 enum stone amaf_color
= map
->map
[ni
->coord
];
218 if (amaf_nakade(map
->map
[ni
->coord
])) {
219 if (!b
->check_nakade
)
221 /* We don't care to implement both_colors
222 * properly since it sucks anyway. */
224 for (i
= map
->game_baselen
; i
< map
->gamelen
; i
++)
225 if (map
->game
[i
].coord
== ni
->coord
226 && map
->game
[i
].color
== child_color
)
228 if (i
== map
->gamelen
)
230 amaf_color
= child_color
;
234 if (amaf_color
!= child_color
) {
239 /* For child_color != player_color, we still want
240 * to record the result unmodified; in that case,
241 * we will correctly negate them at the descend phase. */
243 if (p
->descend
!= ucb1_descend
)
244 ni
->hints
|= NODE_HINT_NOAMAF
; /* Rave, different update function */
245 update_node_amaf(p
, ni
, nres
);
248 fprintf(stderr
, "* %s<%lld> -> %s<%lld> [%d %d => %d/%d]\n", coord2sstr(node
->coord
, &bb
), node
->hash
, coord2sstr(ni
->coord
, &bb
), ni
->hash
, player_color
, child_color
, result
);
252 if (!is_pass(node
->coord
)) {
255 node
= node
->parent
; child_color
= stone_other(child_color
);
261 policy_ucb1amaf_init(struct uct
*u
, char *arg
)
263 struct uct_policy
*p
= calloc(1, sizeof(*p
));
264 struct ucb1_policy_amaf
*b
= calloc(1, sizeof(*b
));
267 p
->descend
= ucb1srave_descend
;
268 p
->choose
= ucb1_choose
;
269 p
->update
= ucb1amaf_update
;
270 p
->wants_amaf
= true;
272 // RAVE: 0.2vs0: 40% (+-7.3) 0.1vs0: 54.7% (+-3.5)
274 b
->explore_p_rave
= 0.01;
275 b
->equiv_rave
= 3000;
277 b
->rave_prior
= true;
278 b
->check_nakade
= true;
281 char *optspec
, *next
= arg
;
284 next
+= strcspn(next
, ":");
285 if (*next
) { *next
++ = 0; } else { *next
= 0; }
287 char *optname
= optspec
;
288 char *optval
= strchr(optspec
, '=');
289 if (optval
) *optval
++ = 0;
291 if (!strcasecmp(optname
, "explore_p")) {
292 b
->explore_p
= atof(optval
);
293 } else if (!strcasecmp(optname
, "fpu") && optval
) {
294 b
->fpu
= atof(optval
);
295 } else if (!strcasecmp(optname
, "urg_randoma") && optval
) {
296 b
->urg_randoma
= atoi(optval
);
297 } else if (!strcasecmp(optname
, "urg_randomm") && optval
) {
298 b
->urg_randomm
= atoi(optval
);
299 } else if (!strcasecmp(optname
, "rave")) {
300 if (optval
&& *optval
== '0')
301 p
->descend
= ucb1_descend
;
302 else if (optval
&& *optval
== 's')
303 p
->descend
= ucb1srave_descend
;
304 } else if (!strcasecmp(optname
, "explore_p_rave") && optval
) {
305 b
->explore_p_rave
= atof(optval
);
306 } else if (!strcasecmp(optname
, "equiv_rave") && optval
) {
307 b
->equiv_rave
= atof(optval
);
308 } else if (!strcasecmp(optname
, "rave_prior") && optval
) {
310 b
->rave_prior
= atoi(optval
);
311 } else if (!strcasecmp(optname
, "both_colors")) {
312 b
->both_colors
= true;
313 } else if (!strcasecmp(optname
, "check_nakade")) {
314 b
->check_nakade
= !optval
|| *optval
== '1';
316 fprintf(stderr
, "ucb1: Invalid policy argument %s or missing value\n", optname
);
321 if (b
->explore_p_rave
< 0) b
->explore_p_rave
= b
->explore_p
;