initial commit
[arachne.git] / position.hs
blobfba079015b605b10848b308b0dac5bd02c829b75
1 -- position.hs
2 -- 20 Mar 2008
4 {-# OPTIONS -fbang-patterns #-}
6 module Main where
8 import Data.Array
9 import qualified Data.Array.Base as ABase
10 import qualified Char
11 import qualified System
12 import Control.Monad
13 import Control.Monad.ST
14 import Time
15 import Text.Printf
16 import Data.STRef
17 import Debug.Trace
19 -- rename this Player and make cell values be Just XX, Just OO, or Nothing
20 data Value = Empty | XX | OO deriving (Eq)
22 opponent :: Value -> Value
23 opponent XX = OO
24 opponent OO = XX
25 opponent Empty = Empty
27 sideValue XX = 1
28 sideValue OO = -1
29 sideValue Empty = 0
31 instance Show Value where
32 show Empty = "-"
33 show XX = "X"
34 show OO = "O"
37 -- use an unboxed array
38 data Position = Position {_board :: Array Int Value, _side :: Int,
39 _turn :: Value, _ways :: ![[Int]]}
40 deriving (Eq)
42 insertSlashes side str =
43 let pattern = map (\x -> if x `mod` side == 0 then "/" else "") [0..side*side-1]
44 in foldl (++) "" $ zipWith (++) pattern str
46 -- XXX Insert slashes.
47 instance Show Position where
48 show (Position board side turn _) = tail $
49 (insertSlashes side $ elems (ABase.amap show board))
50 ++ " " ++ (map Char.toLower $ show turn)
52 newPos :: Int -> Position
53 newPos side =
54 let arr = array (0, side * side - 1) [(i, Empty) | i <- [0..side*side-1]]
55 in Position arr side XX (winningWays side)
56 p = newPos 3
58 pos `at` cell = (_board pos) ! cell
60 afterMove pos@(Position board side turn ways) cell =
61 let board2 = board // [(cell, turn)]
62 in pos {_board = board2, _turn = (opponent turn)}
63 afterMoves = foldl afterMove
65 children pos =
66 [pos `afterMove` cell | cell <- range $ bounds (_board pos), pos `at` cell == Empty]
68 run len step start = [start + step * x | x <- [0..len-1]]
70 winningWays side =
71 let f = run side
72 in (f (side+1) 0) : (f (side-1) (side-1)) :
73 (map (f side) [0..side-1] ++ map (f 1) [0,side..side*side-1])
75 winner' pos [] = Nothing
76 winner' pos (way:rest) =
77 let get = (pos `at`)
78 piece = get (head way)
79 same = all (== piece) $ map get (tail way)
80 in if piece /= Empty && same
81 then Just piece else (winner' pos rest)
83 winner pos = winner' pos (_ways pos)
85 posValue pos = sideValue (_turn pos) * f (winner pos)
86 where f (Just XX) = 88888
87 f (Just OO) = -88888
88 f Nothing = 0
91 negamax pos 0 = (posValue pos, [pos])
92 negamax pos depth =
93 let evals = [negamax child (depth-1) | child <- children pos] in
94 if evals == []
95 then (posValue pos, [pos])
96 else -- :
97 let minValue = minimum $ map fst evals
98 (value, chain) = head $ filter (\ x -> fst x == minValue) evals
99 in (value * (-99) `div` 100, pos : chain)
101 seconds tdiff = (fromIntegral (tdSec tdiff) :: Double) +
102 (fromIntegral (tdPicosec tdiff) :: Double) / 1e12
106 whileM_ is like forM_, but it expects a boolean function.
107 It loops over the list until a False value is found.
109 whileM_ :: (Monad m) => [a] -> (a -> m Bool) -> m ()
110 whileM_ [] fn = return ()
111 whileM_ (x:xs) fn = do
112 result <- fn x
113 if result
114 then whileM_ xs fn
115 else return ()
117 -- returns (value of position, principal variation, node count)
118 alphaBeta :: Position -> Int -> Int -> Int -> ST s (Int, [Position], Int)
119 alphaBeta pos 0 _ _ = return (posValue pos, [pos], 1)
120 alphaBeta pos depth alpha0 beta =
121 let kids = children pos
122 ref = newSTRef
123 (|=) = writeSTRef
124 deref = readSTRef
125 in if kids == [] then return (posValue pos, [pos], 1) else
127 bestChain <- ref ([] :: [Position])
128 alpha <- ref alpha0; count <- ref 1
129 whileM_ kids (\child -> do
130 al <- deref alpha
131 let (val, chain, !nodeNum) = runST (alphaBeta child (depth-1) (-beta) (-al))
132 modifySTRef count (+ nodeNum)
133 let value = -val
134 case () of
135 _ | value >= beta -> do {alpha |= beta; return False;}
136 _ | value > al -> do {alpha |= value; bestChain |= chain; return True;}
137 _ -> return True )
138 chain <- deref bestChain
139 al <- deref alpha; cow <- deref count
140 return (al * 127 `div` 128, pos : chain, cow)
143 alphaBetaTop :: Position -> Int -> (Int, [Position], Int)
144 alphaBetaTop pos depth =
145 let x = 1000000 in runST (alphaBeta pos depth (-x) x)
146 abTop = alphaBetaTop
148 tryPos pos depth =
149 let spc = putStr " " in
150 do printf "\nDepth %d:\n" depth
151 t1 <- getClockTime
152 let !eval = alphaBetaTop pos depth
153 spc; print eval
154 t2 <- getClockTime
156 let duration = seconds (diffClockTimes t2 t1) + 0.001
157 printf " Alpha-beta: %.3f seconds\n" duration
159 let (value, _, count) = eval
160 printf " %.0f nodes per second\n" ((fromIntegral count) / duration)
162 main =
163 do args2 <- System.getArgs
164 let args = args2 ++ ["3"]
165 let side = read (head args) :: Int
166 let pos = newPos side
167 forM_ [0..] (tryPos pos)