Initial snarf.
[shack.git] / naml / naml_ir_partial.ml
blobdee55b3c623f6d63be5fd92e0a29975ce7f0b849
1 (*
2 * partial application handling for IR
4 * ----------------------------------------------------------------
6 * @begin[license]
7 * Copyright (c) Geoffrey Irving, Dylan Symon
9 * This program is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU General Public License
11 * as published by the Free Software Foundation; either version 2
12 * of the License, or (at your option) any later version.
14 * This program is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 * GNU General Public License for more details.
19 * You should have received a copy of the GNU General Public License
20 * along with this program; if not, write to the Free Software
21 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
22 * @end[license]
25 open Symbol
26 open Fir
27 open Naml_ir
28 open Naml_ir_exn
29 open Naml_ir_util
31 let pa_var venv fs v c =
32 if SymbolSet.mem fs v then
33 let fty = SymbolTable.find venv v in
34 let cs = new_symbol v in
35 LetClosure (cs, fty, v, [], c cs)
36 else
37 c v
39 let pa_atom venv fs a c =
40 match a with
41 AtomVar v when SymbolSet.mem fs v ->
42 let fty = SymbolTable.find venv v in
43 let cs = new_symbol v in
44 LetClosure (cs, fty, v, [], c (AtomVar cs))
45 | _ -> c a
47 let rec pa_atom_list venv fs al c =
48 match al with
49 [] -> c []
50 | a::al ->
51 pa_atom venv fs a (fun a -> pa_atom_list venv fs al (fun al -> c (a :: al)))
53 let rec strip_fun_arg tenv t =
54 match t with
55 TyFun ([_], r) -> r
56 | TyFun (_ :: t, r) -> TyFun (t, r)
57 | TyApply (tv, tl) -> strip_fun_arg tenv (apply_type tenv tv tl)
58 | _ -> raise (IRException "wierd function in strip_fun_arg")
60 let rec make_partials tenv v f al t rt e =
61 match al with
62 [a] -> LetApply (v, rt, f, al, e)
63 | a::al ->
64 let pf = new_symbol f in
65 let t = strip_fun_arg tenv t in
66 LetApply (pf, t, f, [a], make_partials tenv v pf al t rt e)
67 | _ -> raise (IRException "tried to apply no arguments in make_partials")
69 let rec pa_expr tenv venv fs e =
70 match e with
71 LetAtom (v, ty, a, e) ->
72 pa_atom venv fs a (fun a ->
73 let venv = SymbolTable.add venv v ty in
74 LetAtom (v, ty, a, pa_expr tenv venv fs e))
75 | LetUnop (v, ty, op, a, e) ->
76 pa_atom venv fs a (fun a ->
77 let venv = SymbolTable.add venv v ty in
78 LetUnop (v, ty, op, a, pa_expr tenv venv fs e))
79 | LetBinop (v, ty, op, a1, a2, e) ->
80 pa_atom venv fs a1 (fun a1 ->
81 pa_atom venv fs a2 (fun a2 ->
82 let venv = SymbolTable.add venv v ty in
83 LetBinop (v, ty, op, a1, a2, pa_expr tenv venv fs e)))
84 | LetExt (v, ty, s, ty2, al, e) ->
85 pa_atom_list venv fs al (fun al ->
86 let venv = SymbolTable.add venv v ty in
87 LetExt (v, ty, s, ty2, al, pa_expr tenv venv fs e))
88 | TailCall (f, al) ->
89 pa_atom_list venv fs al (fun al ->
90 TailCall (f, al))
91 | Match (a, sel) ->
92 pa_atom venv fs a (fun a ->
93 Match (a, List.map (fun (s, e) -> s, pa_expr tenv venv fs e) sel))
94 | LetAlloc (v, aop, e) ->
95 let venv = SymbolTable.add venv v (type_of_alloc_op aop) in
96 (match aop with
97 AllocTuple (tclass, ty, ty_vars, al) ->
98 pa_atom_list venv fs al (fun al ->
99 LetAlloc (v, AllocTuple (tclass, ty, ty_vars, al), pa_expr tenv venv fs e))
100 | AllocDTuple (ty, ty_var, a, al) ->
101 pa_atom venv fs a (fun a ->
102 pa_atom_list venv fs al (fun al ->
103 LetAlloc (v, AllocDTuple (ty, ty_var, a, al), pa_expr tenv venv fs e)))
104 | AllocArray (ty, al) ->
105 pa_atom_list venv fs al (fun al ->
106 LetAlloc (v, AllocArray (ty, al), pa_expr tenv venv fs e))
107 | AllocVArray (ty, si, a1, a2) ->
108 pa_atom venv fs a1 (fun a1 ->
109 pa_atom venv fs a2 (fun a2 ->
110 LetAlloc (v, AllocVArray (ty, si, a1, a2), pa_expr tenv venv fs e)))
111 | AllocUnion (ty, ty_vars, tv, i, al) ->
112 pa_atom_list venv fs al (fun al ->
113 LetAlloc (v, AllocUnion (ty, ty_vars, tv, i, al), pa_expr tenv venv fs e))
114 | AllocMalloc (ty, a) ->
115 pa_atom venv fs a (fun a ->
116 LetAlloc (v, AllocMalloc (ty, a), pa_expr tenv venv fs e))
117 | AllocFrame _ ->
118 LetAlloc (v, aop, pa_expr tenv venv fs e))
119 | LetSubscript (so, v, ty, a, ai, e) ->
120 let venv = SymbolTable.add venv v ty in
121 LetSubscript (so, v, ty, a, ai, pa_expr tenv venv fs e)
122 | SetSubscript (so, a, ai, ty, av, e) ->
123 pa_atom venv fs av (fun av ->
124 SetSubscript (so, a, ai, ty, av, pa_expr tenv venv fs e))
125 | LetFuns (fdl, e) ->
126 let venv = List.fold_left (fun venv (v, _, _, t, _, _) -> SymbolTable.add venv v t) venv fdl in
127 let fs = List.fold_left (fun fs (v, _, _, _, _, _) -> SymbolSet.add fs v) fs fdl in
128 LetFuns (List.map (fun (v, dl, fc, t, vl, e) ->
129 match t with
130 TyFun (tl, _) ->
131 let venv = List.fold_left2 SymbolTable.add venv vl tl in
132 v, dl, fc, t, vl, pa_expr tenv venv fs e
133 | _ -> raise (IRException "non function type in pa_expr"))
134 fdl, pa_expr tenv venv fs e)
135 | LetApply (v, t, f, al, e) ->
136 pa_atom_list venv fs al (fun al ->
137 let fty = SymbolTable.find venv f in
138 (match fty with
139 TyFun (atl, rt) when SymbolSet.mem fs f ->
140 let atll = List.length atl in
141 let all = List.length al in
142 if all < atll then
143 LetClosure (v, t, f, al, pa_expr tenv venv fs e)
144 else if all == atll then
145 LetApply (v, t, f, al, pa_expr tenv venv fs e)
146 else (* if all > atll *)
147 let ps = new_symbol f in
148 let alh, alt = Mc_list_util.split atll al in
149 LetApply (ps, rt, f, alh, make_partials tenv v ps alt rt t (pa_expr tenv venv fs e))
150 | _ -> make_partials tenv v f al fty t (pa_expr tenv venv fs e)))
151 | Return a ->
152 pa_atom venv fs a (fun a ->
153 Return a)
154 | LetExnHandler (f, e) ->
155 pa_var venv fs f (fun f ->
156 LetExnHandler (f, pa_expr tenv venv fs e))
157 | Raise a ->
158 pa_atom venv fs a (fun a ->
159 Raise a)
160 | LetClosure _ -> raise (IRException "letclosure? we haven't made them yet in pa_expr")
162 let partial_prog prog =
163 let { prog_body = e; prog_cont = fini; prog_types = tenv } = prog in
164 let fs = SymbolSet.add SymbolSet.empty fini in
165 let e = pa_expr tenv SymbolTable.empty fs e in
166 { prog with prog_body = e }