deriving (Enum) now uses Deriving_Enum module [#1]
[deriving.git] / lib / monad.ml
blobc5f60dd3ad69bf9aab0bfa80a2e83c604b5b7876
1 (* Copyright Jeremy Yallop 2007.
2 This file is free software, distributed under the MIT license.
3 See the file COPYING for details.
4 *)
6 module type Monad =
7 sig
8 type +'a m
9 val return : 'a -> 'a m
10 val fail : string -> 'a m
11 val (>>=) : 'a m -> ('a -> 'b m) -> 'b m
12 val (>>) : 'a m -> 'b m -> 'b m
13 end
15 module type MonadPlus =
16 sig
17 include Monad
18 val mzero : 'a m
19 val mplus : 'a m -> 'a m -> 'a m
20 end
22 module MonadDefault
23 (M :
24 sig
25 type +'a m
26 val return : 'a -> 'a m
27 val fail : string -> 'a m
28 val (>>=) : 'a m -> ('a -> 'b m) -> 'b m
29 end) : Monad with type 'a m = 'a M.m =
30 struct
31 include M
32 let (>>) x y = x >>= (fun _ -> y)
33 end
35 module Monad_option : MonadPlus
36 with type 'a m = 'a option =
37 struct
38 include MonadDefault(
39 struct
40 type 'a m = 'a option
41 let fail _ = None
42 let return x = Some x
43 let (>>=) x f =
44 match x with
45 | None -> None
46 | Some x -> f x
48 end)
49 let mzero = None
50 let mplus l r = match l, r with
51 | None, r -> r
52 | l, _ -> l
53 end
55 module Monad_list : MonadPlus
56 with type 'a m = 'a list =
57 struct
58 include MonadDefault(
59 struct
60 type 'a m = 'a list
61 let return x = [x]
62 let fail _ = []
63 let (>>=) m f = List.concat (List.map f m)
64 end)
65 let mzero = []
66 let mplus = (@)
67 end
69 module IO =
70 (struct
71 type 'a m = unit -> 'a
72 let return a = fun () -> a
73 let (>>=) m k =
74 fun () ->
75 let v = m () in
76 k v ()
77 let (>>) x y = x >>= (fun _ -> y)
78 let fail = failwith
79 let putStr s = fun () -> print_string s
80 let runIO f = f ()
81 let mkIO (f : unit -> 'b) = return (f ())
82 end)
84 module type MonadUtilsSig =
85 sig
86 include Monad
87 val liftM : ('a -> 'b) -> 'a m -> 'b m
88 val liftM2 : ('a -> 'b -> 'c) -> 'a m -> 'b m -> 'c m
89 val liftM3 : ('a -> 'b -> 'c -> 'd) -> 'a m -> 'b m -> 'c m -> 'd m
90 val liftM4 :
91 ('a -> 'b -> 'c -> 'd -> 'e) -> 'a m -> 'b m -> 'c m -> 'd m -> 'e m
92 val liftM5 :
93 ('a -> 'b -> 'c -> 'd -> 'e -> 'f) ->
94 'a m -> 'b m -> 'c m -> 'd m -> 'e m -> 'f m
95 val ap : ('a -> 'b) m -> 'a m -> 'b m
96 val sequence : 'a m list -> 'a list m
97 val sequence_ : 'a m list -> unit m
98 val mapM : ('a -> 'b m) -> 'a list -> 'b list m
99 val mapM_ : ('a -> 'b m) -> 'a list -> unit m
100 val ( =<< ) : ('a -> 'b m) -> 'a m -> 'b m
101 val join : 'a m m -> 'a m
102 val filterM : ('a -> bool m) -> 'a list -> 'a list m
103 val mapAndUnzipM :
104 ('a -> ('b * 'c) m) -> 'a list -> ('b list * 'c list) m
105 val zipWithM : ('a -> 'b -> 'c m) -> 'a list -> 'b list -> 'c list m
106 val zipWithM_ : ('a -> 'b -> 'c m) -> 'a list -> 'b list -> unit m
107 val foldM : ('a -> 'b -> 'a m) -> 'a -> 'b list -> 'a m
108 val foldM_ : ('a -> 'b -> 'a m) -> 'a -> 'b list -> unit m
109 val replicateM : int -> 'a m -> 'a list m
110 val replicateM_ : int -> 'a m -> unit m
111 val quand : bool -> unit m -> unit m
112 val unless : bool -> unit m -> unit m
115 (* Control.Monad *)
116 module MonadUtils (M : Monad) =
117 struct
118 include M
119 let liftM : ('a1 -> 'r) -> 'a1 m -> 'r m
120 = fun f m1 -> m1 >>= (fun x1 -> return (f x1))
121 let liftM2 : ('a1 -> 'a2 -> 'r) -> 'a1 m -> 'a2 m -> 'r m
122 = fun f m1 m2
123 -> m1 >>= (fun x1
124 -> m2 >>= (fun x2
125 -> return (f x1 x2)))
126 let liftM3 : ('a1 -> 'a2 -> 'a3 -> 'r) -> 'a1 m -> 'a2 m -> 'a3 m -> 'r m
127 = fun f m1 m2 m3
128 -> m1 >>= (fun x1
129 -> m2 >>= (fun x2
130 -> m3 >>= (fun x3
131 -> return (f x1 x2 x3))))
132 let liftM4 : ('a1 -> 'a2 -> 'a3 -> 'a4 -> 'r) -> 'a1 m -> 'a2 m -> 'a3 m -> 'a4 m -> 'r m
133 = fun f m1 m2 m3 m4
134 -> m1 >>= (fun x1
135 -> m2 >>= (fun x2
136 -> m3 >>= (fun x3
137 -> m4 >>= (fun x4
138 -> return (f x1 x2 x3 x4)))))
139 let liftM5 : ('a1 -> 'a2 -> 'a3 -> 'a4 -> 'a5 -> 'r) -> 'a1 m -> 'a2 m -> 'a3 m -> 'a4 m -> 'a5 m -> 'r m
140 = fun f m1 m2 m3 m4 m5
141 -> m1 >>= (fun x1
142 -> m2 >>= (fun x2
143 -> m3 >>= (fun x3
144 -> m4 >>= (fun x4
145 -> m5 >>= (fun x5
146 -> return (f x1 x2 x3 x4 x5))))))
147 let ap : ('a -> 'b) m -> 'a m -> 'b m
148 = fun f -> liftM2 (fun x -> x) f
150 let sequence : ('a m) list -> ('a list) m
151 = let mcons p q = p >>= (fun x -> q >>= (fun y -> return (x::y)))
153 fun l -> List.fold_right mcons l (return [])
155 let sequence_ : ('a m) list -> unit m
156 = fun l -> List.fold_right (>>) l (return ())
158 let mapM : ('a -> 'b m) -> 'a list -> ('b list) m
159 = fun f xs -> sequence (List.map f xs)
161 let mapM_ : ('a -> 'b m) -> 'a list -> unit m
162 = fun f xs -> sequence_ (List.map f xs)
164 let (=<<) : ('a -> 'b m) -> 'a m -> 'b m
165 = fun f x -> x >>= f
167 let join : ('a m) m -> 'a m
168 = fun x -> x >>= (fun x -> x)
170 let rec filterM : ('a -> bool m) -> 'a list -> ('a list) m
171 = fun p -> function
172 | [] -> return []
173 | x::xs -> p x >>= (fun flg ->
174 filterM p xs >>= (fun ys ->
175 return (if flg then (x::ys) else ys)))
177 let mapAndUnzipM : ('a -> ('b *'c) m) -> 'a list -> ('b list * 'c list) m
178 = fun f xs -> sequence (List.map f xs) >>= fun x -> return (List.split x)
180 let zipWithM : ('a -> 'b -> 'c m) -> 'a list -> 'b list -> ('c list) m
181 = fun f xs ys -> sequence (List.map2 f xs ys)
183 let zipWithM_ : ('a -> 'b -> 'c m) -> 'a list -> 'b list -> unit m
184 = fun f xs ys -> sequence_ (List.map2 f xs ys)
186 let rec foldM : ('a -> 'b -> 'a m) -> 'a -> 'b list -> 'a m
187 = fun f a -> function
188 | [] -> return a
189 | x::xs -> f a x >>= (fun fax -> foldM f fax xs)
191 let foldM_ : ('a -> 'b -> 'a m) -> 'a -> 'b list -> unit m
192 = fun f a xs -> foldM f a xs >> return ()
194 let ((replicateM : int -> 'a m -> ('a list) m),
195 (replicateM_ : int -> 'a m -> unit m))
196 = let replicate n i =
197 let rec aux accum = function
198 | 0 -> accum
199 | n -> aux (i::accum) (n-1)
200 in aux [] n
202 ((fun n x -> sequence (replicate n x)),
203 (fun n x -> sequence_ (replicate n x)))
205 let quand (* when *) : bool -> unit m -> unit m
206 = fun p s -> if p then s else return ()
208 let unless : bool -> unit m -> unit m
209 = fun p s -> if p then return () else s
212 module type MonadPlusUtilsSig =
214 include MonadUtilsSig
215 val mzero : 'a m
216 val mplus : 'a m -> 'a m -> 'a m
217 val guard : bool -> unit m
218 val msum : 'a m list -> 'a m
221 module MonadPlusUtils (M : MonadPlus) =
222 struct
223 include MonadUtils(M)
224 let mzero = M.mzero
225 let mplus = M.mplus
226 let guard : bool -> unit M.m
227 = function
228 | true -> M.return ()
229 | false -> M.mzero
231 let msum : ('a M.m) list -> 'a M.m
232 = fun l -> List.fold_right M.mplus l M.mzero
235 module MonadPlusUtils_option = MonadPlusUtils(Monad_option)
236 module MonadPlusUtils_list = MonadPlusUtils(Monad_list)
237 module Monad_IO = MonadUtils(MonadDefault (IO))
239 module type Monad_state_type =
241 include MonadUtilsSig
242 type state
243 val get : state m
244 val put : state -> unit m
245 val runState : 'a m -> state -> 'a * state
248 module Monad_state_impl (A : sig type state end) =
249 struct
250 type state = A.state
251 type 'a m = State of (A.state -> ('a * A.state))
252 let get = State (fun s -> s,s)
253 let put s = State (fun _ -> (), s)
254 let runState (State s) = s
255 let return a = State (fun state -> (a, state))
256 let fail s = failwith ("state monad error " ^ s)
257 let (>>=) (State x) f = State (fun s -> (let v, s' = x s in
258 runState (f v) s'))
259 let (>>) s f = s >>= fun _ -> f
262 module Monad_state(S : sig type state end) :
263 Monad_state_type with type state = S.state =
264 struct
265 module M = Monad_state_impl(S)
266 include MonadUtils(M)
267 type state = M.state
268 let get = M.get
269 let put = M.put
270 let runState = M.runState