lots
[arrocco.git] / search.py
blob465b763144297440167b07772b9a87116507e764
1 # search.py
2 # 28 June 2007
4 """
5 This file is part of Arrocco, which is Copyright 2007 Thomas Plick
6 (tomplick 'at' gmail.com).
8 Arrocco is free software; you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation; either version 3 of the License, or
11 (at your option) any later version.
13 Arrocco is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
18 You should have received a copy of the GNU General Public License
19 along with this program. If not, see <http://www.gnu.org/licenses/>.
20 """
22 import time, ctypes, thread, Queue, threading
24 from lib import lib
25 from position import *
27 pvCache = {}
29 def updatePVCache(pv):
30 for (a, b) in zip(pv, pv[1:]):
31 pvCache[a.fen()] = b
34 def alphaBetaInC(pos, depth,
35 alpha, beta, stopper = None, childArray = None):
36 counter = ctypes.c_long(1)
37 if childArray is None:
38 childArray = (Position * (depth * 32))()
40 value = lib.alphaBeta(ctypes.byref(pos), childArray, depth,
41 alpha, beta, ctypes.byref(counter), ctypes.byref(stopper))
43 # get the PV...
44 pv = L = [pos]
45 node = pos
46 for i in range(depth):
47 try:
48 node = node.children()[pos.getBranch(i)]
49 L.append(node)
50 except: break
52 return dict(value = value, pv = pv,
53 nodeCount = counter.value)
55 def alphaBeta(pos, depth, alpha = -99999, beta = 99999, splits = 0,
56 stopper = None, top = False, parentProc = None):
57 a = time.time()
59 if top:
60 pos.offswitch.value = 0
61 stopper = pos.offswitch
63 if splits <= 0:
64 result = alphaBetaSequential(pos, depth, alpha, beta, stopper)
65 else:
66 result = alphaBetaParallel(pos, depth, alpha, beta, splits, stopper, parentProc)
68 if top:
69 pvCache.clear()
70 updatePVCache(result['pv'])
72 b = time.time()
73 result['time'] = int(b - a) + 1
74 return result
77 def alphaBetaSequential(pos, depth, alpha, beta, stopper):
78 if depth == 0 or not (pos.fen() in pvCache):
79 return alphaBetaInC(pos, depth, alpha, beta, stopper)
81 favorite = pvCache[pos.fen()]
82 children = pos.children()
83 children.remove(favorite)
84 children.insert(0, favorite)
86 bestPV = [pos]
87 nodeCount = 1
88 nextDepth = depth - 1
90 childArray = (Position * (32 * depth))()
92 for i, child in enumerate(children):
93 if i == 0:
94 result = alphaBetaSequential(child, nextDepth, -beta, -alpha,
95 stopper)
96 else:
97 result = alphaBetaInC(child, nextDepth, -beta, -alpha, stopper,
98 childArray = childArray)
100 value = -result['value']
101 nodeCount += result['nodeCount']
102 if value >= beta:
103 alpha = beta
104 break
105 elif value > alpha:
106 alpha = value
107 bestPV = [pos] + result['pv']
109 return dict(value = alpha, pv = bestPV, nodeCount = nodeCount)
112 # parallel stuff
114 import spawn
115 def alphaBetaParallel(pos, depth, alpha, beta, splits, stopper, parentProc = None):
116 kids = pos.children()
117 gstopper = ctypes.c_int(0)
118 if depth == 0 or len(kids) == 0:
119 return alphaBeta(pos, 0, alpha, beta, 0, gstopper)
121 running = []
122 left = len(kids)
123 results = Queue.Queue()
124 for k in kids: k.done = False
126 nodeCount = 1
127 bestPV = []
129 q = Queue.Queue()
130 if pos.fen() in pvCache:
131 favorite = pvCache[pos.fen()]
132 kids[kids.index(favorite)] = None
133 q.put(favorite)
134 at_a_time = 1
135 else:
136 at_a_time = 2
138 for k in kids:
139 if k is not None: q.put(k)
141 rfq = 0
142 i = 0
144 try:
145 while rfq < len(kids):
146 if len(running) < at_a_time:
147 try:
148 next = q.get_nowait().copy()
149 stopper = ctypes.c_int(0)
150 def calc():
151 ret = next, proc, alphaBeta(next, depth - 1, -beta, -alpha,
152 splits - at_a_time + 1, stopper, parentProc)
153 if stopper.value == 0:
154 results.put(ret)
155 return ret
156 def stop(): stopper.value = 1
158 if parentProc:
159 proc = parentProc.spawn(calc, stop)
160 else:
161 proc = spawn.spawn(calc, stop)
163 running.append(proc)
164 proc.position = next
165 proc.alpha = alpha
166 proc.start()
167 except Queue.Empty:
168 pass
170 try:
171 resChild, resProc, result = results.get(True, .001)
172 if hasattr(resProc, 'defunct'):
173 continue
174 try:
175 rfq += 1
176 at_a_time = 2
177 # print rfq
178 running.remove(resProc)
179 except ValueError: pass
180 except Queue.Empty:
181 continue
183 value = -result['value']
184 nodeCount += result['nodeCount']
186 for proc in running:
187 if proc.position == resChild:
188 proc.stop()
189 proc.defunct = True
190 running.remove(proc)
192 if value >= beta:
193 alpha = beta
194 break
195 elif value > alpha:
196 alpha = value
197 bestPV = [pos] + result['pv']
199 # evaluate other pos (if there is one) with this alpha
200 for proc in running:
201 """proc.stop()
202 proc.defunct = True
203 running.remove(proc)"""
205 # put that pos at the front of the queue
206 newQ = Queue.Queue()
207 while not q.empty(): newQ.put(q.get())
208 q.put(proc.position)
209 while not newQ.empty(): q.put(newQ.get())
210 except KeyboardInterrupt:
211 for kid in running: kid.stop()
212 raise
214 for kid in running: kid.stop()
215 return dict(value = alpha, nodeCount = nodeCount, pv = bestPV)