1 (* Copyright Jeremy Yallop 2007.
2 This file is free software, distributed under the MIT license.
3 See the file COPYING for details.
9 val return
: 'a
-> 'a m
10 val fail
: string -> 'a m
11 val (>>=) : 'a m
-> ('a
-> 'b m
) -> 'b m
12 val (>>) : 'a m
-> 'b m
-> 'b m
15 module type MonadPlus
=
19 val mplus
: 'a m
-> 'a m
-> 'a m
26 val return
: 'a
-> 'a m
27 val fail
: string -> 'a m
28 val (>>=) : 'a m
-> ('a
-> 'b m
) -> 'b m
29 end) : Monad
with type 'a m
= 'a
M.m
=
32 let (>>) x y
= x
>>= (fun _
-> y
)
35 module Monad_option
: MonadPlus
36 with type 'a m
= 'a
option =
50 let mplus l r
= match l
, r
with
55 module Monad_list
: MonadPlus
56 with type 'a m
= 'a list
=
63 let (>>=) m f
= List.concat
(List.map f m
)
71 type 'a m
= unit -> 'a
72 let return a
= fun () -> a
77 let (>>) x y
= x
>>= (fun _
-> y
)
79 let putStr s
= fun () -> print_string s
81 let mkIO (f
: unit -> 'b
) = return (f
())
84 module type MonadUtilsSig
=
87 val liftM
: ('a
-> 'b
) -> 'a m
-> 'b m
88 val liftM2
: ('a
-> 'b
-> 'c
) -> 'a m
-> 'b m
-> 'c m
89 val liftM3
: ('a
-> 'b
-> 'c
-> 'd
) -> 'a m
-> 'b m
-> 'c m
-> 'd m
91 ('a
-> 'b
-> 'c
-> 'd
-> 'e
) -> 'a m
-> 'b m
-> 'c m
-> 'd m
-> 'e m
93 ('a
-> 'b
-> 'c
-> 'd
-> 'e
-> 'f
) ->
94 'a m
-> 'b m
-> 'c m
-> 'd m
-> 'e m
-> 'f m
95 val ap
: ('a
-> 'b
) m
-> 'a m
-> 'b m
96 val sequence
: 'a m list
-> 'a list m
97 val sequence_
: 'a m list
-> unit m
98 val mapM
: ('a
-> 'b m
) -> 'a list
-> 'b list m
99 val mapM_
: ('a
-> 'b m
) -> 'a list
-> unit m
100 val ( =<< ) : ('a
-> 'b m
) -> 'a m
-> 'b m
101 val join
: 'a m m
-> 'a m
102 val filterM
: ('a
-> bool m
) -> 'a list
-> 'a list m
104 ('a
-> ('b
* 'c
) m
) -> 'a list
-> ('b list
* 'c list
) m
105 val zipWithM
: ('a
-> 'b
-> 'c m
) -> 'a list
-> 'b list
-> 'c list m
106 val zipWithM_
: ('a
-> 'b
-> 'c m
) -> 'a list
-> 'b list
-> unit m
107 val foldM
: ('a
-> 'b
-> 'a m
) -> 'a
-> 'b list
-> 'a m
108 val foldM_
: ('a
-> 'b
-> 'a m
) -> 'a
-> 'b list
-> unit m
109 val replicateM
: int -> 'a m
-> 'a list m
110 val replicateM_
: int -> 'a m
-> unit m
111 val quand
: bool -> unit m
-> unit m
112 val unless
: bool -> unit m
-> unit m
116 module MonadUtils
(M
: Monad
) =
119 let liftM : ('a1
-> 'r
) -> 'a1 m
-> 'r m
120 = fun f m1
-> m1
>>= (fun x1
-> return (f x1
))
121 let liftM2 : ('a1
-> 'a2
-> 'r
) -> 'a1 m
-> 'a2 m
-> 'r m
125 -> return (f x1 x2
)))
126 let liftM3 : ('a1
-> 'a2
-> 'a3
-> 'r
) -> 'a1 m
-> 'a2 m
-> 'a3 m
-> 'r m
131 -> return (f x1 x2 x3
))))
132 let liftM4 : ('a1
-> 'a2
-> 'a3
-> 'a4
-> 'r
) -> 'a1 m
-> 'a2 m
-> 'a3 m
-> 'a4 m
-> 'r m
138 -> return (f x1 x2 x3 x4
)))))
139 let liftM5 : ('a1
-> 'a2
-> 'a3
-> 'a4
-> 'a5
-> 'r
) -> 'a1 m
-> 'a2 m
-> 'a3 m
-> 'a4 m
-> 'a5 m
-> 'r m
140 = fun f m1 m2 m3 m4 m5
146 -> return (f x1 x2 x3 x4 x5
))))))
147 let ap : ('a
-> 'b
) m
-> 'a m
-> 'b m
148 = fun f
-> liftM2 (fun x
-> x
) f
150 let sequence : ('a m
) list
-> ('a list
) m
151 = let mcons p q
= p
>>= (fun x
-> q
>>= (fun y
-> return (x
::y
)))
153 fun l
-> List.fold_right
mcons l
(return [])
155 let sequence_ : ('a m
) list
-> unit m
156 = fun l
-> List.fold_right
(>>) l
(return ())
158 let mapM : ('a
-> 'b m
) -> 'a list
-> ('b list
) m
159 = fun f xs
-> sequence (List.map f xs
)
161 let mapM_ : ('a
-> 'b m
) -> 'a list
-> unit m
162 = fun f xs
-> sequence_ (List.map f xs
)
164 let (=<<) : ('a
-> 'b m
) -> 'a m
-> 'b m
167 let join : ('a m
) m
-> 'a m
168 = fun x
-> x
>>= (fun x
-> x
)
170 let rec filterM : ('a
-> bool m
) -> 'a list
-> ('a list
) m
173 | x
::xs
-> p x
>>= (fun flg
->
174 filterM p xs
>>= (fun ys
->
175 return (if flg
then (x
::ys
) else ys
)))
177 let mapAndUnzipM : ('a
-> ('b
*'c
) m
) -> 'a list
-> ('b list
* 'c list
) m
178 = fun f xs
-> sequence (List.map f xs
) >>= fun x
-> return (List.split x
)
180 let zipWithM : ('a
-> 'b
-> 'c m
) -> 'a list
-> 'b list
-> ('c list
) m
181 = fun f xs ys
-> sequence (List.map2 f xs ys
)
183 let zipWithM_ : ('a
-> 'b
-> 'c m
) -> 'a list
-> 'b list
-> unit m
184 = fun f xs ys
-> sequence_ (List.map2 f xs ys
)
186 let rec foldM : ('a
-> 'b
-> 'a m
) -> 'a
-> 'b list
-> 'a m
187 = fun f a
-> function
189 | x
::xs
-> f a x
>>= (fun fax
-> foldM f fax xs
)
191 let foldM_ : ('a
-> 'b
-> 'a m
) -> 'a
-> 'b list
-> unit m
192 = fun f a xs
-> foldM f a xs
>> return ()
194 let ((replicateM
: int -> 'a m
-> ('a list
) m
),
195 (replicateM_
: int -> 'a m
-> unit m
))
196 = let replicate n i
=
197 let rec aux accum
= function
199 | n
-> aux (i
::accum
) (n
-1)
202 ((fun n x
-> sequence (replicate n x
)),
203 (fun n x
-> sequence_ (replicate n x
)))
205 let quand (* when *) : bool -> unit m
-> unit m
206 = fun p s
-> if p
then s
else return ()
208 let unless : bool -> unit m
-> unit m
209 = fun p s
-> if p
then return () else s
212 module type MonadPlusUtilsSig
=
214 include MonadUtilsSig
216 val mplus : 'a m
-> 'a m
-> 'a m
217 val guard
: bool -> unit m
218 val msum
: 'a m list
-> 'a m
221 module MonadPlusUtils
(M
: MonadPlus
) =
223 include MonadUtils
(M
)
226 let guard : bool -> unit M.m
228 | true -> M.return ()
231 let msum : ('a
M.m
) list
-> 'a
M.m
232 = fun l
-> List.fold_right
M.mplus l
M.mzero
235 module MonadPlusUtils_option
= MonadPlusUtils
(Monad_option
)
236 module MonadPlusUtils_list
= MonadPlusUtils
(Monad_list
)
237 module Monad_IO
= MonadUtils
(MonadDefault
(IO
))
239 module type Monad_state_type
=
241 include MonadUtilsSig
244 val put
: state
-> unit m
245 val runState
: 'a m
-> state
-> 'a
* state
248 module Monad_state_impl
(A
: sig type state
end) =
251 type 'a m
= State
of (A.state
-> ('a
* A.state
))
252 let get = State
(fun s
-> s
,s
)
253 let put s
= State
(fun _
-> (), s
)
254 let runState (State s
) = s
255 let return a
= State
(fun state
-> (a
, state
))
256 let fail s
= failwith
("state monad error " ^ s
)
257 let (>>=) (State x
) f
= State
(fun s
-> (let v, s'
= x s
in
259 let (>>) s f
= s
>>= fun _
-> f
262 module Monad_state
(S
: sig type state
end) :
263 Monad_state_type
with type state
= S.state
=
265 module M
= Monad_state_impl
(S
)
266 include MonadUtils
(M
)
270 let runState = M.runState