Merge branch '863-forward-propagation-strategy-accept-optionnal-arguments' into ...
[why3.git] / examples / sumrange.mlw
blob166b550d84c89f03971e2577f81603521e71e5f8
2 (** {1 Range Sum Queries}
4 We are interested in specifying and proving correct
5 data structures that support efficient computation of the sum of the
6 values over an arbitrary range of an array.
7 Concretely, given an array of integers `a`, and given a range
8 delimited by indices `i` (inclusive) and `j` (exclusive), we wish
9 to compute the value: `\sum_{k=i}^{j-1} a[k]`.
11 In the first part, we consider a simple loop
12 for computing the sum in linear time.
14 In the second part, we introduce a cumulative sum array
15 that allows answering arbitrary range queries in constant time.
17 In the third part, we explore a tree data structure that
18 supports modification of values from the underlying array `a`,
19 with logarithmic time operations.
24 (** {2 Specification of Range Sum} *)
26 module ArraySum
28   use int.Int
29   use array.Array
31   (** `sum a i j` denotes the sum `\sum_{i <= k < j} a[k]`.
32       It is axiomatizated by the two following axioms expressing
33       the recursive definition
35       if `i <= j` then `sum a i j = 0`
37       if `i < j` then `sum a i j = a[i] + sum a (i+1) j`
39   *)
40   let rec function sum (a:array int) (i j:int) : int
41    requires { 0 <= i <= j <= a.length }
42    variant { j - i }
43    = if j <= i then 0 else a[i] + sum a (i+1) j
45   (** lemma for summation from the right:
47       if `i < j` then `sum a i j = sum a i (j-1) + a[j-1]`
49  *)
50   lemma sum_right : forall a : array int, i j : int.
51     0 <= i < j <= a.length  ->
52     sum a i j = sum a i (j-1) + a[j-1]
54 end
59 (** {2 First algorithm, a linear one} *)
61 module Simple
63   use int.Int
64   use array.Array
65   use ArraySum
66   use ref.Ref
68   (** `query a i j` returns the sum of elements in `a` between
69       index `i` inclusive and index `j` exclusive *)
70   let query (a:array int) (i j:int) : int
71     requires { 0 <= i <= j <= a.length }
72     ensures { result = sum a i j }
73   = let s = ref 0 in
74     for k=i to j-1 do
75       invariant { !s = sum a i k }
76       s := !s + a[k]
77     done;
78     !s
80 end
85 (** {2 Additional lemmas on `sum`}
86   needed in the remaining code *)
88 module ExtraLemmas
90   use int.Int
91   use array.Array
92   use ArraySum
94   (** summation in adjacent intervals *)
95   lemma sum_concat : forall a:array int, i j k:int.
96     0 <= i <= j <= k <= a.length ->
97     sum a i k = sum a i j + sum a j k
99   (** Frame lemma for `sum`, that is `sum a i j` depends only
100       of values of `a[i..j-1]` *)
101   lemma sum_frame : forall a1 a2 : array int, i j : int.
102     0 <= i <= j ->
103     j <= a1.length ->
104     j <= a2.length ->
105     (forall k : int. i <= k < j -> a1[k] = a2[k]) ->
106     sum a1 i j = sum a2 i j
108   (** Updated lemma for `sum`: how does `sum a i j` changes when
109       `a[k]` is changed for some `k` in `[i..j-1]` *)
110   lemma sum_update : forall a:array int, i v l h:int.
111     0 <= l <= i < h <= a.length ->
112     sum (a[i<-v]) l h = sum a l h + v - a[i]
120 (** {2 Algorithm 2: using a cumulative array}
122    creation of cumulative array is linear
124    query is in constant time
126    array update is linear
131 module CumulativeArray
133   use int.Int
134   use array.Array
135   use ArraySum
136   use ExtraLemmas
138   predicate is_cumulative_array_for (c:array int) (a:array int) =
139     c.length = a.length + 1 /\
140     forall i. 0 <= i < c.length -> c[i] = sum a 0 i
142   (** `create a` builds the cumulative array associated with `a`. *)
143   let create (a:array int) : array int
144     ensures { is_cumulative_array_for result a }
145   = let l = a.length in
146     let s = Array.make (l+1) 0 in
147     for i=1 to l do
148       invariant { forall k. 0 <= k < i -> s[k] = sum a 0 k }
149       s[i] <- s[i-1] + a[i-1]
150     done;
151     s
153   (** `query c i j a` returns the sum of elements in `a` between
154       index `i` inclusive and index `j` exclusive, in constant time *)
155   let query (c:array int) (i j:int) (ghost a:array int): int
156     requires { is_cumulative_array_for c a }
157     requires { 0 <= i <= j < c.length }
158     ensures { result = sum a i j }
159   = c[j] - c[i]
162   (** `update c i v a` updates cell `a[i]` to value `v` and updates
163       the cumulative array `c` accordingly *)
164   let update (c:array int) (i:int) (v:int) (ghost a:array int) : unit
165     requires { is_cumulative_array_for c a }
166     requires { 0 <= i < a.length }
167     writes  { c, a }
168     ensures { is_cumulative_array_for c a }
169     ensures { a[i] = v }
170     ensures { forall k. 0 <= k < a.length /\ k <> i ->
171               a[k] = (old a)[k] }
172   = let incr = v - c[i+1] + c[i] in
173     a[i] <- v;
174     for j=i+1 to c.length-1 do
175       invariant { forall k. j <= k < c.length -> c[k] = sum a 0 k - incr }
176       invariant { forall k. 0 <= k < j -> c[k] = sum a 0 k }
177       c[j] <- c[j] + incr
178     done
187 (** {2 Algorithm 3: using a cumulative tree}
189   creation is linear
191   query is logarithmic
193   update is logarithmic
200 module CumulativeTree
202   use int.Int
203   use array.Array
204   use ArraySum
205   use ExtraLemmas
206   use int.ComputerDivision
208   type indexes =
209     { low : int;
210       high : int;
211       isum : int;
212     }
214   type tree = Leaf indexes | Node indexes tree tree
216   let function indexes (t:tree) : indexes =
217     match t with
218     | Leaf ind -> ind
219     | Node ind _ _ -> ind
220     end
222   predicate is_indexes_for (ind:indexes) (a:array int) (i j:int) =
223     ind.low = i /\ ind.high = j /\
224     0 <= i < j <= a.length /\
225     ind.isum = sum a i j
227   predicate is_tree_for (t:tree) (a:array int) (i j:int) =
228     match t with
229     | Leaf ind ->
230         is_indexes_for ind a i j /\ j = i+1
231     | Node ind l r ->
232         is_indexes_for ind a i j /\
233         i = l.indexes.low /\ j = r.indexes.high /\
234         let m = l.indexes.high in
235         m = r.indexes.low /\
236         i < m < j /\ m = div (i+j) 2 /\
237         is_tree_for l a i m /\
238         is_tree_for r a m j
239     end
241   (** {3 creation of cumulative tree} *)
243   let rec tree_of_array (a:array int) (i j:int) : tree
244     requires { 0 <= i < j <= a.length }
245     variant { j - i }
246     ensures { is_tree_for result a i j }
247     = if i+1=j then begin
248        Leaf { low = i; high = j; isum = a[i] }
249        end
250       else
251         begin
252         let m = div (i+j) 2 in
253         assert { i < m < j };
254         let l = tree_of_array a i m in
255         let r = tree_of_array a m j in
256         let s = l.indexes.isum + r.indexes.isum in
257         assert { s = sum a i j };
258         Node { low = i; high = j; isum = s} l r
259         end
262   let create (a:array int) : tree
263     requires { a.length >= 1 }
264     ensures { is_tree_for result a 0 a.length }
265   = tree_of_array a 0 a.length
268 (** {3 query using cumulative tree} *)
271   let rec query_aux (t:tree) (ghost a: array int)
272       (i j:int) : int
273     requires { is_tree_for t a t.indexes.low t.indexes.high }
274     requires { 0 <= t.indexes.low <= i < j <= t.indexes.high <= a.length }
275     variant { t }
276     ensures { result = sum a i j }
277   = match t with
278     | Leaf ind ->
279       ind.isum
280     | Node ind l r ->
281       let k1 = ind.low in
282       let k3 = ind.high in
283       if i=k1 && j=k3 then ind.isum else
284       let m = l.indexes.high in
285       if j <= m then query_aux l a i j else
286       if i >= m then query_aux r a i j else
287       query_aux l a i m + query_aux r a m j
288     end
291   let query (t:tree) (ghost a: array int) (i j:int) : int
292     requires { 0 <= i <= j <= a.length }
293     requires { is_tree_for t a 0 a.length }
294     ensures { result = sum a i j }
295   = if i=j then 0 else query_aux t a i j
298   (** frame lemma for predicate `is_tree_for` *)
299   lemma is_tree_for_frame : forall t:tree, a:array int, k v i j:int.
300     0 <= k < a.length ->
301     k < i \/ k >= j ->
302     is_tree_for t a i j ->
303     is_tree_for t a[k<-v] i j
305 (** {3 update cumulative tree} *)
308   let rec update_aux
309       (t:tree) (i:int) (ghost a :array int) (v:int) : (t': tree, delta: int)
310     requires { is_tree_for t a t.indexes.low t.indexes.high }
311     requires { t.indexes.low <= i < t.indexes.high }
312     variant { t }
313     ensures {
314         delta = v - a[i] /\
315         t'.indexes.low = t.indexes.low /\
316         t'.indexes.high = t.indexes.high /\
317         is_tree_for t' a[i<-v] t'.indexes.low t'.indexes.high }
318   = match t with
319     | Leaf ind ->
320         assert { i = ind.low };
321         (Leaf { ind with isum = v }, v - ind.isum)
322     | Node ind l r ->
323         let m = l.indexes.high in
324       if i < m then
325         let l',delta = update_aux l i a v in
326         assert { is_tree_for l' a[i<-v] t.indexes.low m };
327         assert { is_tree_for r a[i<-v] m t.indexes.high };
328         (Node {ind with isum = ind.isum + delta } l' r, delta)
329       else
330         let r',delta = update_aux r i a v in
331         assert { is_tree_for l a[i<-v] t.indexes.low m };
332         assert { is_tree_for r' a[i<-v] m t.indexes.high };
333         (Node {ind with isum = ind.isum + delta} l r',delta)
334     end
336   let update (t:tree) (ghost a:array int) (i v:int) : tree
337      requires { 0 <= i < a.length }
338      requires { is_tree_for t a 0 a.length }
339      writes { a }
340      ensures { a[i] = v }
341      ensures { forall k. 0 <= k < a.length /\ k <> i -> a[k] = (old a)[k] }
342      ensures { is_tree_for result a 0 a.length }
343   = let t,_ = update_aux t i a v in
344     assert { is_tree_for t a[i <- v] 0 a.length };
345     a[i] <- v;
346     t
349 (** {2 complexity analysis}
351   We would like to prove that `query` is really logarithmic. This is
352   non-trivial because there are two recursive calls in some cases.
354   So far, we are only able to prove that `update` is logarithmic
356   We express the complexity by passing a "credit" as a ghost
357   parameter. We pose the precondition that the credit is at least
358   equal to the depth of the tree.
362   (** preliminaries: definition of the depth of a tree, and showing
363       that it is indeed logarithmic in function of the number of its
364       elements *)
366   use int.MinMax
368   function depth (t:tree) : int =
369     match t with
370     | Leaf _ -> 1
371     | Node _ l r -> 1 + max (depth l) (depth r)
372     end
374   lemma depth_min : forall t. depth t >= 1
376   use bv.Pow2int
378   let rec lemma depth_is_log (t:tree) (a :array int) (k:int)
379      requires { k >= 0 }
380      requires { is_tree_for t a t.indexes.low t.indexes.high }
381      requires { t.indexes.high - t.indexes.low <= pow2 k }
382      variant { t }
383      ensures { depth t <= k+1 }
384   = match t with
385     | Leaf _ -> ()
386     | Node _ l r ->
387        depth_is_log l a (k-1);
388        depth_is_log r a (k-1)
389     end
392   (** `update_aux` function instrumented with a credit *)
394   use ref.Ref
396   let rec update_aux_complexity
397         (t:tree) (i:int) (ghost a :array int)
398         (v:int) (ghost c:ref int) : (t': tree, delta: int)
399      requires { is_tree_for t a t.indexes.low t.indexes.high }
400      requires { t.indexes.low <= i < t.indexes.high }
401      variant { t }
402      ensures { !c - old !c <= depth t }
403      ensures {
404         delta = v - a[i] /\
405         t'.indexes.low = t.indexes.low /\
406         t'.indexes.high = t.indexes.high /\
407         is_tree_for t' a[i<-v] t'.indexes.low t'.indexes.high }
408   = c := !c + 1;
409     match t with
410     | Leaf ind ->
411       assert { i = ind.low };
412       (Leaf { ind with isum = v }, v - ind.isum)
413     | Node ind l r ->
414       let m = l.indexes.high in
415       if i < m then
416         let l',delta = update_aux_complexity l i a v c in
417         assert { is_tree_for l' a[i<-v] t.indexes.low m };
418         assert { is_tree_for r a[i<-v] m t.indexes.high };
419         (Node {ind with isum = ind.isum + delta } l' r, delta)
420       else
421         let r',delta = update_aux_complexity r i a v c in
422         assert { is_tree_for l a[i<-v] t.indexes.low m };
423         assert { is_tree_for r' a[i<-v] m t.indexes.high };
424         (Node {ind with isum = ind.isum + delta} l r',delta) (*>*)
425     end
427   (** `query_aux` function instrumented with a credit *)
429   let rec query_aux_complexity (t:tree) (ghost a: array int)
430       (i j:int) (ghost c:ref int) : int
431     requires { is_tree_for t a t.indexes.low t.indexes.high }
432     requires { 0 <= t.indexes.low <= i < j <= t.indexes.high <= a.length }
433     variant { t }
434     ensures { !c - old !c <=
435          if i = t.indexes.low /\ j = t.indexes.high then 1 else
436          if i = t.indexes.low \/ j = t.indexes.high then 2 * depth t else
437           4 * depth t }
438     ensures { result = sum a i j }
439   = c := !c + 1;
440     match t with
441     | Leaf ind ->
442       ind.isum
443     | Node ind l r ->
444       let k1 = ind.low in
445       let k3 = ind.high in
446       if i=k1 && j=k3 then ind.isum else
447       let m = l.indexes.high in
448       if j <= m then query_aux_complexity l a i j c else
449       if i >= m then query_aux_complexity r a i j c else
450       query_aux_complexity l a i m c + query_aux_complexity r a m j c
451     end