Merge branch '863-forward-propagation-strategy-accept-optionnal-arguments' into ...
[why3.git] / examples / verifythis_2019_ghc_sort.mlw
blob9497fe30731eeddca927f7602d513fcd7f12ad04
1 (** {1 VerifyThis @ ETAPS 2019 competition Challenge 1: Monotonic Segments and GHC sort }
3 Author: Quentin Garchery (LRI, Université Paris-Sud)
4 *)
7 use int.Int
8 use seq.Seq
9 use seq.OfList
10 use seq.FreeMonoid
11 use array.Array
12 use map.Occ
13 use list.ListRich
14 use bool.Bool
15 use exn.Exn
18 (** {2 PART A : Monotonic Segments} *)
21 clone list.Sorted as StrictIncr with type t = int, predicate le = (<)
22 clone list.RevSorted with type t = int, predicate le = (<=)
24 use list.FoldRight
26 let function eqb (b1 b2 : bool) : bool
27   ensures { result <-> b1 = b2 }
29   andb (implb b1 b2) (implb b2 b1)
31 (** Use a type invariant that ensures that the sequence represents the list.
32   It is useful to get a nice specification when dealing with first and last
33   elements of the list (begin-to-end property) *)
34 type list_seq = {
35   list : list int;
36   ghost seq : seq int
37 } invariant {
38   seq = of_list (reverse list)
39 } by {
40   list = Nil;
41   seq = Seq.empty
44 let constant nil : list_seq
45 = { list = Nil;
46     seq = Seq.empty }
48 let extend a lseq
49   ensures { result.list = Cons a lseq.list }
51   { list = Cons a lseq.list;
52     seq = snoc lseq.seq a }
54 (** Compute the monotonic cutpoints of a sequence *)
56 let cutpoints (s : array int)
57   requires { Array.length s > 0 }
58   (* Begin-to-end property *)
59   ensures { get result.seq 0 = 0 }
60   ensures { get result.seq (Seq.length result.seq - 1) = Array.length s }
61   (* Non-empty property *)
62   ensures { length result.list >= 2 }
63   (* Within bounds property *)
64   ensures { forall z. mem z result.list -> 0 <= z <= Array.length s }
65   (* Monotonic property *)
66   ensures { forall k. 0 <= k < Seq.length result.seq - 1 ->
67                       let ck = get result.seq k in
68                       let ck1 = get result.seq (k+1) in
69                       (forall z1 z2. ck <= z1 < z2 < ck1 -> s[z1] < s[z2]) \/
70                       (forall z1 z2. ck <= z1 < z2 < ck1 -> s[z1] >= s[z2]) }
71   (* For the next part, we also need the cutpoints list to be decreasing *)
72   ensures { forall i j. 0 <= i < j < Seq.length result.seq ->
73             get result.seq i <= get result.seq j }
75   let n = s.length in
76   let ref cut = extend 0 nil in
77   let ref x = 0 in
78   let ref y = 1 in
79   let ref increasing = True in
80   while y < n do
81         variant { n - y }
82         invariant { y = x + 1 }
83         invariant { 0 < y <= n+1 }
84         invariant { Seq.length cut.seq > 0 }
85         invariant { get cut.seq 0 = 0 }
86         invariant { get cut.seq (Seq.length cut.seq - 1) = x }
87         invariant { forall z. mem z cut.list -> 0 <= z <= n }
88         invariant { forall k. 0 <= k < Seq.length cut.seq - 1 ->
89                     let ck = get cut.seq k in
90                     let ck1 = get cut.seq (k+1) in
91                     (forall z1 z2. ck <= z1 < z2 < ck1 -> s[z1] < s[z2]) \/
92                     (forall z1 z2. ck <= z1 < z2 < ck1 -> s[z1] >= s[z2])}
93         invariant { forall i j. 0 <= i < j < Seq.length cut.seq ->
94                     get cut.seq i <= get cut.seq j }
95         label StartLoop in
96         increasing <- (s[x] < s[y]);
97         while y < n && eqb (s[y-1] < s[y]) increasing do
98               variant { n - y }
99               invariant { y at StartLoop <= y <= n }
100               invariant { (forall z1 z2. x <= z1 < z2 < y -> s[z1] < s[z2]) \/
101                           (forall z1 z2. x <= z1 < z2 < y -> s[z1] >= s[z2]) }
102               y <- y + 1
103         done;
104         cut <- extend y cut;
105         assert { get (cut.seq at StartLoop) (Seq.length cut.seq - 2) = x };
106         assert { forall k. 0 <= k < Seq.length cut.seq - 2 ->
107                  get cut.seq k at StartLoop = get cut.seq k /\
108                  get cut.seq (k+1) at StartLoop = get cut.seq (k+1) };
109         x <- y;
110         y <- x+1;
111   done;
112   label AfterLoop in
113   if x < n then cut <- extend n cut;
114   assert { get cut.seq (Seq.length cut.seq - 1) = n };
115   assert { forall k. 0 <= k < Seq.length cut.seq - 2 ->
116            get cut.seq k at AfterLoop = get cut.seq k /\
117            get cut.seq (k+1) at AfterLoop = get cut.seq (k+1) };
118   cut
121 (** {2 PART B : GHC Sort} *)
124 lemma reverse_sorted_incr :
125   forall l. Decr.sorted l -> Incr.sorted (reverse l)
126   by Decr.sorted l /\ Incr.sorted Nil /\ compat l Nil
128 let rec lemma lt_le_sorted (l : list int)
129   variant { l }
130   requires { StrictIncr.sorted l}
131   ensures { Incr.sorted l }
133   match l with
134   | Cons _ (Cons h2 t) -> lt_le_sorted (Cons h2 t)
135   | _ -> ()
136   end
138 (** Get an ordered list from a monotonic list *)
139 let function order l
140   requires { StrictIncr.sorted l \/ Decr.sorted l }
141   ensures { Incr.sorted result }
142   ensures { permut l result }
144   match l with
145   | Nil -> l
146   | Cons _ Nil -> l
147   | Cons h1 (Cons h2 _) ->
148       if h1 < h2
149       then (assert { Incr.sorted l by StrictIncr.sorted l }; l)
150       else (assert { Decr.sorted l }; reverse l)
151   end
153 (** Get a monotonic list from two cutpoints in the array *)
154 let rec list_from (a : array int) s e
155   variant { e - s }
156   requires { 0 <= s <= Array.length a }
157   requires { 0 <= e <= Array.length a }
158   requires { (forall z1 z2. s <= z1 < z2 < e -> a[z1] < a[z2]) \/
159              (forall z1 z2. s <= z1 < z2 < e -> a[z1] >= a[z2]) }
160   ensures { forall x. num_occ x result = occ x a.elts s e }
161   ensures { forall x. mem x result -> exists z. s <= z < e /\ a[z] = x }
162   ensures { (forall z1 z2. s <= z1 < z2 < e -> a[z1] < a[z2]) -> StrictIncr.sorted result }
163   ensures { (forall z1 z2. s <= z1 < z2 < e -> a[z1] >= a[z2]) -> Decr.sorted result }
164   ensures { StrictIncr.sorted result \/ Decr.sorted result }
166   if s >= e then Nil
167   else Cons a[s] (list_from a (s+1) e)
169 let rec lemma occ_slice (a : array 'a) (c1 c2 c3 : int)
170   variant { c3 - c2 }
171   requires { 0 <= c1 <= c2 <= c3 <= Array.length a }
172   ensures { forall x. occ x a.elts c1 c3 = occ x a.elts c1 c2 + occ x a.elts c2 c3 }
173 = if c3 <> c2 then occ_slice a c1 c2 (c3-1)
175 (** Get sorted lists from the cutpoints *)
176 let rec sorted_lists (s: array int) (cutp : list_seq)
177   requires { length cutp.list > 0 }
178   (* This is where we need cutpoints to be sorted *)
179   requires { forall x y. 0 <= x < y < Seq.length cutp.seq -> get cutp.seq x <= get cutp.seq y }
180   requires { forall z. mem z cutp.list -> 0 <= z <= Array.length s }
181   requires { forall k. 0 <= k < Seq.length cutp.seq - 1 ->
182                       let ck = get cutp.seq k in
183                       let ck1 = get cutp.seq (k+1) in
184                       (forall z1 z2. ck <= z1 < z2 < ck1 -> s[z1] < s[z2]) \/
185                       (forall z1 z2. ck <= z1 < z2 < ck1 -> s[z1] >= s[z2]) }
186   variant { cutp.list }
187   ensures { forall l. mem l result -> Incr.sorted l }
188   ensures { forall x. let first = get cutp.seq 0 in
189                       let last  = get cutp.seq (Seq.length cutp.seq - 1) in
190                       num_occ x (fold_right (++) result Nil) = occ x s.elts first last }
192   let ls = cutp.list in
193   let seq = cutp.seq in
194   match ls with
195   | Nil | Cons _ Nil -> Nil
196   | Cons h1 (Cons h2 t) ->
197       assert { let k = Seq.length cutp.seq - 2 in
198                h2 = get cutp.seq k /\ h1 = get cutp.seq (k+1) };
199       let seqi = list_from s h2 h1 in
200       let lseq = { list = Cons h2 t; seq = seq[0..(length seq - 1)] } in
201       (* occ_slice s.elts (get seq 0) (get lseq.seq (Seq.length lseq.seq - 1)) *)
202       (*               (get seq (Seq.length seq - 1)); *)
203       assert { Incr.sorted (order seqi) };
204       Cons (order seqi) (sorted_lists s lseq)
205   end
207 (** The merge of mergesort ! *)
208 let rec merge l1 l2
209   variant { length l1 + length l2 }
210   requires { Incr.sorted l1 }
211   requires { Incr.sorted l2 }
212   ensures { Incr.sorted result }
213   ensures { permut result (l1 ++ l2) }
215   match l1, l2 with
216   | Nil, l | l, Nil -> l
217   | Cons h1 t1, Cons h2 t2 ->
218     if h1 < h2
219     then (assert { forall x. mem x (t1 ++ l2) -> h1 <= x };
220          Cons h1 (merge t1 l2))
221     else (assert { forall x. mem x (l1 ++ t2) -> h2 <= x };
222          Cons h2 (merge l1 t2))
223   end
225 (** Merge pair by pair for efficiency *)
226 let rec merge_pair ls
227   variant { length ls }
228   requires { forall l. mem l ls -> Incr.sorted l }
229   ensures { length result <= length ls }
230   ensures { length ls > 1 -> 0 < length result < length ls }
231   ensures { forall l. mem l result -> Incr.sorted l }
232   ensures { permut (fold_right (++) result Nil) (fold_right (++) ls Nil) }
234   match ls with
235   | Nil | Cons _ Nil -> ls
236   | Cons l1 (Cons l2 r) -> Cons (merge l1 l2) (merge_pair r)
237   end
239 (** Repeat previous merge *)
240 let rec mergerec ls
241   requires { forall l. mem l ls -> Incr.sorted l }
242   variant { length ls }
243   ensures { Incr.sorted result }
244   ensures { permut result (fold_right (++) ls Nil) }
246   match ls with
247   | Nil -> Nil
248   | Cons l Nil -> l
249   | Cons _ (Cons _ _) -> mergerec (merge_pair ls)
250   end
252 use seq.Occ as SO
253 use seq.OfList
254 (** Show that the result of <mergerec> has the same length has the initial array *)
255 (** By induction, when increasing the size of the sub-array by one, we remove the new
256  element in the corresponding array (use <find> to find the index to remove) *)
257 let rec find (seq : seq int) (v : int) (s e : int)
258   variant { e - s }
259   requires { 0 <= s < e <= Seq.length seq }
260   requires { SO.occ v seq s e >= 1 }
261   ensures { s <= result < e /\ get seq result = v }
263   if get seq s = v then s
264   else find seq v (s+1) e
266 let rec lemma same_occs_same_lengths (a : array int) (seq : seq int) (s : int)
267   variant { Array.length a - s }
268   requires { 0 <= s <= Array.length a }
269   requires { forall x. occ x a.elts s (Array.length a) = SO.occ x seq 0 (Seq.length seq) }
270   ensures { Seq.length seq = Array.length a - s }
272   let na = Array.length a in
273   let ns = Seq.length seq in
274   if s = na
275   then (assert { ns > 0 -> SO.occ (get seq 0) seq 0 ns > 0 }; ())
276   else (assert { ns = 0 -> false by SO.occ a[s] seq 0 ns = 0 };
277        let i = find seq a[s] 0 ns in
278        let rem_as = seq[0..i] ++ seq[i+1..ns] in (* seq where a[s] is removed *)
279        assert { forall x. (if x = a[s]
280                 then SO.occ x seq 0 ns = SO.occ x rem_as 0 (ns-1) + 1
281                 else SO.occ x seq 0 ns = SO.occ x rem_as 0 (ns-1))
282                 by seq == Seq.(++) seq[0..i] (Seq.(++) seq[i..i+1] seq[i+1..ns])
283                 so SO.occ x seq 0 ns = SO.occ x seq[0..i] 0 i +
284                                        SO.occ x seq[i..i+1] 0 1 +
285                                        SO.occ x seq[i+1..ns] 0 (ns-i-1)
286                 so SO.occ x rem_as 0 (ns-1) = SO.occ x seq[0..i] 0 i +
287                                               SO.occ x seq[i+1..ns] 0 (ns-i-1) };
288        same_occs_same_lengths a rem_as (s+1))
290 let rec lemma num_occ_seq_occ (l : list 'a) (x : 'a)
291   variant { l }
292   ensures { num_occ x l = SO.occ x (of_list l) 0 (length l) }
294   match l with
295   | Nil -> ()
296   | Cons h t -> assert { of_list l = Seq.(++) (singleton h) (of_list t) /\
297                          if x = h
298                          then SO.occ x (singleton h) 0 1 = 1
299                          else SO.occ x (singleton h) 0 1 = 0 };
300                 num_occ_seq_occ t x
301   end
303 let sort_to_list a =
304   requires { Array.length a > 0 }
305   ensures { Incr.sorted result }
306   ensures { forall x. occ x a.elts 0 (Array.length a) = num_occ x result }
307   ensures { length result = Array.length a }
308   let res = mergerec (sorted_lists a (cutpoints a)) in
309   same_occs_same_lengths a (of_list res) 0;
310   res
312 use array.IntArraySorted
313 use array.ArrayPermut
314 use option.Option
315 use list.NthNoOpt
317 (** Copy a list in an array, element by element, starting at a given index *)
318 let rec copy_list (l : list int) (a : array int) (s : int) : unit
319   variant { l }
320   requires { s >= 0 }
321   requires { length l = Array.length a - s }
322   ensures { forall x. s <= x < Array.length a -> a[x] = nth (x - s) l }
323   ensures { forall x. 0 <= x < s -> a[x] = (old a)[x] }
324   ensures { forall x. occ x a.elts s (Array.length a) = num_occ x l }
326   match l with
327   | Nil -> ()
328   | Cons h t -> a[s] <- h; copy_list t a (s+1)
329   end
331 let rec lemma mem_nth_in_bounds (l : list 'a) (j : int)
332   requires { 0 <= j < length l }
333   ensures { mem (nth j l) l }
335   match l with
336   | Nil -> ()
337   | Cons _ t -> if j > 0 then mem_nth_in_bounds t (j-1)
338   end
340 (** Useful to deduce sorted on arrays from sorted on lists *)
341 let rec lemma sorted_list_nth (l : list int) (i j : int)
342   variant { l }
343   requires { Incr.sorted l }
344   requires { 0 <= i <= j < length l }
345   ensures { nth i l <= nth j l }
347   match l with
348   | Nil -> ()
349   | Cons _ t -> if i > 0 then sorted_list_nth t (i-1) (j-1)
350   end
352 let ghc_sort a
353   ensures { sorted a }
354   ensures { permut_all a (old a) }
356   if Array.length a = 0 then ()
357   else let l = sort_to_list a in
358        assert { length l = Array.length a };
359        copy_list l a 0