fix enum_tests.ml to compile
[deriving.git] / syntax / base.ml
blob79fe4fefc1138f74512dde0c25ddf57fbe865fd5
1 (*pp camlp4of *)
3 (* Copyright Jeremy Yallop 2007.
4 This file is free software, distributed under the MIT license.
5 See the file COPYING for details.
6 *)
8 open Utils
9 open Type
10 open Camlp4.PreCast
12 type context = {
13 loc : Loc.t;
14 (* mapping from type parameters to functor arguments *)
15 argmap : name NameMap.t;
16 (* ordered list of type parameters *)
17 params : param list;
18 (* type names *)
19 tnames : NameSet.t;
22 exception Underivable of string
23 exception NoSuchClass of string
25 (* display a fatal error and exit *)
26 let error loc (msg : string) =
27 Syntax.print_warning loc msg;
28 exit 1
30 module type Loc = sig val loc : Loc.t end
32 let contains_tvars, contains_tvars_decl =
33 let o = object
34 inherit [bool] fold as default
35 method crush = List.exists F.id
36 method expr = function
37 | `Param _ -> true
38 | e -> default#expr e
39 end in (o#expr, o#decl)
41 module InContext(L : Loc) =
42 struct
43 include L
44 module Untranslate = Untranslate(L)
46 let instantiate, instantiate_repr =
47 let o lookup = object
48 inherit transform as super
49 method expr = function
50 | `Param (name, _) -> lookup name
51 | e -> super # expr e
52 end in
53 (fun (lookup : name -> expr) -> (o lookup)#expr),
54 (fun (lookup : name -> expr) -> (o lookup)#repr)
56 let instantiate_modargs, instantiate_modargs_repr =
57 let lookup ctxt var =
58 try
59 `Constr ([NameMap.find var ctxt.argmap; "a"], [])
60 with Not_found ->
61 failwith ("Unbound type parameter '" ^ var)
62 in (fun ctxt -> instantiate (lookup ctxt)),
63 (fun ctxt -> instantiate_repr (lookup ctxt))
65 let substitute env =
66 (object
67 inherit transform as default
68 method expr = function
69 | `Param (p,v) when NameMap.mem p env ->
70 `Param (NameMap.find p env,v)
71 | e -> default# expr e
72 end) # expr
74 let cast_pattern ctxt ?(param="x") t =
75 let t = Untranslate.expr (instantiate_modargs ctxt t) in
76 (<:patt< $lid:param$ >>,
77 <:expr<
78 let module M =
79 struct
80 type t = $t$
81 let test = function #t -> true | _ -> false
82 end in M.test $lid:param$ >>,
83 <:expr<
84 (let module M =
85 struct
86 type t = $t$
87 let cast = function #t as t -> t | _ -> assert false
88 end in M.cast $lid:param$ )>>)
90 let seq l r = <:expr< $l$ ; $r$ >>
92 let record_pattern ?(prefix="") (fields : Type.field list) : Ast.patt =
93 <:patt<{$list:
94 (List.map (fun (label,_,_) -> <:patt< $lid:label$ = $lid:prefix ^ label$ >>)
95 fields) $}>>
97 let record_expr : (string * Ast.expr) list -> Ast.expr =
98 fun fields ->
99 let fs =
100 List.fold_left1
101 (fun l r -> <:rec_binding< $l$ ; $r$ >>)
102 (List.map (fun (label, exp) -> <:rec_binding< $lid:label$ = $exp$ >>)
103 fields) in
104 Ast.ExRec (loc, fs, Ast.ExNil loc)
106 let record_expression ?(prefix="") : Type.field list -> Ast.expr =
107 fun fields ->
108 let es = List.fold_left1
109 (fun l r -> <:rec_binding< $l$ ; $r$ >>)
110 (List.map (fun (label,_,_) -> <:rec_binding< $lid:label$ = $lid:prefix ^ label$ >>)
111 fields) in
112 Ast.ExRec (loc, es, Ast.ExNil loc)
114 let mproject mexpr name =
115 match mexpr with
116 | <:module_expr< $id:m$ >> -> <:expr< $id:m$.$lid:name$ >>
117 | _ -> <:expr< let module M = $mexpr$ in M.$lid:name$ >>
119 let expr_list : Ast.expr list -> Ast.expr =
120 (fun exprs ->
121 List.fold_right
122 (fun car cdr -> <:expr< $car$ :: $cdr$ >>)
123 exprs
124 <:expr< [] >>)
126 let patt_list : Ast.patt list -> Ast.patt =
127 (fun patts ->
128 List.fold_right
129 (fun car cdr -> <:patt< $car$ :: $cdr$ >>)
130 patts
131 <:patt< [] >>)
133 let tuple_expr : Ast.expr list -> Ast.expr = function
134 | [] -> <:expr< () >>
135 | [x] -> x
136 | x::xs -> Ast.ExTup (loc, List.fold_left (fun e t -> Ast.ExCom (loc, e,t)) x xs)
138 let tuple ?(param="v") n : Ast.patt * Ast.expr =
139 let v n = Printf.sprintf "%s%d" param n in
140 match n with
141 | 0 -> <:patt< () >>, <:expr< () >>
142 | 1 -> <:patt< $lid:v 0$ >>, <:expr< $lid:v 0$ >>
143 | n ->
144 let patts, exprs =
145 (* At time of writing I haven't managed to write anything
146 using quotations that generates an n-tuple *)
147 List.fold_left
148 (fun (p, e) (patt, expr) -> Ast.PaCom (loc, p, patt), Ast.ExCom (loc, e, expr))
149 (<:patt< >>, <:expr< >>)
150 (List.map (fun n -> <:patt< $lid:v n$ >>, <:expr< $lid:v n $ >>)
151 (List.range 0 n))
153 Ast.PaTup (loc, patts), Ast.ExTup (loc, exprs)
155 let rec modname_from_qname ~qname ~classname =
156 match qname with
157 | [] -> invalid_arg "modname_from_qname"
158 | [t] -> <:ident< $uid:classname ^ "_"^ t$ >>
159 | t::ts -> <:ident< $uid:t$.$modname_from_qname ~qname:ts ~classname$ >>
161 let apply_functor (f : Ast.module_expr) (args : Ast.module_expr list) : Ast.module_expr =
162 List.fold_left (fun f p -> <:module_expr< $f$ ($p$) >>) f args
164 class virtual make_module_expr ~classname ~allow_private =
165 object (self)
167 method mapply ctxt (funct : Ast.module_expr) args =
168 apply_functor funct (List.map (self#expr ctxt) args)
170 method virtual variant : context -> decl -> variant -> Ast.module_expr
171 method virtual sum : ?eq:expr -> context -> decl -> summand list -> Ast.module_expr
172 method virtual record : ?eq:expr -> context -> decl -> field list -> Ast.module_expr
173 method virtual tuple : context -> expr list -> Ast.module_expr
175 method param ctxt (name, variance) =
176 <:module_expr< $uid:NameMap.find name ctxt.argmap$ >>
178 method object_ _ o = raise (Underivable (classname ^ " cannot be derived for object types"))
179 method class_ _ c = raise (Underivable (classname ^ " cannot be derived for class types"))
180 method label _ l = raise (Underivable (classname ^ " cannot be derived for label types"))
181 method function_ _ f = raise (Underivable (classname ^ " cannot be derived for function types"))
183 method constr ctxt (qname, args) =
184 match qname with
185 | [name] when NameSet.mem name ctxt.tnames ->
186 <:module_expr< $uid:Printf.sprintf "%s_%s" classname name$ >>
187 | _ ->
188 let f = (modname_from_qname ~qname ~classname) in
189 self#mapply ctxt (Ast.MeId (loc, f)) args
191 method expr (ctxt : context) : expr -> Ast.module_expr = function
192 | `Param p -> self#param ctxt p
193 | `Object o -> self#object_ ctxt o
194 | `Class c -> self#class_ ctxt c
195 | `Label l -> self#label ctxt l
196 | `Function f -> self#function_ ctxt f
197 | `Constr c -> self#constr ctxt c
198 | `Tuple t -> self#tuple ctxt t
200 method rhs ctxt (tname, params, rhs, constraints, _ as decl : Type.decl) : Ast.module_expr =
201 match rhs with
202 | `Fresh (_, _, (`Private : [`Private|`Public])) when not allow_private ->
203 raise (Underivable ("The class "^ classname ^" cannot be derived for private types"))
204 | `Fresh (eq, Sum summands, _) -> self#sum ?eq ctxt decl summands
205 | `Fresh (eq, Record fields, _) -> self#record ?eq ctxt decl fields
206 | `Expr e -> self#expr ctxt e
207 | `Variant v -> self# variant ctxt decl v
208 | `Nothing -> <:module_expr< >>
211 let atype_expr ctxt expr =
212 Untranslate.expr (instantiate_modargs ctxt expr)
214 let atype ctxt (name, params, rhs, _, _) =
215 match rhs with
216 | `Fresh _ | `Variant _ | `Nothing ->
217 Untranslate.expr (`Constr ([name],
218 List.map (fun (p,_) -> `Constr ([NameMap.find p ctxt.argmap; "a"],[])) params))
219 | `Expr e -> atype_expr ctxt e
221 let make_safe (decls : (decl * Ast.module_binding) list) : Ast.module_binding list =
222 (* re-order a set of mutually recursive modules in an attempt to
223 make initialization problems less likely *)
224 List.map snd
225 (List.sort
226 (fun ((_,_,lrhs,_,_), _) ((_,_,rrhs,_,_), _) -> match (lrhs : rhs), rrhs with
227 (* aliases to types in the group score higher than
228 everything else.
230 In general, things that must come first receive a
231 positive score when they occur on the left and a
232 negative score when they occur on the right. *)
233 | (`Fresh _|`Variant _), (`Fresh _|`Variant _) -> 0
234 | (`Fresh _|`Variant _), _ -> -1
235 | _, (`Fresh _|`Variant _) -> 1
236 | (`Nothing, `Nothing) -> 0
237 | (`Nothing, _) -> 1
238 | (_, `Nothing) -> -1
239 | `Expr l, `Expr r ->
240 let module M =
241 struct
242 type low =
243 [`Param of param
244 |`Tuple of expr list]
245 end in
246 match l, r with
247 | #M.low, _ -> 1
248 | _, #M.low -> -1
249 | _ -> 0)
250 decls)
252 let generate ~context ~decls ~make_module_expr ~classname ?default_module () =
253 (* plan:
254 set up an enclosing recursive module
255 generate functors for all types in the clique
256 project out the inner modules afterwards.
258 later: generate simpler code for simpler cases:
259 - where there are no type parameters
260 - where there's only one type
261 - where there's no recursion
262 - etc.
264 (* FIXME implicit requirement of classname being equal to module name, hence this hack for Enum *)
265 let modulename = if classname = "Enum" then "Deriving_Enum" else classname in
266 (* let _ = ensure_no_polymorphic_recursion in *)
267 let wrapper_name = Printf.sprintf "%s_%s" classname (random_id 32) in
268 let make_functor =
269 List.fold_right
270 (fun (p,_) rhs ->
271 let arg = NameMap.find p context.argmap in
272 <:module_expr< functor ($arg$ : $uid:classname$.$uid:classname$) -> $rhs$ >>)
273 context.params in
274 let apply_defaults mexpr = match default_module with
275 | None -> mexpr
276 | Some default -> <:module_expr< $uid:classname$.$uid:default$ ($mexpr$) >> in
277 let mbinds =
278 List.map
279 (fun (name,_,_,_,_ as decl) ->
280 if name = "a" then
281 raise (Underivable ("deriving: types called `a' are not allowed.\n"
282 ^"Please change the name of your type and try again."))
283 else
284 (decl,
285 <:module_binding<
286 $uid:classname ^ "_"^ name$
287 : $uid:modulename$.$uid:classname$ with type a = $atype context decl$
288 = $apply_defaults (make_module_expr context decl)$ >>))
289 decls in
290 let sorted_mbinds = make_safe mbinds in
291 let mrec =
292 <:str_item< open $uid:modulename$ module rec $list:sorted_mbinds$ >> in
293 match context.params with
294 | [] -> mrec
295 | _ ->
296 let fixed = make_functor <:module_expr< struct $mrec$ end >> in
297 let applied = apply_functor <:module_expr< $uid:wrapper_name$ >>
298 (List.map (fun (p,_) -> <:module_expr< $uid:NameMap.find p context.argmap$>>)
299 context.params) in
300 let projected =
301 List.map (fun (name,params,rhs,_,_) ->
302 let modname = classname ^ "_"^ name in
303 let rhs = <:module_expr< struct module P = $applied$ include P.$uid:modname$ end >> in
304 <:str_item< module $uid:modname$ = $make_functor rhs$>>)
305 decls in
306 let m = <:str_item< module $uid:wrapper_name$ = $fixed$ >> in
307 <:str_item< $m$ $list:projected$ >>
309 let gen_sig ~classname ~context (tname,params,_,_,generated as decl) =
310 (* FIXME implicit requirement of classname being equal to module name, hence this hack for Enum *)
311 let modulename = if classname = "Enum" then "Deriving_Enum" else classname in
312 if tname = "a" then
313 raise (Underivable ("deriving: types called `a' are not allowed.\n"
314 ^"Please change the name of your type and try again."))
315 else
316 if generated then <:sig_item< >>
317 else
318 let t = List.fold_right
319 (fun (p,_) m -> <:module_type< functor ($NameMap.find p context.argmap$ : $uid:classname$.$uid:classname$) -> $m$ >>)
320 params
321 <:module_type< $uid:modulename$.$uid:classname$ with type a = $atype context decl$ >> in
322 <:sig_item< module $uid:Printf.sprintf "%s_%s" classname tname$ : $t$ >>
324 let gen_sigs ~classname ~context ~decls =
325 <:sig_item< $list:List.map (gen_sig ~classname ~context) decls$ >>
328 let find_non_regular params tnames decls : name list =
329 List.concat_map
330 (object
331 inherit [name list] fold as default
332 method crush = List.concat
333 method expr = function
334 | `Constr ([t], args)
335 when NameSet.mem t tnames ->
336 (List.concat_map2
337 (fun (p,_) a -> match a with
338 | `Param (q,_) when p = q -> []
339 | _ -> [t])
340 params
341 args)
342 | e -> default#expr e
343 end)#decl decls
345 let extract_params =
346 let has_params params (_, ps, _, _, _) = ps = params in
347 function
348 | [] -> invalid_arg "extract_params"
349 | (_,params,_,_,_)::rest
350 when List.for_all (has_params params) rest ->
351 params
352 | (_,_,rhs,_,_)::_ ->
353 (* all types in a clique must have the same parameters *)
354 raise (Underivable ("Instances can only be derived for "
355 ^"recursive groups where all types\n"
356 ^"in the group have the same parameters."))
358 let setup_context loc (tdecls : decl list) : context =
359 let params = extract_params tdecls
360 and tnames = NameSet.fromList (List.map (fun (name,_,_,_,_) -> name) tdecls) in
361 match find_non_regular params tnames tdecls with
362 | _::_ as names ->
363 failwith ("The following types contain non-regular recursion:\n "
364 ^String.concat ", " names
365 ^"\nderiving does not support non-regular types")
366 | [] ->
367 let argmap =
368 List.fold_right
369 (fun (p,_) m -> NameMap.add p (Printf.sprintf "V_%s" p) m)
370 params
371 NameMap.empty in
372 { loc = loc;
373 argmap = argmap;
374 params = params;
375 tnames = tnames }
377 type deriver = Loc.t * context * Type.decl list -> Ast.str_item
378 and sigderiver = Loc.t * context * Type.decl list -> Ast.sig_item
379 let derivers : (name, (deriver * sigderiver)) Hashtbl.t = Hashtbl.create 15
380 let register = Hashtbl.add derivers
381 let find classname =
382 try Hashtbl.find derivers classname
383 with Not_found -> raise (NoSuchClass classname)
384 let is_registered : name -> bool =
385 fun classname -> try ignore (find classname); true with NoSuchClass _ -> false