UCT: Improve invalid node condition
[pachi.git] / uct / uct.c
blobfdc6def225a71716878143bcc673b0a4185cfb03
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
5 #define DEBUG
7 #include "debug.h"
8 #include "board.h"
9 #include "move.h"
10 #include "playout.h"
11 #include "montecarlo/hint.h"
12 #include "montecarlo/internal.h"
13 #include "uct/tree.h"
14 #include "uct/uct.h"
18 #define MC_GAMES 40000
19 #define MC_GAMELEN 400
22 /* Internal engine state. */
23 struct uct {
24 int debug_level;
25 int games, gamelen;
26 float resign_ratio;
27 float loss_threshold;
28 float explore_p;
29 int expand_p;
31 struct montecarlo mc;
32 struct tree *t;
35 #define UDEBUGL(n) DEBUGL_(u->debug_level, n)
38 static coord_t
39 domainhint_policy(void *playout_policy, struct board *b, enum stone my_color)
41 struct uct *u = playout_policy;
42 return domain_hint(&u->mc, b, my_color);
45 static int
46 uct_playout(struct uct *u, struct board *b, enum stone color, struct tree *t)
48 struct board b2;
49 board_copy(&b2, b);
51 /* Walk the tree until we find a leaf, then expand it and do
52 * a random playout. */
53 struct tree_node *n = t->root;
54 enum stone orig_color = color;
55 int result;
56 int passes = 0;
57 if (UDEBUGL(8))
58 fprintf(stderr, "--- UCT walk\n");
59 for (; pass; color = stone_other(color)) {
60 if (tree_leaf_node(n)) {
61 if (n->playouts >= u->expand_p)
62 tree_expand_node(t, n, &b2);
64 struct move m = { n->coord, color };
65 result = play_random_game(&b2, &m, u->gamelen, domainhint_policy, u);
66 if (orig_color != color && result >= 0)
67 result = !result;
68 if (UDEBUGL(7))
69 fprintf(stderr, "[%d..%d] %s playout result %d\n", orig_color, color, coord2sstr(n->coord, t->board), result);
70 break;
73 n = tree_uct_descend(t, n, (color == orig_color ? 1 : -1), b2.moves > (b2.size2 - 2) / 2);
74 if (UDEBUGL(7))
75 fprintf(stderr, "-- UCT sent us to [%s] %f\n", coord2sstr(n->coord, t->board), n->value);
76 struct move m = { n->coord, color };
77 int res = board_play(&b2, &m);
78 if (res < 0 || (!is_pass(m.coord) && !group_at(&b2, m.coord)) /* suicide */) {
79 if (UDEBUGL(6))
80 fprintf(stderr, "deleting invalid node %d,%d\n", coord_x(n->coord,b), coord_y(n->coord,b));
81 tree_delete_node(n);
82 board_done_noalloc(&b2);
83 return -1;
86 if (is_pass(n->coord)) {
87 passes++;
88 if (passes >= 2) {
89 float score = board_fast_score(&b2) > 0;
90 result = (orig_color == S_BLACK) ? score < 0 : score > 0;
91 if (UDEBUGL(5))
92 fprintf(stderr, "[%d..%d] %s playout result %d (W %f)\n", orig_color, color, coord2sstr(n->coord, t->board), result, score);
93 if (UDEBUGL(6))
94 board_print(&b2, stderr);
95 break;
97 } else {
98 passes = 0;
102 if (result >= 0)
103 tree_uct_update(n, result);
104 board_done_noalloc(&b2);
105 return result;
108 static coord_t *
109 uct_genmove(struct engine *e, struct board *b, enum stone color)
111 struct uct *u = e->data;
113 if (!u->t) {
114 tree_init:
115 u->t = tree_init(b);
116 u->t->explore_p = u->explore_p;
117 } else {
118 /* XXX: We hope that the opponent didn't suddenly play
119 * several moves in the row. */
120 for (struct tree_node *ni = u->t->root->children; ni; ni = ni->sibling)
121 if (ni->coord == b->last_move.coord) {
122 tree_promote_node(u->t, ni);
123 goto promoted;
125 fprintf(stderr, "CANNOT FIND NODE TO PROMOTE!\n");
126 tree_done(u->t);
127 goto tree_init;
128 promoted:;
131 int i;
132 for (i = 0; i < u->games; i++) {
133 int result = uct_playout(u, b, color, u->t);
134 if (result < 0) {
135 /* Tree descent has hit invalid move. */
136 continue;
139 if (i > 0 && !(i % 1000)) {
140 struct tree_node *best = tree_best_child(u->t->root);
141 if (best && best->playouts >= 100 && best->value >= u->loss_threshold)
142 break;
146 if (UDEBUGL(2))
147 tree_dump(u->t);
149 struct tree_node *best = tree_best_child(u->t->root);
150 if (!best) {
151 tree_done(u->t); u->t = NULL;
152 return coord_copy(pass);
154 if (UDEBUGL(1))
155 fprintf(stderr, "*** WINNER is %d,%d with score %1.4f (%d games)\n", coord_x(best->coord, b), coord_y(best->coord, b), best->value, i);
156 if (best->value < u->resign_ratio && !is_pass(best->coord)) {
157 tree_done(u->t); u->t = NULL;
158 return coord_copy(resign);
160 tree_promote_node(u->t, best);
161 return coord_copy(best->coord);
165 struct uct *
166 uct_state_init(char *arg)
168 struct uct *u = calloc(1, sizeof(struct uct));
170 u->debug_level = 1;
171 u->games = MC_GAMES;
172 u->gamelen = MC_GAMELEN;
173 u->explore_p = 0.2;
174 u->expand_p = 2;
175 u->mc.capture_rate = 100;
176 u->mc.atari_rate = 100;
177 u->mc.cut_rate = 50;
178 // Looking at the actual playouts, this just encourages MC to make
179 // stupid shapes.
180 u->mc.local_rate = 0;
182 if (arg) {
183 char *optspec, *next = arg;
184 while (*next) {
185 optspec = next;
186 next += strcspn(next, ",");
187 if (*next) { *next++ = 0; } else { *next = 0; }
189 char *optname = optspec;
190 char *optval = strchr(optspec, '=');
191 if (optval) *optval++ = 0;
193 if (!strcasecmp(optname, "debug")) {
194 if (optval)
195 u->debug_level = atoi(optval);
196 else
197 u->debug_level++;
198 } else if (!strcasecmp(optname, "games") && optval) {
199 u->games = atoi(optval);
200 } else if (!strcasecmp(optname, "gamelen") && optval) {
201 u->gamelen = atoi(optval);
202 } else if (!strcasecmp(optname, "explore_p") && optval) {
203 u->explore_p = atof(optval);
204 } else if (!strcasecmp(optname, "expand_p") && optval) {
205 u->expand_p = atoi(optval);
206 } else if (!strcasecmp(optname, "pure")) {
207 u->mc.capture_rate = u->mc.local_rate = u->mc.cut_rate = 0;
208 } else if (!strcasecmp(optname, "capturerate") && optval) {
209 u->mc.capture_rate = atoi(optval);
210 } else if (!strcasecmp(optname, "atarirate") && optval) {
211 u->mc.atari_rate = atoi(optval);
212 } else if (!strcasecmp(optname, "localrate") && optval) {
213 u->mc.local_rate = atoi(optval);
214 } else if (!strcasecmp(optname, "cutrate") && optval) {
215 u->mc.cut_rate = atoi(optval);
216 } else {
217 fprintf(stderr, "uct: Invalid engine argument %s or missing value\n", optname);
222 u->resign_ratio = 0.2; /* Resign when most games are lost. */
223 u->loss_threshold = 0.9; /* Stop reading if after at least 1000 games this is best value. */
224 u->mc.debug_level = u->debug_level;
226 return u;
230 struct engine *
231 engine_uct_init(char *arg)
233 struct uct *u = uct_state_init(arg);
234 struct engine *e = calloc(1, sizeof(struct engine));
235 e->name = "UCT Engine";
236 e->comment = "I'm playing UCT. When we both pass, I will consider all the stones on the board alive. If you are reading this, write 'yes'. Please bear with me at the game end, I need to fill the whole board; if you help me, we will both be happier. Filling the board will not lose points (NZ rules).";
237 e->genmove = uct_genmove;
238 e->data = u;
240 return e;