UCT TreePool: Introduce; replays best significant children during playout
[pachi/derm.git] / uct / prior.c
blob0bed606197e1be4a70c24bef5dc278e23e32a9bf
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "joseki/base.h"
10 #include "move.h"
11 #include "random.h"
12 #include "tactics/util.h"
13 #include "uct/internal.h"
14 #include "uct/plugins.h"
15 #include "uct/prior.h"
16 #include "uct/tree.h"
18 /* Applying heuristic values to the tree nodes, skewing the reading in
19 * most interesting directions. */
21 /* TODO: Introduce foreach_fpoint() to iterate only over non-occupied
22 * positions. */
24 struct uct_prior {
25 /* Equivalent experience for prior knowledge. MoGo paper recommends
26 * 50 playouts per source; in practice, esp. with RAVE, about 6
27 * playouts per source seems best. */
28 int eqex;
29 int even_eqex, policy_eqex, b19_eqex, eye_eqex, ko_eqex, plugin_eqex, joseki_eqex;
30 int cfgdn; int *cfgd_eqex;
33 void
34 uct_prior_even(struct uct *u, struct tree_node *node, struct prior_map *map)
36 /* Q_{even} */
37 /* This may be dubious for normal UCB1 but is essential for
38 * reading stability of RAVE, it appears. */
39 add_prior_value(map, pass, 0.5, u->prior->even_eqex);
40 foreach_free_point(map->b) {
41 if (!map->consider[c])
42 continue;
43 add_prior_value(map, c, 0.5, u->prior->even_eqex);
44 } foreach_free_point_end;
47 void
48 uct_prior_eye(struct uct *u, struct tree_node *node, struct prior_map *map)
50 /* Discourage playing into our own eyes. However, we cannot
51 * completely prohibit it:
52 * #######
53 * ...XX.#
54 * XOOOXX#
55 * X.OOOO#
56 * .XXXX.# */
57 foreach_free_point(map->b) {
58 if (!map->consider[c])
59 continue;
60 if (!board_is_one_point_eye(map->b, c, map->to_play))
61 continue;
62 add_prior_value(map, c, 0, u->prior->eye_eqex);
63 } foreach_free_point_end;
66 void
67 uct_prior_ko(struct uct *u, struct tree_node *node, struct prior_map *map)
69 /* Favor fighting ko, if we took it le 10 moves ago. */
70 coord_t ko = map->b->last_ko.coord;
71 if (is_pass(ko) || map->b->moves - map->b->last_ko_age > 10 || !map->consider[ko])
72 return;
73 // fprintf(stderr, "prior ko-fight @ %s %s\n", stone2str(map->to_play), coord2sstr(ko, map->b));
74 add_prior_value(map, ko, 1, u->prior->ko_eqex);
77 void
78 uct_prior_b19(struct uct *u, struct tree_node *node, struct prior_map *map)
80 /* Q_{b19} */
81 /* Specific hints for 19x19 board - priors for certain edge distances. */
82 foreach_free_point(map->b) {
83 if (!map->consider[c])
84 continue;
85 int d = coord_edge_distance(c, map->b);
86 if (d != 0 && d != 2)
87 continue;
88 /* The bonus applies only with no stones in immediate
89 * vincinity. */
90 if (board_stone_radar(map->b, c, 2))
91 continue;
92 /* First line: 0 */
93 /* Third line: 1 */
94 add_prior_value(map, c, d == 2, u->prior->b19_eqex);
95 } foreach_free_point_end;
98 void
99 uct_prior_playout(struct uct *u, struct tree_node *node, struct prior_map *map)
101 /* Q_{playout-policy} */
102 if (u->playout->assess)
103 u->playout->assess(u->playout, map, u->prior->policy_eqex);
106 void
107 uct_prior_cfgd(struct uct *u, struct tree_node *node, struct prior_map *map)
109 /* Q_{common_fate_graph_distance} */
110 /* Give bonus to moves local to the last move, where "local" means
111 * local in terms of groups, not just manhattan distance. */
112 if (is_pass(map->b->last_move.coord) || is_resign(map->b->last_move.coord))
113 return;
115 foreach_free_point(map->b) {
116 if (!map->consider[c])
117 continue;
118 if (map->distances[c] > u->prior->cfgdn)
119 continue;
120 assert(map->distances[c] != 0);
121 int bonus = u->prior->cfgd_eqex[map->distances[c]];
122 add_prior_value(map, c, 1, bonus);
123 } foreach_free_point_end;
126 void
127 uct_prior_joseki(struct uct *u, struct tree_node *node, struct prior_map *map)
129 /* Q_{joseki} */
130 if (!u->jdict)
131 return;
132 for (int i = 0; i < 4; i++) {
133 hash_t h = map->b->qhash[i] & joseki_hash_mask;
134 coord_t *cc = u->jdict->patterns[h].moves[map->to_play - 1];
135 if (!cc) continue;
136 for (; !is_pass(*cc); cc++) {
137 if (coord_quadrant(*cc, map->b) != i)
138 continue;
139 add_prior_value(map, *cc, 1.0, u->prior->joseki_eqex);
144 void
145 uct_prior(struct uct *u, struct tree_node *node, struct prior_map *map)
147 if (u->prior->even_eqex)
148 uct_prior_even(u, node, map);
149 if (u->prior->eye_eqex)
150 uct_prior_eye(u, node, map);
151 if (u->prior->ko_eqex)
152 uct_prior_ko(u, node, map);
153 if (u->prior->b19_eqex)
154 uct_prior_b19(u, node, map);
155 if (u->prior->policy_eqex)
156 uct_prior_playout(u, node, map);
157 if (u->prior->cfgd_eqex)
158 uct_prior_cfgd(u, node, map);
159 if (u->prior->joseki_eqex)
160 uct_prior_joseki(u, node, map);
161 if (u->prior->plugin_eqex)
162 plugin_prior(u->plugins, node, map, u->prior->plugin_eqex);
165 struct uct_prior *
166 uct_prior_init(char *arg, struct board *b)
168 struct uct_prior *p = calloc2(1, sizeof(struct uct_prior));
170 p->even_eqex = p->policy_eqex = p->b19_eqex = p->eye_eqex = p->ko_eqex = p->plugin_eqex = -100;
171 p->joseki_eqex = -200;
172 p->cfgdn = -1;
174 /* Even number! */
175 p->eqex = board_size(b)-2 >= 19 ? 20 : 14;
177 if (arg) {
178 char *optspec, *next = arg;
179 while (*next) {
180 optspec = next;
181 next += strcspn(next, ":");
182 if (*next) { *next++ = 0; } else { *next = 0; }
184 char *optname = optspec;
185 char *optval = strchr(optspec, '=');
186 if (optval) *optval++ = 0;
188 if (!strcasecmp(optname, "eqex") && optval) {
189 p->eqex = atoi(optval);
191 /* In the following settings, you can use negative
192 * numbers to give the hundredths of default eqex.
193 * E.g. -100 is default eqex, -50 is half of the
194 * default eqex, -200 is double the default eqex. */
195 } else if (!strcasecmp(optname, "even") && optval) {
196 p->even_eqex = atoi(optval);
197 } else if (!strcasecmp(optname, "policy") && optval) {
198 p->policy_eqex = atoi(optval);
199 } else if (!strcasecmp(optname, "b19") && optval) {
200 p->b19_eqex = atoi(optval);
201 } else if (!strcasecmp(optname, "cfgd") && optval) {
202 /* cfgd=3%40%20%20 - 3 levels; immediate libs
203 * of last move => 40 wins, their neighbors
204 * 20 wins, 2nd-level neighbors 20 wins;
205 * neighbors are group-transitive. */
206 p->cfgdn = atoi(optval); optval += strcspn(optval, "%");
207 p->cfgd_eqex = calloc2(p->cfgdn + 1, sizeof(*p->cfgd_eqex));
208 p->cfgd_eqex[0] = 0;
209 int i;
210 for (i = 1; *optval; i++, optval += strcspn(optval, "%")) {
211 optval++;
212 p->cfgd_eqex[i] = atoi(optval);
214 if (i != p->cfgdn + 1) {
215 fprintf(stderr, "uct: Missing prior cfdn level %d/%d\n", i, p->cfgdn);
216 exit(1);
219 } else if (!strcasecmp(optname, "joseki") && optval) {
220 p->joseki_eqex = atoi(optval);
221 } else if (!strcasecmp(optname, "eye") && optval) {
222 p->eye_eqex = atoi(optval);
223 } else if (!strcasecmp(optname, "ko") && optval) {
224 p->ko_eqex = atoi(optval);
225 } else if (!strcasecmp(optname, "plugin") && optval) {
226 /* Unlike others, this is just a *recommendation*. */
227 p->plugin_eqex = atoi(optval);
228 } else {
229 fprintf(stderr, "uct: Invalid prior argument %s or missing value\n", optname);
230 exit(1);
235 if (p->even_eqex < 0) p->even_eqex = p->eqex * -p->even_eqex / 100;
236 if (p->policy_eqex < 0) p->policy_eqex = p->eqex * -p->policy_eqex / 100;
237 if (p->b19_eqex < 0) p->b19_eqex = p->eqex * -p->b19_eqex / 100;
238 if (p->eye_eqex < 0) p->eye_eqex = p->eqex * -p->eye_eqex / 100;
239 if (p->ko_eqex < 0) p->ko_eqex = p->eqex * -p->ko_eqex / 100;
240 if (p->joseki_eqex < 0) p->joseki_eqex = p->eqex * -p->joseki_eqex / 100;
241 if (p->plugin_eqex < 0) p->plugin_eqex = p->eqex * -p->plugin_eqex / 100;
243 if (p->cfgdn < 0) {
244 int bonuses[] = { 0, 2*p->eqex, p->eqex, p->eqex };
245 p->cfgdn = 3;
246 p->cfgd_eqex = calloc2(p->cfgdn + 1, sizeof(*p->cfgd_eqex));
247 memcpy(p->cfgd_eqex, bonuses, sizeof(bonuses));
249 if (p->cfgdn > TREE_NODE_D_MAX) {
250 fprintf(stderr, "uct: CFG distances only up to %d available\n", TREE_NODE_D_MAX);
251 exit(1);
254 return p;
257 void
258 uct_prior_done(struct uct_prior *p)
260 assert(p->cfgd_eqex);
261 free(p->cfgd_eqex);
262 free(p);