some late comments :)
[gostyle.git] / pachi.py
blob8f4d64cdc7607cb9c22e827c7238cf74ac42b92c
1 import logging
2 import subprocess
3 from subprocess import PIPE
5 import os
6 from os import remove
7 from os.path import abspath
9 import sys
10 import shutil
11 import re
12 from collections import namedtuple
14 from utils import utils, misc, db_cache
15 from utils.db_cache import declare_pure_function, cache_result
16 from utils.colors import PLAYER_COLOR_BLACK, PLAYER_COLOR_WHITE
17 from utils.godb_models import ProcessingError
18 from result_file import ResultFile
19 import result_file
21 from config import PACHI_DIR
23 PACHI_SPATIAL_DICT = os.path.join(PACHI_DIR, 'patterns.spat')
26 """A wrapper allowing more comfy use of the pachi go engine."""
28 class Pattern:
29 def __init__(self, pattern=None, fpairs=None):
30 if pattern != None:
31 match = re.match('^\((.*)\) *$', pattern)
32 if not match:
33 raise RuntimeError("Pattern format wrong: '%s'"%pattern)
35 # (capture:104 border:6 atari:0 atari:0 cont:1 s:2620)
36 pattern = match.group(1)
38 self.fpairs = []
39 for featpair in pattern.split():
40 feat, payload = featpair.split(':')
41 self.fpairs.append((feat, int(payload)))
42 elif fpairs != None:
43 self.fpairs = fpairs
44 else:
45 raise RuntimeError("Pattern unspecified...")
47 def reduce(self, filterfc):
48 fpairs = [ (f, p) for f, p in self if filterfc(f, p) ]
49 return Pattern(fpairs=fpairs)
51 def iter_feature_payloads(self, feature):
52 for f, p in self:
53 if f == feature:
54 yield p
56 def first_payload(self, feature):
57 return self.iter_feature_payloads(feature).next()
59 def has_feature(self, feature):
60 for f, p in self:
61 if f == feature:
62 return True
63 return False
65 def __iter__(self):
66 return iter(self.fpairs)
68 def __str__(self):
69 return "(%s)"%( ' '.join( "%s:%s"%(feat, payload) for feat, payload in self ) )
71 class IllegalMove(Exception):
72 pass
74 @cache_result
75 @declare_pure_function
76 def generate_spatial_dictionary(game_list, spatmin=4, patargs='', check_size=329):
77 """
78 Generates pachi spatial dictionary from games in the @gamelist.
80 @check_size specifies min spatial dict size, if the filesize is below, raise runtime err.
81 Set this to 0 to disable the check. (328 is the size of empty spatial dict header)
82 """
83 logging.info("Generating spatial dictionary from %s"%(repr(game_list)))
85 # pachi does not handle larger number of handicap stones than 9
86 without_large_handi = filter( lambda g : int(g.sgf_header.get('HA',0)) <= 9, game_list.games )
87 l_old, l_new = len(game_list.games), len(without_large_handi)
88 if l_old != l_new:
89 logging.warn("The spatial dictionary list contains %d games with # of handicap stones >= 10. Skipping those."%(
90 l_old - l_new,))
92 games = '\n'.join([ abspath(game.sgf_file) for game in without_large_handi ])
94 spatial_dict = result_file.get_output_resultfile('.spat')
95 assert not spatial_dict.exists()
97 script="""
98 cd %s
99 SPATMIN='%s' SPATIAL_DICT_FILE='%s' PATARGS='%s' tools/pattern_spatial_gen.sh -"""%(
100 PACHI_DIR, spatmin, abspath(spatial_dict.filename), patargs)
102 #with open("tmp_script", 'w') as tmp:
103 # tmp.write(script)
105 p = subprocess.Popen(script, shell=True, stdin=PIPE)
106 o = p.communicate(input=games.encode('utf-8'))
107 #if stderr:
108 # logging.warn("subprocess pattern_spatial_gen stderr:\n%s"%(stderr,))
109 if p.returncode:
110 raise RuntimeError("Child process `pachi/tools/pattern_spatial_gen` failed, exitcode %d."%(p.returncode,))
111 if check_size and os.stat(spatial_dict.filename).st_size <= check_size:
112 raise RuntimeError("Spatial dict is empty. Probably an uncaught error in subprocess.")
114 logging.info("Returning spatial dictionary %s"%(repr(spatial_dict)))
115 return spatial_dict
118 @cache_result
119 @declare_pure_function
120 def scan_raw_patterns(game, spatial_dict=None, patargs='', skip_empty=True):
122 For a @game, returns list of pairs (player_color, pattern) for each move.
123 The pachi should be compiled to output all the features.
125 if spatial_dict == None:
126 if 'xspat=0' not in patargs.split(','):
127 raise RuntimeError("Spatial dict not specified, though the spatial features are not turned off.")
128 spatial_str=""
129 else:
130 assert spatial_dict.exists(warn=True)
131 spatial_str="spatial_dict_filename=%s"%(abspath(spatial_dict.filename))
133 ## TODO
134 ## pachi has to have some patterns.spat even if the xspat=0
135 ## otw segfault, thought it does not use it...
137 gtpscript="""
138 cd %s
140 ./tools/sgf2gtp.py --stdout '%s'
141 """%(PACHI_DIR, abspath(game.sgf_file) )
142 gtpstream = utils.check_output(gtpscript, shell=True)
144 script = """
145 cd %s
146 ./pachi -d 0 -e patternscan '%s'
147 """%( PACHI_DIR, ','.join(misc.filter_null([spatial_str, patargs])) )
149 p = subprocess.Popen(script, shell=True, stdout=PIPE, stdin=PIPE, stderr=PIPE)
151 pats, stderr = p.communicate(input=gtpstream)
152 if stderr:
153 logging.warn("subprocess pachi:\n\tSCRIPT:\n%s\n\tSTDERR\n%s"%(script, stderr))
155 if p.returncode:
156 raise RuntimeError("Child process `pachi` failed, exitcode %d."%(p.returncode,))
158 lg = filter( lambda x : x, gtpstream.split('\n'))
159 lp = pats.split('\n')
161 # ? illegal move
162 wrong = filter( lambda x: re.search('^\? ',x), lp)
163 if wrong:
164 raise ProcessingError("Illegal move")
165 #raise IllegalMove() #"In game %s"%game)
167 # filter only lines beginning with =
168 lp = filter( lambda x: re.search('^= ',x), lp)
169 # remove '= ' from beginning
170 lp = map( lambda x: re.sub('^= ', '', x), lp)
172 # the command list and the pattern list should be aligned
173 # - each gtp command emits one line of patterns from pachi
174 assert len(lg) == len(lp)
175 gtp_pat = zip(lg, lp)
177 # keep pairs that contain something else than space in pattern
178 # - discards boardsize, handi, komi, ... that emit nothing ('= ')
179 gtp_pat = filter( lambda t: re.search('\S', t[1]), gtp_pat)
181 # filter out other gtp commands than play
182 # - discards e.g. 'fixed_handicap' command and the resulting positions
183 # of handicap stones
184 gtp_pat = filter( lambda t: re.search('^play', t[0]), gtp_pat)
186 # remove empty [()]
187 if skip_empty:
188 gtp_pat = filter( lambda (gtp, pat) : len(pat) != 4, gtp_pat)
190 # remove brackets enclosing features
191 # [(s:99 atariescape:8)]
192 # =>
193 # (s:99 atariescape:8)
194 def remover((gtp, pat)):
195 assert pat[0] == '['
196 assert pat[-1] == ']'
197 return (gtp, pat[1:-1])
198 gtp_pat = map(remover, gtp_pat)
200 return [ ( PLAYER_COLOR_WHITE if gtp[5] == 'W' else PLAYER_COLOR_BLACK,
201 Pattern(pat))
202 for gtp, pat in gtp_pat ]
204 if __name__ == '__main__':
205 #import logging
206 #logger = logging.getLogger()
207 #logger.setLevel(logging.INFO)
208 #db_cache.init_cache(filename=':memory:')
210 from utils.godb_models import Game, GameList, OneSideList, PLAYER_COLOR_BLACK, PLAYER_COLOR_WHITE
211 from utils.godb_session import godb_session_maker
213 s = godb_session_maker(filename=':memory:')
215 game = s.godb_sgf_to_game('./TEST_FILES/test_capture.sgf')
217 pats = scan_raw_patterns(game, patargs='xspat=0')
218 for c, p in pats:
219 print c, list(p)