pattern3_record(): Take pattern3s instead of direct table pointer
[pachi/derm.git] / stats.h
blob7dcae03ca7e22c90ec42cc1be3fe011a54817aac
1 #ifndef ZZGO_STATS_H
2 #define ZZGO_STATS_H
4 #include <math.h>
6 /* Move statistics; we track how good value each move has. */
7 /* These operations are supposed to be atomic - reasonably
8 * safe to perform by multiple threads at once on the same stats.
9 * What this means in practice is that perhaps the value will get
10 * slightly wrong, but not drastically corrupted. */
12 struct move_stats {
13 int playouts; // # of playouts
14 float value; // BLACK wins/playouts
17 /* Add a result to the stats. */
18 static void stats_add_result(struct move_stats *s, float result, int playouts);
20 /* Remove a result from the stats. */
21 static void stats_rm_result(struct move_stats *s, float result, int playouts);
23 /* Merge two stats together. THIS IS NOT ATOMIC! */
24 static void stats_merge(struct move_stats *dest, struct move_stats *src);
26 /* Reverse stats parity. */
27 static void stats_reverse_parity(struct move_stats *s);
29 /* Temper value based on parent value in specified way - the value should be
30 * usable standalone then, representing an improvement against parent value. */
31 static float stats_temper_value(float val, float pval, int mode);
34 /* We actually do the atomicity in a pretty hackish way - we simply
35 * rely on the fact that int,float operations should be atomic with
36 * reasonable compilers (gcc) on reasonable architectures (i386,
37 * x86_64). */
38 /* There is a write order dependency - when we bump the playouts,
39 * our value must be already correct, otherwise the node will receive
40 * invalid evaluation if that's made in parallel, esp. when
41 * current s->playouts is zero. */
43 static inline void
44 stats_add_result(struct move_stats *s, float result, int playouts)
46 int s_playouts = s->playouts;
47 float s_value = s->value;
48 /* Force the load, another thread can work on the
49 * values in parallel. */
50 __sync_synchronize(); /* full memory barrier */
52 s_playouts += playouts;
53 s_value += (result - s_value) * playouts / s_playouts;
55 /* We rely on the fact that these two assignments are atomic. */
56 s->value = s_value;
57 __sync_synchronize(); /* full memory barrier */
58 s->playouts = s_playouts;
61 static inline void
62 stats_rm_result(struct move_stats *s, float result, int playouts)
64 if (s->playouts > playouts) {
65 int s_playouts = s->playouts;
66 float s_value = s->value;
67 /* Force the load, another thread can work on the
68 * values in parallel. */
69 __sync_synchronize(); /* full memory barrier */
71 s_playouts -= playouts;
72 s_value += (s_value - result) * playouts / s_playouts;
74 /* We rely on the fact that these two assignments are atomic. */
75 s->value = s_value;
76 __sync_synchronize(); /* full memory barrier */
77 s->playouts = s_playouts;
79 } else {
80 /* We don't touch the value, since in parallel, another
81 * thread can be adding a result, thus raising the
82 * playouts count after we zero the value. Instead,
83 * leaving the value as is with zero playouts should
84 * not break anything. */
85 s->playouts = 0;
89 static inline void
90 stats_merge(struct move_stats *dest, struct move_stats *src)
92 /* In a sense, this is non-atomic version of stats_add_result(). */
93 if (src->playouts) {
94 dest->playouts += src->playouts;
95 dest->value += (src->value - dest->value) * src->playouts / dest->playouts;
99 static inline void
100 stats_reverse_parity(struct move_stats *s)
102 s->value = 1 - s->value;
105 static inline float
106 stats_temper_value(float val, float pval, int mode)
108 float tval = val;
109 float expd = val - pval;
110 switch (mode) {
111 case 1: /* no tempering */
112 tval = val;
113 break;
114 case 2: /* 0.5+(result-expected)/2 */
115 tval = 0.5 + expd / 2;
116 break;
117 case 3: { /* 0.5+bzz((result-expected)^2) */
118 float ntval = expd * expd;
119 /* val = 1 pval = 0.8 : ntval = 0.04 tval = 0.54
120 * val = 1 pval = 0.6 : ntval = 0.16 tval = 0.66
121 * val = 1 pval = 0.3 : ntval = 0.49 tval = 0.99
122 * val = 1 pval = 0.1 : ntval = 0.81 tval = 1.31 */
123 tval = 0.5 + (val > 0.5 ? 1 : -1) * ntval;
124 break; }
125 case 4: /* 0.5+sqrt(result-expected)/2 */
126 tval = 0.5 + copysignf(sqrt(fabs(expd)), expd) / 2;
127 break;
128 default: assert(0); break;
130 return tval;
133 #endif