3 (* Copyright Jeremy Yallop 2007.
4 This file is free software, distributed under the MIT license.
5 See the file COPYING for details.
14 (* mapping from type parameters to functor arguments *)
15 argmap
: name
NameMap.t
;
16 (* ordered list of type parameters *)
22 exception Underivable
of string
23 exception NoSuchClass
of string
25 (* display a fatal error and exit *)
26 let error loc
(msg
: string) =
27 Syntax.print_warning loc msg
;
30 module type Loc
= sig val loc
: Loc.t
end
32 let contains_tvars, contains_tvars_decl
=
34 inherit [bool] fold
as default
35 method crush
= List.exists
F.id
36 method expr
= function
39 end in (o#expr
, o#decl
)
41 module InContext
(L
: Loc
) =
44 module Untranslate
= Untranslate
(L
)
46 let instantiate, instantiate_repr
=
48 inherit transform
as super
49 method expr
= function
50 | `Param
(name
, _
) -> lookup name
53 (fun (lookup
: name
-> expr
) -> (o lookup
)#expr
),
54 (fun (lookup
: name
-> expr
) -> (o lookup
)#repr
)
56 let instantiate_modargs, instantiate_modargs_repr
=
59 `Constr
([NameMap.find var ctxt
.argmap
; "a"], [])
61 failwith
("Unbound type parameter '" ^ var
)
62 in (fun ctxt
-> instantiate (lookup ctxt
)),
63 (fun ctxt
-> instantiate_repr
(lookup ctxt
))
67 inherit transform
as default
68 method expr
= function
69 | `Param
(p
,v
) when NameMap.mem p env
->
70 `Param
(NameMap.find p env
,v
)
71 | e
-> default# expr e
74 let cast_pattern ctxt ?
(param
="x") t
=
75 let t = Untranslate.expr
(instantiate_modargs ctxt
t) in
76 (<:patt
< $lid
:param$
>>,
81 let test = function #
t -> true | _
-> false
82 end in M.test $lid
:param$
>>,
87 let cast = function #
t as t -> t | _
-> assert false
88 end in M.cast $lid
:param$
)>>)
90 let seq l r
= <:expr
< $l$
; $r$
>>
92 let record_pattern ?
(prefix
="") (fields
: Type.field list
) : Ast.patt
=
94 (List.map
(fun (label
,_
,_
) -> <:patt
< $lid
:label$
= $lid
:prefix ^ label$
>>)
97 let record_expr : (string * Ast.expr
) list
-> Ast.expr
=
101 (fun l r
-> <:rec_binding
< $l$
; $r$
>>)
102 (List.map
(fun (label
, exp
) -> <:rec_binding
< $lid
:label$
= $exp$
>>)
104 Ast.ExRec
(loc
, fs, Ast.ExNil loc
)
106 let record_expression ?
(prefix
="") : Type.field list
-> Ast.expr
=
108 let es = List.fold_left1
109 (fun l r
-> <:rec_binding
< $l$
; $r$
>>)
110 (List.map
(fun (label
,_
,_
) -> <:rec_binding
< $lid
:label$
= $lid
:prefix ^ label$
>>)
112 Ast.ExRec
(loc
, es, Ast.ExNil loc
)
114 let mproject mexpr name
=
116 | <:module_expr
< $id
:m$
>> -> <:expr
< $id
:m$
.$lid
:name$
>>
117 | _
-> <:expr
< let module M
= $mexpr$
in M.$lid
:name$
>>
119 let expr_list : Ast.expr list
-> Ast.expr
=
122 (fun car cdr
-> <:expr
< $car$
:: $cdr$
>>)
126 let patt_list : Ast.patt list
-> Ast.patt
=
129 (fun car cdr
-> <:patt
< $car$
:: $cdr$
>>)
133 let tuple_expr : Ast.expr list
-> Ast.expr
= function
134 | [] -> <:expr
< () >>
136 | x
::xs
-> Ast.ExTup
(loc
, List.fold_left
(fun e
t -> Ast.ExCom
(loc
, e
,t)) x xs
)
138 let tuple ?
(param
="v") n
: Ast.patt
* Ast.expr
=
139 let v n
= Printf.sprintf
"%s%d" param n
in
141 | 0 -> <:patt
< () >>, <:expr
< () >>
142 | 1 -> <:patt
< $lid
:v 0$
>>, <:expr
< $lid
:v 0$
>>
145 (* At time of writing I haven't managed to write anything
146 using quotations that generates an n-tuple *)
148 (fun (p
, e
) (patt
, expr
) -> Ast.PaCom
(loc
, p
, patt
), Ast.ExCom
(loc
, e
, expr
))
149 (<:patt
< >>, <:expr
< >>)
150 (List.map
(fun n
-> <:patt
< $lid
:v n$
>>, <:expr
< $lid
:v n $
>>)
153 Ast.PaTup
(loc
, patts), Ast.ExTup
(loc
, exprs
)
155 let rec modname_from_qname ~qname ~classname
=
157 | [] -> invalid_arg
"modname_from_qname"
158 | [t] -> <:ident
< $uid
:classname ^
"_"^
t$
>>
159 | t::ts
-> <:ident
< $uid
:t$
.$
modname_from_qname ~qname
:ts ~classname$
>>
161 let apply_functor (f
: Ast.module_expr
) (args
: Ast.module_expr list
) : Ast.module_expr
=
162 List.fold_left
(fun f p
-> <:module_expr
< $f$
($p$
) >>) f args
164 class virtual make_module_expr ~classname ~allow_private
=
167 method mapply ctxt
(funct
: Ast.module_expr
) args
=
168 apply_functor funct
(List.map
(self#expr ctxt
) args
)
170 method virtual variant
: context
-> decl
-> variant
-> Ast.module_expr
171 method virtual sum
: ?eq
:expr
-> context
-> decl
-> summand list
-> Ast.module_expr
172 method virtual record
: ?eq
:expr
-> context
-> decl
-> field list
-> Ast.module_expr
173 method virtual tuple : context
-> expr list
-> Ast.module_expr
175 method param ctxt
(name
, variance
) =
176 <:module_expr
< $uid
:NameMap.find name ctxt
.argmap$
>>
178 method object_ _
o = raise
(Underivable
(classname ^
" cannot be derived for object types"))
179 method class_ _ c
= raise
(Underivable
(classname ^
" cannot be derived for class types"))
180 method label _ l
= raise
(Underivable
(classname ^
" cannot be derived for label types"))
181 method function_ _ f
= raise
(Underivable
(classname ^
" cannot be derived for function types"))
183 method constr ctxt
(qname
, args
) =
185 | [name
] when NameSet.mem name ctxt
.tnames
->
186 <:module_expr
< $uid
:Printf.sprintf
"%s_%s" classname name$
>>
188 let f = (modname_from_qname ~qname ~classname
) in
189 self#mapply ctxt
(Ast.MeId
(loc
, f)) args
191 method expr
(ctxt
: context
) : expr
-> Ast.module_expr
= function
192 | `Param p
-> self#param ctxt p
193 | `Object
o -> self#object_ ctxt
o
194 | `Class c
-> self#class_ ctxt c
195 | `Label l
-> self#label ctxt l
196 | `Function
f -> self#function_ ctxt
f
197 | `Constr c
-> self#constr ctxt c
198 | `Tuple
t -> self#
tuple ctxt
t
200 method rhs ctxt
(tname
, params
, rhs
, constraints
, _
as decl
: Type.decl
) : Ast.module_expr
=
202 | `Fresh
(_
, _
, (`Private
: [`Private
|`Public
])) when not allow_private
->
203 raise
(Underivable
("The class "^ classname ^
" cannot be derived for private types"))
204 | `Fresh
(eq
, Sum summands
, _
) -> self#sum ?eq ctxt decl summands
205 | `Fresh
(eq
, Record fields
, _
) -> self#record ?eq ctxt decl fields
206 | `Expr e
-> self#expr ctxt e
207 | `Variant
v -> self# variant ctxt decl
v
208 | `Nothing
-> <:module_expr
< >>
211 let atype_expr ctxt expr
=
212 Untranslate.expr
(instantiate_modargs ctxt expr
)
214 let atype ctxt
(name
, params
, rhs
, _
, _
) =
216 | `Fresh _
| `Variant _
| `Nothing
->
217 Untranslate.expr
(`Constr
([name
],
218 List.map
(fun (p
,_
) -> `Constr
([NameMap.find p ctxt
.argmap
; "a"],[])) params
))
219 | `Expr e
-> atype_expr ctxt e
221 let make_safe (decls
: (decl
* Ast.module_binding
) list
) : Ast.module_binding list
=
222 (* re-order a set of mutually recursive modules in an attempt to
223 make initialization problems less likely *)
226 (fun ((_
,_
,lrhs
,_
,_
), _
) ((_
,_
,rrhs
,_
,_
), _
) -> match (lrhs
: rhs
), rrhs
with
227 (* aliases to types in the group score higher than
230 In general, things that must come first receive a
231 positive score when they occur on the left and a
232 negative score when they occur on the right. *)
233 | (`Fresh _
|`Variant _
), (`Fresh _
|`Variant _
) -> 0
234 | (`Fresh _
|`Variant _
), _
-> -1
235 | _
, (`Fresh _
|`Variant _
) -> 1
236 | (`Nothing
, `Nothing
) -> 0
238 | (_
, `Nothing
) -> -1
239 | `Expr l
, `Expr r
->
244 |`Tuple
of expr list
]
252 let generate ~context ~decls ~make_module_expr ~classname ?default_module
() =
254 set up an enclosing recursive module
255 generate functors for all types in the clique
256 project out the inner modules afterwards.
258 later: generate simpler code for simpler cases:
259 - where there are no type parameters
260 - where there's only one type
261 - where there's no recursion
264 (* FIXME implicit requirement of classname being equal to module name, hence this hack for Enum *)
265 let modulename = if classname
= "Enum" then "Deriving_Enum" else classname
in
266 (* let _ = ensure_no_polymorphic_recursion in *)
267 let wrapper_name = Printf.sprintf
"%s_%s" classname
(random_id
32) in
271 let arg = NameMap.find p context
.argmap
in
272 <:module_expr
< functor ($
arg$
: $uid
:classname$
.$uid
:classname$
) -> $rhs$
>>)
274 let apply_defaults mexpr
= match default_module
with
276 | Some default
-> <:module_expr
< $uid
:classname$
.$uid
:default$
($mexpr$
) >> in
279 (fun (name
,_,_,_,_ as decl
) ->
281 raise
(Underivable
("deriving: types called `a' are not allowed.\n"
282 ^
"Please change the name of your type and try again."))
286 $uid
:classname ^
"_"^ name$
287 : $uid
:modulename$
.$uid
:classname$
with type a
= $
atype context decl$
288 = $
apply_defaults (make_module_expr context decl
)$
>>))
290 let sorted_mbinds = make_safe mbinds in
292 <:str_item
< open $uid
:modulename$
module rec $list
:sorted_mbinds$
>> in
293 match context
.params
with
296 let fixed = make_functor <:module_expr
< struct $
mrec$
end >> in
297 let applied = apply_functor <:module_expr
< $uid
:wrapper_name$
>>
298 (List.map
(fun (p
,_) -> <:module_expr
< $uid
:NameMap.find p context
.argmap$
>>)
301 List.map
(fun (name
,params
,rhs
,_,_) ->
302 let modname = classname ^
"_"^ name
in
303 let rhs = <:module_expr
< struct module P
= $
applied$
include P.$uid
:modname$
end >> in
304 <:str_item
< module $uid
:modname$
= $
make_functor rhs$
>>)
306 let m = <:str_item
< module $uid
:wrapper_name$
= $
fixed$
>> in
307 <:str_item
< $
m$ $list
:projected$
>>
309 let gen_sig ~classname ~context
(tname
,params
,_,_,generated
as decl
) =
310 (* FIXME implicit requirement of classname being equal to module name, hence this hack for Enum *)
311 let modulename = if classname
= "Enum" then "Deriving_Enum" else classname
in
313 raise
(Underivable
("deriving: types called `a' are not allowed.\n"
314 ^
"Please change the name of your type and try again."))
316 if generated
then <:sig_item
< >>
318 let t = List.fold_right
319 (fun (p
,_) m -> <:module_type
< functor ($
NameMap.find p context
.argmap$
: $uid
:classname$
.$uid
:classname$
) -> $
m$
>>)
321 <:module_type
< $uid
:modulename$
.$uid
:classname$
with type a
= $
atype context decl$
>> in
322 <:sig_item
< module $uid
:Printf.sprintf
"%s_%s" classname tname$
: $
t$
>>
324 let gen_sigs ~classname ~context ~decls
=
325 <:sig_item
< $list
:List.map
(gen_sig ~classname ~context
) decls$
>>
328 let find_non_regular params tnames decls
: name list
=
331 inherit [name list
] fold
as default
332 method crush
= List.concat
333 method expr
= function
334 | `Constr
([t], args
)
335 when NameSet.mem
t tnames
->
337 (fun (p
,_) a
-> match a
with
338 | `Param
(q
,_) when p
= q
-> []
342 | e
-> default#expr e
346 let has_params params
(_, ps
, _, _, _) = ps
= params
in
348 | [] -> invalid_arg
"extract_params"
349 | (_,params
,_,_,_)::rest
350 when List.for_all
(has_params params
) rest
->
352 | (_,_,rhs,_,_)::_ ->
353 (* all types in a clique must have the same parameters *)
354 raise
(Underivable
("Instances can only be derived for "
355 ^
"recursive groups where all types\n"
356 ^
"in the group have the same parameters."))
358 let setup_context loc
(tdecls
: decl list
) : context
=
359 let params = extract_params tdecls
360 and tnames
= NameSet.fromList
(List.map
(fun (name
,_,_,_,_) -> name
) tdecls
) in
361 match find_non_regular params tnames tdecls
with
363 failwith
("The following types contain non-regular recursion:\n "
364 ^
String.concat
", " names
365 ^
"\nderiving does not support non-regular types")
369 (fun (p
,_) m -> NameMap.add p
(Printf.sprintf
"V_%s" p
) m)
377 type deriver
= Loc.t * context
* Type.decl list
-> Ast.str_item
378 and sigderiver
= Loc.t * context
* Type.decl list
-> Ast.sig_item
379 let derivers : (name
, (deriver
* sigderiver
)) Hashtbl.t = Hashtbl.create
15
380 let register = Hashtbl.add
derivers
382 try Hashtbl.find derivers classname
383 with Not_found
-> raise
(NoSuchClass classname
)
384 let is_registered : name
-> bool =
385 fun classname
-> try ignore
(find classname
); true with NoSuchClass
_ -> false