UCT: Add policy= parameter, explore_p is now policy argument
[pachi.git] / uct / uct.c
blobd1b82c3e6deefa06e4d5a90b427f554c5dd3f93c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
5 #define DEBUG
7 #include "debug.h"
8 #include "board.h"
9 #include "move.h"
10 #include "playout.h"
11 #include "montecarlo/hint.h"
12 #include "uct/internal.h"
13 #include "uct/tree.h"
14 #include "uct/uct.h"
16 struct uct_policy *policy_ucb1_init(struct uct *u, char *arg);
19 #define MC_GAMES 40000
20 #define MC_GAMELEN 400
23 static coord_t
24 domainhint_policy(void *playout_policy, struct board *b, enum stone my_color)
26 struct uct *u = playout_policy;
27 return domain_hint(&u->mc, b, my_color);
30 static int
31 uct_playout(struct uct *u, struct board *b, enum stone color, struct tree *t)
33 struct board b2;
34 board_copy(&b2, b);
36 /* Walk the tree until we find a leaf, then expand it and do
37 * a random playout. */
38 struct tree_node *n = t->root;
39 enum stone orig_color = color;
40 int result;
41 int pass_limit = (b2.size - 2) * (b2.size - 2) / 2;
42 int passes = is_pass(b->last_move.coord);
43 if (UDEBUGL(8))
44 fprintf(stderr, "--- UCT walk\n");
45 for (; pass; color = stone_other(color)) {
46 if (tree_leaf_node(n)) {
47 if (n->playouts >= u->expand_p)
48 tree_expand_node(t, n, &b2);
50 result = play_random_game(&b2, stone_other(color), u->gamelen, domainhint_policy, u);
51 if (orig_color == color && result >= 0)
52 result = !result;
53 if (UDEBUGL(7))
54 fprintf(stderr, "[%d..%d] %s playout result %d\n", orig_color, color, coord2sstr(n->coord, t->board), result);
55 break;
58 n = u->policy->descend(u->policy, t, n, (color == orig_color ? 1 : -1), pass_limit);
59 if (UDEBUGL(7))
60 fprintf(stderr, "-- UCT sent us to [%s] %f\n", coord2sstr(n->coord, t->board), n->value);
61 struct move m = { n->coord, color };
62 int res = board_play(&b2, &m);
63 if (res < 0 || (!is_pass(m.coord) && !group_at(&b2, m.coord)) /* suicide */
64 || b2.superko_violation) {
65 if (UDEBUGL(6))
66 fprintf(stderr, "deleting invalid node %d,%d\n", coord_x(n->coord,b), coord_y(n->coord,b));
67 tree_delete_node(n);
68 board_done_noalloc(&b2);
69 return -1;
72 if (is_pass(n->coord)) {
73 passes++;
74 if (passes >= 2) {
75 float score = board_official_score(&b2);
76 result = (orig_color == S_BLACK) ? score < 0 : score > 0;
77 if (UDEBUGL(5))
78 fprintf(stderr, "[%d..%d] %s playout result %d (W %f)\n", orig_color, color, coord2sstr(n->coord, t->board), result, score);
79 if (UDEBUGL(6))
80 board_print(&b2, stderr);
81 break;
83 } else {
84 passes = 0;
88 if (result >= 0)
89 tree_uct_update(n, result);
90 board_done_noalloc(&b2);
91 return result;
94 static coord_t *
95 uct_genmove(struct engine *e, struct board *b, enum stone color)
97 struct uct *u = e->data;
99 if (!u->t) {
100 tree_init:
101 u->t = tree_init(b);
102 //board_print(b, stderr);
103 } else {
104 /* XXX: We hope that the opponent didn't suddenly play
105 * several moves in the row. */
106 for (struct tree_node *ni = u->t->root->children; ni; ni = ni->sibling)
107 if (ni->coord == b->last_move.coord) {
108 tree_promote_node(u->t, ni);
109 goto promoted;
111 fprintf(stderr, "CANNOT FIND NODE TO PROMOTE!\n");
112 tree_done(u->t);
113 goto tree_init;
114 promoted:;
117 int i;
118 for (i = 0; i < u->games; i++) {
119 int result = uct_playout(u, b, color, u->t);
120 if (result < 0) {
121 /* Tree descent has hit invalid move. */
122 continue;
125 if (i > 0 && !(i % 1000)) {
126 struct tree_node *best = u->policy->choose(u->policy, u->t->root, b, color);
127 if (best && best->playouts >= 500 && best->value >= u->loss_threshold)
128 break;
132 if (UDEBUGL(2))
133 tree_dump(u->t);
135 struct tree_node *best = u->policy->choose(u->policy, u->t->root, b, color);
136 if (!best) {
137 tree_done(u->t); u->t = NULL;
138 return coord_copy(pass);
140 if (UDEBUGL(1))
141 fprintf(stderr, "*** WINNER is %d,%d with score %1.4f (%d games)\n", coord_x(best->coord, b), coord_y(best->coord, b), best->value, i);
142 if (best->value < u->resign_ratio && !is_pass(best->coord)) {
143 tree_done(u->t); u->t = NULL;
144 return coord_copy(resign);
146 tree_promote_node(u->t, best);
147 return coord_copy(best->coord);
151 struct uct *
152 uct_state_init(char *arg)
154 struct uct *u = calloc(1, sizeof(struct uct));
156 u->debug_level = 1;
157 u->games = MC_GAMES;
158 u->gamelen = MC_GAMELEN;
159 u->expand_p = 2;
160 u->mc.capture_rate = 100;
161 u->mc.atari_rate = 100;
162 u->mc.cut_rate = 0;
163 // Looking at the actual playouts, this just encourages MC to make
164 // stupid shapes.
165 u->mc.local_rate = 0;
167 if (arg) {
168 char *optspec, *next = arg;
169 while (*next) {
170 optspec = next;
171 next += strcspn(next, ",");
172 if (*next) { *next++ = 0; } else { *next = 0; }
174 char *optname = optspec;
175 char *optval = strchr(optspec, '=');
176 if (optval) *optval++ = 0;
178 if (!strcasecmp(optname, "debug")) {
179 if (optval)
180 u->debug_level = atoi(optval);
181 else
182 u->debug_level++;
183 } else if (!strcasecmp(optname, "games") && optval) {
184 u->games = atoi(optval);
185 } else if (!strcasecmp(optname, "gamelen") && optval) {
186 u->gamelen = atoi(optval);
187 } else if (!strcasecmp(optname, "expand_p") && optval) {
188 u->expand_p = atoi(optval);
189 } else if (!strcasecmp(optname, "policy") && optval) {
190 char *policyarg = strchr(optval, '+');
191 if (policyarg)
192 *policyarg++ = 0;
193 if (!strcasecmp(optval, "ucb1")) {
194 u->policy = policy_ucb1_init(u, policyarg);
196 } else if (!strcasecmp(optname, "pure")) {
197 u->mc.capture_rate = u->mc.local_rate = u->mc.cut_rate = 0;
198 } else if (!strcasecmp(optname, "capturerate") && optval) {
199 u->mc.capture_rate = atoi(optval);
200 } else if (!strcasecmp(optname, "atarirate") && optval) {
201 u->mc.atari_rate = atoi(optval);
202 } else if (!strcasecmp(optname, "localrate") && optval) {
203 u->mc.local_rate = atoi(optval);
204 } else if (!strcasecmp(optname, "cutrate") && optval) {
205 u->mc.cut_rate = atoi(optval);
206 } else {
207 fprintf(stderr, "uct: Invalid engine argument %s or missing value\n", optname);
212 u->resign_ratio = 0.2; /* Resign when most games are lost. */
213 u->loss_threshold = 0.95; /* Stop reading if after at least 500 playouts this is best value. */
214 u->mc.debug_level = u->debug_level;
215 u->policy = policy_ucb1_init(u, NULL);
217 return u;
221 struct engine *
222 engine_uct_init(char *arg)
224 struct uct *u = uct_state_init(arg);
225 struct engine *e = calloc(1, sizeof(struct engine));
226 e->name = "UCT Engine";
227 e->comment = "I'm playing UCT. When we both pass, I will consider all the stones on the board alive. If you are reading this, write 'yes'. Please bear with me at the game end, I need to fill the whole board; if you help me, we will both be happier. Filling the board will not lose points (NZ rules).";
228 e->genmove = uct_genmove;
229 e->data = u;
231 return e;