UCT: Start playouts with the right color!
[pachi.git] / uct / uct.c
blob80ca871ca6664ca1ca2d4d9273de8a30b561d337
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
5 #define DEBUG
7 #include "debug.h"
8 #include "board.h"
9 #include "move.h"
10 #include "playout.h"
11 #include "playout/moggy.h"
12 #include "playout/old.h"
13 #include "uct/internal.h"
14 #include "uct/tree.h"
15 #include "uct/uct.h"
17 struct uct_policy *policy_ucb1_init(struct uct *u, char *arg);
18 struct uct_policy *policy_ucb1tuned_init(struct uct *u, char *arg);
21 #define MC_GAMES 40000
22 #define MC_GAMELEN 400
25 static coord_t
26 domainhint_policy(void *playout_policy, struct board *b, enum stone my_color)
28 struct uct *u = playout_policy;
29 return u->playout(&u->mc, b, my_color);
32 static int
33 uct_playout(struct uct *u, struct board *b, enum stone color, struct tree *t)
35 struct board b2;
36 board_copy(&b2, b);
38 /* Walk the tree until we find a leaf, then expand it and do
39 * a random playout. */
40 struct tree_node *n = t->root;
41 enum stone orig_color = color;
42 int result;
43 int pass_limit = (b2.size - 2) * (b2.size - 2) / 2;
44 int passes = is_pass(b->last_move.coord);
45 if (UDEBUGL(8))
46 fprintf(stderr, "--- UCT walk with color %d\n", color);
47 for (; pass; color = stone_other(color)) {
48 if (tree_leaf_node(n)) {
49 if (n->playouts >= u->expand_p)
50 tree_expand_node(t, n, &b2);
52 result = play_random_game(&b2, color, u->gamelen, domainhint_policy, u);
53 if (orig_color != color && result >= 0)
54 result = !result;
55 if (UDEBUGL(7))
56 fprintf(stderr, "[%d..%d] %s random playout result %d\n", orig_color, color, coord2sstr(n->coord, t->board), result);
57 break;
60 n = u->policy->descend(u->policy, t, n, (color == orig_color ? 1 : -1), pass_limit);
61 if (UDEBUGL(7))
62 fprintf(stderr, "-- UCT sent us to [%s] %f\n", coord2sstr(n->coord, t->board), n->value);
63 struct move m = { n->coord, color };
64 int res = board_play(&b2, &m);
65 if (res < 0 || (!is_pass(m.coord) && !group_at(&b2, m.coord)) /* suicide */
66 || b2.superko_violation) {
67 if (UDEBUGL(6))
68 fprintf(stderr, "deleting invalid node %d,%d\n", coord_x(n->coord,b), coord_y(n->coord,b));
69 tree_delete_node(n);
70 board_done_noalloc(&b2);
71 return -1;
74 if (is_pass(n->coord)) {
75 passes++;
76 if (passes >= 2) {
77 float score = board_official_score(&b2);
78 result = (orig_color == S_BLACK) ? score < 0 : score > 0;
79 if (UDEBUGL(5))
80 fprintf(stderr, "[%d..%d] %s playout result %d (W %f)\n", orig_color, color, coord2sstr(n->coord, t->board), result, score);
81 if (UDEBUGL(6))
82 board_print(&b2, stderr);
83 break;
85 } else {
86 passes = 0;
90 if (result >= 0)
91 tree_uct_update(n, result);
92 board_done_noalloc(&b2);
93 return result;
96 static coord_t *
97 uct_genmove(struct engine *e, struct board *b, enum stone color)
99 struct uct *u = e->data;
101 if (!u->t) {
102 tree_init:
103 u->t = tree_init(b);
104 //board_print(b, stderr);
105 } else {
106 /* XXX: We hope that the opponent didn't suddenly play
107 * several moves in the row. */
108 for (struct tree_node *ni = u->t->root->children; ni; ni = ni->sibling)
109 if (ni->coord == b->last_move.coord) {
110 tree_promote_node(u->t, ni);
111 goto promoted;
113 fprintf(stderr, "CANNOT FIND NODE TO PROMOTE!\n");
114 tree_done(u->t);
115 goto tree_init;
116 promoted:;
119 int i;
120 for (i = 0; i < u->games; i++) {
121 int result = uct_playout(u, b, color, u->t);
122 if (result < 0) {
123 /* Tree descent has hit invalid move. */
124 continue;
127 if (i > 0 && !(i % 1000)) {
128 struct tree_node *best = u->policy->choose(u->policy, u->t->root, b, color);
129 if (best && best->playouts >= 500 && best->value >= u->loss_threshold)
130 break;
134 if (UDEBUGL(2))
135 tree_dump(u->t);
137 struct tree_node *best = u->policy->choose(u->policy, u->t->root, b, color);
138 if (!best) {
139 tree_done(u->t); u->t = NULL;
140 return coord_copy(pass);
142 if (UDEBUGL(1))
143 fprintf(stderr, "*** WINNER is %d,%d with score %1.4f (%d games)\n", coord_x(best->coord, b), coord_y(best->coord, b), best->value, i);
144 if (best->value < u->resign_ratio && !is_pass(best->coord)) {
145 tree_done(u->t); u->t = NULL;
146 return coord_copy(resign);
148 tree_promote_node(u->t, best);
149 return coord_copy(best->coord);
153 struct uct *
154 uct_state_init(char *arg)
156 struct uct *u = calloc(1, sizeof(struct uct));
158 u->debug_level = 1;
159 u->games = MC_GAMES;
160 u->gamelen = MC_GAMELEN;
161 u->expand_p = 2;
162 u->mc.capture_rate = 100;
163 u->mc.atari_rate = 100;
164 u->mc.cut_rate = 0;
165 // Looking at the actual playouts, this just encourages MC to make
166 // stupid shapes.
167 u->mc.local_rate = 0;
169 if (arg) {
170 char *optspec, *next = arg;
171 while (*next) {
172 optspec = next;
173 next += strcspn(next, ",");
174 if (*next) { *next++ = 0; } else { *next = 0; }
176 char *optname = optspec;
177 char *optval = strchr(optspec, '=');
178 if (optval) *optval++ = 0;
180 if (!strcasecmp(optname, "debug")) {
181 if (optval)
182 u->debug_level = atoi(optval);
183 else
184 u->debug_level++;
185 } else if (!strcasecmp(optname, "games") && optval) {
186 u->games = atoi(optval);
187 } else if (!strcasecmp(optname, "gamelen") && optval) {
188 u->gamelen = atoi(optval);
189 } else if (!strcasecmp(optname, "expand_p") && optval) {
190 u->expand_p = atoi(optval);
191 } else if (!strcasecmp(optname, "policy") && optval) {
192 char *policyarg = strchr(optval, ':');
193 if (policyarg)
194 *policyarg++ = 0;
195 if (!strcasecmp(optval, "ucb1")) {
196 u->policy = policy_ucb1_init(u, policyarg);
197 } else if (!strcasecmp(optval, "ucb1tuned")) {
198 u->policy = policy_ucb1tuned_init(u, policyarg);
200 } else if (!strcasecmp(optname, "playout") && optval) {
201 char *playoutarg = strchr(optval, ':');
202 if (playoutarg)
203 *playoutarg++ = 0;
204 if (!strcasecmp(optval, "old")) {
205 u->playout = playout_old;
206 } else if (!strcasecmp(optval, "moggy")) {
207 u->playout = playout_moggy;
209 } else if (!strcasecmp(optname, "pure")) {
210 u->mc.capture_rate = u->mc.local_rate = u->mc.cut_rate = 0;
211 } else if (!strcasecmp(optname, "capturerate") && optval) {
212 u->mc.capture_rate = atoi(optval);
213 } else if (!strcasecmp(optname, "atarirate") && optval) {
214 u->mc.atari_rate = atoi(optval);
215 } else if (!strcasecmp(optname, "localrate") && optval) {
216 u->mc.local_rate = atoi(optval);
217 } else if (!strcasecmp(optname, "cutrate") && optval) {
218 u->mc.cut_rate = atoi(optval);
219 } else {
220 fprintf(stderr, "uct: Invalid engine argument %s or missing value\n", optname);
225 u->resign_ratio = 0.2; /* Resign when most games are lost. */
226 u->loss_threshold = 0.95; /* Stop reading if after at least 500 playouts this is best value. */
227 u->mc.debug_level = u->debug_level;
228 if (!u->policy)
229 u->policy = policy_ucb1_init(u, NULL);
230 if (!u->playout)
231 u->playout = playout_old;
233 return u;
237 struct engine *
238 engine_uct_init(char *arg)
240 struct uct *u = uct_state_init(arg);
241 struct engine *e = calloc(1, sizeof(struct engine));
242 e->name = "UCT Engine";
243 e->comment = "I'm playing UCT. When we both pass, I will consider all the stones on the board alive. If you are reading this, write 'yes'. Please bear with me at the game end, I need to fill the whole board; if you help me, we will both be happier. Filling the board will not lose points (NZ rules).";
244 e->genmove = uct_genmove;
245 e->data = u;
247 return e;