tree_merge(): Take amaf_prior argument, properly compute merged node values
[pachi.git] / uct / tree.h
blobd08ecafaa84206c58102b66c305c82455031e394
1 #ifndef ZZGO_UCT_TREE_H
2 #define ZZGO_UCT_TREE_H
4 #include <stdbool.h>
5 #include "move.h"
7 struct board;
8 struct uct;
11 * +------+
12 * | node |
13 * +------+
14 * / <- parent
15 * +------+ v- sibling +------+
16 * | node | ------------ | node |
17 * +------+ +------+
18 * | <- children |
19 * +------+ +------+ +------+ +------+
20 * | node | - | node | | node | - | node |
21 * +------+ +------+ +------+ +------+
24 struct move_stats {
25 int playouts; // # of playouts coming through this node
26 int wins; // # of BLACK wins coming through this node
27 float value; // wins/playouts
30 struct tree_node {
31 hash_t hash;
32 struct tree_node *parent, *sibling, *children;
34 /*** From here on, struct is saved/loaded from opening book */
36 int depth; // just for statistics
38 coord_t coord;
40 struct move_stats u;
41 struct move_stats prior;
42 /* XXX: Should be way for policies to add their own stats */
43 struct move_stats amaf;
44 /* Stats before starting playout; used for multi-thread normalization. */
45 struct move_stats pu, pamaf;
46 int hints;
49 struct tree {
50 struct board *board;
51 struct tree_node *root;
52 struct board_symmetry root_symmetry;
53 enum stone root_color;
55 // Statistics
56 int max_depth;
59 struct tree *tree_init(struct board *board, enum stone color);
60 void tree_done(struct tree *tree);
61 void tree_dump(struct tree *tree, int thres);
62 void tree_save(struct tree *tree, struct board *b, int thres);
63 void tree_load(struct tree *tree, struct board *b);
64 struct tree *tree_copy(struct tree *tree);
65 void tree_merge(struct tree *dest, struct tree *src, bool amaf_prior);
66 void tree_normalize(struct tree *tree, int factor);
68 void tree_expand_node(struct tree *tree, struct tree_node *node, struct board *b, enum stone color, int radar, struct uct *u, int parity);
69 void tree_delete_node(struct tree *tree, struct tree_node *node);
70 void tree_promote_node(struct tree *tree, struct tree_node *node);
71 bool tree_promote_at(struct tree *tree, struct board *b, coord_t c);
73 static bool tree_leaf_node(struct tree_node *node);
74 static void tree_update_node_value(struct tree_node *node, bool rave_prior);
75 static void tree_update_node_rvalue(struct tree_node *node, bool rave_prior);
77 /* Get black parity from parity within the tree. */
78 #define tree_parity(tree, parity) \
79 (tree->root_color == S_WHITE ? (parity) : -1 * (parity))
81 /* Get a value to maximize; @parity is parity within the tree. */
82 #define tree_node_get_value(tree, node, type, parity) \
83 (tree_parity(tree, parity) > 0 ? node->type.value : 1 - node->type.value)
84 #define tree_node_get_wins(tree, node, type, parity) \
85 (tree_parity(tree, parity) > 0 ? node->type.wins : node->type.playouts - node->type.wins)
87 static inline bool
88 tree_leaf_node(struct tree_node *node)
90 return !(node->children);
93 static inline void
94 tree_update_node_value(struct tree_node *node, bool rave_prior)
96 node->u.value = (float)(node->u.wins + (!rave_prior ? node->prior.wins : 0))
97 / (node->u.playouts + (!rave_prior ? node->prior.playouts : 0));
100 static inline void
101 tree_update_node_rvalue(struct tree_node *node, bool rave_prior)
103 node->amaf.value = (float)(node->amaf.wins + (rave_prior ? node->prior.wins : 0))
104 / (node->amaf.playouts + (rave_prior ? node->prior.playouts : 0));
107 #endif