UCT pattern prior: Implement exceedingly naive pattern-based prior
[pachi/t.git] / patternprob.c
blob30a37784be04f4a01e605a7a62fbb52208f72941
1 #define DEBUG
2 #include <assert.h>
3 #include <ctype.h>
4 #include <stdio.h>
5 #include <stdlib.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "pattern.h"
10 #include "patternsp.h"
11 #include "patternprob.h"
14 /* We try to avoid needlessly reloading probability dictionary
15 * since it may take rather long time. */
16 static struct pattern_pdict *cached_dict;
18 struct pattern_pdict *
19 pattern_pdict_init(char *filename, struct pattern_config *pc)
21 if (cached_dict)
22 return cached_dict;
24 if (!filename)
25 filename = "patterns.prob";
26 FILE *f = fopen(filename, "r");
27 if (!f) {
28 if (DEBUGL(1))
29 fprintf(stderr, "No pattern probtable, will not use learned patterns.\n");
30 return NULL;
33 struct pattern_pdict *dict = calloc2(1, sizeof(*dict));
34 dict->pc = pc;
35 dict->table = calloc2(pc->spat_dict->nspatials + 1, sizeof(*dict->table));
37 char *sphcachehit = malloc(pc->spat_dict->nspatials);
38 hash_t (*sphcache)[PTH__ROTATIONS] = malloc(pc->spat_dict->nspatials * sizeof(sphcache[0]));
40 int i = 0;
41 char sbuf[1024];
42 while (fgets(sbuf, sizeof(sbuf), f)) {
43 struct pattern_prob *pb = calloc2(1, sizeof(*pb));
44 int c, o;
46 char *buf = sbuf;
47 if (buf[0] == '#') continue;
48 while (isspace(*buf)) buf++;
49 while (!isspace(*buf)) buf++; // we recompute the probability
50 while (isspace(*buf)) buf++;
51 c = strtol(buf, &buf, 10);
52 while (isspace(*buf)) buf++;
53 o = strtol(buf, &buf, 10);
54 pb->prob = (floating_t) c / o;
55 while (isspace(*buf)) buf++;
56 str2pattern(buf, &pb->p);
58 uint32_t spi = pattern2spatial(dict, &pb->p);
59 pb->next = dict->table[spi];
60 dict->table[spi] = pb;
62 /* We rehash spatials in the order of loaded patterns. This way
63 * we make sure that the most popular patterns will be hashed
64 * last and therefore take priority. */
65 if (!sphcachehit[spi]) {
66 sphcachehit[spi] = 1;
67 for (int r = 0; r < PTH__ROTATIONS; r++)
68 sphcache[spi][r] = spatial_hash(r, &pc->spat_dict->spatials[spi]);
70 for (int r = 0; r < PTH__ROTATIONS; r++)
71 spatial_dict_addh(pc->spat_dict, sphcache[spi][r], spi);
73 i++;
76 free(sphcache);
77 free(sphcachehit);
78 if (DEBUGL(3))
79 spatial_dict_hashstats(pc->spat_dict);
81 fclose(f);
82 if (DEBUGL(1))
83 fprintf(stderr, "Loaded %d pattern-probability pairs.\n", i);
84 cached_dict = dict;
85 return dict;
88 floating_t
89 pattern_rate_moves(struct pattern_config *pc, pattern_spec *ps, struct pattern_pdict *pd,
90 struct board *b, enum stone color,
91 struct pattern *pats, floating_t *probs)
93 floating_t total = 0;
94 for (int f = 0; f < b->flen; f++) {
95 probs[f] = NAN;
97 struct move mo = { .coord = b->f[f], .color = color };
98 if (is_pass(mo.coord))
99 continue;
100 if (!board_is_valid_move(b, &mo))
101 continue;
103 pattern_match(pc, *ps, &pats[f], b, &mo);
104 floating_t prob = pattern_prob(pd, &pats[f]);
105 if (!isnan(prob)) {
106 probs[f] = prob;
107 total += prob;
110 return total;