Merge branch 'master' of git+ssh://repo.or.cz/srv/git/pachi
[pachi/json.git] / uct / policy / ucb1.c
blob1d33900bb967898f3d49f93bea3b1f426aa3b7fc
1 #include <assert.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "move.h"
10 #include "random.h"
11 #include "uct/internal.h"
12 #include "uct/tree.h"
14 /* This implements the basic UCB1 policy. */
16 struct ucb1_policy {
17 /* This is what the Modification of UCT with Patterns in Monte Carlo Go
18 * paper calls 'p'. Original UCB has this on 2, but this seems to
19 * produce way too wide searches; reduce this to get deeper and
20 * narrower readouts - try 0.2. */
21 float explore_p;
22 /* First Play Urgency - if set to less than infinity (the MoGo paper
23 * above reports 1.0 as the best), new branches are explored only
24 * if none of the existing ones has higher urgency than fpu. */
25 float fpu;
26 /* Equivalent experience for prior knowledge. MoGo paper recommends
27 * 50 playouts per source. */
28 int eqex, even_eqex, gp_eqex, policy_eqex;
29 int urg_randoma, urg_randomm;
33 struct tree_node *
34 ucb1_choose(struct uct_policy *p, struct tree_node *node, struct board *b, enum stone color)
36 struct tree_node *nbest = NULL;
37 for (struct tree_node *ni = node->children; ni; ni = ni->sibling)
38 // we compare playouts and choose the best-explored
39 // child; comparing values is more brittle
40 if (!nbest || ni->u.playouts > nbest->u.playouts) {
41 /* Play pass only if we can afford scoring */
42 if (is_pass(ni->coord)) {
43 float score = board_official_score(b);
44 if (color == S_BLACK)
45 score = -score;
46 //fprintf(stderr, "%d score %f\n", color, score);
47 if (score <= 0)
48 continue;
50 nbest = ni;
52 return nbest;
56 struct tree_node *
57 ucb1_descend(struct uct_policy *p, struct tree *tree, struct tree_node *node, int parity, bool allow_pass)
59 /* We want to count in the prior stats here after all. Otherwise,
60 * nodes with positive prior will get explored _LESS_ since the
61 * urgency will be always higher; even with normal FPU because
62 * of the explore coefficient. */
64 struct ucb1_policy *b = p->data;
65 float xpl = log(node->u.playouts + node->prior.playouts) * b->explore_p;
67 // XXX: Stack overflow danger on big boards?
68 struct tree_node *nbest[512] = { node->children }; int nbests = 1;
69 float best_urgency = -9999;
71 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
72 /* Do not consider passing early. */
73 if (likely(!allow_pass) && unlikely(is_pass(ni->coord)))
74 continue;
75 int uct_playouts = ni->u.playouts + ni->prior.playouts;
76 ni->prior.value = (float)ni->prior.wins / ni->prior.playouts;
78 /* XXX: We later compare urgency with best_urgency; this can
79 * be difficult given that urgency can be in register with
80 * higher precision than best_urgency, thus even though
81 * the numbers are in fact the same, urgency will be
82 * slightly higher (or lower). Thus, we declare urgency
83 * as volatile, attempting to force the compiler to keep
84 * everything as a float. Ideally, we should do some random
85 * __FLT_EPSILON__ magic instead. */
86 volatile float urgency = uct_playouts ? (parity > 0 ? ni->u.value : 1 - ni->u.value) + sqrt(xpl / uct_playouts) : b->fpu;
88 #if 0
90 struct board b2; b2.size = 9+2;
91 fprintf(stderr, "[%s -> %s] UCB1 urgency %f (%f + %f : %f)\n", coord2sstr(node->coord, &b2), coord2sstr(ni->coord, &b2), urgency, ni->u.value, sqrt(xpl / ni->u.playouts), b->fpu);
93 #endif
94 if (b->urg_randoma)
95 urgency += (float)(fast_random(b->urg_randoma) - b->urg_randoma / 2) / 1000;
96 if (b->urg_randomm)
97 urgency *= (float)(fast_random(b->urg_randomm) + 5) / b->urg_randomm;
98 if (urgency > best_urgency) {
99 best_urgency = urgency; nbests = 0;
101 if (urgency >= best_urgency) {
102 /* We want to always choose something else than a pass
103 * in case of a tie. pass causes degenerative behaviour. */
104 if (nbests == 1 && is_pass(nbest[0]->coord)) {
105 nbests--;
107 nbest[nbests++] = ni;
110 return nbest[fast_random(nbests)];
113 void
114 ucb1_prior(struct uct_policy *p, struct tree *tree, struct tree_node *node, struct board *b, enum stone color, int parity)
116 /* Initialization of UCT values based on prior knowledge */
117 struct ucb1_policy *pp = p->data;
119 /* Q_{even} */
120 /* This may be dubious for normal UCB1 but is essential for
121 * reading stability of RAVE, it appears. */
122 if (pp->even_eqex) {
123 node->prior.playouts += pp->even_eqex;
124 node->prior.wins += pp->even_eqex / 2;
127 /* Discourage playing into our own eyes. However, we cannot
128 * completely prohibit it:
129 * ######
130 * ...XX.
131 * XOOOXX
132 * X.OOOO
133 * .XXXX. */
134 if (board_is_one_point_eye(b, &node->coord, color)) {
135 node->prior.playouts += 0;
136 node->prior.wins += pp->eqex;
139 /* Q_{grandparent} */
140 if (pp->gp_eqex && node->parent && node->parent->parent && node->parent->parent->parent) {
141 struct tree_node *gpp = node->parent->parent->parent;
142 for (struct tree_node *ni = gpp->children; ni; ni = ni->sibling) {
143 /* Be careful not to emphasize too random results. */
144 if (ni->coord == node->coord && ni->u.playouts > pp->gp_eqex) {
145 node->prior.playouts += pp->gp_eqex;
146 node->prior.wins += pp->gp_eqex * ni->u.wins / ni->u.playouts;
147 node->hints |= 1;
152 /* Q_{playout-policy} */
153 if (pp->policy_eqex) {
154 float assess = NAN;
155 struct playout_policy *playout = p->uct->playout;
156 if (playout->assess) {
157 struct move m = { node->coord, color };
158 assess = playout->assess(playout, b, &m);
160 if (!isnan(assess)) {
161 if (parity < 0) {
162 /* Good moves for enemy are losses for us.
163 * We will properly maximize this in the UCB1
164 * decision. */
165 assess = 1 - assess;
167 node->prior.playouts += pp->policy_eqex;
168 node->prior.wins += pp->policy_eqex * assess;
169 node->hints |= 2;
173 if (node->prior.playouts) {
174 node->prior.value = (float) node->prior.wins / node->prior.playouts;
175 tree_update_node_value(node);
178 //fprintf(stderr, "%s,%s prior: %d/%d = %f (%f)\n", coord2sstr(node->parent->coord, b), coord2sstr(node->coord, b), node->prior.wins, node->prior.playouts, node->prior.value, assess);
181 void
182 ucb1_update(struct uct_policy *p, struct tree *tree, struct tree_node *node, enum stone node_color, enum stone player_color, struct playout_amafmap *map, int result)
184 /* It is enough to iterate by a single chain; we will
185 * update all the preceding positions properly since
186 * they had to all occur in all branches, only in
187 * different order. */
188 for (; node; node = node->parent) {
189 node->u.playouts++;
190 node->u.wins += result;
191 tree_update_node_value(node);
196 struct uct_policy *
197 policy_ucb1_init(struct uct *u, char *arg)
199 struct uct_policy *p = calloc(1, sizeof(*p));
200 struct ucb1_policy *b = calloc(1, sizeof(*b));
201 p->uct = u;
202 p->data = b;
203 p->descend = ucb1_descend;
204 p->choose = ucb1_choose;
205 p->update = ucb1_update;
207 b->explore_p = 0.2;
208 b->fpu = 1.1; //INFINITY;
209 b->even_eqex = 0;
210 b->gp_eqex = b->policy_eqex = -1;
211 b->eqex = 0; //50;
213 if (arg) {
214 char *optspec, *next = arg;
215 while (*next) {
216 optspec = next;
217 next += strcspn(next, ":");
218 if (*next) { *next++ = 0; } else { *next = 0; }
220 char *optname = optspec;
221 char *optval = strchr(optspec, '=');
222 if (optval) *optval++ = 0;
224 if (!strcasecmp(optname, "explore_p") && optval) {
225 b->explore_p = atof(optval);
226 } else if (!strcasecmp(optname, "prior")) {
227 if (optval)
228 b->eqex = atoi(optval);
229 } else if (!strcasecmp(optname, "prior_even") && optval) {
230 b->even_eqex = atoi(optval);
231 } else if (!strcasecmp(optname, "prior_gp") && optval) {
232 b->gp_eqex = atoi(optval);
233 } else if (!strcasecmp(optname, "prior_policy") && optval) {
234 b->policy_eqex = atoi(optval);
235 } else if (!strcasecmp(optname, "fpu") && optval) {
236 b->fpu = atof(optval);
237 } else if (!strcasecmp(optname, "urg_randoma") && optval) {
238 b->urg_randoma = atoi(optval);
239 } else if (!strcasecmp(optname, "urg_randomm") && optval) {
240 b->urg_randomm = atoi(optval);
241 } else {
242 fprintf(stderr, "ucb1: Invalid policy argument %s or missing value\n", optname);
247 if (b->eqex) p->prior = ucb1_prior;
248 if (b->even_eqex < 0) b->even_eqex = b->eqex;
249 if (b->gp_eqex < 0) b->gp_eqex = b->eqex;
250 if (b->policy_eqex < 0) b->policy_eqex = b->eqex;
252 return p;