11 #include "uct/internal.h"
14 /* This implements the UCB1 policy with an extra AMAF heuristics. */
16 struct ucb1_policy_amaf
{
17 /* This is what the Modification of UCT with Patterns in Monte Carlo Go
18 * paper calls 'p'. Original UCB has this on 2, but this seems to
19 * produce way too wide searches; reduce this to get deeper and
20 * narrower readouts - try 0.2. */
22 /* First Play Urgency - if set to less than infinity (the MoGo paper
23 * above reports 1.0 as the best), new branches are explored only
24 * if none of the existing ones has higher urgency than fpu. */
26 int urg_randoma
, urg_randomm
;
34 struct tree_node
*ucb1_choose(struct uct_policy
*p
, struct tree_node
*node
, struct board
*b
, enum stone color
);
36 struct tree_node
*ucb1_descend(struct uct_policy
*p
, struct tree
*tree
, struct tree_node
*node
, int parity
, bool allow_pass
);
39 static inline float fast_sqrt(int x
)
41 static const float table
[] = {
44 1.41421356237309504880,
45 1.73205080756887729352,
46 2.00000000000000000000,
48 //printf("sqrt %d\n", x);
49 if (x
< sizeof(table
) / sizeof(*table
)) {
56 /* Sylvain RAVE function */
58 ucb1srave_descend(struct uct_policy
*p
, struct tree
*tree
, struct tree_node
*node
, int parity
, bool allow_pass
)
60 struct ucb1_policy_amaf
*b
= p
->data
;
61 float rave_coef
= 1.0f
/ b
->equiv_rave
;
62 float nconf
= 1.f
, rconf
= 1.f
;
64 nconf
= sqrt(log(node
->u
.playouts
+ node
->prior
.playouts
));
65 if (b
->explore_p_rave
> 0 && node
->amaf
.playouts
)
66 rconf
= sqrt(log(node
->amaf
.playouts
+ node
->prior
.playouts
));
68 // XXX: Stack overflow danger on big boards?
69 struct tree_node
*nbest
[512] = { node
->children
}; int nbests
= 1;
70 float best_urgency
= -9999;
72 for (struct tree_node
*ni
= node
->children
; ni
; ni
= ni
->sibling
) {
73 /* Do not consider passing early. */
74 if (likely(!allow_pass
) && unlikely(is_pass(ni
->coord
)))
77 /* TODO: Exploration? */
79 int ngames
= ni
->u
.playouts
;
80 int nwins
= ni
->u
.wins
;
81 int rgames
= ni
->amaf
.playouts
;
82 int rwins
= ni
->amaf
.wins
;
83 if (p
->uct
->amaf_prior
) {
84 rgames
+= ni
->prior
.playouts
;
85 rwins
+= ni
->prior
.wins
;
87 ngames
+= ni
->prior
.playouts
;
88 nwins
+= ni
->prior
.wins
;
90 if (tree_parity(tree
, parity
) < 0) {
91 nwins
= ngames
- nwins
;
92 rwins
= rgames
- rwins
;
94 float nval
= 0, rval
= 0;
96 nval
= (float) nwins
/ ngames
;
98 nval
+= b
->explore_p
* nconf
/ fast_sqrt(ngames
);
101 rval
= (float) rwins
/ rgames
;
102 if (b
->explore_p_rave
> 0 && !is_pass(ni
->coord
))
103 rval
+= b
->explore_p_rave
* rconf
/ fast_sqrt(rgames
);
109 /* At the beginning, beta is at 1 and RAVE is used.
110 * At b->equiv_rate, beta is at 1/3 and gets steeper on. */
111 float beta
= (float) rgames
/ (rgames
+ ngames
+ rave_coef
* ngames
* rgames
);
113 //if (node->coord == 7*11+4) // D7
114 fprintf(stderr
, "[beta %f = %d / (%d + %d + %f)]\n",
115 beta
, rgames
, rgames
, ngames
, rave_coef
* ngames
* rgames
);
117 urgency
= beta
* rval
+ (1 - beta
) * nval
;
124 /* assert(!u->even_eqex); */
129 struct board bb
; bb
.size
= 11;
130 //if (node->coord == 7*11+4) // D7
131 fprintf(stderr
, "%s<%lld>-%s<%lld> urgency %f (r %d / %d, n %d / %d)\n",
132 coord2sstr(ni
->parent
->coord
, &bb
), ni
->parent
->hash
,
133 coord2sstr(ni
->coord
, &bb
), ni
->hash
, urgency
,
134 rwins
, rgames
, nwins
, ngames
);
137 urgency
+= (float)(fast_random(b
->urg_randoma
) - b
->urg_randoma
/ 2) / 1000;
139 urgency
*= (float)(fast_random(b
->urg_randomm
) + 5) / b
->urg_randomm
;
141 if (urgency
- best_urgency
> __FLT_EPSILON__
) { // urgency > best_urgency
142 best_urgency
= urgency
; nbests
= 0;
144 if (urgency
- best_urgency
> -__FLT_EPSILON__
) { // urgency >= best_urgency
145 /* We want to always choose something else than a pass
146 * in case of a tie. pass causes degenerative behaviour. */
147 if (nbests
== 1 && is_pass(nbest
[0]->coord
)) {
150 nbest
[nbests
++] = ni
;
154 struct board bb
; bb
.size
= 11;
155 fprintf(stderr
, "[%s %d: ", coord2sstr(node
->coord
, &bb
), nbests
);
156 for (int zz
= 0; zz
< nbests
; zz
++)
157 fprintf(stderr
, "%s", coord2sstr(nbest
[zz
]->coord
, &bb
));
158 fprintf(stderr
, "]\n");
160 return nbest
[fast_random(nbests
)];
164 update_node(struct uct_policy
*p
, struct tree_node
*node
, int result
)
167 node
->u
.wins
+= result
;
168 tree_update_node_value(node
, p
->uct
->amaf_prior
);
172 update_node_amaf(struct uct_policy
*p
, struct tree_node
*node
, int result
)
174 node
->amaf
.playouts
++;
175 node
->amaf
.wins
+= result
;
176 tree_update_node_rvalue(node
, p
->uct
->amaf_prior
);
180 ucb1amaf_update(struct uct_policy
*p
, struct tree
*tree
, struct tree_node
*node
, enum stone node_color
, enum stone player_color
, struct playout_amafmap
*map
, int result
)
182 struct ucb1_policy_amaf
*b
= p
->data
;
183 enum stone child_color
= stone_other(node_color
);
186 struct board bb
; bb
.size
= 9+2;
187 for (struct tree_node
*ni
= node
; ni
; ni
= ni
->parent
)
188 fprintf(stderr
, "%s ", coord2sstr(ni
->coord
, &bb
));
189 fprintf(stderr
, "[color %d] update result %d (color %d)\n",
190 node_color
, result
, player_color
);
194 if (node
->parent
== NULL
)
195 assert(tree
->root_color
== stone_other(child_color
));
197 update_node(p
, node
, result
);
198 if (amaf_nakade(map
->map
[node
->coord
]))
199 amaf_op(map
->map
[node
->coord
], -);
201 /* This loop ignores symmetry considerations, but they should
202 * matter only at a point when AMAF doesn't help much. */
203 for (struct tree_node
*ni
= node
->children
; ni
; ni
= ni
->sibling
) {
204 assert(map
->map
[ni
->coord
] != S_OFFBOARD
);
205 if (map
->map
[ni
->coord
] == S_NONE
)
207 assert(map
->game_baselen
>= 0);
208 enum stone amaf_color
= map
->map
[ni
->coord
];
209 if (amaf_nakade(map
->map
[ni
->coord
])) {
210 if (!b
->check_nakade
)
212 /* We don't care to implement both_colors
213 * properly since it sucks anyway. */
215 for (i
= map
->game_baselen
; i
< map
->gamelen
; i
++)
216 if (map
->game
[i
].coord
== ni
->coord
217 && map
->game
[i
].color
== child_color
)
219 if (i
== map
->gamelen
)
221 amaf_color
= child_color
;
225 if (amaf_color
!= child_color
) {
230 /* For child_color != player_color, we still want
231 * to record the result unmodified; in that case,
232 * we will correctly negate them at the descend phase. */
234 update_node_amaf(p
, ni
, nres
);
237 fprintf(stderr
, "* %s<%lld> -> %s<%lld> [%d %d => %d/%d]\n", coord2sstr(node
->coord
, &bb
), node
->hash
, coord2sstr(ni
->coord
, &bb
), ni
->hash
, player_color
, child_color
, result
);
241 if (!is_pass(node
->coord
)) {
244 node
= node
->parent
; child_color
= stone_other(child_color
);
250 policy_ucb1amaf_init(struct uct
*u
, char *arg
)
252 struct uct_policy
*p
= calloc(1, sizeof(*p
));
253 struct ucb1_policy_amaf
*b
= calloc(1, sizeof(*b
));
256 p
->descend
= ucb1srave_descend
;
257 p
->choose
= ucb1_choose
;
258 p
->update
= ucb1amaf_update
;
259 p
->wants_amaf
= true;
261 // RAVE: 0.2vs0: 40% (+-7.3) 0.1vs0: 54.7% (+-3.5)
263 b
->explore_p_rave
= 0.01;
264 b
->equiv_rave
= 3000;
266 b
->check_nakade
= true;
269 char *optspec
, *next
= arg
;
272 next
+= strcspn(next
, ":");
273 if (*next
) { *next
++ = 0; } else { *next
= 0; }
275 char *optname
= optspec
;
276 char *optval
= strchr(optspec
, '=');
277 if (optval
) *optval
++ = 0;
279 if (!strcasecmp(optname
, "explore_p")) {
280 b
->explore_p
= atof(optval
);
281 } else if (!strcasecmp(optname
, "fpu") && optval
) {
282 b
->fpu
= atof(optval
);
283 } else if (!strcasecmp(optname
, "urg_randoma") && optval
) {
284 b
->urg_randoma
= atoi(optval
);
285 } else if (!strcasecmp(optname
, "urg_randomm") && optval
) {
286 b
->urg_randomm
= atoi(optval
);
287 } else if (!strcasecmp(optname
, "explore_p_rave") && optval
) {
288 b
->explore_p_rave
= atof(optval
);
289 } else if (!strcasecmp(optname
, "equiv_rave") && optval
) {
290 b
->equiv_rave
= atof(optval
);
291 } else if (!strcasecmp(optname
, "both_colors")) {
292 b
->both_colors
= true;
293 } else if (!strcasecmp(optname
, "check_nakade")) {
294 b
->check_nakade
= !optval
|| *optval
== '1';
296 fprintf(stderr
, "ucb1: Invalid policy argument %s or missing value\n", optname
);
301 if (b
->explore_p_rave
< 0) b
->explore_p_rave
= b
->explore_p
;