11 #include "uct/internal.h"
14 /* This implements the UCB1 policy with an extra AMAF heuristics. */
16 struct ucb1_policy_amaf
{
17 /* This is what the Modification of UCT with Patterns in Monte Carlo Go
18 * paper calls 'p'. Original UCB has this on 2, but this seems to
19 * produce way too wide searches; reduce this to get deeper and
20 * narrower readouts - try 0.2. */
22 /* First Play Urgency - if set to less than infinity (the MoGo paper
23 * above reports 1.0 as the best), new branches are explored only
24 * if none of the existing ones has higher urgency than fpu. */
26 /* Equivalent experience for prior knowledge. MoGo paper recommends
27 * 50 playouts per source. */
28 int eqex
, even_eqex
, gp_eqex
, policy_eqex
;
29 int urg_randoma
, urg_randomm
;
32 bool rave_prior
, both_colors
;
36 struct tree_node
*ucb1_choose(struct uct_policy
*p
, struct tree_node
*node
, struct board
*b
, enum stone color
);
38 struct tree_node
*ucb1_descend(struct uct_policy
*p
, struct tree
*tree
, struct tree_node
*node
, int parity
, bool allow_pass
);
40 void ucb1_prior(struct uct_policy
*p
, struct tree
*tree
, struct tree_node
*node
, struct board
*b
, enum stone color
, int parity
);
43 /* Original RAVE function */
45 ucb1orave_descend(struct uct_policy
*p
, struct tree
*tree
, struct tree_node
*node
, int parity
, bool allow_pass
)
47 /* We want to count in the prior stats here after all. Otherwise,
48 * nodes with positive prior will get explored _LESS_ since the
49 * urgency will be always higher; even with normal FPU because
50 * of the explore coefficient. */
52 struct ucb1_policy_amaf
*b
= p
->data
;
53 float xpl
= log(node
->u
.playouts
+ node
->prior
.playouts
) * b
->explore_p
;
54 float xpl_rave
= log(node
->amaf
.playouts
+ (b
->rave_prior
? node
->prior
.playouts
: 0)) * b
->explore_p_rave
;
55 float beta
= sqrt((float)b
->equiv_rave
/ (3 * (node
->u
.playouts
+ node
->prior
.playouts
) + b
->equiv_rave
));
57 struct tree_node
*nbest
= node
->children
;
58 float best_urgency
= -9999;
59 for (struct tree_node
*ni
= node
->children
; ni
; ni
= ni
->sibling
) {
60 /* Do not consider passing early. */
61 if (likely(!allow_pass
) && unlikely(is_pass(ni
->coord
)))
63 int amaf_wins
= ni
->amaf
.wins
+ (b
->rave_prior
? ni
->prior
.wins
: 0);
64 int amaf_playouts
= ni
->amaf
.playouts
+ (b
->rave_prior
? ni
->prior
.playouts
: 0);
65 int uct_playouts
= ni
->u
.playouts
+ ni
->prior
.playouts
;
66 ni
->amaf
.value
= (float)amaf_wins
/ amaf_playouts
;
67 ni
->prior
.value
= (float)ni
->prior
.wins
/ ni
->prior
.playouts
;
68 float uctp
= (parity
> 0 ? ni
->u
.value
: 1 - ni
->u
.value
) + sqrt(xpl
/ uct_playouts
);
69 float ravep
= (parity
> 0 ? ni
->amaf
.value
: 1 - ni
->amaf
.value
) + sqrt(xpl_rave
/ amaf_playouts
);
70 float urgency
= ni
->u
.playouts
? beta
* ravep
+ (1 - beta
) * uctp
: b
->fpu
;
71 // fprintf(stderr, "uctp %f (uct %d/%d) ravep %f (xpl %f amaf %d/%d) beta %f => %f\n", uctp, ni->u.wins, ni->u.playouts, ravep, xpl_rave, amaf_wins, amaf_playouts, beta, urgency);
73 urgency
+= (float)(fast_random(b
->urg_randoma
) - b
->urg_randoma
/ 2) / 1000;
75 urgency
*= (float)(fast_random(b
->urg_randomm
) + 5) / b
->urg_randomm
;
76 if (urgency
> best_urgency
) {
77 best_urgency
= urgency
;
84 float fast_sqrt(int x
)
86 static const float table
[] = {
89 1.41421356237309504880,
90 1.73205080756887729352,
91 2.00000000000000000000,
93 2.23606797749978969640,
94 2.44948974278317809819,
95 2.64575131106459059050,
96 2.82842712474619009760,
97 3.00000000000000000000,
98 3.16227766016837933199,
99 3.31662479035539984911,
100 3.46410161513775458705,
101 3.60555127546398929311,
102 3.74165738677394138558,
103 3.87298334620741688517,
104 4.00000000000000000000,
105 4.12310562561766054982,
106 4.24264068711928514640,
107 4.35889894354067355223,
108 4.47213595499957939281,
109 4.58257569495584000658,
110 4.69041575982342955456,
111 4.79583152331271954159,
112 4.89897948556635619639,
113 5.00000000000000000000,
114 5.09901951359278483002,
115 5.19615242270663188058,
116 5.29150262212918118100,
117 5.38516480713450403125,
118 5.47722557505166113456,
119 5.56776436283002192211,
120 5.65685424949238019520,
121 5.74456264653802865985,
122 5.83095189484530047087,
123 5.91607978309961604256,
124 6.00000000000000000000,
125 6.08276253029821968899,
126 6.16441400296897645025,
127 6.24499799839839820584,
128 6.32455532033675866399,
129 6.40312423743284868648,
130 6.48074069840786023096,
131 6.55743852430200065234,
132 6.63324958071079969822,
133 6.70820393249936908922,
134 6.78232998312526813906,
135 6.85565460040104412493,
136 6.92820323027550917410,
137 7.00000000000000000000,
138 7.07106781186547524400,
139 7.14142842854284999799,
140 7.21110255092797858623,
141 7.28010988928051827109,
142 7.34846922834953429459,
143 7.41619848709566294871,
144 7.48331477354788277116,
145 7.54983443527074969723,
146 7.61577310586390828566,
147 7.68114574786860817576,
148 7.74596669241483377035,
149 7.81024967590665439412,
150 7.87400787401181101968,
151 7.93725393319377177150,
154 //printf("sqrt %d\n", x);
155 if (x
< sizeof(table
) / sizeof(*table
)) {
161 int base
= 1 << (sizeof(int) * 8 - 2);
162 if ((x
& 0xFFFF0000) == 0) base
>>= 16;
163 if ((x
& 0xFF00FF00) == 0) base
>>= 8;
164 if ((x
& 0xF0F0F0F0) == 0) base
>>= 4;
165 if ((x
& 0xCCCCCCCC) == 0) base
>>= 2;
166 // "base" starts at the highest power of four <= the argument.
176 printf("sqrt %d = %d\n", x
, y
);
182 /* Sylvain RAVE function */
184 ucb1srave_descend(struct uct_policy
*p
, struct tree
*tree
, struct tree_node
*node
, int parity
, bool allow_pass
)
186 struct ucb1_policy_amaf
*b
= p
->data
;
187 float rave_coef
= 1.0f
/ b
->equiv_rave
;
189 if (b
->explore_p
> 0 || b
->explore_p_rave
> 0)
190 conf
= sqrt(log(node
->u
.playouts
+ node
->prior
.playouts
));
192 struct tree_node
*nbest
= node
->children
;
193 float best_urgency
= -9999;
194 for (struct tree_node
*ni
= node
->children
; ni
; ni
= ni
->sibling
) {
195 /* Do not consider passing early. */
196 if (likely(!allow_pass
) && unlikely(is_pass(ni
->coord
)))
199 /* TODO: Exploration? */
201 int ngames
= ni
->u
.playouts
;
202 int nwins
= ni
->u
.wins
;
203 int rgames
= ni
->amaf
.playouts
;
204 int rwins
= ni
->amaf
.wins
;
206 rgames
+= ni
->prior
.playouts
;
207 rwins
+= ni
->prior
.wins
;
209 ngames
+= ni
->prior
.playouts
;
210 nwins
+= ni
->prior
.wins
;
213 nwins
= ngames
- nwins
;
214 rwins
= rgames
- rwins
;
216 float nval
= 0, rval
= 0;
218 nval
= (float) nwins
/ ngames
;
219 if (b
->explore_p
> 0)
220 nval
+= b
->explore_p
* conf
/ fast_sqrt(ngames
);
223 rval
= (float) rwins
/ rgames
;
224 if (b
->explore_p_rave
> 0)
225 rval
+= b
->explore_p_rave
* conf
/ fast_sqrt(rgames
);
231 /* At the beginning, beta is at 1 and RAVE is used.
232 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
233 float beta
= (float) rgames
/ (rgames
+ ngames
+ rave_coef
* ngames
* rgames
);
235 //if (node->coord == 7*11+4) // D7
236 fprintf(stderr
, "[beta %f = %d / (%d + %d + %f)]\n",
237 beta
, rgames
, rgames
, ngames
, rave_coef
* ngames
* rgames
);
239 urgency
= beta
* rval
+ (1 - beta
) * nval
;
246 assert(!b
->even_eqex
);
251 struct board bb
; bb
.size
= 11;
252 //if (node->coord == 7*11+4) // D7
253 fprintf(stderr
, "%s<%lld> urgency %f (r %d / %d, n %d / %d)\n",
254 coord2sstr(ni
->coord
, &bb
), ni
->hash
, urgency
, rwins
, rgames
, nwins
, ngames
);
257 urgency
+= (float)(fast_random(b
->urg_randoma
) - b
->urg_randoma
/ 2) / 1000;
259 urgency
*= (float)(fast_random(b
->urg_randomm
) + 5) / b
->urg_randomm
;
260 /* The >= is important since we will always choose something
261 * else than a pass in case of a tie. pass causes degenerative
263 if (urgency
>= best_urgency
) {
264 best_urgency
= urgency
;
272 update_node(struct uct_policy
*p
, struct tree_node
*node
, int result
)
275 node
->u
.wins
+= result
;
276 tree_update_node_value(node
);
279 update_node_amaf(struct uct_policy
*p
, struct tree_node
*node
, int result
)
281 node
->amaf
.playouts
++;
282 node
->amaf
.wins
+= result
;
283 tree_update_node_value(node
);
287 ucb1amaf_update(struct uct_policy
*p
, struct tree
*tree
, struct tree_node
*node
, enum stone node_color
, enum stone player_color
, struct playout_amafmap
*map
, int result
)
289 struct ucb1_policy_amaf
*b
= p
->data
;
290 enum stone child_color
= stone_other(node_color
);
293 struct board bb
; bb
.size
= 9+2;
294 for (struct tree_node
*ni
= node
; ni
; ni
= ni
->parent
)
295 fprintf(stderr
, "%s ", coord2sstr(ni
->coord
, &bb
));
296 fprintf(stderr
, "update color %d result %d\n", player_color
, result
);
299 for (; node
; node
= node
->parent
, child_color
= stone_other(child_color
)) {
300 if (p
->descend
!= ucb1_descend
)
301 node
->hints
|= NODE_HINT_NOAMAF
; /* Rave, different update function */
302 update_node(p
, node
, result
);
303 if (amaf_nakade(map
->map
[node
->coord
]))
304 amaf_op(map
->map
[node
->coord
], -);
305 /* This loop ignores symmetry considerations, but they should
306 * matter only at a point when AMAF doesn't help much. */
307 for (struct tree_node
*ni
= node
->children
; ni
; ni
= ni
->sibling
) {
308 assert(map
->map
[ni
->coord
] != S_OFFBOARD
);
309 if (map
->map
[ni
->coord
] != child_color
310 || amaf_nakade(map
->map
[ni
->coord
]))
312 if (child_color
!= player_color
&& !b
->both_colors
)
316 fprintf(stderr
, "* %s<%lld> -> %s<%lld> [%d %d => %d]\n", coord2sstr(node
->coord
, &bb
), node
->hash
, coord2sstr(ni
->coord
, &bb
), ni
->hash
, child_color
, child_color
== player_color
? result
: !result
);
318 if (p
->descend
!= ucb1_descend
)
319 ni
->hints
|= NODE_HINT_NOAMAF
; /* Rave, different update function */
320 update_node_amaf(p
, ni
, child_color
== player_color
? result
: !result
);
327 policy_ucb1amaf_init(struct uct
*u
, char *arg
)
329 struct uct_policy
*p
= calloc(1, sizeof(*p
));
330 struct ucb1_policy_amaf
*b
= calloc(1, sizeof(*b
));
333 p
->descend
= ucb1srave_descend
;
334 p
->choose
= ucb1_choose
;
335 p
->update
= ucb1amaf_update
;
336 p
->wants_amaf
= true;
338 // RAVE: 0.2vs0: 40% (+-7.3) 0.1vs0: 54.7% (+-3.5)
340 b
->explore_p_rave
= -1;
341 b
->equiv_rave
= 3000;
343 // gp: 14 vs 0: 44% (+-3.5)
345 b
->even_eqex
= b
->policy_eqex
= -1;
349 char *optspec
, *next
= arg
;
352 next
+= strcspn(next
, ":");
353 if (*next
) { *next
++ = 0; } else { *next
= 0; }
355 char *optname
= optspec
;
356 char *optval
= strchr(optspec
, '=');
357 if (optval
) *optval
++ = 0;
359 if (!strcasecmp(optname
, "explore_p")) {
360 b
->explore_p
= atof(optval
);
361 } else if (!strcasecmp(optname
, "prior")) {
363 b
->eqex
= atoi(optval
);
364 } else if (!strcasecmp(optname
, "prior_even") && optval
) {
365 b
->even_eqex
= atoi(optval
);
366 } else if (!strcasecmp(optname
, "prior_gp") && optval
) {
367 b
->gp_eqex
= atoi(optval
);
368 } else if (!strcasecmp(optname
, "prior_policy") && optval
) {
369 b
->policy_eqex
= atoi(optval
);
370 } else if (!strcasecmp(optname
, "fpu") && optval
) {
371 b
->fpu
= atof(optval
);
372 } else if (!strcasecmp(optname
, "urg_randoma") && optval
) {
373 b
->urg_randoma
= atoi(optval
);
374 } else if (!strcasecmp(optname
, "urg_randomm") && optval
) {
375 b
->urg_randomm
= atoi(optval
);
376 } else if (!strcasecmp(optname
, "rave")) {
377 if (optval
&& *optval
== '0')
378 p
->descend
= ucb1_descend
;
379 else if (optval
&& *optval
== 'o')
380 p
->descend
= ucb1orave_descend
;
381 else if (optval
&& *optval
== 's')
382 p
->descend
= ucb1srave_descend
;
383 } else if (!strcasecmp(optname
, "explore_p_rave") && optval
) {
384 b
->explore_p_rave
= atof(optval
);
385 } else if (!strcasecmp(optname
, "equiv_rave") && optval
) {
386 b
->equiv_rave
= atof(optval
);
387 } else if (!strcasecmp(optname
, "rave_prior")) {
389 b
->rave_prior
= true;
390 } else if (!strcasecmp(optname
, "both_colors")) {
391 b
->both_colors
= true;
393 fprintf(stderr
, "ucb1: Invalid policy argument %s or missing value\n", optname
);
398 if (b
->eqex
) p
->prior
= ucb1_prior
;
399 if (b
->even_eqex
< 0) b
->even_eqex
= b
->eqex
;
400 if (b
->gp_eqex
< 0) b
->gp_eqex
= b
->eqex
;
401 if (b
->policy_eqex
< 0) b
->policy_eqex
= b
->eqex
;
402 if (b
->explore_p_rave
< 0) b
->explore_p_rave
= b
->explore_p
;