added support for AbsOpt
[bugg-scheme-compiler.git] / src / sml / cg.sml~
blobf31c95fb5f9a6aa5f42d90f4496096ebe0f59c51
1 (* Code Generation *)
3 (* Data Segment - initialized data segment (constants) *)
4 structure DataSegment: sig
5     val reset: unit -> unit            (* reset symbol and constant tables between code generations *)
6     val add: Sexpr -> string           (* register a new constant *)
7     val emit: unit -> string           (* emit code for all constants *)
8     val symbols: (string*string) list  (* fetch all constant symbols *)
9 end = struct
10     (* new_name - generate unique names *)
11     val initial_names =
12         ["a", "boot",
13         "sc_undef", "sc_void", "sc_nil", "sc_false_data", "sc_true_data"]
14     val (names:string list ref) = ref initial_names
16     local
17         fun is_new name = not (List.exists (fn n => n=name) (!names)) (* allocate new distinct name *)
18         fun add name = names := name :: !names                        (* add names to the table of existing ones *)
19     in
20         (* if the name is not in the table, register and return it,
21         otherwise, add an increasing number to it until a distinct name is found *)
22         fun new_name name =
23             if is_new name then
24             ( add name
25             ; name )
26             else
27             let fun subname name i =
28                 let val name' = name ^ (Int.toString i)
29                 in if is_new name' then
30                     ( add name'
31                     ; name' )
32                 else
33                     subname name (i + 1) end
34             in subname name 0 end
35     end; (* local *)
37     (* list of constants *)
38     val consts = ref [] : (string*Sexpr) list ref;
40     (* convert the collected consts to global variables in C *)
42     (* table of symbol names and representations;
43     used to
44         - maintain a single representation for all occurences of a symbol
45         - register symbols in run-time symbol table for symbol->string and string->symbol *)
46     val (symbols: (string*string) list ref) = ref []
48     (* return existing name for singletons, otherwise generate a new one;
49     the singletons are hardcoded once and forever, so that they are eq? comparable *)
50     fun const_name (x) =
51         case x of
52         Void => "sc_void"
53         | Nil => "sc_nil"
54         | Bool false => "sc_false"
55         | Bool true => "sc_true"
56         | Symbol s => (* symbols must be eq? comparable *)
57         (case List.find (fn sc => (#1 sc)=s) (!symbols)
58             of SOME sc => #2 sc
59              | NONE => let val name = new_name ("sc_symbol")
60                 in symbols := (s, name) :: (!symbols);
61                    name
62                 end)
63         | _ => (* anything else is just new constant every time *)
64         (new_name (case x 
65                 of (Pair _) => "sc_pair"
66                 | (Vector _) => "sc_vector"
67                 | (String _) => "sc_string"
68                 | (Number _) => "sc_number"
69                 | (Char _) => "sc_char"))
71     (* generate code for a single constant:
72         data definition
73     followed by
74         constant definition - a value of SchemeObject type *)
76     fun const_code (name, scheme_type, data_name, data_variant, data_value) =
77         "SchemeObjectData " ^ data_name ^ " = {" ^ "." ^ data_variant ^ " = {" ^ data_value ^ "}};\n" ^
78         "SchemeObject " ^ name ^ " = {" ^ scheme_type ^ ", &" ^ data_name ^ "};\n"
80     (* symbol representation is written once for every symbol *)
82     val (written_symbols: string list ref) = ref []
84     fun emit_const (name, value) = 
85         case value
86         of Pair (a,d) => emit_pair (name, a, d)
87         | Vector es => emit_vector (name, es)
88         | Symbol s => (* the same symbol may appear several times, but must be emitted only once *)
89         if (List.exists (fn n=>n=name) (!written_symbols))
90         then ""
91         else ( written_symbols := name :: (!written_symbols)
92             ; emit_symbol (name, s) )
93         | String s => emit_string (name, s)
94         | Number i => emit_number (name, i)
95         | Char c => emit_char (name, c)
96         | _ => "" (* singletons *)
98     and emit_pair (name, a, d) = 
99         let
100             val a_name = const_name (a)
101             val d_name = const_name (d)
102             val data_name = new_name (name ^ "_data")
103             (* recusrively define car and cdr *)
104             val a_code = emit_const (a_name, a)
105             val d_code = emit_const (d_name, d)
106         in (a_code ^ d_code ^ (const_code (name, "SCHEME_PAIR", data_name,
107                                 "spd", "&" ^ a_name ^ ", &" ^ d_name))) end
109     (* generates names and statements for a list of constants.
110        returns a list of names and a (matching) list of statements *)
111     and sexprs_to_stmts [] names stmts = (names,stmts)
112       | sexprs_to_stmts (em :: rest) names stmts =
113         let val name = const_name em        (* generate name *)
114             val stmt = emit_const (name,em) (* generate code *)
115         in
116             sexprs_to_stmts rest (names @ [name]) (stmts @ [stmt])
117         end
119     and emit_vector (name, es) =
120         let
121             (* generate names and statements for the vector elements *)
122             val (em_names,em_stmts) = sexprs_to_stmts es [] []
123             (* generate name,statement for the array that holds the elements *)
124             val arr_name = new_name (name^"_arr")
125             val arr_stmt = "SchemeObject* "^arr_name^"[] = {&"^
126                            (String.concatWith ", &" em_names)^"};\n"
127             val data_name = new_name (name^"_data")
128         in
129             (String.concat em_stmts)^
130             arr_stmt^
131             const_code (name, "SCHEME_VECTOR", data_name, "svd",
132                 (Int.toString (List.length es))^", "^arr_name)
133         end
135     and emit_symbol (name, s) =
136         let val data_name = new_name (name ^ "_data")
137             val syment_name = new_name (name ^ "_syment")
138         in ("SymbolEntry "^syment_name^" = {\""^(String.toCString s)^"\",0,NULL};\n"^
139             const_code (name, "SCHEME_SYMBOL", data_name, "smd",
140                 "&"^syment_name))
141         end
143     and emit_string (name, s) =
144         let val data_name = new_name (name ^ "_data")
145         in (const_code (name, "SCHEME_STRING", data_name, "ssd",
146                 (Int.toString (String.size s))^", \""^(String.toCString s)^"\""))
147         end
149     and emit_number (name, i) =
150         let val data_name = new_name (name ^ "_data")
151         in (const_code (name, "SCHEME_INT", data_name, "sid",
152                 "(int)" ^ (if i<0 then "-"^(Int.toString (~i))
153                                   else (Int.toString i))))
154         end
156     and emit_char (name, c) =
157         let val data_name = new_name (name ^ "_data")
158         in (const_code (name, "SCHEME_CHAR", data_name,
159                 "scd", "(char)" ^ (Int.toString (Char.ord c)))) end
161     fun reset () =
162         ( consts := []
163         ; symbols := []
164         ; written_symbols := [] )
166     fun add x =
167         let val name = const_name x
168         in consts := (name, x) :: (!consts) (* repeated symbols are detected in emit_const *)
169         ; name end
171     fun emit () = 
172         "SchemeObject sc_undef = {-1, NULL};\n" ^
173         "SchemeObject sc_void = {SCHEME_VOID, NULL};\n" ^
174         "SchemeObject sc_nil = {SCHEME_NIL, NULL};\n" ^
175         "SchemeObjectData sc_false_data = {.sbd = {0}};\n" ^
176         "SchemeObject sc_false = {SCHEME_BOOL, &sc_false_data};\n" ^
177         "SchemeObjectData sc_true_data = {.sbd = {1}};\n" ^
178         "SchemeObject sc_true = {SCHEME_BOOL, &sc_true_data};\n" ^
179         String.concat (map emit_const (List.rev (!consts)))
181     (* freeze the symbols *)
182     val symbols = !symbols
184 end; (* DataSegment *)
186 structure ErrType = struct
187     datatype Type =
188         None
189         | ArgsCount of int
190         | AppNonProc
191         | NotAPair
192         | UndefinedSymbol of string;
194     fun toString None = "\"\""
195         | toString (ArgsCount expected) = "MSG_ERR_ARGCOUNT("^(intToCString expected)^")"
196         | toString AppNonProc = "MSG_ERR_APPNONPROC"
197         | toString NotAPair = "MSG_ERR_NOTPAIR"
198         | toString (UndefinedSymbol name) = "\"Symbol "^(String.toCString name)^" not defined\""
199 end;
201 structure CodeSegment (* : sig
202     val ErrType;
203     type StatementType;
204     val reset : unit -> unit;
205     val add : StatementType -> unit;
206     val emit : unit -> string;
207 end *) = struct
208     datatype StatementType =
209         Comment of string
210       | Label of string
211       | Assertion of string * ErrType.Type
212       | Error of ErrType.Type
213       | ErrorIf of string * ErrType.Type
214       | Branch of string
215       | BranchIf of string * string
216       | Set of string * string
217       | Push of string
218       | Pop of string
219       | Return
220       | Statement of string;
222     fun statementToString (Statement stmt) = "\t"^stmt^";"
223       | statementToString (Set (n,v)) = "\t"^n^" = "^v^";"
224       | statementToString (Branch l) = "\tgoto "^l^";"
225       | statementToString (BranchIf (c,l)) = "\tif ("^c^") goto "^l^";"
226       | statementToString (Comment s) = "\t/* "^s^" */"
227       | statementToString (Label s) = s^":"
228       | statementToString (Assertion (p,e)) = "\tASSERT_ALWAYS("^p^","^(ErrType.toString e)^");"
229       | statementToString (Error e) = "\tfprintf(stderr,\"%s\\n\","^(ErrType.toString e)^"); exit(-1);"
230       | statementToString (ErrorIf (p,e)) = "\tif ("^p^") {fprintf(stderr,\"%s\\n\","^(ErrType.toString e)^"); exit(-1);}"
231       | statementToString (Push s) = "\tpush("^s^");"
232       | statementToString (Pop s) = "\t"^s^" = pop();"
233       | statementToString Return = "\tRETURN();"
234     ;
236     val statements = ref [] : StatementType list ref;
238     fun reset () =
239         statements := [];
241     fun add stmt = statements := !statements @ [stmt];
243     fun emit () =
244         ""^
245         (String.concatWith "\n" (List.map statementToString (!statements)))^
246         "\n\t;";
248 end; (* CodeSegment *)
250 structure Program: sig
251     val reset : unit -> unit;
252     val gen : Expr -> int -> unit;
253     val emit : int * int -> string;
254 end = struct
256     fun makeLabeler prefix =
257         let
258             val number = ref 0;
259         in
260             fn () => (number:= !number + 1
261                      ;prefix^(Int.toString (!number)))
262         end;
264     val makeLabelElse = makeLabeler "else";
265     val makeLabelEndif = makeLabeler "endIf";
266     val makeLabelEndOr = makeLabeler "endOr";
267     val makeLabelSkipBody = makeLabeler "skipBody";
268     val makeLabelBody = makeLabeler "body";
269     val makeLabelRet = makeLabeler "ret";
270     val makeLabelApp = makeLabeler "app";
272     val CSadd = CodeSegment.add;
274     val lblExit = "_exit";
276     fun addProlog () =
277         (CSadd (CodeSegment.Comment "<prolog>")
278         ;CSadd (CodeSegment.Comment "Push an initial activation frame")
279         ;CSadd (CodeSegment.Push "0") (* no arguments *)
280         ;CSadd (CodeSegment.Push "(int)NULL") (* no enviroments *)
281         ;CSadd (CodeSegment.Push ("(int)&&"^lblExit))
282         ;CSadd (CodeSegment.Push "fp")
283         ;CSadd (CodeSegment.Set ("fp","sp"))
284         ;CSadd (CodeSegment.Comment "</prolog>")
285         );
287     fun addEpilog () =
288         (CSadd (CodeSegment.Comment "<epilog>")
289         ;CSadd (CodeSegment.Pop "fp")
290         ;CSadd (CodeSegment.Label lblExit)
291         ;CSadd (CodeSegment.Comment "</epilog>")
292         );
294     fun reset () =
295         (DataSegment.reset ()
296         ;CodeSegment.reset ()
297         ;addProlog ()
298         );
300     fun maprtl f l = map f (List.rev l);
302     (* Generate code for a given expression
303        THE INVARIANT:      r_res contains the value of the
304                              expression after execution
305     *)
306     fun gen (Const se) absDepth =
307         let
308             val lblConst = DataSegment.add se
309         in
310             CSadd (CodeSegment.Set ("r_res","(int)&"^lblConst))
311         end
312       | gen (Var _)         absDepth = raise Match (* shouldn't be here *)
313       | gen (VarFree name)  absDepth =
314         let
315             val lblElse = makeLabelElse ()
316             val lblEndif = makeLabelEndif ()
317         in (* probe for symbol in runtime data-structure *)
318             (CSadd (CodeSegment.Set ("r_res","(int)probeSymbolDefined(\""^name^"\",topLevel)"))
319             ;CSadd (CodeSegment.BranchIf ("r_res==0",lblElse))
320             ;CSadd (CodeSegment.BranchIf ("! ((SymbolEntry*)r_res)->isDefined",lblElse))
321             ;CSadd (CodeSegment.Set ("r_res","(int)((SymbolEntry*)r_res)->sob"))
322             ;CSadd (CodeSegment.Branch lblEndif)
323             ;CSadd (CodeSegment.Label lblElse)
324             ;CSadd (CodeSegment.Error (ErrType.UndefinedSymbol name))
325             ;CSadd (CodeSegment.Label lblEndif)
326             )
327         end
328       | gen (VarParam (name,ndx))         absDepth =
329         (CSadd (CodeSegment.Assertion ("("^(intToCString ndx)^">=0) & ("^(intToCString ndx)^"<ST_ARG_COUNT())",ErrType.None))
330         ;CSadd (CodeSegment.Set ("r_res",("ST_ARG("^(intToCString ndx)^")")))
331         )
332       | gen (VarBound (name,major,minor)) absDepth =
333         CSadd (CodeSegment.Set ("r_res",("((int**)ST_ENV())["^(intToCString major)^"]["^(intToCString minor)^"]")))
334       | gen (If (test,dit,dif))           absDepth =
335         let
336             val lblElse = makeLabelElse ()
337             val lblEndif = makeLabelEndif ()
338             val lblFalse = DataSegment.add (Bool false)
339         in
340             (gen test absDepth
341             ;CSadd (CodeSegment.BranchIf ("(SchemeObject*)r_res==&"^lblFalse,lblElse))
342             ;gen dit absDepth
343             ;CSadd (CodeSegment.Branch lblEndif)
344             ;CSadd (CodeSegment.Label lblElse)
345             ;gen dif absDepth
346             ;CSadd (CodeSegment.Label lblEndif)
347             )
348         end
349       | gen (abs as Abs _)    absDepth = genAbs abs absDepth
350       | gen (abs as AbsOpt _) absDepth = genAbs abs absDepth
351       | gen (abs as AbsVar _) absDepth = genAbs abs absDepth
352       | gen (App (proc,args)) absDepth =
353         let
354             val lblRet = makeLabelRet ()
355             val lblApp = makeLabelApp ()
356         in
357             (* for each arg in args (backwards) do:
358                   evaluate arg
359                   push r_res to stack *)
360             ((maprtl (fn arg => (gen arg absDepth;
361                                  CSadd (CodeSegment.Push "r_res")))
362                      args)
363             (* push length(args) to stack *)
364             ;CSadd (CodeSegment.Push (intToCString (List.length args)))
365             (* evaluate proc *)
366             ;gen proc absDepth
367             (* if r_res is not a closure then: error *)
368             ;CSadd (CodeSegment.BranchIf ("IS_SOB_CLOSURE(r_res)",lblApp))
369             ;CSadd (CodeSegment.Error ErrType.AppNonProc)
370             ;CSadd (CodeSegment.Label lblApp)
371             (* push proc.env to stack *)
372             ;CSadd (CodeSegment.Push "(int)SOB_CLOSURE_ENV(r_res)")
373             (* push return address *)
374             ;CSadd (CodeSegment.Push ("(int)&&"^lblRet))
375             (* goto proc.code *)
376             ;CSadd (CodeSegment.Branch "*(SOB_CLOSURE_CODE(r_res))")
377             (* return address *)
378             ;CSadd (CodeSegment.Label lblRet)
379             ;CSadd (CodeSegment.Set ("sp","fp"))
380             )
381         end
382       | gen (AppTP (proc,args))         absDepth =
383         (* for each arg in args (backwards) do:
384                 evaluate arg
385                 push r_res to stack *)
386         ((maprtl (fn arg => (gen arg absDepth;
387                                 CSadd (CodeSegment.Push "r_res")))
388                     args)
389         (* push length(args) to stack *)
390         ;CSadd (CodeSegment.Push (intToCString (List.length args)))
391         (* evaluate proc *)
392         ;gen proc absDepth
393         (* if r_res is not a closure then: error *)
394         ;CSadd (CodeSegment.ErrorIf ("! IS_SOB_CLOSURE(r_res)",ErrType.AppNonProc))
395         (* push proc.env to stack *)
396         ;CSadd (CodeSegment.Push "(int)SOB_CLOSURE_ENV(r_res)")
397         (* push return address (of current activation frame) *)
398         ;CSadd (CodeSegment.Push "ST_RET()")
399         (* override current activation frame *)
400         ;CSadd (CodeSegment.Statement ("shiftActFrmDown()"))
401         (* goto proc.code *)
402         ;CSadd (CodeSegment.Branch "*(SOB_CLOSURE_CODE(r_res))")
403         )
404       | gen (Seq [])                    absDepth = ()
405       | gen (Seq (e :: rest))           absDepth = (gen e absDepth; gen (Seq rest) absDepth)
406       | gen (Or preds)                  absDepth =
407         let
408             val lblEndOr = makeLabelEndOr ()
409         in
410            (genOrPreds lblEndOr preds
411            ;CSadd (CodeSegment.Label lblEndOr)
412            )
413         end
414       | gen (Set ((VarFree name),value)) absDepth =
415         (* Set on VarFree is just the same as Def on VarFree *)
416         gen (Def ((VarFree name),value)) absDepth
417       | gen (Set ((VarParam (name,ndx)),value)) absDepth =
418         let val lblVoid = DataSegment.add Void
419         in
420             (gen value absDepth
421             ;CSadd (CodeSegment.Assertion ("("^(intToCString ndx)^">=0) & ("^(intToCString ndx)^"<ST_ARG_COUNT())",ErrType.None))
422             ;CSadd (CodeSegment.Set (("ST_ARG("^(intToCString ndx)^")"),"r_res"))
423             ;CSadd (CodeSegment.Set ("r_res","(int)&"^lblVoid))
424             )
425         end
426       | gen (Set ((VarBound (name,major,minor)),value)) absDepth =
427         let val lblVoid = DataSegment.add Void
428         in
429             (gen value absDepth
430             ;CSadd (CodeSegment.Set (("((int**)ST_ENV())["^(intToCString major)^"]["^(intToCString minor)^"]"),"r_res"))
431             ;CSadd (CodeSegment.Set ("r_res","(int)&"^lblVoid))
432             )
433         end
434       | gen (Def ((VarFree name),value)) absDepth =
435         let val lblVoid = DataSegment.add Void
436         in
437             (gen value absDepth
438             ;CSadd (CodeSegment.Set ("r[0]","(int)getSymbol(\""^name^"\",topLevel)"))
439             ;CSadd (CodeSegment.Set ("((SymbolEntry*)r[0])->isDefined","1"))
440             ;CSadd (CodeSegment.Set ("((SymbolEntry*)r[0])->sob","(SchemeObject*)r_res"))
441             ;CSadd (CodeSegment.Set ("r_res","(int)&"^lblVoid))
442             )
443         end
444       | gen (Def (_,_)) absDepth = raise Match (* shouldn't be here *)
445     and genOrPreds _ [] absDepth = ()
446       | genOrPreds lblEndOr (p :: rest) absDepth =
447         let
448             val lblFalse = DataSegment.add (Bool false)
449         in
450             (gen p absDepth
451             ;CSadd (CodeSegment.BranchIf ("(SchemeObject*)r_res!=&"^lblFalse,lblEndOr))
452             ;genOrPreds lblEndOr rest absDepth
453             )
454         end
455     and genAbs abs absDepth =
456         let
457             val formalParams = case abs of
458                 Abs    (params,_)   => List.length params
459               | AbsOpt (params,_,_) => List.length params + 1
460               | AbsVar (_,_)        => 1
461               | _ => raise Match (* shouldn't be here *)
462             val body = case abs of
463                 Abs    (_,body)   => body
464               | AbsOpt (_,_,body) => body
465               | AbsVar (_,body)   => body
466               | _ => raise Match (* shouldn't be here *)
467             val lblSkipBody = makeLabelSkipBody ()
468             val lblBody = makeLabelBody ()
469         in
470             (* 1. extend enviroment *)
471             (CSadd (CodeSegment.Set ("r[0]","(int)extendEnviroment( (int**)"^
472                                             (if absDepth=0 then "NULL"
473                                                            else "ST_ENV()")^
474                                             ", "^
475                                             (intToCString absDepth)^
476                                             ")"))
477             (* 2. prepare code *)
478             ;CSadd (CodeSegment.Branch lblSkipBody)
479             ;CSadd (CodeSegment.Label lblBody)
480             (* prolog *)
481             ;CSadd (CodeSegment.Push "fp")
482             ;CSadd (CodeSegment.Set ("fp","sp"))
483             (* fix stack if needed *)
484             ;case abs of
485                 (Abs _)    => ()
486               | (AbsOpt _) => CSadd (CodeSegment.Statement ("prepareStackForAbsOpt("^(intToCString formalParams)^")"))
487               | (AbsVar _) => raise Match (* todo *)
488               | _ => raise Match (* shouldn't be here *)
489             (* verify number of actual arguments *)
490             ;CSadd (CodeSegment.ErrorIf ("ST_ARG_COUNT()!="^(intToCString formalParams),
491                                          (ErrType.ArgsCount formalParams)))
492             (* body *)
493             ;gen body (absDepth+1)
494             (* epilog *)
495             ;CSadd (CodeSegment.Pop ("fp"))
496             ;CSadd CodeSegment.Return
497             ;CSadd (CodeSegment.Label lblSkipBody)
498             (* 3. create closure *)
499             ;CSadd (CodeSegment.Set ("r_res","(int)makeSchemeClosure((void*)r[0],&&"^lblBody^")"))
500             )
501         end
502     ;
504     fun emit (nregs,stacksize) =
505        (addEpilog ();
506         "/* COMP091 Scheme->C Compiler Generated Code */\n\n" ^
507         "#include \"scheme.h\"\n" ^
508         "#include \"assertions.h\"\n" ^
509         "#include \"arch.h\"\n" ^
510         "#include \"rtemgr.h\"\n" ^
511         "#include \"strings.h\"\n" ^
512         "extern SymbolNode *topLevel;\n"^
513         "\n/* Data Segment */\n" ^
514         (DataSegment.emit ()) ^
515         "\n/* Code Segment */\n" ^
516         "void schemeCompiledFunction() {\n" ^
517         "\t#include \"builtins.c\"\n\n"^
518         "\tinitArchitecture("^Int.toString(stacksize)^","^Int.toString(nregs)^");\n" ^
519         "\n" ^
520         (CodeSegment.emit ()) ^
521         "\n}\n"
522        );
524 end; (* Program *)