7 clone algebra.UnitaryCommutativeRing as F with axiom .
9 function get (mat F.t) int int : F.t
10 function set (mat F.t) int int F.t : mat F.t
12 function row_zeros (mat F.t) int : int
13 function col_zeros (mat F.t) int : int
16 forall m: mat F.t, i j: int. 0 <= i -> j >= row_zeros m i -> get m i j = F.zero
19 forall m: mat F.t, i j: int. 0 <= j -> i >= col_zeros m j -> get m i j = F.zero
21 axiom row_zeros_nonneg:
22 forall m: mat F.t, i: int. 0 <= i -> 0 <= row_zeros m i
24 axiom col_zeros_nonneg:
25 forall m: mat F.t, j: int. 0 <= j -> 0 <= col_zeros m j
26 (*FIXME should be invariants*)
28 axiom set_def_changed:
29 forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
30 get (set m i j v) i j = v
32 axiom set_def_unchanged:
33 forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
34 forall i' j': int. 0 <= i' -> 0 <= j' -> (i <> i' \/ j <> j') ->
35 get (set m i j v) i' j' = get m i' j'
37 axiom set_def_rowz_changed:
38 forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
39 j >= row_zeros m i -> row_zeros (set m i j v) i = j+1
41 axiom set_def_colz_changed:
42 forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
43 i >= col_zeros m j -> col_zeros (set m i j v) j = i+1
45 axiom set_def_rowz_unchanged:
46 forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
47 j < row_zeros m i -> row_zeros (set m i j v) i = row_zeros m i
49 axiom set_def_colz_unchanged:
50 forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
51 i < col_zeros m j -> col_zeros (set m i j v) j = col_zeros m j
53 axiom set_def_other_rowz:
54 forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
55 forall i': int. 0 <= i' -> i <> i' ->
56 row_zeros (set m i j v) i' = row_zeros m i'
58 axiom set_def_other_colz:
59 forall m: mat F.t, i j: int, v: F.t. 0 <= i -> 0 <= j ->
60 forall j': int. 0 <= j' -> j <> j' ->
61 col_zeros (set m i j v) j' = col_zeros m j'
63 predicate (==) (m1 m2: mat F.t) =
64 forall i j: int. 0 <= i -> 0 <= j -> get m1 i j = get m2 i j
67 forall m1 m2: mat F.t. m1 == m2 -> m1 = m2
69 predicate (===) (m1 m2: mat F.t) =
70 forall i j: int. 0 <= i -> 0 <= j ->
71 row_zeros m1 i = row_zeros m2 i /\ col_zeros m1 j = col_zeros m2 j
73 predicate in_bounds (m: mat F.t) (i j: int) =
74 0 <= i < col_zeros m j /\ 0 <= j < row_zeros m i
76 let lemma ext_by_bounds (m1 m2: mat F.t)
77 requires { m1 === m2 }
78 requires { forall i j. in_bounds m1 i j -> get m1 i j = get m2 i j }
83 forall m: mat F.t, i j: int. 0 <= i -> 0 <= j -> not in_bounds m i j
86 predicate size (m: mat F.t) (r c: int) =
87 (forall i: int. 0 <= i -> row_zeros m i = c)
88 /\ (forall j: int. 0 <= j -> col_zeros m j = r)
91 forall m: mat F.t, r c i j: int. size m r c -> (in_bounds m i j <-> (0 <= i < r /\ 0 <= j < c))
94 forall a b: mat F.t, r c: int. a === b -> (size a r c <-> size b r c)
97 forall a: mat F.t, r c i: int. size a r c ->
98 0 <= i < r -> row_zeros a i = c
99 by forall j: int. in_bounds a i j -> 0 <= j < c
102 forall a b: mat F.t, r c: int. size a r c -> size b r c -> a === b
111 clone export algebra.UnitaryCommutativeRing with
112 type t = t, constant zero = tzero, axiom .
115 clone export relations.MinMax with
116 type t = int, predicate le = (<=),
117 axiom . (* FIXME: replace with "goal" and prove *)
121 val function get mat int int : t
123 val function row_zeros mat int : int
124 val function col_zeros mat int : int
126 val function create (rz: int -> int) (cz: int -> int) (f: int -> int -> t)
130 forall rz cz: int -> int, f: int -> int -> t, i: int.
131 0 <= i -> 0 <= rz i -> row_zeros (create rz cz f) i = rz i
134 forall rz cz: int -> int, f: int -> int -> t, j: int.
135 0 <= j -> 0 <= cz j -> col_zeros (create rz cz f) j = cz j
138 forall rz cz: int -> int, f: int -> int -> t, i j: int.
139 0 <= i < cz j -> 0 <= j < rz i -> get (create rz cz f) i j = f i j
141 axiom create_get_oob:
142 forall rz cz: int -> int, f: int -> int -> t, i j: int.
143 0 <= i -> 0 <= j -> (i >= cz j \/ j >= rz i) ->
144 get (create rz cz f) i j = tzero
146 let ghost function set (m: mat) (i j:int) (v:t) : mat =
150 (fun i1 -> if i1 = i then max (j+1) (row_zeros m i) else row_zeros m i1)
151 (fun j1 -> if j1 = j then max (i+1) (col_zeros m j) else col_zeros m j1)
152 (fun i1 j1 -> if i1 = i && j1 = j then v else get m i1 j1)
155 clone export InfMatrixGen with type mat 'a = mat,
159 function row_zeros = row_zeros,
160 function col_zeros = col_zeros,
161 lemma set_def_changed,
162 lemma set_def_unchanged,
163 lemma set_def_colz_changed,
164 lemma set_def_colz_unchanged,
165 lemma set_def_rowz_changed,
166 lemma set_def_rowz_unchanged,
167 lemma set_def_other_rowz,
168 lemma set_def_other_colz,
169 axiom . (* FIXME: replace with "goal" and prove *)
171 let ghost function fcreate (r c: int) (f: int -> int -> t) : mat =
172 create (fun _ -> max 0 c) (fun _ -> max 0 r) f
174 lemma fcreate_get_ib:
175 forall r c i j: int, f: int -> int -> t.
176 0 <= i < r -> 0 <= j < c -> get (fcreate r c f) i j = f i j
178 lemma fcreate_get_oob:
179 forall r c i j: int, f: int -> int -> t.
180 0 <= i -> 0 <= j -> (i >= r \/ j >= c) -> get (fcreate r c f) i j = tzero
183 forall r c: int, f: int -> int -> t. 0 <= r -> 0 <= c ->
184 size (fcreate r c f) r c
189 (* copied over from verifythis_2016_matrix_multiplication *)
195 function addf (f g:int -> int) : int -> int = fun x -> f x + g x
197 function smulf (f:int -> int) (l:int) : int -> int = fun x -> l * f x
199 let rec lemma sum_mult (f:int -> int) (a b l:int) : unit
200 ensures { sum (smulf f l) a b = l * sum f a b }
202 = if b > a then sum_mult f a (b-1) l
205 let rec lemma sum_add (f g:int -> int) (a b:int) : unit
206 ensures { sum (addf f g) a b = sum f a b + sum g a b }
208 = if b > a then sum_add f g a (b-1)
211 function sumf (f:int -> int -> int) (a b:int) : int -> int = fun x -> sum (f x) a b
213 let rec lemma fubini (f1 f2: int -> int -> int) (a b c d: int) : unit
214 requires { forall x y. a <= x < b /\ c <= y < d -> f1 x y = f2 y x }
215 ensures { sum (sumf f1 c d) a b = sum (sumf f2 a b) c d }
218 then assert { forall x. sumf f2 a b x = 0 }
220 fubini f1 f2 a (b-1) c d;
221 assert { let ha = addf (sumf f2 a (b-1)) (f1 (b-1)) in
222 sum (sumf f2 a b) c d = sum ha c d
223 by forall y. c <= y < d -> sumf f2 a b y = ha y }
226 let ghost sum_ext (f g: int -> int) (a b: int) : unit
227 requires { forall i. a <= i < b -> f i = g i }
228 ensures { sum f a b = sum g a b }
238 (* maximum of a function over an interval; always at least 0 *)
240 let rec function maxf (f: int -> int) (a b: int) : int
242 = if b <= a then zero else max (maxf f a (b - 1)) (f (b-1))
244 let rec lemma maxf_nonneg (f: int -> int) (a b: int)
246 ensures { maxf f a b >= 0 }
248 = if a = b then () else maxf_nonneg f a (b-1)
250 let rec lemma maxf_larger (f: int -> int) (a b i: int)
251 requires { a <= i < b }
253 ensures { maxf f a b >= f i }
254 = if i = b-1 then () else maxf_larger f a (b-1) i
256 let rec lemma max_left (f: int -> int) (a b: int)
258 ensures { maxf f a b = max (f a) (maxf f (a+1) b) }
260 = if a = b-1 then () else max_left f a (b-1)
262 let rec lemma max_ext (f g: int -> int) (a b: int)
265 requires { forall i. a <= i < b -> f i = g i }
266 ensures { maxf f a b = maxf g a b }
267 = if a = b-1 then () else max_ext f g a (b-1)
269 let rec lemma max_decomp (f: int -> int) (a b c: int)
270 requires { a <= b <= c }
272 ensures { maxf f a c = max (maxf f a b) (maxf f b c) }
274 then assert { maxf f a c = max (maxf f a b) (maxf f b c)
278 max_decomp f a b (c-1);
279 assert { maxf f a c = max (maxf f a b) (maxf f b c)
280 by maxf f a c = max (maxf f a (c-1)) (f (c-1))
281 so maxf f b c = max (maxf f b (c-1)) (f (c-1))
282 so maxf f a c = max (max (maxf f a b) (maxf f b (c-1))) (f (c-1))
283 = max (maxf f a b) (maxf f b c) }
286 let rec lemma max_constant (f: int -> int) (v a b: int)
289 requires { forall i. a <= i < b -> f i = v }
290 ensures { maxf f a b = v }
292 = if a = b-1 then () else max_constant f v a (b-1)
299 clone export InfMatrix with
300 type t = int, constant tzero = zero,
301 axiom . (* FIXME: replace with "goal" and prove *)
304 use int.Int (*FIXME needed so i < i+1 ?*)
308 let constant zerof : int -> int -> int = fun _ _ -> 0
310 let ghost constant mzero : mat = fcreate 0 0 zerof
312 let ghost function zerorc (r c: int) : mat = fcreate r c zerof
314 (* Identity matrix *)
316 let constant idf : int -> int -> int = fun x y -> if x = y then 1 else 0
318 let constant id : mat = create (fun i -> i+1) (fun j -> j+1) idf
321 forall i. 0 <= i -> get id i i = 1
323 function idrc (r c: int) : mat = fcreate r c idf
325 (* Matrix addition *)
327 let ghost function add2f (a b: mat) : int -> int -> int =
328 fun x y -> get a x y + get b x y
330 function f_add (a b: mat) : mat =
331 create (fun i -> max (row_zeros a i) (row_zeros b i))
332 (fun j -> max (col_zeros a j) (col_zeros b j))
335 val function add (a b: mat) : mat
336 ensures { result = f_add a b }
339 forall a b: mat, i j: int. 0 <= i -> 0 <= j ->
340 get (add a b) i j = get a i j + get b i j
343 forall a b: mat. a === b -> a === add a b === b
346 forall a b: mat, r c: int. size a r c -> size b r c -> size (add a b) r c
348 (in_bounds a i j \/ in_bounds b i j) <-> in_bounds (add a b) i j)
351 forall a b c: mat. add (add a b) c = add a (add b c)
352 by add (add a b) c == add a (add b c)
354 lemma add_commutative:
355 forall a b: mat. add a b = add b a by add a b == add b a
358 forall a. add a mzero = a by add a mzero == a
360 (* Matrix additive inverse *)
361 function opp2f (a: mat) : int -> int -> int =
362 fun x y -> - get a x y
364 function f_opp (a: mat) : mat =
365 create (row_zeros a) (col_zeros a) (opp2f a)
367 val function opp (a: mat) : mat
368 ensures { result = f_opp a }
370 let function sub (a b: mat) : mat = add a (opp b)
373 forall a b: mat, r c: int. size a r c -> size b r c -> size (sub a b) r c
375 (in_bounds a i j \/ in_bounds b i j) <-> in_bounds (sub a b) i j)
377 lemma opp_involutive:
378 forall m. opp (opp m) = m by opp (opp m) == m
380 (* Matrix multiplication *)
382 function mul_atom (a b: mat) (i j: int) : int -> int =
383 fun k -> get a i k * get b k j
386 forall a b: mat, i j k: int. 0 <= i -> 0 <= j ->
387 k >= row_zeros a i \/ k >= col_zeros b j -> mul_atom a b i j k = 0
388 by if k >= row_zeros a i then get a i k = 0
389 else k >= col_zeros b j so get b k j = 0
391 function mul_cell_bound (a b: mat) (i j: int) : int
392 = min (row_zeros a i) (col_zeros b j) (* row_zeros a i*)
394 function mul_cell (a b: mat) : int -> int -> int =
395 fun i j -> sum (mul_atom a b i j) 0 (mul_cell_bound a b i j)
400 forall a b: mat, i j: int.
401 j >= maxf (fun k -> row_zeros b k) 0 (row_zeros a i) ->
403 by forall k. 0 <= k < mul_cell_bound a b i j ->
404 mul_atom a b i j k = 0
405 by 0 <= k < row_zeros a i
406 so j >= row_zeros b k
410 forall a b: mat, i j: int.
411 i >= maxf (fun k -> col_zeros a k) 0 (col_zeros b j) ->
413 by forall k. 0 <= k < mul_cell_bound a b i j ->
414 mul_atom a b i j k = 0
415 by 0 <= k < col_zeros b j
416 so i >= col_zeros a k
419 function f_mul (a b: mat) : mat =
420 create (fun i -> maxf (fun k -> row_zeros b k) 0 (row_zeros a i))
421 (fun j -> maxf (fun k -> col_zeros a k) 0 (col_zeros b j))
424 val function mul (a b: mat) : mat
425 ensures { result = f_mul a b }
427 let lemma mul_sizes (m1 m2: mat) (m n p: int)
428 requires { size m1 m n /\ size m2 n p }
429 requires { 0 < m /\ 0 < n /\ 0 < p }
430 ensures { size (mul m1 m2) m p }
433 max_constant (fun k -> row_zeros m2 k) p 0 n;
434 assert { forall i. 0 <= i -> row_zeros r i = p };
435 max_constant (fun k -> col_zeros m1 k) m 0 n;
436 assert { forall j. 0 <= j -> col_zeros r j = m }
440 forall m: mat. mul m id = m
442 by (forall i j. in_bounds m i j -> mul_cell m id i j = get m i j)
443 /\ (forall i j. 0 <= i -> 0 <= j -> not (in_bounds m i j)
444 -> mul_cell m id i j = 0 = get m i j))
447 forall m: mat. mul id m = m
449 by (forall i j. in_bounds m i j ->
450 mul_cell id m i j = get m i j
451 by let t = mul_cell_bound id m i j in
452 sum (mul_atom id m i j) 0 i = 0
453 so sum (mul_atom id m i j) (i+1) t = 0
455 = sum (mul_atom id m i j) 0 t
456 = sum (mul_atom id m i j) 0 i + mul_atom id m i j i + sum (mul_atom id m i j) (i+1) t
457 = 0 + (get id i i)*(get m i j) + 0
459 /\ (forall i j. 0 <= i -> 0 <= j -> not (in_bounds m i j)
460 -> (forall k. 0 <= k < mul_cell_bound id m i j -> mul_atom id m i j k = 0)
462 = sum (fun k -> mul_atom id m i j k) 0 (mul_cell_bound id m i j)
467 function ft1 (a b c: mat) (i j: int) : int -> int -> int =
468 fun k -> smulf (mul_atom a b i k) (get c k j)
470 function ft2 (a b c: mat) (i j: int) : int -> int -> int =
471 fun k -> smulf (mul_atom b c k j) (get a i k)
473 let lemma mul_assoc_get (a b c: mat) (i j: int)
474 requires { 0 <= i /\ 0 <= j }
475 ensures { get (mul (mul a b) c) i j = get (mul a (mul b c)) i j }
476 = let ft1 = ft1 a b c i j in
477 let ft2 = ft2 a b c i j in
480 let m_ab_c = mul_cell_bound ab c i j in
481 let m_a_bc = mul_cell_bound a bc i j in
482 fubini ft1 ft2 0 m_ab_c 0 m_a_bc;
483 assert { forall k. 0 <= k < m_ab_c -> mul_cell_bound a b i k <= m_a_bc
484 by mul_cell_bound a b i k <= row_zeros a i
485 so mul_cell_bound a b i k <= col_zeros b k
486 so col_zeros bc j = maxf (fun k -> col_zeros b k) 0 (col_zeros c j)
487 so 0 <= k < col_zeros c j
488 so col_zeros b k <= maxf (fun k -> col_zeros b k) 0 (col_zeros c j)
489 so col_zeros b k <= col_zeros bc j };
490 assert { forall k. 0 <= k < m_ab_c ->
491 sumf ft1 0 m_a_bc k = sumf ft1 0 (mul_cell_bound a b i k) k
492 by sumf ft1 0 m_a_bc k
493 = sum (ft1 k) 0 m_a_bc
494 = sum (ft1 k) 0 (mul_cell_bound a b i k)
495 + sum (ft1 k) (mul_cell_bound a b i k) m_a_bc
496 = sumf ft1 0 (mul_cell_bound a b i k) k
497 + sum (ft1 k) (mul_cell_bound a b i k) m_a_bc
498 so (forall l. l >= mul_cell_bound a b i k ->
499 mul_atom a b i k l = 0
501 so sum (ft1 k) (mul_cell_bound a b i k) m_a_bc = 0 };
502 assert { forall k. 0 <= k < m_ab_c ->
503 mul_atom ab c i j k = sumf ft1 0 m_a_bc k
504 by get ab i k = mul_cell a b i k
505 so sumf ft1 0 m_a_bc k
506 = sumf ft1 0 (mul_cell_bound a b i k) k
507 = sum (ft1 k) 0 (mul_cell_bound a b i k)
508 = sum (smulf (mul_atom a b i k) (get c k j))
509 0 (mul_cell_bound a b i k)
510 = get c k j * sum (mul_atom a b i k) 0 (mul_cell_bound a b i k)
511 = get c k j * get ab i k
512 = mul_atom ab c i j k };
513 sum_ext (mul_atom ab c i j) (sumf ft1 0 m_a_bc) 0 m_ab_c;
514 assert { get (mul ab c) i j = sum (sumf ft1 0 m_a_bc) 0 m_ab_c };
515 assert { forall k. 0 <= k < m_a_bc -> mul_cell_bound b c k j <= m_ab_c
516 by mul_cell_bound b c k j <= col_zeros c j
517 so mul_cell_bound b c k j <= row_zeros b k
518 so row_zeros ab i = maxf (fun k -> row_zeros b k) 0 (row_zeros a i)
519 so 0 <= k < row_zeros a i
520 so row_zeros b k <= maxf (fun k -> row_zeros b k) 0 (row_zeros a i)
521 so row_zeros b k <= row_zeros ab i };
522 assert { forall k. 0 <= k < m_a_bc ->
523 sumf ft2 0 m_ab_c k = sumf ft2 0 (mul_cell_bound b c k j) k
524 by sumf ft2 0 m_ab_c k
525 = sum (ft2 k) 0 m_ab_c
526 = sum (ft2 k) 0 (mul_cell_bound b c k j)
527 + sum (ft2 k) (mul_cell_bound b c k j) m_ab_c
528 = sumf ft2 0 (mul_cell_bound b c k j) k
529 + sum (ft2 k) (mul_cell_bound b c k j) m_ab_c
530 so (forall l. l >= mul_cell_bound b c k j ->
531 mul_atom b c k j l = 0
533 so sum (ft2 k) (mul_cell_bound b c k j) m_ab_c = 0 };
534 assert { forall k. 0 <= k < m_a_bc ->
535 mul_atom a bc i j k = sumf ft2 0 m_ab_c k
536 by get bc k j = mul_cell b c k j
537 so sumf ft2 0 m_ab_c k
538 = sumf ft2 0 (mul_cell_bound b c k j) k
539 = sum (ft2 k) 0 (mul_cell_bound b c k j)
540 = sum (smulf (mul_atom b c k j) (get a i k))
541 0 (mul_cell_bound b c k j)
542 = get a i k * sum (mul_atom b c k j) 0 (mul_cell_bound b c k j)
543 = get a i k * get bc k j
544 = mul_atom a bc i j k };
545 sum_ext (mul_atom a bc i j) (sumf ft2 0 m_ab_c) 0 m_a_bc;
546 assert { get (mul a bc) i j = sum (sumf ft2 0 m_ab_c) 0 m_a_bc }
552 let a_bc = mul a bc in
553 let ab_c = mul ab c in
554 a_bc = ab_c by a_bc == ab_c
557 let lemma mul_distr_right_get (a b c: mat) (i j: int)
558 requires { 0 <= i /\ 0 <= j }
559 ensures { get (mul (add a b) c) i j = get (add (mul a c) (mul b c)) i j }
560 = let mac = mul_atom a c i j in
561 let mbc = mul_atom b c i j in
562 let b_ac = mul_cell_bound a c i j in
563 let b_bc = mul_cell_bound b c i j in
564 let ma = max b_ac b_bc in
565 assert { get (add (mul a c) (mul b c)) i j = sum (addf mac mbc) 0 ma
566 by sum mac 0 ma = sum mac 0 b_ac + sum mac b_ac ma
567 so forall k. k >= b_ac -> mac k = 0
568 so sum mac b_ac ma = 0
569 so sum mac 0 b_ac = sum mac 0 ma
570 so sum mbc 0 ma = sum mbc 0 b_bc + sum mbc b_bc ma
571 so forall k. k >= b_bc -> mbc k = 0
572 so sum mbc b_bc ma = 0
573 so sum mbc 0 b_bc = sum mbc 0 ma
574 so get (mul a c) i j = sum mac 0 b_ac
575 so get (mul b c) i j = sum mbc 0 b_bc
576 so get (add (mul a c) (mul b c)) i j
577 = get (mul a c) i j + get (mul b c) i j
578 = sum mac 0 b_ac + sum mbc 0 b_bc
579 = sum mac 0 ma + sum mbc 0 ma
580 = sum (addf mac mbc) 0 ma };
581 sum_ext (addf mac mbc) (mul_atom (add a b) c i j) 0 ma;
582 assert { get (mul (add a b) c) i j = mul_cell (add a b) c i j }
585 (* External product *)
587 function extf (c: int) (a: mat): int -> int -> int =
588 fun x y -> c * (get a x y)
590 function f_extp (c: int) (a: mat) : mat =
591 create (fun i -> row_zeros a i) (fun j -> col_zeros a j) (extf c a)
593 val function extp (c: int) (a: mat) : mat
594 ensures { result = f_extp c a }
597 forall m: mat, r: int. extp r m === m
600 forall m: mat, r i j: int. 0 <= i -> 0 <= j ->
601 get (extp r m) i j = r * (get m i j)
603 lemma ext_dist_sum_mat:
604 forall x y: mat, r: int. extp r (add x y) = add (extp r x) (extp r y)
605 by extp r (add x y) == add (extp r x) (extp r y)
607 lemma ext_dist_sum_r:
608 forall x: mat, r s: int. extp (r+s) x = add (extp r x) (extp s x)
609 by extp (r+s) x == add (extp r x) (extp s x)
612 forall x: mat, r s: int. extp (r*s) x = extp r (extp s x)
613 by extp (r*s) x == extp r (extp s x)
616 forall x: mat. extp 1 x = x by extp 1 x == x
618 let lemma comm_mul_ext_ij (x y: mat) (r i j: int)
619 requires { 0 <= i /\ 0 <= j }
620 ensures { get (mul (extp r x) y) i j = r * (get (mul x y) i j) }
621 ensures { get (mul x (extp r y)) i j = r * (get (mul x y) i j) }
623 let b = mul_cell_bound x y i j in
624 assert { mul_cell_bound (extp r x) y i j = b
625 = mul_cell_bound x (extp r y) i j };
626 sum_ext (mul_atom (extp r x) y i j) (smulf (mul_atom x y i j) r) 0 b;
627 sum_ext (mul_atom x (extp r y) i j) (smulf (mul_atom x y i j) r) 0 b;
628 sum_mult (mul_atom (extp r x) y i j) 0 b r;
629 sum_mult (mul_atom x (extp r y) i j) 0 b r;
630 assert { get (mul (extp r x) y) i j
631 = r * (get (mul x y) i j)
632 = get (mul x (extp r y)) i j
633 by get (mul (extp r x) y) i j
634 = mul_cell (extp r x) y i j
635 = r * mul_cell x y i j
636 = mul_cell x (extp r y) i j
637 = get (mul x (extp r y)) i j
638 so r * mul_cell x y i j = r * (get (mul x y) i j) }
641 forall x y: mat, r: int.
642 extp r (mul x y) = mul (extp r x) y = mul x (extp r y)
643 by extp r (mul x y) == mul (extp r x) y == mul x (extp r y)
647 module InfIntMatrixDecision
652 let predicate eq0_int (x:int) = x = 0
654 clone export ringdecision.AssocAlgebraDecision with type r = int, type a = mat, val rzero = Int.zero, val rone = Int.one, val rplus = (+), val ropp = (-_), val rtimes = (*), val azero = mzero, val aone = id, val aplus = add, val aopp = opp, val atimes = mul, val asub = sub, val ($) = extp, goal AUnitary, goal ANonTrivial, goal ExtDistSumA, goal ExtDistSumR, goal AssocMulExt, goal UnitExt, goal CommMulExt, val eq0 = eq0_int, goal A.MulAssoc.Assoc, goal A.Unit_def_l, goal A.Unit_def_r, goal A.Comm, goal A.Assoc, goal A.Mul_distr_l, goal A.Mul_distr_r, goal asub_def, goal A.Inv_def_l, goal A.Inv_def_r,
655 axiom . (* FIXME: replace with "goal" and prove *)
657 meta reflection val norm_f
665 use InfIntMatrixDecision
669 function cols (a: mat) : int (* if matrix is a finite rectangle, return number of cols *)
670 function rows (a: mat) : int
672 (* lemma t: forall a: mat, r1 r2 c: int. size a r1 c -> size a r2 c -> r1 = r2*)
675 forall a: mat, r c: int. 0 <= r -> 0 <= c -> size a r c -> rows a = r
678 forall a: mat, r c: int. 0 <= r -> 0 <= c -> size a r c -> cols a = c
680 predicate is_finite (m: mat) = size m m.rows m.cols
682 function ofs2 (a: mat) (ai aj: int) : int -> int -> int
683 = fun i j -> get a (ai + i) (aj + j)
685 function block (a: mat) (r dr c dc: int) : mat =
686 fcreate dr dc (ofs2 a r c)
688 predicate c_blocks (a a1 a2: mat) =
689 0 <= a1.cols <= a.cols /\ a1 = block a 0 a.rows 0 a1.cols /\
690 a2 = block a 0 a.rows a1.cols (a.cols - a1.cols)
692 predicate r_blocks (a a1 a2: mat) =
693 0 <= a1.rows <= a.rows /\ a1 = block a 0 a1.rows 0 a.cols /\
694 a2 = block a a1.rows (a.rows - a1.rows) 0 a.cols
696 let rec lemma block_mul_ij (a a1 a2 b b1 b2: mat) (k: int)
697 requires { a.cols = b.rows /\ a1.cols = b1.rows}
698 requires { 0 <= k <= a.cols }
699 requires { c_blocks a a1 a2 /\ r_blocks b b1 b2 }
700 ensures { forall i j. 0 <= i < a.rows -> 0 <= j < b.cols ->
702 sum (mul_atom a b i j) 0 k = sum (mul_atom a1 b1 i j) 0 k }
703 ensures { forall i j. 0 <= i < a.rows -> 0 <= j < b.cols ->
704 a1.cols <= k <= a.cols ->
705 sum (mul_atom a b i j) 0 k =
706 sum (mul_atom a1 b1 i j) 0 a1.cols +
707 sum (mul_atom a2 b2 i j) 0 (k - a1.cols) }
709 = if 0 < k then begin
711 assert { forall i j. 0 <= i < a.rows -> 0 <= j < b.cols ->
713 then mul_atom a b i j k = mul_atom a1 b1 i j k
714 else (mul_atom a b i j k = mul_atom a2 b2 i j (k - a1.cols)
715 by get a i k = get a2 i (k - a1.cols)
716 so get b k j = get b2 (k-a1.cols) j)};
717 block_mul_ij a a1 a2 b b1 b2 k
720 let lemma mul_split (a a1 a2 b b1 b2: mat) : unit
721 requires { is_finite a /\ is_finite b }
722 requires { a.cols = b.rows /\ a1.cols = b1.rows}
723 requires { 0 < a.rows /\ 0 < a.cols /\ 0 < b.cols
724 /\ 0 < a1.cols /\ 0 < a2.cols }
725 requires { c_blocks a a1 a2 /\ r_blocks b b1 b2 }
726 ensures { add (mul a1 b1) (mul a2 b2) = mul a b }
727 = block_mul_ij a a1 a2 b b1 b2 a.cols;
728 mul_sizes a b a.rows a.cols b.cols;
729 mul_sizes a1 b1 a.rows a1.cols b.cols;
730 mul_sizes a2 b2 a.rows (a.cols - a1.cols) b.cols;
731 assert { add (mul a1 b1) (mul a2 b2) === mul a b
732 by size (add (mul a1 b1) (mul a2 b2)) a.rows b.cols
733 so size (mul a b) a.rows b.cols };
734 assert { forall i j. in_bounds (mul a b) i j ->
735 get (add (mul a1 b1) (mul a2 b2)) i j = get (mul a b) i j
736 by mul_cell_bound a1 b1 i j = a1.cols
737 so mul_cell_bound a2 b2 i j = a2.cols = a.cols - a1.cols
740 = sum (mul_atom a b i j) 0 a.cols
741 = sum (mul_atom a1 b1 i j) 0 a1.cols
742 + sum (mul_atom a2 b2 i j) 0 (a.cols - a1.cols)
743 = get (mul a1 b1) i j + get (mul a2 b2) i j
744 = get (add (mul a1 b1) (mul a2 b2)) i j };
745 ext_by_bounds (add (mul a1 b1) (mul a2 b2)) (mul a b)
747 let lemma mul_block_cell (a b: mat) (r dr c dc i j: int) : unit
748 requires { is_finite a /\ is_finite b }
749 requires { a.cols = b.rows }
750 requires { 0 <= r /\ r + dr <= a.rows }
751 requires { 0 <= c /\ c + dc <= b.cols }
752 requires { 0 <= i < dr /\ 0 <= j < dc }
753 ensures { ofs2 (mul a b) r c i j =
754 get (mul (block a r dr 0 a.cols) (block b 0 b.rows c dc)) i j }
755 = let a' = block a r dr 0 a.cols in
756 let b' = block b 0 b.rows c dc in
757 sum_ext (mul_atom a b (i + r) (j + c)) (mul_atom a' b' i j) 0 a.cols;
758 assert { ofs2 (mul a b) r c i j = get (mul a b) (i+r) (j+c)
759 = sum (mul_atom a b (i+r) (j+c)) 0 a.cols
760 = sum (mul_atom a' b' i j) 0 a.cols
761 = get (mul a' b') i j }
763 let lemma mul_block (a b a' b' m': mat) (r dr c dc: int)
764 requires { a.cols = b.rows }
765 requires { 0 <= r <= r + dr <= a.rows }
766 requires { 0 <= c <= c + dc <= b.cols }
767 requires { a' = block a r dr 0 a.cols }
768 requires { b' = block b 0 b.rows c dc }
769 requires { m' = block (mul a b) r dr c dc }
770 ensures { m' = mul a' b' }
771 = assert { m' == mul a' b' }
773 predicate quarters (a a11 a12 a21 a22: mat) =
774 (is_finite a /\ is_finite a11 /\ is_finite a12 /\ is_finite a21 /\ is_finite a22) /\
775 (rows a11 = rows a12 = rows a21 = rows a22 = cols a11 = cols a12 = cols a21 = cols a22) /\
776 rows a = cols a = 2 * rows a11 /\
777 a11 = block a 0 a11.rows 0 a11.cols /\ a12 = block a 0 a11.rows a11.cols a11.cols /\
778 a21 = block a a11.rows a11.rows 0 a11.cols /\ a22 = block a a11.rows a11.rows a11.cols a11.cols
780 let lemma naive_blocks (a b c a11 a12 a21 a22 b11 b12 b21 b22 c11 c12 c21 c22: mat)
781 requires { is_finite a /\ is_finite b /\ is_finite c }
782 requires { quarters a a11 a12 a21 a22 }
783 requires { quarters b b11 b12 b21 b22 }
784 requires { quarters c c11 c12 c21 c22 }
785 requires { c11 = add (mul a11 b11) (mul a12 b21) }
786 requires { c12 = add (mul a11 b12) (mul a12 b22) }
787 requires { c21 = add (mul a21 b11) (mul a22 b21) }
788 requires { c22 = add (mul a21 b12) (mul a22 b22) }
789 ensures { c = mul a b }
791 assert { c == mul a b }
795 use int.ComputerDivision
797 let ghost function cut_quarters (a: mat) : (mat, mat, mat, mat)
798 requires { is_finite a }
799 requires { rows a = cols a }
800 requires { even (rows a) }
801 returns { (a11, a12, a21, a22) -> quarters a a11 a12 a21 a22 }
803 let s = div (rows a) 2 in
804 (block a 0 s 0 s, block a 0 s s s, block a s s 0 s, block a s s s s)
806 let ghost function paste_quarters (a11 a12 a21 a22: mat): mat
807 requires { is_finite a11 /\ is_finite a12 /\ is_finite a21 /\ is_finite a22 }
808 requires { rows a11 = rows a12 = rows a21 = rows a22
809 = cols a11 = cols a12 = cols a21 = cols a22 }
810 ensures { quarters result a11 a12 a21 a22 }
813 let r = fcreate (2 * s) (2 * s)
814 (fun i j -> if i < s && j < s then get a11 i j
815 else if i < s then get a12 i (j-s)
816 else if j < s then get a21 (i-s) j
817 else get a22 (i-s) (j-s)) in
818 assert { a11 = block r 0 s 0 s by a11 == block r 0 s 0 s };
819 assert { a12 = block r 0 s s s by a12 == block r 0 s s s };
820 assert { a21 = block r s s 0 s by a21 == block r s s 0 s };
821 assert { a22 = block r s s s s by a22 == block r s s s s };
824 meta "compute_max_steps" 0x100000
826 let rec ghost function strassen_pow2 (a b: mat) (ghost k: int)
828 requires { size a (power 2 k) (power 2 k) }
829 requires { size b (power 2 k) (power 2 k) }
830 ensures { result = mul a b }
833 let cutoff = begin ensures { result >= 1 } 4 end in
834 if k <= cutoff then mul a b
836 let (a11, a12, a21, a22) = cut_quarters a in
837 let (b11, b12, b21, b22) = cut_quarters b in
838 let s = power 2 (k-1) in
839 assert { s > 0 by k-1 >= 1 so power 2 (k-1) >= power 2 1 = 2};
840 assert { size a11 s s /\ size a12 s s /\ size a21 s s /\ size a22 s s };
841 assert { size b11 s s /\ size b12 s s /\ size b21 s s /\ size b22 s s };
842 let ghost c11 = add (mul a11 b11) (mul a12 b21) in
843 let ghost c12 = add (mul a11 b12) (mul a12 b22) in
844 let ghost c21 = add (mul a21 b11) (mul a22 b21) in
845 let ghost c22 = add (mul a21 b12) (mul a22 b22) in
846 mul_sizes a11 b11 s s s;
847 assert { size c11 s s /\ size c12 s s /\ size c21 s s /\ size c22 s s };
848 let ghost c = paste_quarters c11 c12 c21 c22 in
849 assert { c = mul a b };
850 let m1 = strassen_pow2 (add a11 a22) (add b11 b22) (k-1) in
851 let m2 = strassen_pow2 (add a21 a22) b11 (k-1) in
852 let m3 = strassen_pow2 a11 (sub b12 b22) (k-1) in
853 let m4 = strassen_pow2 a22 (sub b21 b11) (k-1) in
854 let m5 = strassen_pow2 (add a11 a12) b22 (k-1) in
855 let m6 = strassen_pow2 (sub a21 a11) (add b11 b12) (k-1) in
856 let m7 = strassen_pow2 (sub a12 a22) (add b21 b22) (k-1) in
857 let s11 = add m1 (add m4 (sub m7 m5)) in
858 let s12 = add m3 m5 in
859 let s21 = add m2 m4 in
860 let s22 = add m1 (add m3 (sub m6 m2)) in
861 (* assertions proved by reflection *)
862 assert { s11 = c11 };
863 assert { s12 = c12 };
864 assert { s21 = c21 };
865 assert { s22 = c22 };
866 paste_quarters s11 s12 s21 s22