17 #include "uct/internal.h"
22 // This should become a dynamic parameter. 7G is suitable for 20k*23 threads.
23 #define MAX_NODE_SIZES (7L*1024*1024*1024)
26 uct_get_extra_komi(struct uct
*u
, struct board
*b
)
28 float extra_komi
= board_effective_handicap(b
) * (u
->dynkomi
- b
->moves
) / u
->dynkomi
;
33 uct_progress_status(struct uct
*u
, struct tree
*t
, enum stone color
, int playouts
)
39 struct tree_node
*best
= u
->policy
->choose(u
->policy
, t
->root
, t
->board
, color
);
41 fprintf(stderr
, "... No moves left\n");
44 fprintf(stderr
, "[%d] ", playouts
);
45 fprintf(stderr
, "best %f ", tree_node_get_value(t
, 1, best
->u
.value
));
48 fprintf(stderr
, "deepest % 2d ", t
->max_depth
- t
->root
->depth
);
51 fprintf(stderr
, "| seq ");
52 for (int depth
= 0; depth
< 6; depth
++) {
53 if (best
&& best
->u
.playouts
>= 25) {
54 fprintf(stderr
, "%3s ", coord2sstr(best
->coord
, t
->board
));
55 best
= u
->policy
->choose(u
->policy
, best
, t
->board
, color
);
62 fprintf(stderr
, "| can ");
64 struct tree_node
*can
[cans
];
65 memset(can
, 0, sizeof(can
));
66 best
= t
->root
->children
;
69 while ((!can
[c
] || best
->u
.playouts
> can
[c
]->u
.playouts
) && ++c
< cans
);
70 for (int d
= 0; d
< c
; d
++) can
[d
] = can
[d
+ 1];
71 if (c
> 0) can
[c
- 1] = best
;
76 fprintf(stderr
, "%3s(%.3f) ",
77 coord2sstr(can
[cans
]->coord
, t
->board
),
78 tree_node_get_value(t
, 1, can
[cans
]->u
.value
));
84 fprintf(stderr
, "\n");
89 uct_leaf_node(struct uct
*u
, struct board
*b
, enum stone player_color
,
90 struct playout_amafmap
*amaf
,
91 struct tree
*t
, struct tree_node
*n
, enum stone node_color
,
94 enum stone next_color
= stone_other(node_color
);
95 int parity
= (next_color
== player_color
? 1 : -1);
96 if (n
->u
.playouts
>= u
->expand_p
&& t
->node_sizes
< MAX_NODE_SIZES
) {
97 // fprintf(stderr, "expanding %s (%p ^-%p)\n", coord2sstr(n->coord, b), n, n->parent);
98 if (!u
->parallel_tree
) {
99 /* Single-threaded, life is easy. */
100 tree_expand_node(t
, n
, b
, next_color
, u
, parity
);
102 /* We need to make sure only one thread expands
103 * the node. If we are unlucky enough for two
104 * threads to meet in the same node, the latter
105 * one will simply do another simulation from
106 * the node itself, no big deal. */
107 if (!__sync_lock_test_and_set(&n
->is_expanded
, 1) &&
108 t
->node_sizes
< MAX_NODE_SIZES
) {
109 assert(tree_leaf_node(n
));
110 tree_expand_node(t
, n
, b
, next_color
, u
, parity
);
115 fprintf(stderr
, "%s*-- UCT playout #%d start [%s] %f\n",
116 spaces
, n
->u
.playouts
, coord2sstr(n
->coord
, t
->board
),
117 tree_node_get_value(t
, parity
, n
->u
.value
));
119 struct playout_setup ps
= { .gamelen
= u
->gamelen
};
120 int result
= play_random_game(&ps
, b
, next_color
,
121 u
->playout_amaf
? amaf
: NULL
,
122 &u
->ownermap
, u
->playout
);
123 if (next_color
== S_WHITE
) {
124 /* We need the result from black's perspective. */
128 fprintf(stderr
, "%s -- [%d..%d] %s random playout result %d\n",
129 spaces
, player_color
, next_color
, coord2sstr(n
->coord
, t
->board
), result
);
135 scale_value(struct uct
*u
, struct board
*b
, int result
)
137 float rval
= result
> 0;
139 int vp
= u
->val_points
;
141 vp
= board_size(b
) - 1; vp
*= vp
; vp
*= 2;
144 float sval
= (float) abs(result
) / vp
;
145 sval
= sval
> 1 ? 1 : sval
;
146 if (result
< 0) sval
= 1 - sval
;
148 rval
+= u
->val_scale
* sval
;
150 rval
= (1 - u
->val_scale
) * rval
+ u
->val_scale
* sval
;
151 // fprintf(stderr, "score %d => sval %f, rval %f\n", result, sval, rval);
158 uct_playout(struct uct
*u
, struct board
*b
, enum stone player_color
, struct tree
*t
)
163 struct playout_amafmap
*amaf
= NULL
;
164 if (u
->policy
->wants_amaf
) {
165 amaf
= calloc(1, sizeof(*amaf
));
166 amaf
->map
= calloc(board_size2(&b2
) + 1, sizeof(*amaf
->map
));
167 amaf
->map
++; // -1 is pass
170 /* Walk the tree until we find a leaf, then expand it and do
171 * a random playout. */
172 struct tree_node
*n
= t
->root
;
173 enum stone node_color
= stone_other(player_color
);
174 assert(node_color
== t
->root_color
);
176 void *dstate
= NULL
, *dstater
= NULL
;
179 int pass_limit
= (board_size(&b2
) - 2) * (board_size(&b2
) - 2) / 2;
180 int passes
= is_pass(b
->last_move
.coord
) && b
->moves
> 0;
184 static char spaces
[] = "\0 ";
187 fprintf(stderr
, "--- UCT walk with color %d\n", player_color
);
189 while (!tree_leaf_node(n
) && passes
< 2) {
190 spaces
[depth
++] = ' '; spaces
[depth
] = 0;
192 /* Parity is chosen already according to the child color, since
193 * it is applied to children. */
194 node_color
= stone_other(node_color
);
195 int parity
= (node_color
== player_color
? 1 : -1);
196 n
= (!u
->random_policy_chance
|| fast_random(u
->random_policy_chance
))
197 ? u
->policy
->descend(u
->policy
, &dstate
, t
, n
, parity
, pass_limit
)
198 : u
->random_policy
->descend(u
->random_policy
, &dstater
, t
, n
, parity
, pass_limit
);
200 assert(n
== t
->root
|| n
->parent
);
202 fprintf(stderr
, "%s+-- UCT sent us to [%s:%d] %f\n",
203 spaces
, coord2sstr(n
->coord
, t
->board
), n
->coord
,
204 tree_node_get_value(t
, parity
, n
->u
.value
));
206 /* Add virtual loss if we need to; this is used to discourage
207 * other threads from visiting this node in case of multiple
208 * threads doing the tree search. */
210 stats_add_result(&n
->u
, tree_parity(t
, parity
) > 0 ? 0 : 1, 1);
212 assert(n
->coord
>= -1);
213 if (amaf
&& !is_pass(n
->coord
)) {
214 if (amaf
->map
[n
->coord
] == S_NONE
|| amaf
->map
[n
->coord
] == node_color
) {
215 amaf
->map
[n
->coord
] = node_color
;
216 } else { // XXX: Respect amaf->record_nakade
217 amaf_op(amaf
->map
[n
->coord
], +);
219 amaf
->game
[amaf
->gamelen
].coord
= n
->coord
;
220 amaf
->game
[amaf
->gamelen
].color
= node_color
;
222 assert(amaf
->gamelen
< sizeof(amaf
->game
) / sizeof(amaf
->game
[0]));
225 struct move m
= { n
->coord
, node_color
};
226 int res
= board_play(&b2
, &m
);
228 if (res
< 0 || (!is_pass(m
.coord
) && !group_at(&b2
, m
.coord
)) /* suicide */
229 || b2
.superko_violation
) {
231 for (struct tree_node
*ni
= n
; ni
; ni
= ni
->parent
)
232 fprintf(stderr
, "%s<%"PRIhash
"> ", coord2sstr(ni
->coord
, t
->board
), ni
->hash
);
233 fprintf(stderr
, "marking invalid %s node %d,%d res %d group %d spk %d\n",
234 stone2str(node_color
), coord_x(n
->coord
,b
), coord_y(n
->coord
,b
),
235 res
, group_at(&b2
, m
.coord
), b2
.superko_violation
);
237 n
->hints
|= TREE_HINT_INVALID
;
242 if (is_pass(n
->coord
))
249 amaf
->game_baselen
= amaf
->gamelen
;
250 amaf
->record_nakade
= u
->playout_amaf_nakade
;
253 if (u
->dynkomi
> b2
.moves
&& (player_color
& u
->dynkomi_mask
))
254 b2
.komi
+= uct_get_extra_komi(u
, &b2
);
257 /* XXX: No dead groups support. */
258 float score
= board_official_score(&b2
, NULL
);
259 /* Result from black's perspective (no matter who
260 * the player; black's perspective is always
261 * what the tree stores. */
262 result
= - (score
* 2);
265 fprintf(stderr
, "[%d..%d] %s p-p scoring playout result %d (W %f)\n",
266 player_color
, node_color
, coord2sstr(n
->coord
, t
->board
), result
, score
);
268 board_print(&b2
, stderr
);
270 board_ownermap_fill(&u
->ownermap
, &b2
);
272 } else { assert(u
->parallel_tree
|| tree_leaf_node(n
));
273 /* In case of parallel tree search, the assertion might
274 * not hold if two threads chew on the same node. */
275 result
= uct_leaf_node(u
, &b2
, player_color
, amaf
, t
, n
, node_color
, spaces
);
278 if (amaf
&& u
->playout_amaf_cutoff
) {
279 int cutoff
= amaf
->game_baselen
;
280 cutoff
+= (amaf
->gamelen
- amaf
->game_baselen
) * u
->playout_amaf_cutoff
/ 100;
281 /* Now, reconstruct the amaf map. */
282 memset(amaf
->map
, 0, board_size2(&b2
) * sizeof(*amaf
->map
));
283 for (int i
= 0; i
< cutoff
; i
++) {
284 coord_t coord
= amaf
->game
[i
].coord
;
285 enum stone color
= amaf
->game
[i
].color
;
286 if (amaf
->map
[coord
] == S_NONE
|| amaf
->map
[coord
] == color
) {
287 amaf
->map
[coord
] = color
;
288 /* Nakade always recorded for in-tree part */
289 } else if (amaf
->record_nakade
|| i
<= amaf
->game_baselen
) {
290 amaf_op(amaf
->map
[n
->coord
], +);
295 assert(n
== t
->root
|| n
->parent
);
297 float rval
= scale_value(u
, b
, result
);
299 u
->policy
->update(u
->policy
, t
, n
, node_color
, player_color
, amaf
, rval
);
301 if (u
->root_heuristic
&& n
->parent
) {
303 t
->chvals
= calloc(board_size2(b
), sizeof(t
->chvals
[0]));
304 t
->chchvals
= calloc(board_size2(b
), sizeof(t
->chchvals
[0]));
307 /* Possibly transform the rval appropriately. */
308 rval
= stats_temper_value(rval
, n
->parent
->u
.value
, u
->root_heuristic
);
310 struct tree_node
*ni
= n
;
311 while (ni
->parent
->parent
&& ni
->parent
->parent
->parent
)
313 if (ni
->parent
->parent
) {
314 if (likely(!is_pass(ni
->coord
)))
315 stats_add_result(&t
->chchvals
[ni
->coord
], rval
, 1);
318 assert(ni
->parent
&& !ni
->parent
->parent
);
320 if (likely(!is_pass(ni
->coord
)))
321 stats_add_result(&t
->chvals
[ni
->coord
], rval
, 1);
326 /* We need to undo the virtual loss we added during descend. */
327 if (u
->virtual_loss
) {
328 int parity
= (node_color
== player_color
? 1 : -1);
329 for (; n
->parent
; n
= n
->parent
) {
330 stats_rm_result(&n
->u
, tree_parity(t
, parity
) > 0 ? 0 : 1, 1);
335 if (dstater
) free(dstater
);
336 if (dstate
) free(dstate
);
341 board_done_noalloc(&b2
);
346 uct_playouts(struct uct
*u
, struct board
*b
, enum stone color
, struct tree
*t
, int games
)
348 /* Should we print progress info? In case all threads work on the same
349 * tree, only the first thread does. */
350 #define ok_to_talk (!u->parallel_tree || !thread_id)
353 for (i
= 0; i
< games
; i
++) {
354 int result
= uct_playout(u
, b
, color
, t
);
356 /* Tree descent has hit invalid move. */
360 if (unlikely(ok_to_talk
&& i
> 0 && !(i
% 10000))) {
361 uct_progress_status(u
, t
, color
, i
);
364 if (i
> 0 && !(i
% 500)) {
365 struct tree_node
*best
= u
->policy
->choose(u
->policy
, t
->root
, b
, color
);
366 if (best
&& ((best
->u
.playouts
>= 2000 && tree_node_get_value(t
, 1, best
->u
.value
) >= u
->loss_threshold
)
367 || (best
->u
.playouts
>= 500 && tree_node_get_value(t
, 1, best
->u
.value
) >= 0.95)))
373 fprintf(stderr
, "<halting early, %d games skipped>\n", games
- i
);
379 uct_progress_status(u
, t
, color
, i
);
381 tree_dump(t
, u
->dumpthres
);