Merge branch 'cleaning_again_example_sin_cos' into 'master'
[why3.git] / stdlib / array.mlw
blobc5ce0cf0f82ab8b3f41037621fca1d63aff209de
1 (** {1 Arrays} *)
3 (** {2 Generic Arrays}
5 The length is a non-mutable field, so that we get for free that
6 modification of an array does not modify its length.
8 *)
10 module Array
12   use int.Int
13   use map.Map
15   type array [@extraction:array] 'a = private {
16     mutable ghost elts : int -> 'a;
17                 length : int
18   } invariant { 0 <= length }
20   function ([]) (a: array 'a) (i: int) : 'a = a.elts i
22   val ([]) (a: array 'a) (i: int) : 'a
23     requires { [@expl:index in array bounds] 0 <= i < length a }
24     ensures  { result = a[i] }
26   val ghost function ([<-]) (a: array 'a) (i: int) (v: 'a): array 'a
27     ensures { result.length = a.length }
28     ensures { result.elts = Map.set a.elts i v }
30   val ([]<-) (a: array 'a) (i: int) (v: 'a) : unit writes {a}
31     requires { [@expl:index in array bounds] 0 <= i < length a }
32     ensures  { a.elts = Map.set (old a).elts i v }
33     ensures  { a = (old a)[i <- v] }
35   (** unsafe get/set operations with no precondition *)
37   exception OutOfBounds
39   let defensive_get (a: array 'a) (i: int)
40     ensures { 0 <= i < length a /\ result = a[i] }
41     raises  { OutOfBounds -> i < 0 \/ i >= length a }
42   = if i < 0 || i >= length a then raise OutOfBounds;
43     a[i]
45   let defensive_set (a: array 'a) (i: int) (v: 'a)
46     ensures { 0 <= i < length a }
47     ensures { a = (old a)[i <- v] }
48     raises  { OutOfBounds -> (i < 0 \/ i >= length a) /\ a = old a }
49   = if i < 0 || i >= length a then raise OutOfBounds;
50     a[i] <- v
52   function make (n: int) (v: 'a) : array 'a
54   axiom make_spec : forall n:int, v:'a.
55     n >= 0 ->
56     (forall i:int. 0 <= i < n -> (make n v)[i] = v) /\
57     length (make n v) = n
59   val make [@extraction:array_make] (n: int) (v: 'a) : array 'a
60     requires { [@expl:array creation size] n >= 0 }
61     ensures { forall i:int. 0 <= i < n -> result[i] = v }
62     ensures { result.length = n }
64   val empty () : array 'a
65     ensures { result.length = 0 }
67   let copy (a: array 'a) : array 'a
68     ensures  { length result = length a }
69     ensures  { forall i:int. 0 <= i < length result -> result[i] = a[i] }
70   =
71     let len = length a in
72     if len = 0 then empty ()
73     else begin
74       let b = make len a[0] in
75       for i = 1 to len - 1 do
76         invariant { forall k. 0 <= k < i -> b[k] = a[k] }
77         b[i] <- a[i]
78       done;
79       b
80     end
82   let sub (a: array 'a) (ofs: int) (len: int) : array 'a
83     requires { 0 <= ofs /\ 0 <= len /\ ofs + len <= length a }
84     ensures  { length result = len }
85     ensures  { forall i:int. 0 <= i < len -> result[i] = a[ofs + i] }
86   =
87     if length a = 0 then begin
88      assert { len = 0 };
89      empty ()
90     end else begin
91       let b = make len a[0] in
92       for i = 0 to len-1 do
93         invariant { forall k. 0 <= k < i -> b[k] = a[ofs+k] }
94         b[i] <- a[ofs+i];
95       done;
96       b
97     end
99   let fill (a: array 'a) (ofs: int) (len: int) (v: 'a)
100     requires { 0 <= ofs /\ 0 <= len /\ ofs + len <= length a }
101     ensures  { forall i:int.
102       (0 <= i < ofs \/ ofs + len <= i < length a) -> a[i] = old a[i] }
103     ensures  { forall i:int. ofs <= i < ofs + len -> a[i] = v }
104   =
105     for k = 0 to len - 1 do
106       invariant { forall i:int.
107         (0 <= i < ofs \/ ofs + len <= i < length a) -> a[i] = old a[i] }
108       invariant { forall i:int. ofs <= i < ofs + k -> a[i] = v }
109       a[ofs + k] <- v
110     done
112   let blit (a1: array 'a) (ofs1: int)
113                  (a2: array 'a) (ofs2: int) (len: int) : unit writes {a2}
114     requires { 0 <= ofs1 /\ 0 <= len /\ ofs1 + len <= length a1 }
115     requires { 0 <= ofs2 /\             ofs2 + len <= length a2 }
116     ensures  { forall i:int.
117       (0 <= i < ofs2 \/ ofs2 + len <= i < length a2) -> a2[i] = old a2[i] }
118     ensures  { forall i:int.
119       ofs2 <= i < ofs2 + len -> a2[i] = a1[ofs1 + i - ofs2] }
120   =
121     for i = 0 to len - 1 do
122       invariant { forall k. not (0 <= k < i) -> a2[ofs2 + k] = old a2[ofs2 + k] }
123       invariant { forall k. 0 <= k < i -> a2[ofs2 + k] = a1[ofs1 + k] }
124       a2[ofs2 + i] <- a1[ofs1 + i];
125     done
127   let append (a1: array 'a) (a2: array 'a) : array 'a
128     ensures { length result = length a1 + length a2 }
129     ensures { forall i. 0 <= i < length a1 -> result[i] = a1[i] }
130     ensures { forall i. 0 <= i < length a2 -> result[length a1 + i] = a2[i] }
131   =
132     if length a1 = 0 then copy a2
133     else begin
134       let a = make (length a1 + length a2) a1[0] in
135       blit a1 0 a 0 (length a1);
136       blit a2 0 a (length a1) (length a2);
137       a
138     end
140   let self_blit (a: array 'a) (ofs1: int) (ofs2: int) (len: int) : unit
141     writes {a}
142     requires { 0 <= ofs1 /\ 0 <= len /\ ofs1 + len <= length a }
143     requires { 0 <= ofs2 /\             ofs2 + len <= length a }
144     ensures  { forall i:int.
145       (0 <= i < ofs2 \/ ofs2 + len <= i < length a) -> a[i] = old a[i] }
146     ensures  { forall i:int.
147       ofs2 <= i < ofs2 + len -> a[i] = old a[ofs1 + i - ofs2] }
148   =
149     if ofs1 <= ofs2 then (* from right to left *)
150       for k = len - 1 downto 0 do
151         invariant  { forall i:int.
152           (0 <= i <= ofs2 + k \/ ofs2 + len <= i < length a) ->
153           a[i] = (old a)[i] }
154         invariant  { forall i:int.
155           ofs2 + k < i < ofs2 + len -> a[i] = (old a)[ofs1 + i - ofs2] }
156         a[ofs2 + k] <- a[ofs1 + k]
157       done
158     else (* from left to right *)
159       for k = 0 to len - 1 do
160         invariant  { forall i:int.
161           (0 <= i < ofs2 \/ ofs2 + k <= i < length a) ->
162           a[i] = (old a)[i] }
163         invariant  { forall i:int.
164           ofs2 <= i < ofs2 + k -> a[i] = (old a)[ofs1 + i - ofs2] }
165         a[ofs2 + k] <- a[ofs1 + k]
166       done
168   (*** TODO?
169      - concat : 'a array list -> 'a array
170      - to_list
171      - of_list
172   *)
176 module Init
178   use int.Int
179   use export Array
181   let init (n: int) (f: int -> 'a) : array 'a
182     requires { [@expl:array creation size] n >= 0 }
183     ensures { forall i:int. 0 <= i < n -> result[i] = f i }
184     ensures { result.length = n }
185   =
186     if n = 0 then empty ()
187     else begin
188       let a = make n (f 0) in
189       for i = 1 to n - 1 do
190         invariant { forall k. 0 <= k < i -> a[k] = f k }
191         a[i] <- f i
192       done;
193       a
194     end
199 (** {2 Sorted Arrays} *)
201 module IntArraySorted
203   use int.Int
204   use Array
205   clone map.MapSorted as M with type elt = int, predicate le = (<=)
207   predicate sorted_sub (a : array int) (l u : int) =
208     M.sorted_sub a.elts l u
209   (** `sorted_sub a l u` is true whenever the array segment `a(l..u-1)`
210       is sorted w.r.t order relation `le` *)
212   predicate sorted (a : array int) =
213     M.sorted_sub a.elts 0 a.length
214   (** `sorted a` is true whenever the array `a` is sorted w.r.t `le` *)
218 module Sorted
220   use int.Int
221   use Array
223   type elt
225   predicate le elt elt
227   predicate sorted_sub (a: array elt) (l u: int) =
228     forall i1 i2 : int. l <= i1 < i2 < u -> le a[i1] a[i2]
229   (** `sorted_sub a l u` is true whenever the array segment `a(l..u-1)`
230       is sorted w.r.t order relation `le` *)
232   predicate sorted (a: array elt) =
233     forall i1 i2 : int. 0 <= i1 < i2 < length a -> le a[i1] a[i2]
234   (** `sorted a` is true whenever the array `a` is sorted w.r.t `le` *)
238 (** {2 Arrays Equality} *)
240 module ArrayEq
242   use int.Int
243   use Array
244   use map.MapEq
246   predicate array_eq_sub (a1 a2: array 'a) (l u: int) =
247     a1.length = a2.length /\ 0 <= l <= a1.length /\ 0 <= u <= a1.length /\
248     map_eq_sub a1.elts a2.elts l u
250   predicate array_eq (a1 a2: array 'a) =
251     a1.length = a2.length /\ map_eq_sub a1.elts a2.elts 0 (length a1)
255 module ArrayExchange
257   use int.Int
258   use Array
259   use map.MapExchange as M
261   predicate exchange (a1 a2: array 'a) (i j: int) =
262     a1.length = a2.length /\
263     M.exchange a1.elts a2.elts 0 a1.length i j
264   (** `exchange a1 a2 i j` means that arrays `a1` and `a2` only differ
265       by the swapping of elements at indices `i` and `j` *)
269 (** {2 Permutation} *)
271 module ArrayPermut
273   use int.Int
274   use Array
275   use map.MapPermut as M
276   use map.MapEq
277   use ArrayEq
278   use export ArrayExchange
280   predicate permut (a1 a2: array 'a) (l u: int) =
281     a1.length = a2.length /\ 0 <= l <= a1.length /\ 0 <= u <= a1.length /\
282     M.permut a1.elts a2.elts l u
283   (** `permut a1 a2 l u` is true when the segment
284       `a1(l..u-1)` is a permutation of the segment `a2(l..u-1)`.
285       Values outside of the interval `(l..u-1)` are ignored. *)
287   predicate permut_sub (a1 a2: array 'a) (l u: int) =
288     map_eq_sub a1.elts a2.elts 0 l /\
289     permut a1 a2 l u /\
290     map_eq_sub a1.elts a2.elts u (length a1)
291   (** `permut_sub a1 a2 l u` is true when the segment
292       `a1(l..u-1)` is a permutation of the segment `a2(l..u-1)`
293       and values outside of the interval `(l..u-1)` are equal. *)
295   predicate permut_all (a1 a2: array 'a) =
296     a1.length = a2.length /\ M.permut a1.elts a2.elts 0 a1.length
297   (** `permut_all a1 a2 l u` is true when array `a1` is a permutation
298       of array `a2`. *)
300   lemma exchange_permut_sub:
301     forall a1 a2: array 'a, i j l u: int.
302     exchange a1 a2 i j -> l <= i < u -> l <= j < u ->
303     0 <= l -> u <= length a1 -> permut_sub a1 a2 l u
305   lemma permut_sub_trans:
306     forall a1 a2 a3: array 'a, l u: int.
307     0 <= l -> u <= length a1 -> permut_sub a1 a2 l u ->
308     permut_sub a2 a3 l u -> permut_sub a1 a3 l u
310   (** we can always enlarge the interval *)
311   lemma permut_sub_weakening:
312     forall a1 a2: array 'a, l1 u1 l2 u2: int.
313     permut_sub a1 a2 l1 u1 -> 0 <= l2 <= l1 -> u1 <= u2 <= length a1 ->
314     permut_sub a1 a2 l2 u2
316   lemma exchange_permut_all:
317     forall a1 a2: array 'a, i j: int.
318     exchange a1 a2 i j -> permut_all a1 a2
322 module ArraySwap
324   use int.Int
325   use Array
326   use export ArrayExchange
328   let swap (a:array 'a) (i:int) (j:int) : unit
329     requires { 0 <= i < length a /\ 0 <= j < length a }
330     writes   { a }
331     ensures  { exchange (old a) a i j }
332   = let v = a[i] in
333     a[i] <- a[j];
334     a[j] <- v
338 (** {2 Sum of elements} *)
340 module ArraySum
342   use Array
343   use int.Sum as S
345   (** `sum a l h` is the sum of `a[i]` for `l <= i < h` *)
346   function sum (a: array int) (l h: int) : int = S.sum a.elts l h
350 (** {2 Number of array elements satisfying a given predicate} *)
352 module NumOf
353   use Array
354   use int.NumOf as N
356   (** the number of `a[i]` such that `l <= i < u` and `pr i a[i]` *)
357   function numof (pr: int -> 'a -> bool) (a: array 'a) (l u: int) : int =
358     N.numof (fun i -> pr i a[i]) l u
362 module NumOfEq
363   use Array
364   use int.NumOf as N
366   (** the number of `a[i]` such that `l <= i < u` and `a[i] = v` *)
367   function numof (a: array 'a) (v: 'a) (l u: int) : int =
368     N.numof (fun i -> a[i] = v) l u
372 module ToList
373   use int.Int
374   use Array
375   use list.List
377   let rec function to_list (a: array 'a) (l u: int) : list 'a
378     requires { l >= 0 /\ u <= a.length }
379     variant  { u - l }
380   = if u <= l then Nil else Cons a[l] (to_list a (l+1) u)
382   use list.Append
384   let rec lemma to_list_append (a: array 'a) (l m u: int)
385     requires { 0 <= l <= m <= u <= a.length }
386     variant  { m - l }
387     ensures  { to_list a l m ++ to_list a m u = to_list a l u }
388   = if l < m then to_list_append a (l+1) m u
392 module ToSeq
393   use int.Int
394   use Array
395   use seq.Seq as S
397   let rec function to_seq_sub (a: array 'a) (l u: int) : S.seq 'a
398     requires { l >= 0 /\ u <= a.length }
399     variant { u - l }
400   = if u <= l then S.empty else S.cons a[l] (to_seq_sub a (l+1) u)
402   let rec lemma to_seq_length (a: array 'a) (l u: int)
403     requires { 0 <= l <= u <= length a }
404     variant  { u - l }
405     ensures  { S.length (to_seq_sub a l u) = u - l }
406   = if l < u then to_seq_length a (l+1) u
408   let rec lemma to_seq_nth (a: array 'a) (l i u: int)
409     requires { 0 <= l <= i < u <= length a }
410     variant  { i - l }
411     ensures  { S.get (to_seq_sub a l u) (i - l) = a[i] }
412   = if l < i then to_seq_nth a (l+1) i u
414   let function to_seq (a: array 'a) : S.seq 'a = to_seq_sub a 0 (length a)
415   meta coercion function to_seq
419 (** {2 Number of inversions in an array of integers}
421     We show that swapping two elements that are ill-sorted decreases
422     the number of inversions. Useful to prove the termination of
423     sorting algorithms that use swaps. *)
425 module Inversions
427   use Array
428   use ArrayExchange
429   use int.Int
430   use int.Sum
431   use int.NumOf
433   (* to prove termination, we count the total number of inversions *)
434   predicate inversion (a: array int) (i j: int) =
435     a[i] > a[j]
437   function inversions_for (a: array int) (i: int) : int =
438     numof (inversion a i) i (length a)
440   function inversions (a: array int) : int =
441     sum (inversions_for a) 0 (length a)
443   (* the key lemma to prove termination: whenever we swap two consecutive
444      values that are ill-sorted, the total number of inversions decreases *)
445   let lemma exchange_inversion (a1 a2: array int) (i0: int)
446     requires { 0 <= i0 < length a1 - 1 }
447     requires { a1[i0] > a1[i0 + 1] }
448     requires { exchange a1 a2 i0 (i0 + 1) }
449     ensures  { inversions a2 < inversions a1 }
450   = assert { inversion a1 i0 (i0+1) };
451     assert { not (inversion a2 i0 (i0+1)) };
452     assert { forall i. 0 <= i < i0 ->
453              inversions_for a2 i = inversions_for a1 i
454              by numof (inversion a2 i) i (length a2)
455               = numof (inversion a2 i) i i0
456               + numof (inversion a2 i) i0 (i0+1)
457               + numof (inversion a2 i) (i0+1) (i0+2)
458               + numof (inversion a2 i) (i0+2) (length a2)
459              /\ numof (inversion a1 i) i (length a1)
460               = numof (inversion a1 i) i i0
461               + numof (inversion a1 i) i0 (i0+1)
462               + numof (inversion a1 i) (i0+1) (i0+2)
463               + numof (inversion a1 i) (i0+2) (length a1)
464              /\ numof (inversion a2 i) i0 (i0+1)
465                 = numof (inversion a1 i) (i0+1) (i0+2)
466              /\ numof (inversion a2 i) (i0+1) (i0+2)
467                 = numof (inversion a1 i) i0 (i0+1)
468              /\ numof (inversion a2 i) i i0 = numof (inversion a1 i) i i0
469              /\ numof (inversion a2 i) (i0+2) (length a2)
470                 = numof (inversion a1 i) (i0+2) (length a1)
471               };
472     assert { forall i. i0 + 1 < i < length a1 ->
473              inversions_for a2 i = inversions_for a1 i };
474     assert { inversions_for a2 i0 = inversions_for a1 (i0+1)
475              by numof (inversion a1 (i0+1)) (i0+2) (length a1)
476               = numof (inversion a2 i0    ) (i0+2) (length a1) };
477     assert { 1 + inversions_for a2 (i0+1) = inversions_for a1 i0
478              by numof (inversion a1 i0) i0 (length a1)
479               = numof (inversion a1 i0) (i0+1) (length a1)
480               = 1 + numof (inversion a1 i0) (i0+2) (length a1)
481               = 1 + numof (inversion a2 (i0+1)) (i0+2) (length a2) };
482     let sum_decomp (a: array int) (i j k: int)
483       requires { 0 <= i <= j <= k <= length a = length a1 }
484       ensures  { sum (inversions_for a) i k =
485                  sum (inversions_for a) i j + sum (inversions_for a) j k }
486     = () in
487     let decomp (a: array int)
488       requires { length a = length a1 }
489       ensures  { inversions a = sum (inversions_for a) 0 i0
490                               + inversions_for a i0
491                               + inversions_for a (i0+1)
492                               + sum (inversions_for a) (i0+2) (length a) }
493     = sum_decomp a 0 i0 (length a);
494       sum_decomp a i0 (i0+1) (length a);
495       sum_decomp a (i0+1) (i0+2) (length a);
496     in
497     decomp a1; decomp a2;
498     ()