Merge branch '863-forward-propagation-strategy-accept-optionnal-arguments' into ...
[why3.git] / examples / ring_decision / strassen.mlw
blob3c4cedfd903dbd8026dfd52654cae27a1a953e0c
1 theory InfMatrixGen
3 use int.Int
5 type mat 'a
7 clone algebra.UnitaryCommutativeRing as F with axiom .
9 function get (mat F.t) int int : F.t
10 function set (mat F.t) int int F.t : mat F.t
12 function row_zeros (mat F.t) int : int
13 function col_zeros (mat F.t) int : int
15 axiom row_zeros_def:
16   forall m: mat F.t, i j: int. 0 <= i -> j >= row_zeros m i -> get m i j = F.zero
18 axiom col_zeros_def:
19   forall m: mat F.t, i j: int. 0 <= j -> i >= col_zeros m j -> get m i j = F.zero
21 axiom row_zeros_nonneg:
22   forall m: mat F.t, i: int. 0 <= i -> 0 <= row_zeros m i
24 axiom col_zeros_nonneg:
25   forall m: mat F.t, j: int. 0 <= j -> 0 <= col_zeros m j
26 (*FIXME should be invariants*)
28 axiom set_def_changed:
29   forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
30   get (set m i j v) i j = v
32 axiom set_def_unchanged:
33   forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
34   forall i' j': int. 0 <= i' -> 0 <= j' -> (i <> i' \/ j <> j') ->
35   get (set m i j v) i' j' = get m i' j'
37 axiom set_def_rowz_changed:
38   forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
39   j >= row_zeros m i -> row_zeros (set m i j v) i = j+1
41 axiom set_def_colz_changed:
42   forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
43   i >= col_zeros m j -> col_zeros (set m i j v) j = i+1
45 axiom set_def_rowz_unchanged:
46   forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
47   j < row_zeros m i -> row_zeros (set m i j v) i = row_zeros m i
49 axiom set_def_colz_unchanged:
50   forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
51   i < col_zeros m j -> col_zeros (set m i j v) j = col_zeros m j
53 axiom set_def_other_rowz:
54   forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
55   forall i': int. 0 <= i' -> i <> i' ->
56   row_zeros (set m i j v) i' = row_zeros m i'
58 axiom set_def_other_colz:
59   forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
60   forall j': int. 0 <= j' -> j <> j' ->
61   col_zeros (set m i j v) j' = col_zeros m j'
63 predicate (==) (m1 m2: mat F.t) =
64   forall i j: int. 0 <= i -> 0 <= j -> get m1 i j = get m2 i j
66 axiom extensionality:
67   forall m1 m2: mat F.t. m1 == m2 -> m1 = m2
69 predicate (===) (m1 m2: mat F.t) =
70   forall i j: int. 0 <= i -> 0 <= j ->
71   row_zeros m1 i = row_zeros m2 i /\ col_zeros m1 j = col_zeros m2 j
73 predicate in_bounds (m: mat F.t) (i j: int) =
74   0 <= i < col_zeros m j /\ 0 <= j < row_zeros m i
76 let lemma ext_by_bounds (m1 m2: mat F.t)
77   requires { m1 === m2 }
78   requires { forall i j. in_bounds m1 i j -> get m1 i j = get m2 i j }
79   ensures  { m1 == m2 }
80 = ()
82 lemma oob_zero:
83   forall m: mat F.t, i j: int. 0 <= i -> 0 <= j -> not in_bounds m i j
84   -> get m i j = F.zero
86 predicate size (m: mat F.t) (r c: int) =
87   (forall i: int. 0 <= i -> row_zeros m i = c)
88   /\ (forall j: int. 0 <= j -> col_zeros m j = r)
90 lemma size_to_bounds:
91   forall m: mat F.t, r c i j: int. size m r c -> (in_bounds m i j <-> (0 <= i < r /\ 0 <= j < c))
93 lemma iso_size:
94   forall a b: mat F.t, r c: int. a === b -> (size a r c <-> size b r c)
96 lemma size_rows_ib:
97   forall a: mat F.t, r c i: int. size a r c ->
98   0 <= i < r -> row_zeros a i = c
99   by forall j: int. in_bounds a i j -> 0 <= j < c
101 lemma size_iso:
102   forall a b: mat F.t, r c: int. size a r c -> size b r c -> a === b
106 module InfMatrix
108   type t
109   constant tzero: t
111   clone export algebra.UnitaryCommutativeRing with
112     type t = t, constant zero = tzero, axiom .
114   use int.Int
115   clone export relations.MinMax with
116     type t = int, predicate le = (<=),
117     axiom . (* FIXME: replace with "goal" and prove *)
119   type mat
121   val function get mat int int : t
123   val function row_zeros mat int : int
124   val function col_zeros mat int : int
126   val function create (rz: int -> int) (cz: int -> int) (f: int -> int -> t)
127              : mat
129   axiom create_rowz:
130     forall rz cz: int -> int, f: int -> int -> t, i: int.
131     0 <= i -> 0 <= rz i -> row_zeros (create rz cz f) i = rz i
133   axiom create_colz:
134     forall rz cz: int -> int, f: int -> int -> t, j: int.
135     0 <= j -> 0 <= cz j -> col_zeros (create rz cz f) j = cz j
137   axiom create_get_ib:
138     forall rz cz: int -> int, f: int -> int -> t, i j: int.
139     0 <= i < cz j -> 0 <= j < rz i -> get (create rz cz f) i j = f i j
141   axiom create_get_oob:
142     forall rz cz: int -> int, f: int -> int -> t, i j: int.
143     0 <= i -> 0 <= j -> (i >= cz j \/ j >= rz i) ->
144     get (create rz cz f) i j = tzero
146   let ghost function set (m: mat) (i j:int) (v:t) : mat =
147     if 0 <= i && 0 <= j
148     then
149     create
150       (fun i1 -> if i1 = i then max (j+1) (row_zeros m i) else row_zeros m i1)
151       (fun j1 -> if j1 = j then max (i+1) (col_zeros m j) else col_zeros m j1)
152       (fun i1 j1 -> if i1 = i && j1 = j then v else get m i1 j1)
153     else m
155   clone export InfMatrixGen with type mat 'a = mat,
156     type F.t = t,
157     function get = get,
158     function set = set,
159     function row_zeros = row_zeros,
160     function col_zeros = col_zeros,
161     lemma set_def_changed,
162     lemma set_def_unchanged,
163     lemma set_def_colz_changed,
164     lemma set_def_colz_unchanged,
165     lemma set_def_rowz_changed,
166     lemma set_def_rowz_unchanged,
167     lemma set_def_other_rowz,
168     lemma set_def_other_colz,
169     axiom . (* FIXME: replace with "goal" and prove *)
171   let ghost function fcreate (r c: int) (f: int -> int -> t) : mat =
172     create (fun _ -> max 0 c) (fun _ -> max 0 r) f
174   lemma fcreate_get_ib:
175     forall r c i j: int, f: int -> int -> t.
176     0 <= i < r -> 0 <= j < c -> get (fcreate r c f) i j = f i j
178   lemma fcreate_get_oob:
179     forall r c i j: int, f: int -> int -> t.
180     0 <= i -> 0 <= j -> (i >= r \/ j >= c) -> get (fcreate r c f) i j = tzero
182   lemma fcreate_size:
183     forall r c: int, f: int -> int -> t. 0 <= r -> 0 <= c ->
184     size (fcreate r c f) r c
189 (* copied over from verifythis_2016_matrix_multiplication *)
190 module Sum_extended
192   use int.Int
193   use int.Sum
195   function addf (f g:int -> int) : int -> int = fun x -> f x + g x
197   function smulf (f:int -> int) (l:int) : int -> int = fun x -> l * f x
199   let rec lemma sum_mult (f:int -> int) (a b l:int) : unit
200     ensures { sum (smulf f l) a b = l * sum f a b }
201     variant { b - a }
202   = if b > a then sum_mult f a (b-1) l
205   let rec lemma sum_add (f g:int -> int) (a b:int) : unit
206     ensures { sum (addf f g) a b = sum f a b + sum g a b }
207     variant { b - a }
208   = if b > a then sum_add f g a (b-1)
211   function sumf (f:int -> int -> int) (a b:int) : int -> int = fun x -> sum (f x) a b
213   let rec lemma fubini (f1 f2: int -> int -> int) (a b c d: int) : unit
214     requires { forall x y. a <= x < b /\ c <= y < d -> f1 x y = f2 y x }
215     ensures  { sum (sumf f1 c d) a b = sum (sumf f2 a b) c d }
216     variant  { b - a }
217   = if b <= a
218     then assert { forall x. sumf f2 a b x = 0 }
219     else begin
220       fubini f1 f2 a (b-1) c d;
221       assert { let ha = addf (sumf f2 a (b-1)) (f1 (b-1)) in
222         sum (sumf f2 a b) c d = sum ha c d
223         by forall y. c <= y < d -> sumf f2 a b y = ha y }
224     end
226   let ghost sum_ext (f g: int -> int) (a b: int) : unit
227     requires { forall i. a <= i < b -> f i = g i }
228     ensures  { sum f a b  = sum g a b }
229   = ()
233 theory MaxFun
235   use int.Int
236   use int.MinMax
238   (* maximum of a function over an interval; always at least 0 *)
240   let rec function maxf (f: int -> int) (a b: int) : int
241     variant { b - a }
242   = if b <= a then zero else max (maxf f a (b - 1)) (f (b-1))
244   let rec lemma maxf_nonneg (f: int -> int) (a b: int)
245     requires { a <= b }
246     ensures { maxf f a b >= 0 }
247     variant { b - a }
248   = if a = b then () else maxf_nonneg f a (b-1)
250   let rec lemma maxf_larger (f: int -> int) (a b i: int)
251     requires { a <= i < b }
252     variant { b - a }
253     ensures { maxf f a b >= f i }
254   = if i = b-1 then () else maxf_larger f a (b-1) i
256   let rec lemma max_left (f: int -> int) (a b: int)
257     requires { a < b }
258     ensures { maxf f a b = max (f a) (maxf f (a+1) b) }
259     variant { b - a }
260   = if a = b-1 then () else max_left f a (b-1)
262   let rec lemma max_ext (f g: int -> int) (a b: int)
263     requires { a < b }
264     variant { b - a }
265     requires { forall i. a <= i < b -> f i = g i }
266     ensures { maxf f a b = maxf g a b }
267   = if a = b-1 then () else max_ext f g a (b-1)
269   let rec lemma max_decomp (f: int -> int) (a b c: int)
270     requires { a <= b <= c }
271     variant { c - b }
272     ensures  { maxf f a c = max (maxf f a b) (maxf f b c) }
273   = if b = c
274     then assert { maxf f a c = max (maxf f a b) (maxf f b c)
275                   by maxf f b c = 0
276                   so maxf f a b >= 0 }
277     else begin
278       max_decomp f a b (c-1);
279       assert { maxf f a c = max (maxf f a b) (maxf f b c)
280                by maxf f a c = max (maxf f a (c-1)) (f (c-1))
281                so maxf f b c = max (maxf f b (c-1)) (f (c-1))
282                so maxf f a c = max (max (maxf f a b) (maxf f b (c-1))) (f (c-1))
283                              = max (maxf f a b) (maxf f b c) }
284     end
286   let rec lemma max_constant (f: int -> int) (v a b: int)
287       requires { v >= 0 }
288       requires { a < b }
289       requires { forall i. a <= i < b -> f i = v }
290       ensures  { maxf f a b = v }
291       variant  { b-a }
292   = if a = b-1 then () else max_constant f v a (b-1)
296 module InfIntMatrix
298   use int.Int
299   clone export InfMatrix with
300     type t = int, constant tzero = zero,
301     axiom . (* FIXME: replace with "goal" and prove *)
302   use int.Sum
304   use int.Int (*FIXME needed so i < i+1 ?*)
306   (* Zero matrix *)
308   let constant zerof : int -> int -> int = fun _ _ -> 0
310   let ghost constant mzero : mat = fcreate 0 0 zerof
312   let ghost function zerorc (r c: int) : mat = fcreate r c zerof
314   (* Identity matrix *)
316   let constant idf : int -> int -> int = fun x y -> if x = y then 1 else 0
318   let constant id : mat = create (fun i -> i+1) (fun j -> j+1) idf
320   lemma id_def:
321     forall i. 0 <= i -> get id i i = 1
323   function idrc (r c: int) : mat = fcreate r c idf
325   (* Matrix addition *)
327   let ghost function add2f (a b: mat) : int -> int -> int =
328     fun x y -> get a x y + get b x y
330   function f_add (a b: mat) : mat =
331     create (fun i -> max (row_zeros a i) (row_zeros b i))
332            (fun j -> max (col_zeros a j) (col_zeros b j))
333            (add2f a b)
335   val function add (a b: mat) : mat
336     ensures { result = f_add a b }
338   lemma add_get:
339     forall a b: mat, i j: int. 0 <= i -> 0 <= j ->
340     get (add a b) i j = get a i j + get b i j
342   lemma add_iso:
343     forall a b: mat. a === b -> a === add a b === b
345   lemma add_size:
346     forall a b: mat, r c: int. size a r c -> size b r c -> size (add a b) r c
347     by (forall i j:int.
348        (in_bounds a i j \/ in_bounds b i j) <-> in_bounds (add a b) i j)
350   lemma add_assoc:
351     forall a b c: mat. add (add a b) c = add a (add b c)
352     by add (add a b) c == add a (add b c)
354   lemma add_commutative:
355     forall a b: mat. add a b = add b a by add a b == add b a
357   lemma zero_neutral:
358     forall a. add a mzero = a by add a mzero == a
360   (* Matrix additive inverse *)
361   function opp2f (a: mat) : int -> int -> int =
362     fun x y -> - get a x y
364   function f_opp (a: mat) : mat =
365     create (row_zeros a) (col_zeros a) (opp2f a)
367   val function opp (a: mat) : mat
368     ensures { result = f_opp a }
370   let function sub (a b: mat) : mat = add a (opp b)
372   lemma sub_size:
373     forall a b: mat, r c: int. size a r c -> size b r c -> size (sub a b) r c
374     by (forall i j:int.
375        (in_bounds a i j \/ in_bounds b i j) <-> in_bounds (sub a b) i j)
377   lemma opp_involutive:
378     forall m. opp (opp m) = m by opp (opp m) == m
380   (* Matrix multiplication *)
382   function mul_atom (a b: mat) (i j: int) : int -> int =
383     fun k -> get a i k * get b k j
385   lemma atom_oob:
386     forall a b: mat, i j k: int. 0 <= i -> 0 <= j ->
387     k >= row_zeros a i \/ k >= col_zeros b j -> mul_atom a b i j k = 0
388     by if k >= row_zeros a i then get a i k = 0
389        else k >= col_zeros b j so get b k j = 0
391   function mul_cell_bound (a b: mat) (i j: int) : int
392     = min (row_zeros a i) (col_zeros b j) (* row_zeros a i*)
394   function mul_cell (a b: mat) : int -> int -> int =
395     fun i j -> sum (mul_atom a b i j) 0 (mul_cell_bound a b i j)
397   use MaxFun
399   lemma cell_oob_r:
400     forall a b: mat, i j: int.
401     j >= maxf (fun k -> row_zeros b k) 0 (row_zeros a i) ->
402     mul_cell a b i j = 0
403     by forall k. 0 <= k < mul_cell_bound a b i j ->
404        mul_atom a b i j k = 0
405        by 0 <= k < row_zeros a i
406        so j >= row_zeros b k
407        so get b k j = 0
409   lemma cell_oob_c:
410     forall a b: mat, i j: int.
411     i >= maxf (fun k -> col_zeros a k) 0 (col_zeros b j) ->
412     mul_cell a b i j = 0
413     by forall k. 0 <= k < mul_cell_bound a b i j ->
414        mul_atom a b i j k = 0
415        by 0 <= k < col_zeros b j
416        so i >= col_zeros a k
417        so get a i k = 0
419   function f_mul (a b: mat) : mat =
420     create (fun i -> maxf (fun k -> row_zeros b k) 0 (row_zeros a i))
421            (fun j -> maxf (fun k -> col_zeros a k) 0 (col_zeros b j))
422            (mul_cell a b)
424   val function mul (a b: mat) : mat
425     ensures { result = f_mul a b }
427   let lemma mul_sizes (m1 m2: mat) (m n p: int)
428     requires { size m1 m n /\ size m2 n p }
429     requires { 0 < m /\ 0 < n /\ 0 < p }
430     ensures  { size (mul m1 m2) m p }
431   =
432     let r = mul m1 m2 in
433     max_constant (fun k -> row_zeros m2 k) p 0 n;
434     assert { forall i. 0 <= i -> row_zeros r i = p };
435     max_constant (fun k -> col_zeros m1 k) m 0 n;
436     assert { forall j. 0 <= j -> col_zeros r j = m }
439   lemma id_neutral_r:
440     forall m: mat. mul m id = m
441     by (mul m id == m
442        by (forall i j. in_bounds m i j -> mul_cell m id i j = get m i j)
443           /\ (forall i j. 0 <= i -> 0 <= j -> not (in_bounds m i j)
444              -> mul_cell m id i j = 0 = get m i j))
446   lemma id_neutral_l:
447     forall m: mat. mul id m = m
448     by (mul id m == m
449        by (forall i j. in_bounds m i j ->
450              mul_cell id m i j = get m i j
451              by let t = mul_cell_bound id m i j in
452              sum (mul_atom id m i j) 0 i = 0
453              so sum (mul_atom id m i j) (i+1) t = 0
454              so mul_cell id m i j
455              = sum (mul_atom id m i j) 0 t
456              = sum (mul_atom id m i j) 0 i + mul_atom id m i j i + sum (mul_atom id m i j) (i+1) t
457              = 0 + (get id i i)*(get m i j) + 0
458              = get m i j)
459           /\ (forall i j. 0 <= i -> 0 <= j -> not (in_bounds m i j)
460             -> (forall k. 0 <= k < mul_cell_bound id m i j -> mul_atom id m i j k = 0)
461                so mul_cell id m i j
462                = sum (fun k -> mul_atom id m i j k) 0 (mul_cell_bound id m i j)
463                = 0 = get m i j))
465   use Sum_extended
467   function ft1 (a b c: mat) (i j: int)  : int -> int -> int =
468                 fun k -> smulf (mul_atom a b i k) (get c k j)
470   function ft2 (a b c: mat) (i j: int) : int -> int -> int =
471                 fun k -> smulf (mul_atom b c k j) (get a i k)
473   let lemma mul_assoc_get (a b c: mat) (i j: int)
474     requires { 0 <= i /\ 0 <= j }
475     ensures  { get (mul (mul a b) c) i j = get (mul a (mul b c)) i j }
476   = let ft1 = ft1 a b c i j in
477     let ft2 = ft2 a b c i j in
478     let ab = mul a b in
479     let bc = mul b c in
480     let m_ab_c = mul_cell_bound ab c i j in
481     let m_a_bc = mul_cell_bound a bc i j in
482     fubini ft1 ft2 0 m_ab_c 0 m_a_bc;
483     assert { forall k. 0 <= k < m_ab_c -> mul_cell_bound a b i k <= m_a_bc
484              by mul_cell_bound a b i k <= row_zeros a i
485              so mul_cell_bound a b i k <= col_zeros b k
486              so col_zeros bc j = maxf (fun k -> col_zeros b k) 0 (col_zeros c j)
487              so 0 <= k < col_zeros c j
488              so col_zeros b k <= maxf (fun k -> col_zeros b k) 0 (col_zeros c j)
489              so col_zeros b k <= col_zeros bc j };
490     assert { forall k. 0 <= k < m_ab_c ->
491                 sumf ft1 0 m_a_bc k = sumf ft1 0 (mul_cell_bound a b i k) k
492              by sumf ft1 0 m_a_bc k
493                 = sum (ft1 k) 0 m_a_bc
494                 = sum (ft1 k) 0 (mul_cell_bound a b i k)
495                   + sum (ft1 k) (mul_cell_bound a b i k) m_a_bc
496                 = sumf ft1 0 (mul_cell_bound a b i k) k
497                   + sum (ft1 k) (mul_cell_bound a b i k) m_a_bc
498              so (forall l. l >= mul_cell_bound a b i k ->
499                  mul_atom a b i k l = 0
500                  so ft1 k l = 0)
501              so sum (ft1 k) (mul_cell_bound a b i k) m_a_bc = 0 };
502     assert { forall k. 0 <= k < m_ab_c ->
503              mul_atom ab c i j k = sumf ft1 0 m_a_bc k
504              by get ab i k = mul_cell a b i k
505              so sumf ft1 0 m_a_bc k
506                 = sumf ft1 0 (mul_cell_bound a b i k) k
507                 = sum (ft1 k) 0 (mul_cell_bound a b i k)
508                 = sum (smulf (mul_atom a b i k) (get c k j))
509                       0 (mul_cell_bound a b i k)
510                 = get c k j * sum (mul_atom a b i k) 0 (mul_cell_bound a b i k)
511                 = get c k j * get ab i k
512                 = mul_atom ab c i j k };
513     sum_ext (mul_atom ab c i j) (sumf ft1 0 m_a_bc) 0 m_ab_c;
514     assert { get (mul ab c) i j = sum (sumf ft1 0 m_a_bc) 0 m_ab_c };
515     assert { forall k. 0 <= k < m_a_bc -> mul_cell_bound b c k j <= m_ab_c
516              by mul_cell_bound b c k j <= col_zeros c j
517              so mul_cell_bound b c k j <= row_zeros b k
518              so row_zeros ab i = maxf (fun k -> row_zeros b k) 0 (row_zeros a i)
519              so 0 <= k < row_zeros a i
520              so row_zeros b k <= maxf (fun k -> row_zeros b k) 0 (row_zeros a i)
521              so row_zeros b k <= row_zeros ab i };
522     assert { forall k. 0 <= k < m_a_bc ->
523                 sumf ft2 0 m_ab_c k = sumf ft2 0 (mul_cell_bound b c k j) k
524              by sumf ft2 0 m_ab_c k
525                 = sum (ft2 k) 0 m_ab_c
526                 = sum (ft2 k) 0 (mul_cell_bound b c k j)
527                   + sum (ft2 k) (mul_cell_bound b c k j) m_ab_c
528                 = sumf ft2 0 (mul_cell_bound b c k j) k
529                   + sum (ft2 k) (mul_cell_bound b c k j) m_ab_c
530              so (forall l. l >= mul_cell_bound b c k j ->
531                  mul_atom b c k j l = 0
532                  so ft2 k l = 0)
533              so sum (ft2 k) (mul_cell_bound b c k j) m_ab_c = 0 };
534     assert { forall k. 0 <= k < m_a_bc ->
535              mul_atom a bc i j k = sumf ft2 0 m_ab_c k
536              by get bc k j = mul_cell b c k j
537              so sumf ft2 0 m_ab_c k
538                 = sumf ft2 0 (mul_cell_bound b c k j) k
539                 = sum (ft2 k) 0 (mul_cell_bound b c k j)
540                 = sum (smulf (mul_atom b c k j) (get a i k))
541                       0 (mul_cell_bound b c k j)
542                 = get a i k * sum (mul_atom b c k j) 0 (mul_cell_bound b c k j)
543                 = get a i k * get bc k j
544                 = mul_atom a bc i j k };
545     sum_ext (mul_atom a bc i j) (sumf ft2 0 m_ab_c) 0 m_a_bc;
546     assert { get (mul a bc) i j = sum (sumf ft2 0 m_ab_c) 0 m_a_bc }
548   lemma mul_assoc:
549     forall a b c: mat.
550     let ab = mul a b in
551     let bc = mul b c in
552     let a_bc = mul a bc in
553     let ab_c = mul ab c in
554     a_bc =  ab_c by a_bc == ab_c
557   let lemma mul_distr_right_get (a b c: mat) (i j: int)
558     requires { 0 <= i /\ 0 <= j }
559     ensures  { get (mul (add a b) c) i j = get (add (mul a c) (mul b c)) i j }
560   = let mac = mul_atom a c i j in
561     let mbc = mul_atom b c i j in
562     let b_ac = mul_cell_bound a c i j in
563     let b_bc = mul_cell_bound b c i j in
564     let ma = max b_ac b_bc in
565     assert { get (add (mul a c) (mul b c)) i j = sum (addf mac mbc) 0 ma
566              by sum mac 0 ma = sum mac 0 b_ac + sum mac b_ac ma
567              so forall k. k >= b_ac -> mac k = 0
568              so sum mac b_ac ma = 0
569              so sum mac 0 b_ac = sum mac 0 ma
570              so sum mbc 0 ma = sum mbc 0 b_bc + sum mbc b_bc ma
571              so forall k. k >= b_bc -> mbc k = 0
572              so sum mbc b_bc ma = 0
573              so sum mbc 0 b_bc = sum mbc 0 ma
574              so get (mul a c) i j = sum mac 0 b_ac
575              so get (mul b c) i j = sum mbc 0 b_bc
576              so get (add (mul a c) (mul b c)) i j
577              = get (mul a c) i j + get (mul b c) i j
578              = sum mac 0 b_ac + sum mbc 0 b_bc
579              = sum mac 0 ma + sum mbc 0 ma
580              = sum (addf mac mbc) 0 ma };
581     sum_ext (addf mac mbc) (mul_atom (add a b) c i j) 0 ma;
582     assert { get (mul (add a b) c) i j = mul_cell (add a b) c i j }
585 (* External product *)
587 function extf (c: int) (a: mat): int -> int -> int =
588   fun x y -> c * (get a x y)
590 function f_extp (c: int) (a: mat) : mat =
591   create (fun i -> row_zeros a i) (fun j -> col_zeros a j) (extf c a)
593 val function extp (c: int) (a: mat) : mat
594   ensures { result = f_extp c a }
596 lemma ext_iso:
597   forall m: mat, r: int. extp r m === m
599 lemma ext_get:
600   forall m: mat, r i j: int. 0 <= i -> 0 <= j ->
601   get (extp r m) i j = r * (get m i j)
603 lemma ext_dist_sum_mat:
604   forall x y: mat, r: int. extp r (add x y) = add (extp r x) (extp r y)
605   by extp r (add x y) == add (extp r x) (extp r y)
607 lemma ext_dist_sum_r:
608   forall x: mat, r s: int. extp (r+s) x = add (extp r x) (extp s x)
609   by extp (r+s) x == add (extp r x) (extp s x)
611 lemma assoc_mul_ext:
612   forall x: mat, r s: int. extp (r*s) x = extp r (extp s x)
613   by extp (r*s) x == extp r (extp s x)
615 lemma unit_ext:
616   forall x: mat. extp 1 x = x by extp 1 x == x
618 let lemma comm_mul_ext_ij (x y: mat) (r i j: int)
619   requires { 0 <= i /\ 0 <= j }
620   ensures { get (mul (extp r x) y) i j = r * (get (mul x y) i j) }
621   ensures { get (mul x (extp r y)) i j = r * (get (mul x y) i j) }
623   let b = mul_cell_bound x y i j in
624   assert { mul_cell_bound (extp r x) y i j = b
625            = mul_cell_bound x (extp r y) i j };
626   sum_ext (mul_atom (extp r x) y i j) (smulf (mul_atom x y i j) r) 0 b;
627   sum_ext (mul_atom x (extp r y) i j) (smulf (mul_atom x y i j) r) 0 b;
628   sum_mult (mul_atom (extp r x) y i j) 0 b r;
629   sum_mult (mul_atom x (extp r y) i j) 0 b r;
630   assert { get (mul (extp r x) y) i j
631            = r * (get (mul x y) i j)
632            = get (mul x (extp r y)) i j
633            by get (mul (extp r x) y) i j
634               = mul_cell (extp r x) y i j
635               = r * mul_cell x y i j
636               = mul_cell x (extp r y) i j
637               = get (mul x (extp r y)) i j
638            so r * mul_cell x y i j = r * (get (mul x y) i j) }
640 lemma comm_mul_ext:
641   forall x y: mat, r: int.
642      extp r (mul x y) = mul (extp r x) y = mul x (extp r y)
643   by extp r (mul x y) == mul (extp r x) y == mul x (extp r y)
647 module InfIntMatrixDecision
649 use InfIntMatrix
650 use int.Int
652 let predicate eq0_int (x:int) = x = 0
654 clone export ringdecision.AssocAlgebraDecision with type r = int, type a = mat, val rzero = Int.zero, val rone = Int.one, val rplus = (+), val ropp = (-_), val rtimes = (*), val azero = mzero, val aone = id, val aplus = add, val aopp = opp, val atimes = mul, val asub = sub, val ($) = extp, goal AUnitary, goal ANonTrivial, goal ExtDistSumA, goal ExtDistSumR, goal AssocMulExt, goal UnitExt, goal CommMulExt, val eq0 = eq0_int, goal A.MulAssoc.Assoc, goal A.Unit_def_l, goal A.Unit_def_r, goal A.Comm, goal A.Assoc, goal A.Mul_distr_l, goal A.Mul_distr_r, goal asub_def, goal A.Inv_def_l, goal A.Inv_def_r,
655   axiom . (* FIXME: replace with "goal" and prove *)
657 meta reflection val norm_f
661 module MatrixTests
662   use InfIntMatrix
663   use int.Int
665   use InfIntMatrixDecision
666   use int.Sum
667   use Sum_extended
669   function cols (a: mat) : int (* if matrix is a finite rectangle, return number of cols *)
670   function rows (a: mat) : int
672  (* lemma t: forall a: mat, r1 r2 c: int. size a r1 c -> size a r2 c -> r1 = r2*)
674   axiom rows_def:
675     forall a: mat, r c: int. 0 <= r -> 0 <= c -> size a r c -> rows a = r
677   axiom cols_def:
678     forall a: mat, r c: int. 0 <= r -> 0 <= c -> size a r c -> cols a = c
680   predicate is_finite (m: mat) = size m m.rows m.cols
682   function ofs2 (a: mat) (ai aj: int) : int -> int -> int
683     = fun i j -> get a (ai + i) (aj + j)
685   function block (a: mat) (r dr c dc: int) : mat =
686     fcreate dr dc (ofs2 a r c)
688   predicate c_blocks (a a1 a2: mat) =
689     0 <= a1.cols <= a.cols /\ a1 = block a 0 a.rows 0 a1.cols /\
690     a2 = block a 0 a.rows a1.cols (a.cols - a1.cols)
692   predicate r_blocks (a a1 a2: mat) =
693     0 <= a1.rows <= a.rows /\ a1 = block a 0 a1.rows 0 a.cols /\
694     a2 = block a a1.rows (a.rows - a1.rows) 0 a.cols
696   let rec lemma block_mul_ij (a a1 a2 b b1 b2: mat) (k: int)
697     requires { a.cols = b.rows /\ a1.cols = b1.rows}
698     requires { 0 <= k <= a.cols }
699     requires { c_blocks a a1 a2 /\ r_blocks b b1 b2 }
700     ensures  { forall i j. 0 <= i < a.rows -> 0 <= j < b.cols ->
701                  0 <= k <= a1.cols ->
702                    sum (mul_atom a b i j) 0 k = sum (mul_atom a1 b1 i j) 0 k }
703     ensures  { forall i j. 0 <= i < a.rows -> 0 <= j < b.cols ->
704                 a1.cols <= k <= a.cols ->
705                   sum (mul_atom a b i j) 0 k =
706                     sum (mul_atom a1 b1 i j) 0 a1.cols +
707                     sum (mul_atom a2 b2 i j) 0 (k - a1.cols) }
708     variant { k }
709   = if 0 < k then begin
710       let k = k - 1 in
711       assert { forall i j. 0 <= i < a.rows -> 0 <= j < b.cols ->
712                 if k < a1.cols
713                 then mul_atom a b i j k = mul_atom a1 b1 i j k
714                 else (mul_atom a b i j k = mul_atom a2 b2 i j (k - a1.cols)
715                       by get a i k = get a2 i (k - a1.cols)
716                       so get b k j = get b2 (k-a1.cols) j)};
717       block_mul_ij a a1 a2 b b1 b2 k
718     end
720   let lemma mul_split (a a1 a2 b b1 b2: mat) : unit
721     requires { is_finite a /\ is_finite b }
722     requires { a.cols = b.rows /\ a1.cols = b1.rows}
723     requires { 0 < a.rows /\ 0 < a.cols /\ 0 < b.cols
724                /\ 0 < a1.cols /\ 0 < a2.cols }
725     requires { c_blocks a a1 a2 /\ r_blocks b b1 b2 }
726     ensures  { add (mul a1 b1) (mul a2 b2) = mul a b }
727   = block_mul_ij a a1 a2 b b1 b2 a.cols;
728     mul_sizes a b a.rows a.cols b.cols;
729     mul_sizes a1 b1 a.rows a1.cols b.cols;
730     mul_sizes a2 b2 a.rows (a.cols - a1.cols) b.cols;
731     assert { add (mul a1 b1) (mul a2 b2) === mul a b
732              by size (add (mul a1 b1) (mul a2 b2)) a.rows b.cols
733              so size (mul a b) a.rows b.cols };
734     assert { forall i j. in_bounds (mul a b) i j ->
735              get (add (mul a1 b1) (mul a2 b2)) i j = get (mul a b) i j
736              by mul_cell_bound a1 b1 i j = a1.cols
737                      so mul_cell_bound a2 b2 i j = a2.cols = a.cols - a1.cols
738                      so get (mul a b) i j
739                      = mul_cell a b i j
740                      = sum (mul_atom a b i j) 0 a.cols
741                      =  sum (mul_atom a1 b1 i j) 0 a1.cols
742                         + sum (mul_atom a2 b2 i j) 0 (a.cols - a1.cols)
743                      = get (mul a1 b1) i j + get (mul a2 b2) i j
744                      = get (add (mul a1 b1) (mul a2 b2)) i j };
745     ext_by_bounds (add (mul a1 b1) (mul a2 b2)) (mul a b)
747   let lemma mul_block_cell (a b: mat) (r dr c dc i j: int) : unit
748     requires { is_finite a /\ is_finite b }
749     requires { a.cols = b.rows }
750     requires { 0 <= r /\ r + dr <= a.rows }
751     requires { 0 <= c /\ c + dc <= b.cols }
752     requires { 0 <= i < dr /\ 0 <= j < dc }
753     ensures  { ofs2 (mul a b) r c i j =
754                get (mul (block a r dr 0 a.cols) (block b 0 b.rows c dc)) i j }
755   = let a' = block a r dr 0 a.cols in
756     let b' = block b 0 b.rows c dc in
757     sum_ext (mul_atom a b (i + r) (j + c)) (mul_atom a' b' i j) 0 a.cols;
758     assert { ofs2 (mul a b) r c i j = get (mul a b) (i+r) (j+c)
759              = sum (mul_atom a b (i+r) (j+c)) 0 a.cols
760              = sum (mul_atom a' b' i j) 0 a.cols
761              = get (mul a' b') i j }
763   let lemma mul_block (a b a' b' m': mat) (r dr c dc: int)
764     requires { a.cols = b.rows }
765     requires { 0 <= r <= r + dr <= a.rows }
766     requires { 0 <= c <= c + dc <= b.cols }
767     requires { a' = block a r dr 0 a.cols }
768     requires { b' = block b 0 b.rows c dc }
769     requires { m' = block (mul a b) r dr c dc }
770     ensures  { m' =  mul a' b' }
771   = assert { m' == mul a' b' }
773   predicate quarters (a a11 a12 a21 a22: mat) =
774     (is_finite a /\ is_finite a11 /\ is_finite a12 /\ is_finite a21 /\ is_finite a22) /\
775     (rows a11 = rows a12 = rows a21 = rows a22 = cols a11 = cols a12 = cols a21 = cols a22) /\
776     rows a = cols a = 2 * rows a11 /\
777     a11 = block a 0 a11.rows 0 a11.cols /\ a12 = block a 0 a11.rows a11.cols a11.cols /\
778     a21 = block a a11.rows a11.rows 0 a11.cols /\ a22 = block a a11.rows a11.rows a11.cols a11.cols
780   let lemma naive_blocks (a b c a11 a12 a21 a22 b11 b12 b21 b22 c11 c12 c21 c22: mat)
781     requires { is_finite a /\ is_finite b /\ is_finite c }
782     requires { quarters a a11 a12 a21 a22 }
783     requires { quarters b b11 b12 b21 b22 }
784     requires { quarters c c11 c12 c21 c22 }
785     requires { c11 = add (mul a11 b11) (mul a12 b21) }
786     requires { c12 = add (mul a11 b12) (mul a12 b22) }
787     requires { c21 = add (mul a21 b11) (mul a22 b21) }
788     requires { c22 = add (mul a21 b12) (mul a22 b22) }
789     ensures  { c = mul a b }
790   =
791     assert { c == mul a b }
793   use int.Power
794   use number.Parity
795   use int.ComputerDivision
797   let ghost function cut_quarters (a: mat) : (mat, mat, mat, mat)
798     requires { is_finite a }
799     requires { rows a = cols a }
800     requires { even (rows a) }
801     returns  { (a11, a12, a21, a22) -> quarters a a11 a12 a21 a22 }
802   =
803     let s = div (rows a) 2 in
804     (block a 0 s 0 s, block a 0 s s s, block a s s 0 s, block a s s s s)
806   let ghost function paste_quarters (a11 a12 a21 a22: mat): mat
807     requires { is_finite a11 /\ is_finite a12 /\ is_finite a21 /\ is_finite a22 }
808     requires { rows a11 = rows a12 = rows a21 = rows a22
809                = cols a11 = cols a12 = cols a21 = cols a22 }
810     ensures  { quarters result a11 a12 a21 a22 }
811   =
812     let s = rows a11 in
813     let r = fcreate (2 * s) (2 * s)
814             (fun i j -> if i < s && j < s then get a11 i j
815                         else if i < s then get a12 i (j-s)
816                         else if j < s then get a21 (i-s) j
817                         else get a22 (i-s) (j-s)) in
818     assert { a11 = block r 0 s 0 s by a11 == block r 0 s 0 s };
819     assert { a12 = block r 0 s s s by a12 == block r 0 s s s };
820     assert { a21 = block r s s 0 s by a21 == block r s s 0 s };
821     assert { a22 = block r s s s s by a22 == block r s s s s };
822     r
824   meta "compute_max_steps" 0x100000
826   let rec ghost function strassen_pow2 (a b: mat) (ghost k: int)
827     requires { 0 <= k }
828     requires { size a (power 2 k) (power 2 k) }
829     requires { size b (power 2 k) (power 2 k) }
830     ensures  { result = mul a b }
831     variant  { k }
832   =
833     let cutoff = begin ensures { result >= 1 } 4 end in
834     if k <= cutoff then mul a b
835     else begin
836       let (a11, a12, a21, a22) = cut_quarters a in
837       let (b11, b12, b21, b22) = cut_quarters b in
838       let s = power 2 (k-1) in
839       assert { s > 0 by k-1 >= 1 so power 2 (k-1) >= power 2 1 = 2};
840       assert { size a11 s s /\ size a12 s s /\ size a21 s s /\ size a22 s s };
841       assert { size b11 s s /\ size b12 s s /\ size b21 s s /\ size b22 s s };
842       let ghost c11 = add (mul a11 b11) (mul a12 b21) in
843       let ghost c12 = add (mul a11 b12) (mul a12 b22) in
844       let ghost c21 = add (mul a21 b11) (mul a22 b21) in
845       let ghost c22 = add (mul a21 b12) (mul a22 b22) in
846       mul_sizes a11 b11 s s s;
847       assert { size c11 s s /\ size c12 s s /\ size c21 s s /\ size c22 s s };
848       let ghost c = paste_quarters c11 c12 c21 c22 in
849       assert { c = mul a b };
850       let m1 = strassen_pow2 (add a11 a22) (add b11 b22) (k-1) in
851       let m2 = strassen_pow2 (add a21 a22) b11 (k-1) in
852       let m3 = strassen_pow2 a11 (sub b12 b22) (k-1) in
853       let m4 = strassen_pow2 a22 (sub b21 b11) (k-1) in
854       let m5 = strassen_pow2 (add a11 a12) b22 (k-1) in
855       let m6 = strassen_pow2 (sub a21 a11) (add b11 b12) (k-1) in
856       let m7 = strassen_pow2 (sub a12 a22) (add b21 b22) (k-1) in
857       let s11 = add m1 (add m4 (sub m7 m5)) in
858       let s12 = add m3 m5 in
859       let s21 = add m2 m4 in
860       let s22 = add m1 (add m3 (sub m6 m2)) in
861       (* assertions proved by reflection *)
862       assert { s11 = c11 };
863       assert { s12 = c12 };
864       assert { s21 = c21 };
865       assert { s22 = c22 };
866       paste_quarters s11 s12 s21 s22
867       end