First commit
[arrocco.git] / search.py
blob4d77f15d10c108cc7624c0fcc4462524448b3acd
1 # search.py
2 # 28 June 2007
4 import time, ctypes, thread, Queue, threading
6 from lib import lib
7 from position import *
9 pvCache = {}
11 def updatePVCache(pv):
12 for (a, b) in zip(pv, pv[1:]):
13 pvCache[a.fen()] = b
16 def alphaBetaInC(pos, depth,
17 alpha, beta, stopper = None, childArray = None):
18 counter = ctypes.c_long(1)
19 if childArray is None:
20 childArray = (Position * (depth * 32))()
22 value = lib.alphaBeta(ctypes.byref(pos), childArray, depth,
23 alpha, beta, ctypes.byref(counter), ctypes.byref(stopper))
25 # get the PV...
26 pv = L = [pos]
27 node = pos
28 for i in range(depth):
29 try:
30 node = node.children()[pos.getBranch(i)]
31 L.append(node)
32 except: break
34 return dict(value = value, pv = pv,
35 nodeCount = counter.value)
37 def alphaBeta(pos, depth, alpha = -99999, beta = 99999, splits = 0,
38 stopper = None, top = False, parentProc = None):
39 a = time.time()
41 if top:
42 pos.offswitch.value = 0
43 stopper = pos.offswitch
45 if splits <= 0:
46 result = alphaBetaSequential(pos, depth, alpha, beta, stopper)
47 else:
48 result = alphaBetaParallel(pos, depth, alpha, beta, splits, stopper, parentProc)
50 if top:
51 pvCache.clear()
52 updatePVCache(result['pv'])
54 b = time.time()
55 result['time'] = int(b - a) + 1
56 return result
59 def alphaBetaSequential(pos, depth, alpha, beta, stopper):
60 if depth == 0 or not (pos.fen() in pvCache):
61 return alphaBetaInC(pos, depth, alpha, beta, stopper)
63 favorite = pvCache[pos.fen()]
64 children = pos.children()
65 children.remove(favorite)
66 children.insert(0, favorite)
68 bestPV = [pos]
69 nodeCount = 1
70 nextDepth = depth - 1
72 childArray = (Position * (32 * depth))()
74 for i, child in enumerate(children):
75 if i == 0:
76 result = alphaBetaSequential(child, nextDepth, -beta, -alpha,
77 stopper)
78 else:
79 result = alphaBetaInC(child, nextDepth, -beta, -alpha, stopper,
80 childArray = childArray)
82 value = -result['value']
83 nodeCount += result['nodeCount']
84 if value >= beta:
85 alpha = beta
86 break
87 elif value > alpha:
88 alpha = value
89 bestPV = [pos] + result['pv']
91 return dict(value = alpha, pv = bestPV, nodeCount = nodeCount)
94 # parallel stuff
96 import spawn
97 def alphaBetaParallel(pos, depth, alpha, beta, splits, stopper, parentProc = None):
98 kids = pos.children()
99 gstopper = ctypes.c_int(0)
100 if depth == 0 or len(kids) == 0:
101 return alphaBeta(pos, 0, alpha, beta, 0, gstopper)
103 running = []
104 left = len(kids)
105 results = Queue.Queue()
106 for k in kids: k.done = False
108 nodeCount = 1
109 bestPV = []
111 q = Queue.Queue()
112 if pos.fen() in pvCache:
113 favorite = pvCache[pos.fen()]
114 kids[kids.index(favorite)] = None
115 q.put(favorite)
116 at_a_time = 1
117 else:
118 at_a_time = 2
120 for k in kids:
121 if k is not None: q.put(k)
123 rfq = 0
124 i = 0
126 try:
127 while rfq < len(kids):
128 if len(running) < at_a_time:
129 try:
130 next = q.get_nowait().copy()
131 stopper = ctypes.c_int(0)
132 def calc():
133 ret = next, proc, alphaBeta(next, depth - 1, -beta, -alpha,
134 splits - at_a_time + 1, stopper, parentProc)
135 if stopper.value == 0:
136 results.put(ret)
137 return ret
138 def stop(): stopper.value = 1
140 if parentProc:
141 proc = parentProc.spawn(calc, stop)
142 else:
143 proc = spawn.spawn(calc, stop)
145 running.append(proc)
146 proc.position = next
147 proc.alpha = alpha
148 proc.start()
149 except Queue.Empty:
150 pass
152 try:
153 resChild, resProc, result = results.get(True, .001)
154 if hasattr(resProc, 'defunct'):
155 continue
156 try:
157 rfq += 1
158 at_a_time = 2
159 # print rfq
160 running.remove(resProc)
161 except ValueError: pass
162 except Queue.Empty:
163 continue
165 value = -result['value']
166 nodeCount += result['nodeCount']
168 for proc in running:
169 if proc.position == resChild:
170 proc.stop()
171 proc.defunct = True
172 running.remove(proc)
174 if value >= beta:
175 alpha = beta
176 break
177 elif value > alpha:
178 alpha = value
179 bestPV = [pos] + result['pv']
181 # evaluate other pos (if there is one) with this alpha
182 for proc in running:
183 """proc.stop()
184 proc.defunct = True
185 running.remove(proc)"""
187 # put that pos at the front of the queue
188 newQ = Queue.Queue()
189 while not q.empty(): newQ.put(q.get())
190 q.put(proc.position)
191 while not newQ.empty(): q.put(newQ.get())
192 except KeyboardInterrupt:
193 for kid in running: kid.stop()
194 raise
196 for kid in running: kid.stop()
197 return dict(value = alpha, nodeCount = nodeCount, pv = bestPV)