fix build with menhir 20211230
[sqlgg.git] / lib / sql.ml
blobcf6817f99eae57de9b0e1f3e2227541ee5241a0b
1 (** *)
3 open Printf
4 open ExtLib
5 open Prelude
7 module Type =
8 struct
9 type t =
10 | Unit of [`Interval]
11 | Int
12 | Text
13 | Blob
14 | Float
15 | Bool
16 | Datetime
17 | Decimal
18 | Any
19 [@@deriving show {with_path=false}]
21 let to_string = show
23 let matches x y =
24 match x,y with
25 | Any, _ | _, Any -> true
26 | _ -> x = y
28 let is_unit = function Unit _ -> true | _ -> false
30 let order x y =
31 if x = y then
32 `Equal
33 else
34 match x,y with
35 | Any, t | t, Any -> `Order (t, Any)
36 | Int, Float | Float, Int -> `Order (Int,Float)
37 | Text, Blob | Blob, Text -> `Order (Text,Blob)
38 | Int, Datetime | Datetime, Int -> `Order (Int,Datetime)
39 | Text, Datetime | Datetime, Text -> `Order (Datetime, Text)
40 | _ -> `No
42 let common_type f x y =
43 match order x y with
44 | `Equal -> Some x
45 | `Order p -> Some (f p)
46 | `No -> None
48 let common_supertype = common_type snd
49 let common_subtype = common_type fst
50 let common_type x y = Option.is_some @@ common_subtype x y
52 type tyvar = Typ of t | Var of int
53 let string_of_tyvar = function Typ t -> to_string t | Var i -> sprintf "'%c" (Char.chr @@ Char.code 'a' + i)
55 type func =
56 | Group of t (* _ -> t *)
57 | Agg (* 'a -> 'a *)
58 | Multi of tyvar * tyvar (* 'a -> ... -> 'a -> 'b *)
59 | Ret of t (* _ -> t *) (* TODO eliminate *)
60 | F of tyvar * tyvar list
62 let monomorphic ret args = F (Typ ret, List.map (fun t -> Typ t) args)
63 let fixed = monomorphic
65 let identity = F (Var 0, [Var 0])
67 let pp_func pp =
68 let open Format in
69 function
70 | Agg -> fprintf pp "|'a| -> 'a"
71 | Group ret -> fprintf pp "|_| -> %s" (to_string ret)
72 | Ret ret -> fprintf pp "_ -> %s" (to_string ret)
73 | F (ret, args) -> fprintf pp "%s -> %s" (String.concat " -> " @@ List.map string_of_tyvar args) (string_of_tyvar ret)
74 | Multi (ret, each_arg) -> fprintf pp "{ %s }+ -> %s" (string_of_tyvar each_arg) (string_of_tyvar ret)
76 let string_of_func = Format.asprintf "%a" pp_func
78 let is_grouping = function
79 | Group _ | Agg -> true
80 | Ret _ | F _ | Multi _ -> false
81 end
83 module Constraint =
84 struct
85 type conflict_algo = | Ignore | Replace | Abort | Fail | Rollback
86 [@@deriving show{with_path=false}, ord]
88 type t = | PrimaryKey | NotNull | Null | Unique | Autoincrement | OnConflict of conflict_algo
89 [@@deriving show{with_path=false}, ord]
90 end
92 module Constraints = struct
93 include Set.Make(Constraint)
94 let show s = [%derive.show: Constraint.t list] (elements s)
95 let pp fmt s = Format.fprintf fmt "%s" (show s)
96 end
98 type attr = {name : string; domain : Type.t; extra : Constraints.t; }
99 [@@deriving show {with_path=false}]
101 let make_attribute name domain extra = {name;domain;extra}
103 module Schema =
104 struct
105 type t = attr list
106 [@@deriving show]
108 exception Error of t * string
110 (** FIXME attribute case sensitivity? *)
111 let by_name name = function attr -> attr.name = name
112 let find_by_name t name = List.find_all (by_name name) t
114 let find t name =
115 match find_by_name t name with
116 | [x] -> x
117 | [] -> raise (Error (t,"missing attribute : " ^ name))
118 | _ -> raise (Error (t,"duplicate attribute : " ^ name))
120 let make_unique = List.unique ~cmp:(fun a1 a2 -> a1.name = a2.name && a1.name <> "")
121 let is_unique t = List.length (make_unique t) = List.length t
122 let check_unique t = is_unique t || raise (Error (t,"duplicate attributes"))
124 let project names t = List.map (find t) names
126 let change_inplace t before after =
127 ignore (find t before);
128 List.map (fun attr ->
129 match by_name before attr with
130 | true -> after
131 | false -> attr ) t
133 let exists t name =
134 match (find t name : attr) with
135 | _ -> true
136 | exception _ -> false
138 let rename t oldname newname =
139 if not (exists t oldname) then raise @@ Error (t, "no such column : " ^ oldname);
140 if exists t newname then raise @@ Error (t, "column already exists : " ^ newname);
141 List.map (fun attr -> if attr.name = oldname then { attr with name = newname } else attr) t
143 let cross t1 t2 = t1 @ t2
145 (** [contains t attr] tests whether schema [t] contains attribute [attr] *)
146 let contains t attr = find t attr.name = attr
148 let check_contains t attr =
149 if not (contains t attr) then
150 raise (Error (t,"type mismatch for attribute " ^ attr.name))
152 let sub l a = List.filter (fun x -> not (List.mem x a)) l
154 let to_string v = v |> List.map (fun attr -> sprintf "%s %s" (Type.to_string attr.domain) attr.name) |>
155 String.concat ", " |> sprintf "[%s]"
156 let names t = t |> List.map (fun attr -> attr.name) |> String.concat "," |> sprintf "[%s]"
158 let natural_ t1 t2 =
159 let (common,t1only) = List.partition (fun x -> List.mem x t2) t1 in
160 if 0 = List.length common then failwith "natural'";
161 let t2only = sub t2 common in
162 common @ t1only @ t2only
164 let natural t1 t2 =
165 try natural_ t1 t2 with
166 | _ -> raise (Error (t1,"no common attributes for natural join of " ^
167 (names t1) ^ " and " ^ (names t2)))
169 let join_using l t1 t2 =
170 let common = List.map (find t1) l in
171 List.iter (check_contains t2) common;
172 common @ sub t1 common @ sub t2 common
174 let check_types t1 t2 =
175 List.iter2 (fun a1 a2 ->
176 match a1.domain, a2.domain with
177 | Type.Any, _
178 | _, Type.Any -> ()
179 | x, y when x = y -> ()
180 | _ -> raise (Error (t1, sprintf "Atributes do not match : %s of type %s and %s of type %s"
181 a1.name (Type.to_string a1.domain)
182 a2.name (Type.to_string a2.domain)))) t1 t2
184 let check_types t1 t2 =
185 try check_types t1 t2 with
186 | List.Different_list_size _ -> raise (Error (t1, (to_string t1) ^ " differs in size to " ^ (to_string t2)))
188 let compound t1 t2 = check_types t1 t2; t1
190 let add t col pos =
191 match find_by_name t col.name with
192 | [] ->
193 begin
194 match pos with
195 | `First -> col::t
196 | `Default -> t @ [col]
197 | `After name ->
199 let (i,_) = List.findi (fun _ attr -> by_name name attr) t in
200 let (l1,l2) = List.split_nth (i+1) t in
201 l1 @ (col :: l2)
202 with
203 Not_found -> raise (Error (t,"Can't insert column " ^ col.name ^ " after non-existing column " ^ name))
205 | _ -> raise (Error (t,"Already has column " ^ col.name))
207 let drop t col =
208 ignore (find t col);
209 List.remove_if (by_name col) t
211 let change t oldcol col pos =
212 match pos with
213 | `Default -> change_inplace t oldcol col
214 | `First | `After _ -> add (drop t oldcol) col pos
216 let to_string = show
217 let print x = prerr_endline (to_string x)
221 type table_name = { db : string option; tn : string } [@@deriving show]
222 let show_table_name { db; tn } = match db with Some db -> sprintf "%s.%s" db tn | None -> tn
223 let make_table_name ?db tn = { db; tn }
224 type schema = Schema.t [@@deriving show]
225 type table = table_name * schema [@@deriving show]
227 let print_table out (name,schema) =
228 IO.write_line out (show_table_name name);
229 schema |> List.iter begin fun {name;domain;extra} ->
230 IO.printf out "%10s %s %s\n" (Type.to_string domain) name (Constraints.show extra)
231 end;
232 IO.write_line out ""
234 (** optional name and start/end position in string *)
235 type param_id = { label : string option; pos : int * int; } [@@deriving show]
236 type param = { id : param_id; typ : Type.t; attr : attr option; } [@@deriving show]
237 let new_param ?attr id typ = { id; typ; attr }
238 type params = param list [@@deriving show]
239 type ctor =
240 | Simple of param_id * var list option
241 | Verbatim of string * string
242 and var =
243 | Single of param
244 | SingleIn of param
245 | Choice of param_id * ctor list
246 [@@deriving show]
247 type vars = var list [@@deriving show]
249 type alter_pos = [ `After of string | `Default | `First ]
250 type alter_action = [
251 | `Add of attr * alter_pos
252 | `RenameTable of table_name
253 | `RenameColumn of string * string
254 | `RenameIndex of string * string
255 | `Drop of string
256 | `Change of string * attr * alter_pos
257 | `None ]
259 type select_result = (schema * param list)
261 type direction = [ `Fixed | `Param of param_id ] [@@deriving show]
263 type int_or_param = [`Const of int | `Limit of param]
264 type limit_t = [ `Limit | `Offset ]
265 type col_name = {
266 cname : string; (** column name *)
267 tname : table_name option;
269 and limit = param list * bool
270 and nested = source * (source * join_cond) list
271 and source = [ `Select of select_full | `Table of table_name | `Nested of nested ] * table_name option (* alias *)
272 and join_cond = [ `Cross | `Search of expr | `Default | `Natural | `Using of string list ]
273 and select = {
274 columns : column list;
275 from : nested option;
276 where : expr option;
277 group : expr list;
278 having : expr option;
280 and select_full = {
281 select : select * select list;
282 order : order;
283 limit : limit option;
285 and order = (expr * direction option) list
286 and 'expr choices = (param_id * 'expr option) list
287 and expr =
288 | Value of Type.t (** literal value *)
289 | Param of param
290 | Inparam of param
291 | Choices of param_id * expr choices
292 | Fun of Type.func * expr list (** parameters *)
293 | SelectExpr of select_full * [ `AsValue | `Exists ]
294 | Column of col_name
295 | Inserted of string (** inserted value *)
296 and column =
297 | All
298 | AllOf of table_name
299 | Expr of expr * string option (** name *)
300 [@@deriving show {with_path=false}]
302 type columns = column list [@@deriving show]
304 type expr_q = [ `Value of Type.t (** literal value *)
305 | `Param of param
306 | `Inparam of param
307 | `Choice of param_id * expr_q choices
308 | `Func of Type.func * expr_q list (** return type, grouping, parameters *)
310 [@@deriving show]
312 let expr_to_string = show_expr
314 type assignments = (col_name * expr) list
316 type insert_action =
318 target : table_name;
319 action : [ `Set of assignments option
320 | `Values of (string list option * [ `Expr of expr | `Default ] list list option) (* column names * list of value tuples *)
321 | `Select of (string list option * select_full) ];
322 on_duplicate : assignments option;
325 type stmt =
326 | Create of table_name * [ `Schema of schema | `Select of select_full ]
327 | Drop of table_name
328 | Alter of table_name * alter_action list
329 | Rename of (table_name * table_name) list
330 | CreateIndex of string * table_name * string list (* index name, table name, columns *)
331 | Insert of insert_action
332 | Delete of table_name * expr option
333 | DeleteMulti of table_name list * nested * expr option
334 | Set of string * expr
335 | Update of table_name * assignments * expr option * order * param list (* where, order, limit *)
336 | UpdateMulti of source list * assignments * expr option
337 | Select of select_full
338 | CreateRoutine of string * Type.t option * (string * Type.t * expr option) list
341 open Schema
343 let test = [{name="a";domain=Type.Int}; {name="b";domain=Type.Int}; {name="c";domain=Type.Text};];;
345 let () = print test
346 let () = print (project ["b";"c";"b"] test)
347 let () = print (project ["b";"d"] test)
348 let () = print (rename test "a" "new_a")
351 module Function : sig
353 val lookup : string -> int -> Type.func
355 val add : int -> Type.func -> string -> unit
356 val exclude : int -> string -> unit
357 val monomorphic : Type.t -> Type.t list -> string -> unit
358 val multi : ret:Type.tyvar -> Type.tyvar -> string -> unit
359 val multi_polymorphic : string -> unit
361 end = struct
363 let h = Hashtbl.create 10
365 let add_ narg typ name =
366 let name = String.lowercase name in
367 if Hashtbl.mem h (name,narg) then
368 let func = match narg with None -> sprintf "%S" name | Some n -> sprintf "%S of %d arguments" name n in
369 fail "Function %s already registered" func
370 else
371 Hashtbl.add h (name,narg) typ
373 let exclude narg name = add_ (Some narg) None name
374 let add_multi typ name = add_ None (Some typ) name
375 let add narg typ name = add_ (Some narg) (Some typ) name
377 let sponge = Type.(Multi (Typ Any, Typ Any))
379 let lookup name narg =
380 let name = String.lowercase name in
381 match Hashtbl.find h (name,Some narg) with
382 | None ->
383 eprintfn "W: wrong number of arguments for known function %S, treating as untyped" name;
384 sponge
385 | Some t -> t
386 | exception _ ->
387 match Hashtbl.find h (name,None) with
388 | None -> assert false
389 | Some t -> t
390 | exception _ ->
391 eprintfn "W: unknown function %S of %d arguments, treating as untyped" name narg;
392 sponge
394 let monomorphic ret args name = add (List.length args) Type.(monomorphic ret args) name
395 let multi_polymorphic name = add_multi Type.(Multi (Var 0, Var 0)) name
396 let multi ~ret args name = add_multi Type.(Multi (ret, args)) name
400 let () =
401 let open Type in
402 let open Function in
403 let (||>) x f = List.iter f x in
404 "count" |> add 0 (Group Int); (* count( * ) - asterisk is treated as no parameters in parser *)
405 "count" |> add 1 (Group Int);
406 "avg" |> add 1 (Group Float);
407 ["max";"min";"sum"] ||> add 1 Agg;
408 ["max";"min"] ||> multi_polymorphic; (* sqlite3 *)
409 ["lower";"upper"] ||> monomorphic Text [Text];
410 "length" |> monomorphic Int [Text];
411 ["random"] ||> monomorphic Int [];
412 ["nullif";"ifnull"] ||> add 2 (F (Var 0, [Var 0; Var 0]));
413 ["least";"greatest";"coalesce"] ||> multi_polymorphic;
414 "strftime" |> exclude 1; (* requires at least 2 arguments *)
415 ["concat";"strftime"] ||> multi ~ret:(Typ Text) (Typ Text);
416 ["date";"time"] ||> monomorphic Text [Datetime];
417 "julianday" |> multi ~ret:(Typ Float) (Typ Text);
418 "from_unixtime" |> monomorphic Datetime [Int];
419 "from_unixtime" |> monomorphic Text [Int;Text];
420 ["pow"; "power"] ||> monomorphic Float [Float;Int];
421 "unix_timestamp" |> monomorphic Int [];
422 "unix_timestamp" |> monomorphic Int [Datetime];
423 ["timestampdiff";"timestampadd"] ||> monomorphic Int [Unit `Interval;Datetime;Datetime];
424 "any_value" |> add 1 (F (Var 0,[Var 0])); (* 'a -> 'a but not aggregate *)
425 "substring" |> monomorphic Text [Text; Int];
426 "substring" |> monomorphic Text [Text; Int; Int];
427 "substring_index" |> monomorphic Text [Text; Text; Int];
428 "last_insert_id" |> monomorphic Int [];
429 "last_insert_id" |> monomorphic Int [Int];