libgen
[arrocco.git] / position.py
blobe768583500561f180a6019c88c2f9e67285571dd
1 # position.py
2 # 1 Mar 2007
4 """
5 This file is part of Arrocco, which is Copyright 2007 Thomas Plick
6 (tomplick 'at' gmail.com).
8 Arrocco is free software; you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation; either version 3 of the License, or
11 (at your option) any later version.
13 Arrocco is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
18 You should have received a copy of the GNU General Public License
19 along with this program. If not, see <http://www.gnu.org/licenses/>.
20 """
22 import ctypes, time, threading, thread, copy, Queue
24 from lib import lib
25 SIZE_OF_POSITION = lib.sizeOfPosition()
26 c_int = ctypes.c_int
27 c_short = ctypes.c_short
28 int64 = ctypes.c_ulonglong
30 charForPiece = " PpNnBbRrQqKk"
32 for r in range(8):
33 for c in range(8):
34 exec "%c%c = 8 * %d + %d" % ("abcdefgh"[c], "12345678"[r], r, c) in globals(), globals()
37 class Position(ctypes.Structure):
38 _fields_ = [("opaque", ctypes.c_byte * SIZE_OF_POSITION)]
40 def __init__(self):
41 ctypes.Structure.__init__(self)
42 self.offswitch = ctypes.c_int(0)
44 def getTurn(self): return lib.getTurn(ctypes.pointer(self))
45 def setTurn(self, x): lib.setTurn(ctypes.pointer(self), x)
46 turn = property(getTurn, setTurn)
48 def getValue(self): return lib.getValue(ctypes.pointer(self))
49 def setValue(self, x): lib.setValue(ctypes.pointer(self), x)
50 value = property(getValue, setValue)
52 def getBranch(self, i): return lib.getBranch(ctypes.byref(self), i)
54 def __getitem__(self, (r, c)):
55 return lib.getPieceAt(ctypes.pointer(self), 8 * r + c)
57 def __setitem__(self, (r, c), v):
58 lib.setPieceAt(ctypes.pointer(self), 8 * r + c, v)
59 lib.calculatePositionValue(ctypes.byref(self))
61 def fen(self):
62 def rowString(r): return ''.join(charForPiece[self[r, c]] for c in range(8))
63 boardString = '/'.join(rowString(r) for r in range(7, -1, -1)).replace(' ', '1')
64 for i in range(8, 1, -1): boardString = boardString.replace('1' * i, str(i))
65 return "%s %s" % (boardString, ['w', 'b'][self.turn & 1])
67 def setFen(self, fen):
68 parts = fen.split(" ") + ["w"]
69 boardString, turn = parts[0:2]
71 boardString = boardString.replace('/', '') + (' ' * 64)
72 for i in range(1, 9): boardString = boardString.replace(str(i), ' ' * i)
73 for n in range(64):
74 r, c = divmod(n, 8)
75 r = 7 - r
76 piece = charForPiece.find(boardString[n])
77 if piece < 0: piece = 0
78 self[r, c] = piece
80 lib.setTurn(ctypes.pointer(self), 0 if turn == 'w' else 1)
82 def __repr__(self): return "position(%r)" % self.fen()
83 def __str__(self): return ''.join(self.display())
84 def display(self):
85 for r in reversed(range(8)):
86 for c in range(8):
87 yield charForPiece[self[r, c]].replace(' ', '_')
88 yield '\n'
90 def children(self):
91 childArray = (Position * 512)()
92 numberOfChildren = lib.makeChildren(ctypes.byref(self), ctypes.byref(childArray))
93 return childArray[0:numberOfChildren]
94 def legalChildren(self):
95 return [child for child in self.children()
96 if not lib.inCheck(ctypes.byref(child), self.getTurn())]
98 def childrenGen(self):
99 childArray = (Position * 512)()
100 numberOfChildren = lib.makeChildren(ctypes.byref(self), ctypes.byref(childArray))
101 i = 0
102 for child in childArray:
103 if i >= numberOfChildren:
104 break
105 yield child
106 i += 1
109 def afterMove(self, sq1, sq2, promote=None):
110 copy = Position()
111 lib.afterMove(ctypes.byref(self), sq1, sq2, ctypes.byref(copy))
112 return copy
114 def afterPass(self):
115 copy = self.copy(); copy.turn ^= 1; return copy
117 def copy(self):
118 copy = Position()
119 ctypes.memmove(ctypes.byref(copy), ctypes.byref(self), SIZE_OF_POSITION)
120 return copy
122 def __cmp__(self, other): return cmp(self.fen(), other.fen())
123 def __hash__(self): return hash(self.fen())
125 def unstop(self): self.offswitch.value = 0
126 def stop(self): self.offswitch.value = 1
129 def position(s):
130 pos = Position(); pos.setFen(s); return pos
132 allSquares = [(r, c) for r in range(8) for c in range(8)]
133 def algebraicForPV(pv):
134 def f((r, c)): return "abcdefgh"[c] + "12345678"[r]
135 def g(x): return charForPiece[x].upper()
137 L = []
138 while len(pv) > 1:
139 parent, child = pv[:2]
140 x = [sq for sq in allSquares if parent[sq] != child[sq]]
141 try:
142 if child[x[0]] > 0:
143 x.reverse()
144 L.append(g(parent[x[0]]) + f(x[0]) + f(x[1]))
145 except: break
147 pv = pv[1:]
148 return L
150 ipos = position("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq")
151 pos12 = position("k7/8/KR")