Merge branch 'why3tools-register-main' into 'master'
[why3.git] / examples / binary_search.mlw
blob2d9851d9d9a8beb9a3a9e8c3e539e11ecdb07560
1 (* Binary search
3    A classical example. Searches a sorted array for a given value v. *)
5 module BinarySearch
7   use int.Int
8   use int.ComputerDivision
9   use ref.Ref
10   use array.Array
12   (* the code and its specification *)
14   exception Not_found (* raised to signal a search failure *)
16   let binary_search (a: array int) (v: int) : int
17     requires { forall i1 i2. 0 <= i1 <= i2 < length a -> a[i1] <= a[i2] }
18     ensures  { 0 <= result < length a /\ a[result] = v }
19     raises   { Not_found -> forall i. 0 <= i < length a -> a[i] <> v }
20   =
21     let ref l = 0 in
22     let ref u = length a - 1 in
23     while l <= u do
24       invariant { 0 <= l /\ u < length a }
25       invariant {
26         forall i. 0 <= i < length a -> a[i] = v -> l <= i <= u }
27       variant { u - l }
28       let m = l + div (u - l) 2 in
29       assert { l <= m <= u };
30       if a[m] < v then
31         l := m + 1
32       else if a[m] > v then
33         u := m - 1
34       else
35         return m
36     done;
37     raise Not_found
39 end
41 (* A generalization: the midpoint is computed by some abstract function.
42    The only requirement is that it lies between l and u. *)
44 module BinarySearchAnyMidPoint
46   use int.Int
47   use ref.Ref
48   use array.Array
50   exception Not_found (* raised to signal a search failure *)
52   val midpoint (l: int) (u: int) : int
53     requires { l <= u } ensures { l <= result <= u }
55   let binary_search (a: array int) (v: int) : int
56     requires { forall i1 i2. 0 <= i1 <= i2 < length a -> a[i1] <= a[i2] }
57     ensures  { 0 <= result < length a /\ a[result] = v }
58     raises   { Not_found -> forall i. 0 <= i < length a -> a[i] <> v }
59   =
60     let ref l = 0 in
61     let ref u = length a - 1 in
62     while l <= u do
63       invariant { 0 <= l /\ u < length a }
64       invariant { forall i. 0 <= i < length a -> a[i] = v -> l <= i <= u }
65       variant { u - l }
66       let m = midpoint l u in
67       if a[m] < v then
68         l := m + 1
69       else if a[m] > v then
70         u := m - 1
71       else
72         return m
73     done;
74     raise Not_found
76 end
78 (* The following version of binary search is faster in practice, by being
79    friendlier with the branch predictor of most processors. It also happens
80    to be stable, since it always return the highest index. *)
82 module BinarySearchBranchless
84   use int.Int
85   use int.ComputerDivision
86   use ref.Ref
87   use array.Array
89   exception Not_found (* raised to signal a search failure *)
91   let binary_search (a: array int) (v: int) : int
92     requires { forall i1 i2. 0 <= i1 <= i2 < length a -> a[i1] <= a[i2] }
93     ensures  { 0 <= result < length a /\ a[result] = v }
94     ensures  { forall i. result < i < length a -> a[i] <> v }
95     raises   { Not_found -> forall i. 0 <= i < length a -> a[i] <> v }
96   =
97     let ref l = 0 in
98     let ref s = length a in
99     if s = 0 then raise Not_found;
100     while s > 1 do
101       invariant { 0 <= l /\ l + s <= length a /\ s >= 1 }
102       invariant {
103         forall i. 0 <= i < length a -> a[i] = v -> a[l] <= v /\ i < l + s }
104       variant { s }
105       let h = div s 2 in
106       let m = l + h in
107       l := if a[m] > v then l else m;
108       s := s - h;
109     done;
110     if a[l] = v then l
111     else raise Not_found
115 (* binary search using 32-bit integers *)
117 module BinarySearchInt32
119   use int.Int
120   use mach.int.Int32
121   use ref.Ref
122   use mach.array.Array32
124   exception Not_found   (* raised to signal a search failure *)
126   let binary_search (a: array int32) (v: int32) : int32
127     requires { forall i1 i2. 0 <= i1 <= i2 < a.length -> a[i1] <= a[i2] }
128     ensures  { 0 <= result < a.length /\ a[result] = v }
129     raises   { Not_found -> forall i. 0 <= i < a.length -> a[i] <> v }
130   =
131     let ref l = 0 in
132     let ref u = length a - 1 in
133     while l <= u do
134       invariant { 0 <= l /\ u < a.length }
135       invariant { forall i. 0 <= i < a.length -> a[i] = v -> l <= i <= u }
136       variant   { u - l }
137       let m = l + (u - l) / 2 in
138       assert { l <= m <= u };
139       if a[m] < v then
140         l := m + 1
141       else if a[m] > v then
142         u := m - 1
143       else
144         return m
145     done;
146     raise Not_found
150 (* A particular case with Boolean values (0 or 1) and a sentinel 1 at the end.
151    We look for the first position containing a 1. *)
153 module BinarySearchBoolean
155   use int.Int
156   use int.ComputerDivision
157   use ref.Ref
158   use array.Array
160   let binary_search (a: array int) : int
161     requires { 0 < a.length }
162     requires { forall i j. 0 <= i <= j < a.length -> 0 <= a[i] <= a[j] <= 1 }
163     requires { a[a.length - 1] = 1 }
164     ensures  { 0 <= result < a.length }
165     ensures  { a[result] = 1 }
166     ensures  { forall i. 0 <= i < result -> a[i] = 0 }
168     let ref lo = 0 in
169     let ref hi = length a - 1 in
170     while lo < hi do
171       invariant { 0 <= lo <= hi < a.length }
172       invariant { a[hi] = 1 }
173       invariant { forall i. 0 <= i < lo -> a[i] = 0 }
174       variant   { hi - lo }
175       let mid = lo + div (hi - lo) 2 in
176       if a[mid] = 1 then
177         hi := mid
178       else
179         lo := mid + 1
180     done;
181     lo
185 module Complexity
187   use int.Int
188   use int.ComputerDivision
189   use ref.Ref
190   use array.Array
192   let rec function log2 (n: int) : int
193     variant { n }
194   = if n <= 1 then 0 else 1 + log2 (div n 2)
196   let rec lemma log2_monotone (x y: int)
197     requires { x <= y }
198     ensures  { log2 x <= log2 y }
199     variant  { y }
200   = if y > 1 then log2_monotone (div x 2) (div y 2)
202   let function f (n: int) : int
203   = if n = 0 then 0 else 1 + log2 n
205   lemma upper_bound:
206     forall n. n >= 2 -> f n <= 2 * log2 n
208   val ref time: int
210   let binary_search (a: array int) (v: int) : int
211     requires { forall i1 i2. 0 <= i1 <= i2 < length a -> a[i1] <= a[i2] }
212     requires { time = 0 }
213     ensures  { 0 <= result < length a && a[result] = v
214             || result = -1 && forall i. 0 <= i < length a -> a[i] <> v }
215     ensures  { time - old time <= f (length a) }
216   =
217     let ref lo = 0 in
218     let ref hi = length a in
219     while lo < hi do
220       invariant { 0 <= lo <= hi <= length a }
221       invariant { forall i. 0 <= i < lo || hi <= i < length a -> a[i] <> v }
222       invariant { (time - old time) + f (hi - lo) <= f (length a) }
223       variant   { hi - lo }
224       let mid = lo + div (hi - lo) 2 in
225       if a[mid] < v then
226         lo <- mid + 1
227       else if a[mid] > v then
228         hi <- mid
229       else
230         return mid;
231       time <- time + 1
232     done;
233     -1
237 (* Search in a two-dimensional grid where all rows and columns are
238    sorted.  Here is an example:
240                             j
241      +---+---+---+---+---+-->
242      | 1 | 3 | 4 | 4 | 7*|
243      +---+---+---+---+---+
244      | 1 | 4 | 6 | 6 | 8*|
245      +---+---+---+---+---+
246      | 2 | 5 | 7 | 9*|11*|
247      +---+---+---+---+---+
248      | 4 | 8 | 8*|12*|13 |
249      +---+---+---+---+---+
250      |         *
251    i v
253   Algorithm: start from the upper right corner and then move left
254   (resp. down) when the value is greater (resp. smaller) than the
255   value we search for.
257   In the example above, the stars depict the elements that are
258   examined when we search for the value 10. Here we end up outside of
259   the grid and thus the search is unsuccessful.
262 module TwoDimensional
264   use int.Int
265   use matrix.Matrix
267   let search (m: matrix int) (v: int) : bool
268     requires  { [@expl: rows are sorted] forall i. 0 <= i < rows m ->
269                 forall j1 j2. 0 <= j1 <= j2 < columns m ->
270                 get m i j1 <= get m i j2  }
271     requires  { [@expl: columns are sorted] forall j. 0 <= j < columns m ->
272                 forall i1 i2. 0 <= i1 <= i2 < rows m ->
273                 get m i1 j <= get m i2 j  }
274     ensures   { result <->
275       exists i j. 0 <= i < rows m && 0 <= j < columns m && get m i j = v }
276   = let ref i = 0 in
277     let ref j = columns m - 1 in
278     while i < rows m && 0 <= j do
279       invariant {  0 <= i <= rows    m }
280       invariant { -1 <= j <  columns m }
281       invariant { forall i' j'. 0 <= i' < rows m -> 0 <= j' < columns m ->
282                   i' < i || j' > j -> get m i' j' <> v }
283       variant   { rows m - i + j }
284       let x = get m i j in
285       if x = v then return true;
286       if x < v then i <- i + 1 else j <- j - 1
287     done;
288     return false