caml: do not generate duplicate code in Fold (close #48)
[sqlgg.git] / src / gen_caml.ml
blob9413b79facead3009cb43ba4db5e40a7e72d243e
1 (* OCaml code generation *)
3 open ExtLib
4 open Prelude
5 open Printf
7 open Stmt
8 open Gen
9 open Sql
11 let inline_values = String.concat " "
13 let quote = String.replace_chars (function '\n' -> "\\n\\\n" | '\r' -> "" | '"' -> "\\\"" | c -> String.make 1 c)
14 let quote s = "\"" ^ quote s ^ "\""
16 let rec replace_all ~str ~sub ~by =
17 match String.replace ~str ~sub ~by with
18 | (true,s) -> replace_all ~str:s ~sub ~by
19 | (false,s) -> s
21 let quote_comment_inline str =
22 let str = replace_all ~str ~sub:"*)" ~by:"* )" in
23 replace_all ~str ~sub:"(*" ~by:"( *"
25 let make_comment str = "(* " ^ (quote_comment_inline str) ^ " *)"
26 let comment () fmt = Printf.kprintf (indent_endline $ make_comment) fmt
28 let empty_line () = print_newline ()
30 module L = struct
31 let as_lang_type = function
32 | Type.Blob -> Type.to_string Type.Text
33 | t -> Type.to_string t
35 let as_api_type = as_lang_type
36 end
38 let get_column index attr =
39 sprintf "(T.get_column_%s stmt %u)"
40 (L.as_lang_type attr.domain)
41 index
43 module T = Translate(L)
45 (* open L *)
46 open T
48 let output_schema_binder _ schema =
49 let name = "invoke_callback" in
50 output "let %s stmt =" name;
51 indented (fun () ->
52 output "callback";
53 indented (fun () ->
54 List.iteri (fun i a -> output "%s" (get_column i a)) schema));
55 output "in";
56 name
58 let output_select1_cb _ schema =
59 let name = "get_row" in
60 output "let %s stmt =" name;
61 indented (fun () ->
62 List.mapi get_column schema |> String.concat ", " |> indent_endline);
63 output "in";
64 name
66 let output_schema_binder index schema kind =
67 match schema with
68 | [] -> "execute",""
69 | _ -> match kind with
70 | Stmt.Select (`Zero_one | `One) -> "select1", output_select1_cb index schema
71 | _ -> "select",output_schema_binder index schema
73 let is_callback stmt =
74 match stmt.schema, stmt.kind with
75 | [],_ -> false
76 | _, Stmt.Select (`Zero_one | `One) -> false
77 | _ -> true
79 let params_to_values = List.map fst $ params_to_values
81 let set_param index param =
82 let (id,t) = param in
83 output "T.set_param_%s p %u %s;"
84 (param_type_to_string t)
85 index
86 (param_name_to_string id index)
88 let output_params_binder _ params =
89 output "let set_params stmt =";
90 inc_indent ();
91 output "let p = T.start_params stmt %u in" (List.length params);
92 List.iteri set_param params;
93 output "T.finish_params p";
94 dec_indent ();
95 output "in";
96 "set_params"
98 let output_params_binder index params =
99 match params with
100 | [] -> "T.no_params"
101 | _ -> output_params_binder index params
103 let prepend prefix = function s -> prefix ^ s
105 type t = unit
107 let start () = ()
109 let generate_stmt fold index stmt =
110 let name = choose_name stmt.props stmt.kind index |> String.uncapitalize in
111 let subst = Props.get_all stmt.props "subst" in
112 let values = (subst @ params_to_values stmt.params) |> List.map (prepend "~") |> inline_values in
113 match fold, is_callback stmt with
114 | true, false -> output "let %s = %s" name name (* alias non-fold impl, identical *)
115 | _ ->
116 let fold = fold && is_callback stmt in
117 let all_params = values ^ (if is_callback stmt then " callback" else "") ^ (if fold then " acc" else "") in
118 output "let %s db %s =" name all_params;
119 inc_indent ();
120 let sql = quote (get_sql stmt) in
121 let sql = match subst with
122 | [] -> sql
123 | vars ->
124 output "let __sqlgg_sql =";
125 output " let replace_all ~str ~sub ~by =";
126 output " let rec loop str = match ExtString.String.replace ~str ~sub ~by with";
127 output " | true, str -> loop str";
128 output " | false, s -> s";
129 output " in loop str";
130 output " in";
131 output " let sql = %s in" sql;
132 List.iter begin fun var ->
133 output " let sql = replace_all ~str:sql ~sub:(\"%%%%%s%%%%\") ~by:%s in" var var;
134 end vars;
135 output " sql";
136 output "in";
137 "__sqlgg_sql"
139 let (func,callback) = output_schema_binder index stmt.schema stmt.kind in
140 let params_binder_name = output_params_binder index stmt.params in
141 if fold then output "let r_acc = ref acc in";
142 let s_callback = if callback = "" then "" else " " ^ callback in
143 output "T.%s db %s %s%s" func sql params_binder_name
144 (if fold then " (fun x -> r_acc := " ^ callback ^ " x !r_acc);" else s_callback);
145 if fold then output "!r_acc";
146 dec_indent ();
147 empty_line ()
149 let generate () name stmts =
151 let types =
152 String.concat " and " (List.map (fun s -> sprintf "%s = T.%s" s s) ["num";"text";"any"])
155 output "module %s (T : Sqlgg_traits.M) = struct" (String.capitalize name);
156 empty_line ();
157 inc_indent ();
158 Enum.iteri (generate_stmt false) (Enum.clone stmts);
159 output "module Fold = struct";
160 inc_indent ();
161 Enum.iteri (generate_stmt true) (Enum.clone stmts);
162 dec_indent ();
163 output "end (* module Fold *)";
164 dec_indent ();
165 output "end (* module %s *)" (String.capitalize name)