makefile
[arrocco.git] / search.py
blob50392456ca7752a231a6be38272dd337face9518
1 # search.py
2 # 28 June 2007
4 """
5 This file is part of Arrocco, which is Copyright 2007 Thomas Plick
6 (tomplick 'at' gmail.com).
8 Arrocco is free software; you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation; either version 3 of the License, or
11 (at your option) any later version.
13 Arrocco is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
18 You should have received a copy of the GNU General Public License
19 along with this program. If not, see <http://www.gnu.org/licenses/>.
20 """
22 import time, ctypes, thread, Queue, threading
24 from lib import lib
25 from position import *
27 pvCache = {}
29 def updatePVCache(pv):
30 for (a, b) in zip(pv, pv[1:]):
31 pvCache[a.fen()] = b
34 def alphaBetaInC(pos, depth,
35 alpha, beta, stopper = None, childArray = None):
36 counter = ctypes.c_long(1)
37 if childArray is None:
38 childArray = (Position * (depth * 32))()
40 value = lib.alphaBeta(ctypes.byref(pos), childArray, depth,
41 alpha, beta, ctypes.byref(counter), ctypes.byref(stopper))
43 # get the PV...
44 pv = L = [pos]
45 node = pos
46 for i in range(depth):
47 try:
48 node = node.children()[pos.getBranch(i)]
49 L.append(node)
50 except: break
52 return dict(value = value, pv = pv,
53 nodeCount = counter.value)
55 # XXX Fold into parallel function
56 def alphaBeta(pos, depth, alpha = -99999, beta = 99999, splits = 0,
57 stopper = None, top = False, parentProc = None):
58 a = time.time()
60 pos.offswitch.value = 0
61 result = alphaBetaParallel(pos, depth, alpha, beta, splits,
62 pos.offswitch, parentProc)
63 pvCache.clear()
64 updatePVCache(result['pv'])
66 b = time.time()
67 result['time'] = int(b - a) + 1
68 return result
71 # parallel stuff
73 import spawn
74 def alphaBetaParallel(pos, depth, alpha, beta, splits, stopper, parentProc = None):
75 kids = pos.children()
76 if depth == 0 or len(kids) == 0 or stopper.value:
77 return alphaBetaInC(pos, 0, alpha, beta, stopper)
79 running = []
80 results = Queue.Queue()
81 for k in kids: k.done = False
83 nodeCount = 1
84 bestPV = []
86 q = Queue.Queue()
87 if pos.fen() in pvCache:
88 favorite = pvCache[pos.fen()]
89 kids[kids.index(favorite)] = None
90 q.put(favorite)
91 at_a_time = 1
92 else:
93 at_a_time = 2
94 if splits <= 0:
95 return alphaBetaInC(pos, depth, alpha, beta, stopper)
97 for k in kids:
98 if k is not None: q.put(k)
100 rfq = 0
101 try:
102 while rfq < len(kids):
103 if len(running) < at_a_time:
104 try:
105 next = q.get_nowait().copy()
106 stopper = ctypes.c_int(0)
107 def calc():
108 ret = next, proc, alphaBeta(next, depth - 1, -beta, -alpha,
109 splits - at_a_time + 1, stopper, parentProc)
110 if stopper.value == 0: results.put(ret)
111 def stop(): stopper.value = 1
113 if not parentProc:
114 parentProc = spawn # module spawn
115 proc = parentProc.spawn(calc, stop)
117 running.append(proc)
118 proc.position, proc.alpha = next, alpha
119 proc.start()
120 anotherOneStarted()
121 except Queue.Empty:
122 pass
124 try:
125 resChild, resProc, result = results.get(True, .05)
126 if hasattr(resProc, 'defunct'):
127 continue
128 try:
129 rfq += 1
130 if splits > 0: at_a_time = 2
131 running.remove(resProc)
132 except ValueError: pass
133 except Queue.Empty:
134 continue
136 value = -result['value']
137 nodeCount += result['nodeCount']
139 for proc in running:
140 if proc.position == resChild:
141 proc.stop(); proc.defunct = True
142 running.remove(proc)
144 if value >= beta:
145 alpha = beta
146 break
147 elif value > alpha:
148 alpha = value
149 bestPV = [pos] + result['pv']
151 # evaluate other pos (if there is one) with this alpha
152 for proc in running:
153 # put that pos at the front of the queue
154 newQ = Queue.Queue()
155 while not q.empty(): newQ.put(q.get())
156 q.put(proc.position)
157 while not newQ.empty(): q.put(newQ.get())
158 except KeyboardInterrupt:
159 for kid in running: kid.stop()
160 raise
162 for kid in running: kid.stop()
163 return dict(value = alpha, nodeCount = nodeCount, pv = bestPV)
165 threadsStarted = 0
166 def anotherOneStarted():
167 global threadsStarted
168 threadsStarted += 1
169 if threadsStarted % 1000 == 0:
170 print "********* Started %s threads" % threadsStarted