UCT: Add ucb1tuned policy
[pachi/peepo.git] / uct / uct.c
blobb2ff59e77c9cc1b148d1e1c7b1d63701d09f7e60
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
5 #define DEBUG
7 #include "debug.h"
8 #include "board.h"
9 #include "move.h"
10 #include "playout.h"
11 #include "montecarlo/hint.h"
12 #include "uct/internal.h"
13 #include "uct/tree.h"
14 #include "uct/uct.h"
16 struct uct_policy *policy_ucb1_init(struct uct *u, char *arg);
17 struct uct_policy *policy_ucb1tuned_init(struct uct *u, char *arg);
20 #define MC_GAMES 40000
21 #define MC_GAMELEN 400
24 static coord_t
25 domainhint_policy(void *playout_policy, struct board *b, enum stone my_color)
27 struct uct *u = playout_policy;
28 return domain_hint(&u->mc, b, my_color);
31 static int
32 uct_playout(struct uct *u, struct board *b, enum stone color, struct tree *t)
34 struct board b2;
35 board_copy(&b2, b);
37 /* Walk the tree until we find a leaf, then expand it and do
38 * a random playout. */
39 struct tree_node *n = t->root;
40 enum stone orig_color = color;
41 int result;
42 int pass_limit = (b2.size - 2) * (b2.size - 2) / 2;
43 int passes = is_pass(b->last_move.coord);
44 if (UDEBUGL(8))
45 fprintf(stderr, "--- UCT walk\n");
46 for (; pass; color = stone_other(color)) {
47 if (tree_leaf_node(n)) {
48 if (n->playouts >= u->expand_p)
49 tree_expand_node(t, n, &b2);
51 result = play_random_game(&b2, stone_other(color), u->gamelen, domainhint_policy, u);
52 if (orig_color == color && result >= 0)
53 result = !result;
54 if (UDEBUGL(7))
55 fprintf(stderr, "[%d..%d] %s playout result %d\n", orig_color, color, coord2sstr(n->coord, t->board), result);
56 break;
59 n = u->policy->descend(u->policy, t, n, (color == orig_color ? 1 : -1), pass_limit);
60 if (UDEBUGL(7))
61 fprintf(stderr, "-- UCT sent us to [%s] %f\n", coord2sstr(n->coord, t->board), n->value);
62 struct move m = { n->coord, color };
63 int res = board_play(&b2, &m);
64 if (res < 0 || (!is_pass(m.coord) && !group_at(&b2, m.coord)) /* suicide */
65 || b2.superko_violation) {
66 if (UDEBUGL(6))
67 fprintf(stderr, "deleting invalid node %d,%d\n", coord_x(n->coord,b), coord_y(n->coord,b));
68 tree_delete_node(n);
69 board_done_noalloc(&b2);
70 return -1;
73 if (is_pass(n->coord)) {
74 passes++;
75 if (passes >= 2) {
76 float score = board_official_score(&b2);
77 result = (orig_color == S_BLACK) ? score < 0 : score > 0;
78 if (UDEBUGL(5))
79 fprintf(stderr, "[%d..%d] %s playout result %d (W %f)\n", orig_color, color, coord2sstr(n->coord, t->board), result, score);
80 if (UDEBUGL(6))
81 board_print(&b2, stderr);
82 break;
84 } else {
85 passes = 0;
89 if (result >= 0)
90 tree_uct_update(n, result);
91 board_done_noalloc(&b2);
92 return result;
95 static coord_t *
96 uct_genmove(struct engine *e, struct board *b, enum stone color)
98 struct uct *u = e->data;
100 if (!u->t) {
101 tree_init:
102 u->t = tree_init(b);
103 //board_print(b, stderr);
104 } else {
105 /* XXX: We hope that the opponent didn't suddenly play
106 * several moves in the row. */
107 for (struct tree_node *ni = u->t->root->children; ni; ni = ni->sibling)
108 if (ni->coord == b->last_move.coord) {
109 tree_promote_node(u->t, ni);
110 goto promoted;
112 fprintf(stderr, "CANNOT FIND NODE TO PROMOTE!\n");
113 tree_done(u->t);
114 goto tree_init;
115 promoted:;
118 int i;
119 for (i = 0; i < u->games; i++) {
120 int result = uct_playout(u, b, color, u->t);
121 if (result < 0) {
122 /* Tree descent has hit invalid move. */
123 continue;
126 if (i > 0 && !(i % 1000)) {
127 struct tree_node *best = u->policy->choose(u->policy, u->t->root, b, color);
128 if (best && best->playouts >= 500 && best->value >= u->loss_threshold)
129 break;
133 if (UDEBUGL(2))
134 tree_dump(u->t);
136 struct tree_node *best = u->policy->choose(u->policy, u->t->root, b, color);
137 if (!best) {
138 tree_done(u->t); u->t = NULL;
139 return coord_copy(pass);
141 if (UDEBUGL(1))
142 fprintf(stderr, "*** WINNER is %d,%d with score %1.4f (%d games)\n", coord_x(best->coord, b), coord_y(best->coord, b), best->value, i);
143 if (best->value < u->resign_ratio && !is_pass(best->coord)) {
144 tree_done(u->t); u->t = NULL;
145 return coord_copy(resign);
147 tree_promote_node(u->t, best);
148 return coord_copy(best->coord);
152 struct uct *
153 uct_state_init(char *arg)
155 struct uct *u = calloc(1, sizeof(struct uct));
157 u->debug_level = 1;
158 u->games = MC_GAMES;
159 u->gamelen = MC_GAMELEN;
160 u->expand_p = 2;
161 u->mc.capture_rate = 100;
162 u->mc.atari_rate = 100;
163 u->mc.cut_rate = 0;
164 // Looking at the actual playouts, this just encourages MC to make
165 // stupid shapes.
166 u->mc.local_rate = 0;
168 if (arg) {
169 char *optspec, *next = arg;
170 while (*next) {
171 optspec = next;
172 next += strcspn(next, ",");
173 if (*next) { *next++ = 0; } else { *next = 0; }
175 char *optname = optspec;
176 char *optval = strchr(optspec, '=');
177 if (optval) *optval++ = 0;
179 if (!strcasecmp(optname, "debug")) {
180 if (optval)
181 u->debug_level = atoi(optval);
182 else
183 u->debug_level++;
184 } else if (!strcasecmp(optname, "games") && optval) {
185 u->games = atoi(optval);
186 } else if (!strcasecmp(optname, "gamelen") && optval) {
187 u->gamelen = atoi(optval);
188 } else if (!strcasecmp(optname, "expand_p") && optval) {
189 u->expand_p = atoi(optval);
190 } else if (!strcasecmp(optname, "policy") && optval) {
191 char *policyarg = strchr(optval, '+');
192 if (policyarg)
193 *policyarg++ = 0;
194 if (!strcasecmp(optval, "ucb1")) {
195 u->policy = policy_ucb1_init(u, policyarg);
196 } else if (!strcasecmp(optval, "ucb1tuned")) {
197 u->policy = policy_ucb1tuned_init(u, policyarg);
199 } else if (!strcasecmp(optname, "pure")) {
200 u->mc.capture_rate = u->mc.local_rate = u->mc.cut_rate = 0;
201 } else if (!strcasecmp(optname, "capturerate") && optval) {
202 u->mc.capture_rate = atoi(optval);
203 } else if (!strcasecmp(optname, "atarirate") && optval) {
204 u->mc.atari_rate = atoi(optval);
205 } else if (!strcasecmp(optname, "localrate") && optval) {
206 u->mc.local_rate = atoi(optval);
207 } else if (!strcasecmp(optname, "cutrate") && optval) {
208 u->mc.cut_rate = atoi(optval);
209 } else {
210 fprintf(stderr, "uct: Invalid engine argument %s or missing value\n", optname);
215 u->resign_ratio = 0.2; /* Resign when most games are lost. */
216 u->loss_threshold = 0.95; /* Stop reading if after at least 500 playouts this is best value. */
217 u->mc.debug_level = u->debug_level;
218 u->policy = policy_ucb1_init(u, NULL);
220 return u;
224 struct engine *
225 engine_uct_init(char *arg)
227 struct uct *u = uct_state_init(arg);
228 struct engine *e = calloc(1, sizeof(struct engine));
229 e->name = "UCT Engine";
230 e->comment = "I'm playing UCT. When we both pass, I will consider all the stones on the board alive. If you are reading this, write 'yes'. Please bear with me at the game end, I need to fill the whole board; if you help me, we will both be happier. Filling the board will not lose points (NZ rules).";
231 e->genmove = uct_genmove;
232 e->data = u;
234 return e;