fast_frandom() always returns float, not floating_t
[pachi/t.git] / uct / prior.c
blob29c8c7903086113f751452544a93c8ae9a335884
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #define DEBUG
8 #include "board.h"
9 #include "debug.h"
10 #include "joseki/base.h"
11 #include "move.h"
12 #include "random.h"
13 #include "tactics/ladder.h"
14 #include "tactics/util.h"
15 #include "uct/internal.h"
16 #include "uct/plugins.h"
17 #include "uct/prior.h"
18 #include "uct/tree.h"
20 /* Applying heuristic values to the tree nodes, skewing the reading in
21 * most interesting directions. */
23 /* TODO: Introduce foreach_fpoint() to iterate only over non-occupied
24 * positions. */
26 struct uct_prior {
27 /* Equivalent experience for prior knowledge. MoGo paper recommends
28 * 50 playouts per source; in practice, esp. with RAVE, about 6
29 * playouts per source seems best. */
30 int eqex;
31 int even_eqex, policy_eqex, b19_eqex, eye_eqex, ko_eqex, plugin_eqex, joseki_eqex, pattern_eqex;
32 int cfgdn; int *cfgd_eqex;
33 bool prune_ladders;
36 void
37 uct_prior_even(struct uct *u, struct tree_node *node, struct prior_map *map)
39 /* Q_{even} */
40 /* This may be dubious for normal UCB1 but is essential for
41 * reading stability of RAVE, it appears. */
42 add_prior_value(map, pass, 0.5, u->prior->even_eqex);
43 foreach_free_point(map->b) {
44 if (!map->consider[c])
45 continue;
46 add_prior_value(map, c, 0.5, u->prior->even_eqex);
47 } foreach_free_point_end;
50 void
51 uct_prior_eye(struct uct *u, struct tree_node *node, struct prior_map *map)
53 /* Discourage playing into our own eyes. However, we cannot
54 * completely prohibit it:
55 * #######
56 * ...XX.#
57 * XOOOXX#
58 * X.OOOO#
59 * .XXXX.# */
60 foreach_free_point(map->b) {
61 if (!map->consider[c])
62 continue;
63 if (!board_is_one_point_eye(map->b, c, map->to_play))
64 continue;
65 add_prior_value(map, c, 0, u->prior->eye_eqex);
66 } foreach_free_point_end;
69 void
70 uct_prior_ko(struct uct *u, struct tree_node *node, struct prior_map *map)
72 /* Favor fighting ko, if we took it le 10 moves ago. */
73 coord_t ko = map->b->last_ko.coord;
74 if (is_pass(ko) || map->b->moves - map->b->last_ko_age > 10 || !map->consider[ko])
75 return;
76 // fprintf(stderr, "prior ko-fight @ %s %s\n", stone2str(map->to_play), coord2sstr(ko, map->b));
77 add_prior_value(map, ko, 1, u->prior->ko_eqex);
80 void
81 uct_prior_b19(struct uct *u, struct tree_node *node, struct prior_map *map)
83 /* Q_{b19} */
84 /* Specific hints for 19x19 board - priors for certain edge distances. */
85 foreach_free_point(map->b) {
86 if (!map->consider[c])
87 continue;
88 int d = coord_edge_distance(c, map->b);
89 if (d != 0 && d != 2)
90 continue;
91 /* The bonus applies only with no stones in immediate
92 * vincinity. */
93 if (board_stone_radar(map->b, c, 2))
94 continue;
95 /* First line: 0 */
96 /* Third line: 1 */
97 add_prior_value(map, c, d == 2, u->prior->b19_eqex);
98 } foreach_free_point_end;
101 void
102 uct_prior_playout(struct uct *u, struct tree_node *node, struct prior_map *map)
104 /* Q_{playout-policy} */
105 if (u->playout->assess)
106 u->playout->assess(u->playout, map, u->prior->policy_eqex);
109 void
110 uct_prior_cfgd(struct uct *u, struct tree_node *node, struct prior_map *map)
112 /* Q_{common_fate_graph_distance} */
113 /* Give bonus to moves local to the last move, where "local" means
114 * local in terms of groups, not just manhattan distance. */
115 if (is_pass(map->b->last_move.coord) || is_resign(map->b->last_move.coord))
116 return;
118 foreach_free_point(map->b) {
119 if (!map->consider[c])
120 continue;
121 if (map->distances[c] > u->prior->cfgdn)
122 continue;
123 assert(map->distances[c] != 0);
124 int bonus = u->prior->cfgd_eqex[map->distances[c]];
125 add_prior_value(map, c, 1, bonus);
126 } foreach_free_point_end;
129 void
130 uct_prior_joseki(struct uct *u, struct tree_node *node, struct prior_map *map)
132 /* Q_{joseki} */
133 if (!u->jdict)
134 return;
135 for (int i = 0; i < 4; i++) {
136 hash_t h = map->b->qhash[i] & joseki_hash_mask;
137 coord_t *cc = u->jdict->patterns[h].moves[map->to_play - 1];
138 if (!cc) continue;
139 for (; !is_pass(*cc); cc++) {
140 if (coord_quadrant(*cc, map->b) != i)
141 continue;
142 add_prior_value(map, *cc, 1.0, u->prior->joseki_eqex);
147 void
148 uct_prior_pattern(struct uct *u, struct tree_node *node, struct prior_map *map)
150 /* Q_{pattern} */
151 if (!u->pat.pd)
152 return;
154 struct board *b = map->b;
155 struct pattern pats[b->flen];
156 floating_t probs[b->flen];
157 pattern_rate_moves(&u->pat, b, map->to_play, pats, probs);
158 if (UDEBUGL(5)) {
159 fprintf(stderr, "Pattern prior at node %s\n", coord2sstr(node->coord, b));
160 board_print(b, stderr);
163 for (int f = 0; f < b->flen; f++) {
164 if (isnan(probs[f]) || probs[f] < 0.001)
165 continue;
166 assert(!is_pass(b->f[f]));
167 if (UDEBUGL(5)) {
168 char s[256]; pattern2str(s, &pats[f]);
169 fprintf(stderr, "\t%s: %.3f %s\n", coord2sstr(b->f[f], b), probs[f], s);
171 add_prior_value(map, b->f[f], 1.0, sqrt(probs[f]) * u->prior->pattern_eqex);
175 void
176 uct_prior(struct uct *u, struct tree_node *node, struct prior_map *map)
178 if (u->prior->prune_ladders && !board_playing_ko_threat(map->b)) {
179 foreach_free_point(map->b) {
180 if (!map->consider[c])
181 continue;
182 group_t atari_neighbor = board_get_atari_neighbor(map->b, c, map->to_play);
183 if (atari_neighbor && is_ladder(map->b, c, atari_neighbor, true)) {
184 if (UDEBUGL(5))
185 fprintf(stderr, "Pruning ladder move %s\n", coord2sstr(c, map->b));
186 map->consider[c] = false;
188 } foreach_free_point_end;
191 if (u->prior->even_eqex)
192 uct_prior_even(u, node, map);
193 if (u->prior->eye_eqex)
194 uct_prior_eye(u, node, map);
195 if (u->prior->ko_eqex)
196 uct_prior_ko(u, node, map);
197 if (u->prior->b19_eqex)
198 uct_prior_b19(u, node, map);
199 if (u->prior->policy_eqex)
200 uct_prior_playout(u, node, map);
201 if (u->prior->cfgd_eqex)
202 uct_prior_cfgd(u, node, map);
203 if (u->prior->joseki_eqex)
204 uct_prior_joseki(u, node, map);
205 if (u->prior->pattern_eqex)
206 uct_prior_pattern(u, node, map);
207 if (u->prior->plugin_eqex)
208 plugin_prior(u->plugins, node, map, u->prior->plugin_eqex);
211 struct uct_prior *
212 uct_prior_init(char *arg, struct board *b, struct uct *u)
214 struct uct_prior *p = calloc2(1, sizeof(struct uct_prior));
216 p->even_eqex = p->policy_eqex = p->b19_eqex = p->eye_eqex = p->ko_eqex = p->plugin_eqex = -100;
217 /* FIXME: Optimal pattern_eqex is about -1000 with small playout counts
218 * but only -400 on a cluster. We need a better way to set the default
219 * here. */
220 p->pattern_eqex = -400;
221 p->joseki_eqex = -200;
222 p->cfgdn = -1;
224 /* Even number! */
225 p->eqex = board_large(b) ? 20 : 14;
227 p->prune_ladders = true;
229 if (arg) {
230 char *optspec, *next = arg;
231 while (*next) {
232 optspec = next;
233 next += strcspn(next, ":");
234 if (*next) { *next++ = 0; } else { *next = 0; }
236 char *optname = optspec;
237 char *optval = strchr(optspec, '=');
238 if (optval) *optval++ = 0;
240 if (!strcasecmp(optname, "eqex") && optval) {
241 p->eqex = atoi(optval);
243 /* In the following settings, you can use negative
244 * numbers to give the hundredths of default eqex.
245 * E.g. -100 is default eqex, -50 is half of the
246 * default eqex, -200 is double the default eqex. */
247 } else if (!strcasecmp(optname, "even") && optval) {
248 p->even_eqex = atoi(optval);
249 } else if (!strcasecmp(optname, "policy") && optval) {
250 p->policy_eqex = atoi(optval);
251 } else if (!strcasecmp(optname, "b19") && optval) {
252 p->b19_eqex = atoi(optval);
253 } else if (!strcasecmp(optname, "cfgd") && optval) {
254 /* cfgd=3%40%20%20 - 3 levels; immediate libs
255 * of last move => 40 wins, their neighbors
256 * 20 wins, 2nd-level neighbors 20 wins;
257 * neighbors are group-transitive. */
258 p->cfgdn = atoi(optval); optval += strcspn(optval, "%");
259 p->cfgd_eqex = calloc2(p->cfgdn + 1, sizeof(*p->cfgd_eqex));
260 p->cfgd_eqex[0] = 0;
261 int i;
262 for (i = 1; *optval; i++, optval += strcspn(optval, "%")) {
263 optval++;
264 p->cfgd_eqex[i] = atoi(optval);
266 if (i != p->cfgdn + 1) {
267 fprintf(stderr, "uct: Missing prior cfdn level %d/%d\n", i, p->cfgdn);
268 exit(1);
271 } else if (!strcasecmp(optname, "joseki") && optval) {
272 p->joseki_eqex = atoi(optval);
273 } else if (!strcasecmp(optname, "eye") && optval) {
274 p->eye_eqex = atoi(optval);
275 } else if (!strcasecmp(optname, "ko") && optval) {
276 p->ko_eqex = atoi(optval);
277 } else if (!strcasecmp(optname, "pattern") && optval) {
278 /* Pattern-based prior eqex. */
279 /* Note that this prior is still going to be
280 * used only if you have downloaded or
281 * generated the pattern files! */
282 p->pattern_eqex = atoi(optval);
283 } else if (!strcasecmp(optname, "plugin") && optval) {
284 /* Unlike others, this is just a *recommendation*. */
285 p->plugin_eqex = atoi(optval);
286 } else if (!strcasecmp(optname, "prune_ladders")) {
287 p->prune_ladders = !optval || atoi(optval);
288 } else {
289 fprintf(stderr, "uct: Invalid prior argument %s or missing value\n", optname);
290 exit(1);
295 if (p->even_eqex < 0) p->even_eqex = p->eqex * -p->even_eqex / 100;
296 if (p->policy_eqex < 0) p->policy_eqex = p->eqex * -p->policy_eqex / 100;
297 if (p->b19_eqex < 0) p->b19_eqex = p->eqex * -p->b19_eqex / 100;
298 if (p->eye_eqex < 0) p->eye_eqex = p->eqex * -p->eye_eqex / 100;
299 if (p->ko_eqex < 0) p->ko_eqex = p->eqex * -p->ko_eqex / 100;
300 if (p->joseki_eqex < 0) p->joseki_eqex = p->eqex * -p->joseki_eqex / 100;
301 if (p->pattern_eqex < 0) p->pattern_eqex = p->eqex * -p->pattern_eqex / 100;
302 if (p->plugin_eqex < 0) p->plugin_eqex = p->eqex * -p->plugin_eqex / 100;
304 if (p->cfgdn < 0) {
305 static int large_bonuses[] = { 0, 55, 50, 15 };
306 static int small_bonuses[] = { 0, 45, 40, 15 };
307 p->cfgdn = 3;
308 p->cfgd_eqex = calloc2(p->cfgdn + 1, sizeof(*p->cfgd_eqex));
309 memcpy(p->cfgd_eqex, board_large(b) ? large_bonuses : small_bonuses, sizeof(large_bonuses));
311 if (p->cfgdn > TREE_NODE_D_MAX) {
312 fprintf(stderr, "uct: CFG distances only up to %d available\n", TREE_NODE_D_MAX);
313 exit(1);
316 if (p->pattern_eqex)
317 u->want_pat = true;
319 return p;
322 void
323 uct_prior_done(struct uct_prior *p)
325 assert(p->cfgd_eqex);
326 free(p->cfgd_eqex);
327 free(p);