test case for #57
[sqlgg.git] / src / gen.ml
blob5cd8f19465000c68d79bfcdddef54dedd76859f0
1 (* Code generation *)
3 open Printf
4 open ExtLib
5 open Prelude
6 open Stmt
8 type subst_mode = | Named | Unnamed | Oracle | PostgreSQL
10 type stmt = { schema : Sql.Schema.t; params : Sql.params; kind : kind; props : Props.t; }
12 (** defines substitution function for parameter literals *)
13 let params_mode = ref None
15 let (inc_indent,dec_indent,make_indent) =
16 let v = ref 0 in
17 (fun () -> v := !v + 2),
18 (fun () -> v := !v - 2),
19 (fun () -> String.make !v ' ')
21 let print_indent () = print_string (make_indent ())
22 let indent s = print_indent (); print_string s
23 let indent_endline s = print_indent (); print_endline s
24 let output fmt = kprintf indent_endline fmt
25 let output_l = List.iter indent_endline
26 let print fmt = kprintf print_endline fmt
27 let indented k = inc_indent (); k (); dec_indent ()
29 let name_of attr index =
30 match attr.Sql.name with
31 | "" -> sprintf "_%u" index
32 | s -> s
34 let param_name_to_string ((name,_):Sql.param_id) index =
35 match name with
36 | None -> sprintf "_%u" index
37 | Some s -> s
39 let make_name props default = Option.default default (Props.get props "name")
40 let default_name str index = sprintf "%s_%u" str index
42 let choose_name props kind index =
43 let fix = String.map begin function
44 | ('a'..'z' | 'A'..'Z' | '0'..'9' | '_' as c) -> c
45 | _ -> '_'
46 end in
47 let fix s =
48 match Props.get props "subst" with
49 | Some x -> let (_,s) = String.replace ~str:s ~sub:("%%"^x^"%%") ~by:x in fix s
50 | None -> fix s
52 let name = match kind with
53 | Create t -> sprintf "create_%s" (fix t)
54 | CreateIndex t -> sprintf "create_index_%s" (fix t)
55 | Update (Some t) -> sprintf "update_%s_%u" (fix t) index
56 | Update None -> sprintf "update_%u" index
57 | Insert (_,t) -> sprintf "insert_%s_%u" (fix t) index
58 | Delete t -> sprintf "delete_%s_%u" (fix t) index
59 | Alter t -> sprintf "alter_%s_%u" (fix t) index
60 | Drop t -> sprintf "drop_%s" (fix t)
61 | Select _ -> sprintf "select_%u" index
62 | Other -> sprintf "statement_%u" index
64 make_name props name
66 let substitute_params s params f =
67 let index = ref 0 in
68 let b = Buffer.create (String.length s) in
69 let last = List.fold_left (fun i ((_,(i1,i2)),_ as param) ->
70 let prefix = String.slice ~first:i ~last:i1 s in
71 Buffer.add_string b prefix;
72 Buffer.add_string b (f !index param);
73 incr index;
74 i2) 0 params in
75 Buffer.add_string b (String.slice ~first:last s);
76 Buffer.contents b
78 let subst_named index (id,_) = "@" ^ (param_name_to_string id index)
79 let subst_oracle index (id,_) = ":" ^ (param_name_to_string id index)
80 let subst_postgresql index _ = "$" ^ string_of_int (index + 1)
81 let subst_unnamed _ _ = "?"
83 let get_sql stmt =
84 let sql = Props.get stmt.props "sql" |> Option.get in
85 match !params_mode with
86 | None -> sql
87 | Some subst ->
88 let f = match subst with
89 | Named -> subst_named
90 | Unnamed -> subst_unnamed
91 | Oracle -> subst_oracle
92 | PostgreSQL -> subst_postgresql
94 substitute_params sql stmt.params f
96 let time_string () =
97 let module U = Unix in
98 let t = U.time () |> U.gmtime in
99 sprintf "%04u-%02u-%02uT%02u:%02uZ" (1900 + t.U.tm_year) (t.U.tm_mon+1) t.U.tm_mday t.U.tm_hour t.U.tm_min
101 module type LangTypes = sig
103 val as_api_type : Sql.Type.t -> string
104 val as_lang_type : Sql.Type.t -> string
108 module Translate(T : LangTypes) = struct
110 let param_type_to_string = T.as_api_type
111 let schema_to_values = List.mapi (fun i attr -> name_of attr i, T.as_lang_type attr.Sql.domain)
112 (* let schema_to_string = G.Values.to_string $ schema_to_values *)
113 let all_params_to_values l =
114 l |> List.mapi (fun i (n,t) -> param_name_to_string n i, T.as_lang_type t)
115 |> List.unique ~cmp:(fun (n1,_) (n2,_) -> String.equal n1 n2)
116 (* rev unique rev -- to preserve ordering with respect to first occurrences *)
117 let params_to_values = List.rev $ List.unique ~cmp:(=) $ List.rev $ all_params_to_values
121 module type Generator = sig
122 type t
123 val generate : t -> string -> stmt list -> unit
124 val start : unit -> t
125 val comment : t -> ('a,unit,string,unit) format4 -> 'a
126 val empty_line : t -> unit
129 module Make(S : Generator) = struct
131 let generate_header out =
132 S.comment out "DO NOT EDIT MANUALLY";
133 S.comment out "";
134 S.comment out "generated by sqlgg %s on %s" Sqlgg_config.version (time_string ());
135 S.empty_line out
137 let process name stmts =
138 let out = S.start () in
139 if !Sqlgg_config.gen_header then generate_header out;
140 S.generate out name stmts