Split libmap_hash to libmap_group array, create libmap_group slots only to pre-simula...
[pachi.git] / tactics / goals.c
blob5c57d9274ec0b7dad67d8518224e582861ed6e0f
1 #include <assert.h>
2 #include <limits.h>
3 #include <stdio.h>
4 #include <stdlib.h>
6 #include "board.h"
7 #include "debug.h"
8 #include "libmap.h"
9 #include "move.h"
10 #include "tactics/goals.h"
11 #include "tactics/util.h"
14 struct libmap_config libmap_config = {
15 .pick_mode = LMP_THRESHOLD,
16 .pick_threshold = 0.7,
17 .pick_epsilon = 10,
19 .explore_p = 0.2,
20 .prior = { .value = 0.5, .playouts = 1 },
21 .tenuki_prior = { .value = 0.4, .playouts = 1 },
23 .mq_merge_groups = true,
24 .counterattack = LMC_DEFENSE | LMC_ATTACK | LMC_DEFENSE_ATTACK,
25 .eval = LME_LVALUE,
28 void
29 libmap_setup(char *arg)
31 if (!arg)
32 return;
34 char *optspec, *next = arg;
35 while (*next) {
36 optspec = next;
37 next += strcspn(next, ":");
38 if (*next) { *next++ = 0; } else { *next = 0; }
40 char *optname = optspec;
41 char *optval = strchr(optspec, '=');
42 if (optval) *optval++ = 0;
44 if (!strcasecmp(optname, "pick_mode") && optval) {
45 if (!strcasecmp(optval, "threshold")) {
46 libmap_config.pick_mode = LMP_THRESHOLD;
47 } else if (!strcasecmp(optval, "ucb")) {
48 libmap_config.pick_mode = LMP_UCB;
49 } else {
50 fprintf(stderr, "Invalid libmap:pick_mode value %s\n", optval);
51 exit(1);
54 } else if (!strcasecmp(optname, "pick_threshold") && optval) {
55 libmap_config.pick_threshold = atof(optval);
56 } else if (!strcasecmp(optname, "pick_epsilon") && optval) {
57 libmap_config.pick_epsilon = atoi(optval);
58 } else if (!strcasecmp(optname, "avoid_bad")) {
59 libmap_config.avoid_bad = !optval || atoi(optval);
61 } else if (!strcasecmp(optname, "explore_p") && optval) {
62 libmap_config.explore_p = atof(optval);
63 } else if (!strcasecmp(optname, "prior") && optval && strchr(optval, 'x')) {
64 libmap_config.prior.value = atof(optval);
65 optval += strcspn(optval, "x") + 1;
66 libmap_config.prior.playouts = atoi(optval);
67 } else if (!strcasecmp(optname, "tenuki_prior") && optval && strchr(optval, 'x')) {
68 libmap_config.tenuki_prior.value = atof(optval);
69 optval += strcspn(optval, "x") + 1;
70 libmap_config.tenuki_prior.playouts = atoi(optval);
72 } else if (!strcasecmp(optname, "mq_merge_groups")) {
73 libmap_config.mq_merge_groups = !optval || atoi(optval);
74 } else if (!strcasecmp(optname, "counterattack") && optval) {
75 /* Combination of letters d, a, x (both), these kinds
76 * of hashes are going to be recorded. */
77 /* Note that using multiple letters makes no sense
78 * if mq_merge_groups is set. */
79 libmap_config.counterattack = 0;
80 if (strchr(optval, 'd'))
81 libmap_config.counterattack |= LMC_DEFENSE;
82 if (strchr(optval, 'a'))
83 libmap_config.counterattack |= LMC_ATTACK;
84 if (strchr(optval, 'x'))
85 libmap_config.counterattack |= LMC_DEFENSE_ATTACK;
86 } else if (!strcasecmp(optname, "eval") && optval) {
87 if (!strcasecmp(optval, "local")) {
88 libmap_config.eval = LME_LOCAL;
89 } else if (!strcasecmp(optval, "lvalue")) {
90 libmap_config.eval = LME_LVALUE;
91 } else if (!strcasecmp(optval, "global")) {
92 libmap_config.eval = LME_GLOBAL;
93 } else {
94 fprintf(stderr, "Invalid libmap:eval value %s\n", optval);
95 exit(1);
97 } else if (!strcasecmp(optname, "tenuki")) {
98 libmap_config.tenuki = !optval || atoi(optval);
99 } else {
100 fprintf(stderr, "Invalid libmap argument %s or missing value\n", optname);
101 exit(1);
107 struct libmap_hash *
108 libmap_init(struct board *b)
110 struct libmap_hash *lm = calloc2(1, sizeof(*lm));
111 lm->b = b;
112 b->libmap = lm;
113 lm->refcount = 1;
115 lm->groups[0] = calloc2(board_size2(b), sizeof(*lm->groups[0]));
116 lm->groups[1] = calloc2(board_size2(b), sizeof(*lm->groups[1]));
117 for (group_t g = 1; g < board_size2(b); g++) // foreach_group
118 if (group_at(b, g) == g)
119 libmap_group_init(lm, b, g, board_at(b, g));
121 return lm;
124 void
125 libmap_put(struct libmap_hash *lm)
127 if (__sync_sub_and_fetch(&lm->refcount, 1) > 0)
128 return;
129 for (group_t g = 0; g < board_size2(lm->b); g++) {
130 if (lm->groups[0][g])
131 free(lm->groups[0][g]);
132 if (lm->groups[1][g])
133 free(lm->groups[1][g]);
135 free(lm->groups[0]);
136 free(lm->groups[1]);
137 free(lm);
140 void
141 libmap_group_init(struct libmap_hash *lm, struct board *b, group_t g, enum stone color)
143 assert(color == S_BLACK || color == S_WHITE);
144 if (lm->groups[color - 1][g])
145 return;
147 struct libmap_group *lmg = calloc2(1, sizeof(*lmg));
148 lmg->group = g;
149 lmg->color = color;
150 lm->groups[color - 1][g] = lmg;
154 void
155 libmap_queue_process(struct board *b, enum stone winner)
157 struct libmap_mq *lmqueue = b->lmqueue;
158 assert(lmqueue->mq.moves <= MQL);
159 for (unsigned int i = 0; i < lmqueue->mq.moves; i++) {
160 struct libmap_move_groupinfo *gi = &lmqueue->gi[i];
161 struct move m = { .coord = lmqueue->mq.move[i], .color = lmqueue->color[i] };
162 struct libmap_group *lg = b->libmap->groups[gi->color - 1][gi->group];
163 if (!lg) continue;
164 floating_t val;
165 if (libmap_config.eval == LME_LOCAL || libmap_config.eval == LME_LVALUE) {
166 val = board_local_value(libmap_config.eval == LME_LVALUE, b, gi->group, gi->goal);
168 } else { assert(libmap_config.eval == LME_GLOBAL);
169 val = winner == gi->goal ? 1.0 : 0.0;
171 libmap_add_result(b->libmap, lg, gi->hash, m, val, 1);
173 lmqueue->mq.moves = 0;
176 void
177 libmap_add_result(struct libmap_hash *lm, struct libmap_group *lg, hash_t hash, struct move move,
178 floating_t result, int playouts)
180 /* If hash line is full, replacement strategy is naive - pick the
181 * move with minimum move[0].stats.playouts; resolve each tie
182 * randomly. */
183 unsigned int min_playouts = INT_MAX; hash_t min_hash = hash;
184 hash_t ih;
185 for (ih = hash; lg->hash[ih & libmap_hash_mask].hash != hash; ih++) {
186 // fprintf(stderr, "%"PRIhash": check %"PRIhash" (%d)\n", hash & libmap_hash_mask, ih & libmap_hash_mask, lg->hash[ih & libmap_hash_mask].moves);
187 if (lg->hash[ih & libmap_hash_mask].moves == 0) {
188 lg->hash[ih & libmap_hash_mask].hash = hash;
189 break;
191 if (ih >= hash + libmap_hash_maxline) {
192 /* Snatch the least used bucket. */
193 ih = min_hash;
194 // fprintf(stderr, "clear %"PRIhash"\n", ih & libmap_hash_mask);
195 memset(&lg->hash[ih & libmap_hash_mask], 0, sizeof(lg->hash[0]));
196 lg->hash[ih & libmap_hash_mask].hash = hash;
197 break;
200 /* Keep track of least used bucket. */
201 assert(lg->hash[ih & libmap_hash_mask].moves > 0);
202 unsigned int hp = lg->hash[ih & libmap_hash_mask].move[0].stats.playouts;
203 if (hp < min_playouts || (hp == min_playouts && fast_random(2))) {
204 min_playouts = hp;
205 min_hash = ih;
209 // fprintf(stderr, "%"PRIhash": use %"PRIhash" (%d)\n", hash & libmap_hash_mask, ih & libmap_hash_mask, lg->hash[ih & libmap_hash_mask].moves);
210 struct libmap_context *lc = &lg->hash[ih & libmap_hash_mask];
211 lc->visits++;
213 for (int i = 0; i < lc->moves; i++) {
214 if (lc->move[i].move.coord == move.coord
215 && lc->move[i].move.color == move.color) {
216 stats_add_result(&lc->move[i].stats, result, playouts);
217 return;
221 int moves = lc->moves; // to preserve atomicity
222 if (moves >= GROUP_REFILL_LIBS) {
223 if (DEBUGL(5))
224 fprintf(stderr, "(%s) too many libs\n", coord2sstr(move.coord, lm->b));
225 return;
227 lc->move[moves].move = move;
228 stats_add_result(&lc->move[moves].stats, result, playouts);
229 lc->moves = ++moves;
232 struct move_stats
233 libmap_board_move_stats(struct libmap_hash *lm, struct board *b, struct move move)
235 struct move_stats tot = { .playouts = 0, .value = 0 };
236 if (is_pass(move.coord))
237 return tot;
238 assert(board_at(b, move.coord) != S_OFFBOARD);
240 neighboring_groups_list(b, board_at(b, c) == S_BLACK || board_at(b, c) == S_WHITE,
241 move.coord, groups, groups_n, groupsbycolor_xxunused);
242 for (int i = 0; i < groups_n; i++) {
243 struct libmap_group *lg = lm->groups[board_at(b, groups[i]) - 1][groups[i]];
244 if (!lg) continue;
245 hash_t hash = group_to_libmap(b, groups[i]);
246 struct move_stats *lp = libmap_move_stats(b->libmap, lg, hash, move);
247 if (!lp) continue;
248 stats_merge(&tot, lp);
251 return tot;