Merge branch 'why3tools-register-main' into 'master'
[why3.git] / examples / multiprecision / sqrtrem.mlw
blob1c2afb186f80a5e175b6130eaed8af25dbd93d61
1 module Sqrt
3   use array.Array
4   use map.Map
5   use mach.c.C
6   use ref.Ref
7   use mach.int.Int32
8   use import mach.int.UInt64GMP as Limb
9   use int.EuclideanDivision
10   use int.Int
11   use int.Power
12   use types.Types
13   use types.Int32Eq
14   use types.UInt64Eq
15   use lemmas.Lemmas
16   use compare.Compare
17   use util.UtilOld
18   use add_1.Add_1
19   use add.AddOld
20   use sub_1.Sub_1
21   use sub.SubOld
22   use mul.Mul
23   use logical.LogicalUtil
24   use logical.Logical
25   use div.Div
26   use sqrt.Sqrt1
27   use ptralias.Alias
29   use real.ExpLog (* have to use real to be able to remove log2/log10... *)
30   meta remove_prop lemma log10_increasing
31   meta remove_prop lemma log2_increasing
32   meta remove_logic function log10 (* kills CVC3 *)
33   meta remove_logic function log2  (* same *)
34   meta remove_prop axiom log_increasing
35   meta remove_prop lemma exp_sum_opposite
36   meta remove_prop axiom exp_increasing
37   meta remove_prop axiom exp_positive
38   meta remove_prop axiom exp_inv
40   let lemma same_mod (a b:int)
41     requires { 0 <= a }
42     requires { 0 < b }
43     ensures  { ComputerDivision.mod a b = EuclideanDivision.mod a b }
44   = ()
46   meta remove_prop axiom same_mod
48   let wmpn_sqrtrem2 (sp rp np: ptr limb) : limb
49     requires { valid rp 1 }
50     requires { valid sp 1 }
51     requires { valid np 2 }
52     requires { (pelts np)[offset np + 1] >= power 2 (Limb.length - 2) }
53     requires { writable sp /\ writable rp }
54     ensures  { value np 2
55                = (pelts sp)[offset sp] * (pelts sp)[offset sp]
56                   + result * radix + (pelts rp)[offset rp] }
57     ensures  { (pelts rp)[offset rp] + result * radix <= 2 * (pelts sp)[offset sp] }
58     ensures  { 0 <= result <= 1 }
59   =
60     let np0 = C.get np in
61     let ghost np1 = C.get_ofs np 1 in
62     let ref sp0 = sqrt1 rp (C.get_ofs np 1) in
63     let ref rp0 = C.get rp in
64     let ghost orp = pure { rp0 } in
65     let ghost osp = pure { sp0 } in
66     let prec = (Limb.of_int Limb.length) / 2 in (* prec = 32 *)
67     assert { power 2 prec * power 2 prec = radix };
68     assert { sp0 * sp0 + rp0 = np1 };
69     assert { sp0 >= power 2 (prec - 1)
70              by np1 >= power 2 (Limb.length - 2)
71              so ((sp0 + 1) * (sp0 + 1) > np1
72                 by (sp0 + 1) * (sp0 + 1) > sp0 * sp0 + 2 * sp0 >= np1)
73              so (power 2 (prec - 1)) * (power 2 (prec - 1))
74                  = power 2 (Limb.length - 2) };
75     assert { sp0 < power 2 prec
76              by sp0 * sp0 <= np1 < radix = power 2 (Limb.length)
77              so (power 2 prec) * (power 2 prec) = power 2 (Limb.length) };
78     let nph = lsr_mod np0 (prec + 1) in
79     assert { nph < power 2 (prec - 1)
80              by nph = div np0 (power 2 (prec + 1))
81              so nph * power 2 (prec + 1) <= np0 < radix
82              so power 2 (prec - 1) * power 2 (prec + 1) = radix };
83     assert { power 2 (prec - 1) * rp0 + nph < radix
84              by rp0 < power 2 (prec + 1)
85              so rp0 <= power 2 (prec + 1) - 1
86              so power 2 (prec - 1) * rp0
87                 <= power 2 (prec + prec) - power 2 (prec - 1)
88                 = radix - power 2 (prec - 1)
89              so nph < power 2 (prec - 1) };
90     rp0 <- lsl rp0 (prec - 1) + nph;
91     label Div in
92     let ref q = Limb.(/) rp0 sp0 in
93     assert { q <= power 2 prec
94              by rp0 = power 2 (prec - 1) * orp + nph
95              so orp <= 2 * sp0
96              so nph < power 2 (prec - 1) <= sp0
97              so rp0 < power 2 prec * sp0 + sp0
98                      = (power 2 prec + 1) * sp0
99              so 0 <= mod rp0 sp0
100              so rp0 = sp0 * q + mod rp0 sp0
101              so q * sp0 <= rp0
102              so q * sp0 < (power 2 prec + 1) * sp0
103              so q < power 2 prec + 1 };
104     assert { q = div rp0 osp };
105     begin
106       ensures { if old q = power 2 prec
107                 then q = power 2 prec - 1
108                 else q = old q }
109       ensures { q < power 2 prec }
110       let rq = lsr_mod q prec in
111       assert { q = power 2 prec -> rq = div q q = 1 by mod q q = 0 };
112       q <- q - rq
113     end;
114     assert { q * sp0 < radix by q < power 2 prec so sp0 < power 2 prec
115              so q * sp0 < power 2 prec * power 2 prec = radix };
116     assert { rp0 - q * sp0 >= 0
117              by q <= div rp0 sp0
118              so rp0 = div rp0 sp0 * sp0 + mod rp0 sp0
119              so 0 <= mod rp0 sp0
120              so rp0 >= div rp0 sp0 * sp0
121              so q * sp0 <= div rp0 sp0 * sp0 };
122     let u = rp0 - (q * sp0) in
123     assert { sp0 * power 2 prec < radix - q
124              by sp0 <= power 2 prec - 1
125              so sp0 * power 2 prec <= power 2 prec * power 2 prec - power 2 prec
126                 = radix - power 2 prec < radix - q };
127     assert { q <> power 2 prec - 1 -> u <= osp - 1
128              by div rp0 osp = q
129              so rp0 = osp * div rp0 osp + mod rp0 osp
130              so u = mod rp0 osp < osp };
131     assert { q = power 2 prec - 1 -> u <= osp + nph
132              by rp0 = power 2 (prec - 1) * orp + nph
133              so orp <= 2 * osp
134              so rp0 <= power 2 prec * osp + nph
135              so q = power 2 prec - 1
136              so u = rp0 - (power 2 prec - 1) * osp };
137     sp0 <- lsl sp0 prec + q;
138     assert { sp0 = osp * power 2 prec + q };
139     let uh = lsr_mod u (prec - 1) in
140     assert { uh <= power 2 (prec + 1)
141              by uh = div u (power 2 (prec - 1))
142              so uh * power 2 (prec - 1) = u - mod u (power 2 (prec - 1))
143                 <= u < radix
144              so power 2 (prec + 1) * power 2 (prec - 1)
145                 = power 2 (prec + 1 + prec - 1) = radix };
146     let ref cc = to_int64 uh in
147     let npl = np0 % (lsl 1 (prec + 1)) in
148     assert { np0 = power 2 (prec + 1) * nph + npl
149              by np0 = (power 2 (prec + 1)) * div np0 (power 2 (prec + 1))
150                       + mod np0 (power 2 (prec + 1))
151              so npl = mod np0 (power 2 (prec + 1)) };
152     let ul = lsl_mod_ext u (prec + 1) in
153     rp0 <- ul + npl;
154     assert { q * q < radix by q < power 2 prec
155              so q * q < power 2 prec * power 2 prec = radix };
156     let q2 = q * q in
157     assert { ul + radix * uh = power 2 (prec + 1) * u
158              by
159                let p = u * power 2 (prec + 1) in
160                let m = mod u (power 2 (prec - 1)) in
161                mod p radix = ul
162                so m < power 2 (prec - 1)
163                so power 2 (prec + 1) * m
164                   < power 2 (prec + 1) * power 2 (prec - 1) = radix
165                so u = power 2 (prec - 1) * uh + m
166                so p = power 2 (prec + 1) * power 2 (prec - 1) * uh
167                       + power 2 (prec + 1) * m
168                     = uh * radix + power 2 (prec + 1) * m
169                     < uh * radix + radix
170                so uh * radix <= p
171                so div p radix = uh
172                so p = uh * radix + ul };
173     assert { rp0 + radix * cc = npl + power 2 (prec + 1) * u
174              by rp0 + radix * cc = npl + ul + radix * uh };
175     begin ensures { rp0 + radix * cc = old (rp0 + radix * cc) - q2 }
176       label S in
177       if rp0 < q2 then cc <- Int64.(-) cc 1;
178       rp0 <- sub_mod rp0 q2;
179     end;
180     assert { sp0 * sp0 + rp0 + radix * cc = np0 + radix * np1
181              by rp0 + radix * cc = (power 2 (prec + 1)) * u + npl - q * q
182              so sp0 * sp0 = ((power 2 prec) * osp + q)
183                             * ((power 2 prec) * osp + q)
184                           = power 2 prec * power 2 prec * osp * osp
185                             + q * q
186                             + 2 * power 2 prec * osp * q
187                           = radix * osp * osp + q * q
188                             + 2 * power 2 prec * osp * q
189              so osp * q = rp0 at Div - u
190              so osp * osp = np1 - orp
191              so rp0 at Div = power 2 (prec - 1) * orp + nph };
192     assert { rp0 + radix * cc <= 2 * sp0
193              by rp0 + radix * cc = power 2 (prec + 1) * u + npl - q * q
194              so 2 * sp0 = 2 * (power 2 prec * osp + q)
195                         >= power 2 (prec + 1) * osp
196              so npl < power 2 (prec + 1)
197              so if q = power 2 prec - 1
198                 then
199                   u <= osp + nph
200                   so power 2 (prec + 1) * u
201                      <= power 2 (prec + 1) * (osp + nph)
202                   so rp0 + radix * cc
203                      <= power 2 (prec + 1) * osp
204                         + power 2 (prec + 1) * nph + npl - q * q
205                       = power 2 (prec + 1) * osp + np0 - q * q
206                   so 2 * sp0 = power 2 (prec + 1) * osp + 2 * q
207                   so q * q = (power 2 prec - 1) * (power 2 prec - 1)
208                            = radix - power 2 (prec + 1) + 1
209                   so rp0 + radix * cc - 2 * sp0
210                      <= np0 - q * q - 2 * q
211                      <= radix - 1 - q * q - 2 * q
212                      = power 2 (prec + 1) - 2 - 2 * q
213                      = 0
214                 else
215                   rp0 + radix * cc <= power 2 (prec + 1) * (osp - 1) + npl
216                                    <= power 2 (prec + 1) * osp };
217     label Adjust in
218     let ghost sp0a = pure { sp0 } in
219     if Int64.(<) cc 0 (* cc = -1 *)
220     then begin
221       assert { cc = -1 };
222       assert { sp0 + sp0 > radix
223                by sp0 * sp0 + rp0 - radix = np0 + radix * np1
224                so np1 >= power 2 (Limb.length - 2)
225                so rp0 < radix
226                so sp0 * sp0 > np0 + radix * np1 >= radix * np1
227                   >= power 2 (Limb.length) * power 2 (Limb.length - 2)
228                   = power 2 (Limb.length + Limb.length - 2)
229                   = power 2 (Limb.length - 1) * power 2 (Limb.length - 1)
230                so sp0 > power 2 (Limb.length - 1) };
231       begin ensures { rp0 + radix * cc = old (rp0 + radix * cc) + sp0 }
232             ensures { cc >= 0 \/ rp0 = old rp0 + sp0 }
233         rp0 <- add_mod rp0 sp0;
234         if rp0 < sp0 then cc <- Int64.(+) cc 1;
235       end;
236       sp0 <- sp0 - 1;
237       begin ensures { rp0 + radix * cc = old (rp0 + radix * cc) + sp0 }
238             ensures { cc >= 0 }
239         label A2 in
240         rp0 <- add_mod rp0 sp0;
241         if rp0 < sp0 then cc <- Int64.(+) cc 1
242       end;
243       assert { sp0 * sp0 + rp0 + radix * cc
244                = (sp0 * sp0 + rp0 + radix * cc) at Adjust
245                by sp0 = sp0a - 1
246                so sp0 * sp0 = sp0a * sp0a - sp0a - sp0a + 1
247                so rp0 + radix * cc = rp0 + radix * cc at Adjust + sp0 + sp0a };
248     end;
249     C.set rp rp0;
250     C.set sp sp0;
251     assert { value np 2 = np0 + radix * np1 };
252     of_int64 cc
254   use toom.Toom
256   let rec wmpn_dc_sqrtrem (sp np: ptr limb) (n:int32) (scratch: ptr limb) : limb
257     requires { valid np (n+n) }
258     requires { valid sp n }
259     requires { 1 <= n }
260     requires { valid scratch (1 + div n 2) }
261     requires { (pelts np)[offset np + n + n - 1] >= power 2 (Limb.length - 2) }
262     requires { writable sp /\ writable scratch /\ writable np }
263     requires { 4 * n < max_int32 }
264 (*    writes   { np, sp, scratch }*)
265     ensures  { (value sp n) * (value sp n)
266                + value np n + (power radix n) * result
267                = old value np (n+n) }
268     ensures  { value np n + power radix n * result <= 2 * value sp n }
269     ensures  { (pelts sp)[offset sp + n-1] >= power 2 (Limb.length - 1) }
270     ensures  { 0 <= result <= 1 }
271     ensures  { max np = old max np }
272     ensures  { min np = old min np }
273     ensures  { plength np = old plength np }
274     ensures  { max scratch = old max scratch }
275     ensures  { min scratch = old min scratch }
276     ensures  { plength scratch = old plength scratch }
277     ensures  { max sp = old max sp }
278     ensures  { min sp = old min sp }
279     ensures  { plength sp = old plength sp }
280     variant  { n }
281   = label Start in
282     if n = 1
283     then
284       let r = wmpn_sqrtrem2 sp scratch np in
285       C.set np (C.get scratch);
286       r
287     else begin
288       let l = n / 2 in
289       assert { 1 <= l };
290       let h = n - l in
291       let ghost vn = value np (int32'int n + int32'int n) in
292       value_concat np (l+l) (n+n);
293       let np' = C.incr_split np (l+l) in
294       value_concat np l (l+l);
295       let ghost n0 = value np (int32'int l) in
296       let ghost n1 = value_sub (pelts np) (offset np + int32'int l)
297                                (offset np + int32'int l + int32'int l) in
298       let ghost n'' = pure { n0 + power radix l * n1 } in
299       let ghost n' = pure { value np' (h+h) } in
300       assert { value np (l+l) = n''};
301       assert { vn = n'' + power radix (l+l) * n' };
302       begin ensures { power radix (n+n) <= 4 * vn }
303         value_tail np (n+n-1);
304         assert { 4 * vn >= power radix (n+n)
305                  by vn = value np (n+n-1)
306                          + power radix (n+n-1) * (pelts np)[offset np + (n+n-1)]
307                        >= power radix (n+n-1) * (pelts np)[offset np + (n+n-1)]
308                        >= power radix (n+n-1) * power 2 (Limb.length - 2)
309                  so 4 * power 2 (Limb.length - 2) = radix
310                  so 4 * vn >= power radix (n+n-1) * radix = power radix (n+n) };
311       end;
312       let spl = C.incr_split sp l in
313       label Rec in
314       let ref q = wmpn_dc_sqrtrem spl np' h scratch in
315       assert { n' = value spl h * value spl h
316                     + value np' h + power radix h * q };
317       begin ensures { power radix l <= 2 * value spl h }
318         assert { power radix (h+h) * power radix (l+l)
319                  = power radix (n+n)
320                  <= 4 * (n'' + power radix (l+l) * n')
321                  < 4 * (n' + 1) * power radix (l+l)
322                  by n'' < power radix (l+l) };
323         assert { power radix (h+h) <= 4 * n' };
324         assert { (value spl h + 1) * (value spl h + 1) > n' };
325         let ghost ts = pure { 2 * (value spl h + 1) } in
326         assert { power radix h < ts
327                  by power radix h * power radix h <= 4 * n' < ts * ts
328                  so 0 < power radix h * power radix h < ts * ts
329                  so 0 < power radix h so 0 < ts
330                  so 0 < (ts - power radix h) * (ts + power radix h)
331                  so 0 < ts + power radix h
332                  so 0 < ts - power radix h };
333         assert { power radix l <= 2 * value spl h
334                  by power radix l <= power radix h
335                  so power radix h < ts
336                  so power radix l < ts};
337       end;
338       let ghost r' = value np' (int32'int h) + power radix (int32'int h) * (l2i q) in
339       assert { r' <= 2 * value spl h };
340       label Sub in
341       begin ensures {    (q = 1 /\ value np' h = r' - value spl h)
342                       \/ (q = 0 /\ value np' h = r') }
343       if (q <> 0) then begin
344         assert { q = 1 };
345         assert { value np' h = r' - power radix h };
346         assert { value np' h < value spl h
347                  by value np' h + power radix h = value np' h + power radix h * q
348                     <= 2 * value spl h
349                  so value np' h < power radix h
350                  so value np' h + value np' h < 2 * value spl h };
351         let ghost b = wmpn_sub_n_in_place np' spl h in
352         assert { b = 1 };
353         assert { value np' h = r' - value spl h
354                  by value np' h - power radix h = value np' h at Sub - value spl h
355                     = r' - power radix h - value spl h };
356         end
357       end;
358       label Join1 in
359       let ghost onp = { np } in
360       let ghost onp' = { np' } in
361       join np np';
362       value_sub_frame (pelts np) (pelts onp') (offset np + p2i l + p2i l)
363                                       (offset np + p2i l + p2i l + p2i h);
364       value_sub_frame (pelts np) (pelts onp) (offset np) (offset np + p2i l);
365       value_sub_frame (pelts np) (pelts onp) (offset np + p2i l) (offset np + p2i l + p2i l);
366       assert { value_sub (pelts np) (offset np + l + l) (offset np + l + l + h)
367                = value onp' h };
368       let npl = C.incr_split np l in
369       assert { value_sub (pelts npl) (offset npl + l) (offset npl + l + h)
370                = value np' h at Join1 };
371       value_concat npl l n;
372       assert { value npl n = n1 + power radix l * value onp' h };
373       label DivS in
374       wmpn_tdiv_qr_in_place scratch 0 npl n spl h;
375       assert { n1 + power radix l * (r' - q * value spl h)
376                = value scratch (l+1) * value spl h + value npl h };
377       value_tail scratch l;
378       value_tail spl h;
379       assert { 0 <= (pelts scratch)[offset scratch + l] <= 1
380                by value (npl at DivS) n < power radix n
381                so (pelts spl)[offset spl + h-1] >= power 2 (Limb.length - 1)
382                so 2 * (pelts spl)[offset spl + h - 1] >= radix
383                so value spl h
384                   >= power radix (h-1) * (pelts spl)[offset spl + h-1]
385                so 2 * value spl h
386                   >= power radix (h-1) * 2 * (pelts spl)[offset spl + h - 1]
387                   >= power radix (h-1) * radix = power radix h
388                so 0 <= value npl h
389                so value spl h * value scratch (l+1) < power radix n
390                   = power radix h * power radix l
391                   <= value spl h * 2 * power radix l
392                so value scratch (l+1) < 2 * power radix l
393                so value scratch (l+1)
394                   = value scratch l + power radix l
395                     * (pelts scratch)[offset scratch + l]
396                   >= power radix l * (pelts scratch)[offset scratch + l]
397                so (pelts scratch)[offset scratch + l] < 2 };
398       let sl = get_ofs scratch l in
399       value_concat scratch l (l+1);
400       assert { value scratch (l+1) = value scratch l + power radix l * sl };
401       q <- q + sl;
402       assert { 0 <= q <= 2 };
403       assert { n1 + power radix l * r'
404                = (value scratch l + power radix l * q) * value spl h
405                  + value npl h };
406       let sh = C.get scratch in
407       let ref c = to_int64 (sh % 2) in
408       value_concat scratch 1 l;
409       assert { c = mod (value scratch l) 2
410                by let st = value_sub (pelts scratch) (offset scratch + 1) (offset scratch + l) in
411                   value scratch l = sh + radix * st
412                so sh = value scratch 1
413                so let q = div sh 2 in
414                   c = mod sh 2
415                so sh = c + 2 * q
416                so value scratch l = c + 2 * q + radix * st
417                   = c + 2 * (q + power 2 (Limb.length - 1) * st) };
418       let ghost r = wmpn_rshift_sep sp scratch l 1 in
419       label Div2 in
420       assert { 2 * value sp l + c = value scratch l
421                by r + radix * value sp l = value scratch l * power 2 (Limb.length - 1)
422                so let p = power 2 (Limb.length - 1) in
423                   2 * p = radix
424                so p * (2 * value sp l) + r = p * value scratch l };
425       let st = C.get_ofs sp (l-1) in
426       value_tail sp (l-1);
427       assert { value sp l = value sp (l-1) + power radix (l-1) * st };
428       assert { st + power 2 (Limb.length - 1) < radix
429                by 2 * value sp l <= value scratch l < power radix l
430                so value sp l  >= power radix (l-1) * st
431                so (2 * st) * power radix (l-1) < power radix l
432                      = radix * power radix (l-1)
433                so 2 * st < radix
434                so st < power 2 (Limb.length - 1) };
435       let ql = lsl_mod_ext q (Limb.of_int Limb.length - 1) in
436       let qh = lsr_mod q 1 in
437       assert { 0 <= qh <= 1
438                by 0 <= q <= 2
439                so qh = div q 2 };
440       assert { ql + radix * qh = power 2 (Limb.length - 1) * q
441                by qh = div q 2
442                so q = 2 * qh + mod q 2
443                so power 2 (Limb.length - 1) * q
444                   = radix * qh + power 2 (Limb.length - 1) * mod q 2
445                so (0 <= power 2 (Limb.length - 1) * mod q 2 < radix
446                    by mod q 2 = 0 \/ mod q 2 = 1)
447                so ql = mod (q * power 2 (Limb.length - 1)) radix
448                      = mod (radix * qh + power 2 (Limb.length - 1) * mod q 2) radix
449                      = mod (power 2 (Limb.length - 1) * mod q 2) radix
450                      = power 2 (Limb.length - 1) * mod q 2
451                };
452       value_sub_update_no_change (pelts sp) (sp.offset + p2i l - 1)
453                                  (sp.offset) (sp.offset + p2i l - 1) (st + ql);
454       C.set_ofs sp (l-1) (st + ql);
455       value_tail sp (l-1);
456       assert { value sp l = value sp l at Div2 + power radix (l-1) * ql
457                by value sp l = value sp (l-1) + power radix (l-1) * (st + ql)
458                so value sp (l-1) = value sp (l-1) at Div2 };
459       (* TODO if (UNLIKELY ((sp[0] & approx) != 0)) /* (sp[0] & mask) > 1 */
460         return 1; /* Remainder is non-zero */ *)
461       q <- qh;
462       assert { 2 * (value sp l + power radix l * q) + c
463                  = value scratch l + power radix l * (q at Div2) };
464       assert { n1 + power radix l * r'
465                = 2 * value spl h * (value sp l + power radix l * q)
466                  + value spl h * c
467                  + value npl h
468                by 2 * value spl h * (value sp l + power radix l * q)
469                   + value spl h * c
470                   = value spl h * (2 * (value sp l + power radix l * q) + c)
471                   = value spl h
472                     * (value scratch l + power radix l * (q at Div2)) };
473       assert { value npl h < value spl h };
474       begin
475         ensures { n1 + power radix l * r'
476                   = 2 * value spl h * (value sp l + power radix l * q)
477                     + value npl h + power radix h * c }
478         ensures { 0 <= c <= 1 }
479         ensures { 0 <= value npl h + power radix h * c < 2 * value spl h }
480         if not (Int64.(=) c 0)
481         then begin
482           assert { c = 1 };
483           assert { n1 + power radix l * r'
484                = 2 * value spl h * (value sp l + power radix l * q)
485                  + value spl h + value npl h };
486           let c' = wmpn_add_n_in_place npl spl h in
487           c <- to_int64 c';
488           end
489       end;
490       let ghost dq = pure { value sp l + power radix l * q } in
491       let ghost s' = pure { value spl h } in
492       let ghost r'' = pure { value npl h + power radix h * c } in
493       assert { n1 + power radix l * r'
494                   = (2 * s') * dq + r''};
495       assert { r'' < 2 * s' };
496       assert { 0 <= dq <= power radix l
497                by n1 < power radix l <= 2 * s'
498                so r' <= 2 * s'
499                so n1 + power radix l * r'
500                   < 2 * s' + power radix l * r'
501                   <= 2 * s' + power radix l * (2 * s')
502                   = 2 * s' * (1 + power radix l)
503                so 0 <= r''
504                so (2 * s') * dq <= (2 * s') * dq + r'' = n1 + power radix l * r'
505                so (2 * s') * dq < (2 * s') * (1 + power radix l)
506                so dq < 1 + power radix l
507                so 0 <= value sp l
508                so 0 <= q };
509       let ghost onp = pure { np } in
510       let ghost onpl = pure { npl } in
511       join np npl;
512       value_sub_frame (pelts np) (pelts onpl)
513                       (offset np + p2i l) (offset np + p2i n);
514       value_sub_frame (pelts np) (pelts onp) (offset np) (offset np + p2i l);
515       assert { value_sub (pelts np) (offset np + l) (offset np + n)
516                = value onpl h by offset npl + h = offset np + n};
517       assert { value np l = value onp l = n0 };
518       value_concat np l n;
519       assert { value np n + power radix n * c = n0 + power radix l * r''
520                by value np n = n0 + power radix l * value onpl h
521                so power radix n = power radix (l+h)
522                   = power radix l * power radix h
523                so power radix n * c = power radix l * (power radix h * c)
524                so value np n + power radix n * c
525                   = n0 + power radix l * value onpl h
526                     + power radix l * (power radix h * c) };
527       let npn = C.incr_split np n in
528       let ghost _ = wmpn_mul npn sp l sp l 64 in
529       let ll = 2 * l in
530       assert { value npn ll + power radix ll * q = dq * dq
531                by 0 <= q <= 1 so q * q = q
532                so dq <= power radix l
533                so power radix ll = power radix l * power radix l
534                so value sp l + power radix l * q <= power radix l
535                so (value sp l = 0 \/ q = 0
536                    by 0 <= value sp l
537                    so if q = 1
538                       then value sp l = 0
539                       else q = 0)
540                so value sp l * q = 0
541                so dq * dq = value sp l * value sp l
542                   + (power radix l * q) * (power radix l * q)
543                   = value sp l * value sp l + power radix ll * q };
544       label Sub2 in
545       let ghost onp = pure { np } in
546       value_concat np ll n;
547       let bo = wmpn_sub_n_in_place np npn ll in
548       value_concat np ll n;
549       value_sub_frame (pelts np) (pelts onp) (offset np + int32'int ll)
550                                              (offset np + int32'int n);
551       let b = q + bo in
552       assert { value np ll - power radix ll * b
553                = value np ll at Sub2 - dq * dq };
554       assert { value np n - power radix ll * b
555                = value np n at Sub2 - dq * dq
556                by value np n = value np ll + power radix ll
557                     * value_sub (pelts np) (offset np + ll) (offset np + n)
558                so value_sub (pelts np) (offset np + ll) (offset np + n)
559                   = value_sub (pelts onp) (offset np + ll) (offset np + n) };
560       begin ensures { value np n + power radix n * c
561                       = n0 + power radix l * r'' - dq * dq }
562             ensures { - 1 <= c <= 1 }
563         if l = h
564         then begin
565           assert { n = ll };
566           assert { value np n - power radix n * b
567                    = value np n at Sub2 - dq * dq };
568           c <- Int64.(-) c (to_int64 b);
569           assert { value np n + power radix n * c
570                    = value np n - power radix n * b
571                      + power radix n * (c at Sub2)
572                    = (value np n + power radix n * c at Sub2) - dq * dq
573                    = n0 + power radix l * r'' - dq * dq };
574           assert { -1 <= c
575                    by dq * dq <= power radix l * power radix l = power radix n
576                    so 0 <= n0 so 0 <= r''
577                    so 0 <= n0 + power radix l * r''
578                    so - power radix n <= n0 + power radix l * r'' - dq * dq
579                    so value np n < power radix n
580                    so - power radix n <= value np n + power radix n * c
581                       < power radix n + power radix n * c
582                    so - 2 * power radix n < power radix n * c
583                    so -2 < c };
584         end
585         else begin
586           assert { h = l + 1
587                    by n = 2 * l + ComputerDivision.mod n 2
588                    so h = l + ComputerDivision.mod n 2
589                    so h <= l + 1 };
590           assert { n = ll + 1 };
591           let nll = C.incr np ll in
592           label Borrow in
593           let ghost onp = pure { np } in
594           value_concat np ll n;
595           let bo = wmpn_sub_1_in_place nll 1 b in
596           value_sub_frame (pelts np) (pelts onp) (offset np)
597                           (offset np + int32'int ll);
598           value_concat np ll n;
599           assert { value nll 1 = value_sub (pelts np) (offset np + ll) (offset np + n) };
600           assert { value np n - power radix n * bo
601                    = value np n at Borrow - power radix ll * b
602                    by value np ll = value np ll at Borrow
603                    so value nll 1 - radix * bo = value nll 1 at Borrow - b
604                    so value np n - power radix n * bo
605                       = value np ll + power radix ll * value nll 1 - power radix n * bo
606                       = value np ll + power radix ll * (value nll 1 - radix * bo)
607                       = value np ll + power radix ll * (value nll 1 at Borrow - b)};
608           c <- Int64.(-) c (to_int64 bo);
609           assert { value np n + power radix n * c
610                    = value np n - power radix n * bo
611                      + power radix n * (c at Sub2)
612                    = value np n at Borrow - power radix ll * b
613                      + power radix n * c at Sub2
614                    = (value np n + power radix n * c at Sub2) - dq * dq
615                    = n0 + power radix l * r'' - dq * dq };
616         end
617       end;
618       let ghost vs = pure { dq + power radix l * s' } in
619       let ghost vr = pure { power radix l * r'' + n0 - dq * dq } in
620       assert { vn = vs * vs + vr
621                by vn = n' * power radix (l+l) + n1 * power radix l + n0
622                so n' = s' * s' + r'
623                so power radix (l+l) = power radix l * power radix l
624                so vn = s' * s' * power radix (l+l) + r' * power radix (l+l)
625                      + n1 * power radix l + n0
626                      = (s' * power radix l) * (s' * power radix l)
627                        + r' * power radix (l+l) + n1 * power radix l + n0
628                      = (s' * power radix l) * (s' * power radix l)
629                        + power radix l * (r' * power radix l + n1) + n0
630                so r' * power radix l + n1 = 2 * s' * dq + r''
631                so vn = (s' * power radix l) * (s' * power radix l)
632                        + power radix l * (2 * s' * dq + r'') + n0
633                      = (s' * power radix l) * (s' * power radix l)
634                        + 2 * (s' * power radix l * dq)
635                        + dq * dq + power radix l * r'' + n0 - dq * dq
636                so (s' * power radix l) * (s' * power radix l)
637                        + 2 * (s' * power radix l * dq)
638                        + dq * dq
639                    = vs * vs
640                so vn = vs * vs + power radix l * r'' + n0 - dq * dq };
641       assert { vr <= 2 * vs
642                by n0 < power radix l
643                so r'' <= 2 * s' - 1
644                so r'' * power radix l + n0
645                     < (2 * s' - 1) * power radix l + power radix l
646                     = 2 * s' * power radix l
647                so dq * dq >= 0
648                so vr <= r'' * power radix l + n0 };
649       label Adjust in
650       assert { dq = value sp l + power radix l * q };
651       assert { vr = value np n + power radix n * c };
652       assert { value spl h = s' };
653       assert { power radix n = power radix l * power radix h };
654       assert { (vs - 1) * (vs - 1) <= vn
655                by vn = (vs - 1) * (vs - 1) + vr + 2 * vs - 1
656                so dq - 1 <= power radix l <= 2 * s'
657                so (dq - 1) * (dq - 1) <= (2 * s') * (power radix l)
658                so 0 <= n0
659                so power radix l * r'' >= 0
660                so n0 >= 0
661                so vr + 2 * vs - 1
662                   = power radix l * r'' + n0 - dq * dq + 2 * dq
663                     + 2 * s' * power radix l - 1
664                   = power radix l * r'' + n0
665                     + 2 * s' * power radix l - (dq - 1) * (dq - 1)
666                   >= power radix l * r'' + n0
667                   >= 0 };
668       assert { vs <= power radix n
669                by (vs - 1) * (vs - 1) <= vn
670                   < power radix (n+n) = power radix n *  power radix n
671                so if vs - 1 < power radix n
672                   then true
673                   else false
674                        by power radix n * power radix n <= (vs - 1) * (vs - 1) };
675       if (Int64.(<) c 0)
676       then begin
677         assert { vr < 0
678                  by value np n < power radix n
679                  so power radix n * c <= - power radix n };
680         q <- wmpn_add_1_in_place spl h q;
681         assert { q = 0 \/ q = 1 };
682         assert { value sp l + power radix l * value spl h + power radix n * q
683                    = vs
684                  by value spl h + power radix h * q
685                     = s' + q at Adjust
686                  so value sp l + power radix l * value spl h + power radix n * q
687                     = value sp l
688                       + power radix l * (value spl h + power radix h * q)
689                     = value sp l + power radix l * s'
690                       + power radix l * (q at Adjust)
691                     = dq + power radix l * s' = vs };
692         let ghost osp = pure { sp } in
693         let ghost ospl = pure { spl } in
694         join sp spl;
695         value_sub_frame (pelts sp) (pelts osp) (offset sp)
696                                 (offset sp + int32'int l);
697         value_sub_frame (pelts sp) (pelts ospl) (offset sp + int32'int l)
698                                                 (offset sp + int32'int n);
699         value_concat sp l n;
700         assert { value sp n = value osp l + power radix l * value ospl h
701                  by value ospl h = value_sub (pelts ospl) (offset sp + l)
702                     (offset sp + n)};
703         assert { value sp n + power radix n * q = vs };
704         assert { q = 0 \/ value sp n = 0
705                  by 0 <= value sp n
706                  so if q = 1
707                     then value sp n = 0
708                     else q = 0 };
709         let c' = wmpn_addmul_1 np sp n 2 in
710         assert { c' = 0 \/ (q = 0 /\ c' <= 2)
711                  by if q = 1
712                     then value sp n = 0
713                          so value np n + power radix n * c'
714                             = value np n at Adjust < power radix n
715                          so power radix n * c' < power radix n * 1
716                          so c' = 0
717                     else value np n + power radix n * c'
718                            = value np n at Adjust + 2 * vs
719                          so value np n at Adjust < power radix n
720                          so vs <= power radix n
721                          so value np n + power radix n * c' < 3 * power radix n
722                          so 0 <= value np n
723                          so c' < 3 };
724         c <- Int64.(+) c (to_int64 (2 * q + c'));
725         assert { value np n + power radix n * c
726                  = value np n at Adjust + 2 * value sp n
727                    + power radix n * (2 * q)
728                    + power radix n * (c at Adjust)
729                  = vr + power radix n * (2 * q)
730                       + 2 * value sp n
731                  = vr + 2 * vs };
732         c <- Int64.(-) c (to_int64 (wmpn_sub_1_in_place np n 1));
733         assert { value np n + power radix n * c = vr + 2 * vs - 1 };
734         assert { 0 <= c
735                  by 0 <= vr + 2 * vs - 1
736                  so 0 <= value np n + power radix n * c
737                  so value np n < power radix n
738                  so -1 < c };
739         label AdjS in
740         let bo = wmpn_sub_1_in_place sp n 1 in
741         assert { bo = 1 -> q = 1
742                  by value sp n - power radix n * bo
743                     = value sp n at AdjS - 1
744                     so value sp n < power radix n
745                     so value sp n - power radix n * bo < 0
746                     so value sp n at AdjS = 0
747                     so vs = power radix n * q
748                     so 0 < vs };
749         assert { q = 1 -> bo = 1
750                  by value sp n at AdjS + power radix n = vs <= power radix n
751                  so value sp n at AdjS = 0
752                  so value sp n - power radix n * bo = - 1
753                  so 0 <= value sp n };
754         q <- q - bo;
755         assert { q = 0 };
756         assert { value sp n = vs - 1 };
757         assert { (value sp n) * (value sp n) + value np n + power radix n * c
758                    = vn
759                  by (value sp n) * (value sp n) = (vs - 1) * (vs - 1)
760                     = vs * vs - 2 * vs + 1
761                  so value np n + power radix n * c = vr + 2 * vs - 1 };
762         assert { value np n + power radix n * c <= 2 * value sp n
763                  by value np n + power radix n * c = vr + 2 * vs - 1
764                     <= 2 * vs - 1 };
765       end
766       else begin
767         assert { 0 <= vr
768                  by 0 <= value np n
769                  so 0 <= power radix n * c };
770         let ghost osp = pure { sp } in
771         let ghost ospl = pure { spl } in
772         join sp spl;
773         value_sub_frame (pelts sp) (pelts osp) (offset sp)
774                                  (offset sp + int32'int l);
775         value_sub_frame (pelts sp) (pelts ospl) (offset sp + int32'int l)
776                                                 (offset sp + int32'int n);
777         value_concat sp l n;
778         assert { value sp n = value osp l + power radix l * s'
779                  by s' = value ospl h
780                     = value_sub (pelts ospl) (offset sp + l) (offset sp + n) };
781         assert { vs = value sp n + power radix l * q };
782         assert { dq * dq < power radix l * (r'' + 1)
783                  by 0 <= vr
784                  so dq * dq <= power radix l * r'' + n0
785                  so n0 < power radix l
786                  so power radix l * r'' + n0 < power radix l * (r'' + 1) };
787         assert { q = 1 -> dq = power radix l };
788         assert { q = 1 -> r'' < power radix l
789                  by r' * power radix l + n1 = 2 * s' * dq + r''
790                  so dq = power radix l
791                  so r' * power radix l + n1 = (2 * s') * power radix l + r''
792                  so r' <= 2 * s'
793                  so r' * power radix l <= (2 * s') * power radix l
794                  so r'' <= n1 < power radix l };
795         assert { q = 1 -> false
796                  by r'' + 1 <= power radix l
797                  so power radix l * power radix l = dq * dq
798                     < power radix l * (r'' + 1)
799                     <= power radix l * power radix l };
800         assert { q = 0 };
801         assert { vs = value sp n };
802       end;
803       let ghost onp = pure { np } in
804       join np npn;
805       value_sub_frame (pelts np) (pelts onp) (offset np)
806                               (offset np + int32'int n);
807       assert { value np n = value onp n };
808       value_tail sp (n-1);
809       let ghost ms = C.get_ofs sp (n-1) in
810       let ghost sqrt = pure { value sp n } in
811       assert { sqrt = value sp (n-1) + power radix (n-1) * ms };
812       assert { (sqrt + 1) * (sqrt + 1) > vn };
813       assert { ms >= power 2 (Limb.length - 1)
814                by power radix (n+n) <= 4 * vn
815                so if (2 * (sqrt + 1)) <= power radix n
816                   then (false
817                         by 4 * vn <= (2 * (sqrt + 1)) * (2 * (sqrt + 1))
818                                   <= power radix n * power radix n
819                                   = power radix (n+n))
820                   else true
821                so power radix n < 2 * (sqrt + 1)
822                so value sp (n-1) < power radix (n-1)
823                so 1 + value sp (n-1) <= power radix (n-1)
824                so sqrt + 1 <= power radix (n-1) * (ms + 1)
825                so power radix n = power radix (n-1) * radix
826                so power radix (n-1) * radix < power radix (n-1) * (2 * (ms + 1))
827                so 0 < power radix (n-1)
828                so 0 <= radix so 0 <= 2 * (ms + 1)
829                so radix < 2 * (ms + 1)
830                so radix = 2 * power 2 (Limb.length - 1) };
831       of_int64 c
832     end
834   let ghost function ceilhalf (n:int)
835     ensures { 2 * result >= n }
836     ensures { 2 * (result + 1) > n }
837   = ComputerDivision.div (n+1) 2
839   (* TODO rp = NULL case? *)
841   let lemma sqrt_norm (n nn c s s0 s1 : int)
842     requires { 0 <= c }
843     requires { 0 < n }
844     requires { 0 <= s }
845     requires { 0 <= s0 < power 2 c }
846     requires { nn = power 2 (2 * c) * n }
847     requires { s1 = power 2 c * s + s0 }
848     requires { s1 * s1 <= nn < (s1 + 1) * (s1 + 1) }
849     ensures  { s * s <= n < (s+1) * (s+1) }
850   =
851     assert { power 2 (2 * c) = power 2 c * power 2 c };
852     assert { s * s <= n < (s + 1) * (s + 1)
853              by 0 <= s so 0 <= power 2 c
854              so power 2 (2 * c) * (s * s)
855                 = (power 2 c * s) * (power 2 c * s)
856                 <= s1 * s1 <= nn = power 2 (2 * c) * n
857              so power 2 (2 * c) * (s * s) <= power 2 (2 * c) * n
858              so 0 <= power 2 (2 * c) so 0 <= s * s so 0 <= n
859              so s * s <= n
860              so s0 < power 2 c
861              so (s1 + 1) = power 2 c * s + s0 + 1
862                 <= power 2 c * s + power 2 c
863                 = power 2 c * (s + 1)
864              so power 2 (2 * c) * n = nn < (s1 + 1) * (s1 + 1)
865                 <= (power 2 c * (s + 1)) * (power 2 c * (s + 1))
866                 = power 2 (2 * c) * ((s + 1) * (s + 1))
867              so power 2 (2 * c) * n
868                 < power 2 (2 * c) * ((s + 1) * (s + 1))
869              so 0 < power 2 (2 * c) so 0 < (s + 1) * (s + 1) so 0 < n
870              so n < (s + 1) * (s + 1) }
872   let rec wmpn_sqrtrem (sp rp np: ptr limb) (n: int32) : int32
873     requires { valid sp (ceilhalf n) }
874     requires { valid rp n }
875     requires { valid np n }
876     requires { writable sp /\ writable rp /\ writable np }
877     requires { 1 <= n }
878     requires { 4 * n < max_int32 }
879     requires { (pelts np)[offset np + n - 1] > 0 }
880     ensures  { value np n = value sp (ceilhalf n) * value sp (ceilhalf n)
881                             + value rp result }
882     ensures  { 0 <= result <= n }
883     ensures  { value rp result <= 2 * value sp (ceilhalf n) }
884     ensures  { result > 0 -> (pelts rp)[offset rp + result - 1] > 0 }
885     ensures  { forall j. (pelts np)[j] = old (pelts np)[j] }
886     ensures  { max np = old max np }
887     ensures  { min np = old min np }
888     ensures  { plength np = old plength np }
889     ensures  { max rp = old max rp }
890     ensures  { min rp = old min rp }
891     ensures  { plength rp = old plength rp }
892     ensures  { max sp = old max sp }
893     ensures  { min sp = old min sp }
894     ensures  { plength sp = old plength sp }
895     variant  { n }
896   =
897     label Start in
898     let ghost k = ceilhalf (int32'int n) in
899     let high = C.get_ofs np (n-1) in
900     let ref c = (of_int32 (count_leading_zeros high)) / 2 in
901     assert { power 2 (2 * c) * high < radix };
902     assert { power 2 (2 * c) * high <= radix - power 2 (2 * c)
903              by let p = power 2 (2 * c) in
904                 let q = power 2 (64 - 2 * c) in
905                 let r = p * high in
906                 radix = p * q
907                 so mod r p = mod (p * high + 0) p = 0
908                 so r = p * div r p
909                 so r < p * q
910                 so div r p < q
911                 so r <= p * (q - 1) = radix - power 2 (2 * c) };
912     assert { 4 * power 2 (2 * c) * high >= radix };
913     if n = 1
914     then begin
915       assert { k = 1 };
916       value_tail np 0;
917       assert { value np n = high };
918       if c = 0
919       (* TODO if high & 0xc000_0000_0000_0000 *)
920       then begin
921         let s = sqrt1 rp high in
922         C.set sp s;
923         value_tail sp 0;
924         assert { value sp k = s };
925       end
926       else begin
927         let nh = lsl high (2 * c) in
928         assert { nh = power 2 (2 * c) * high
929                  so 4 * nh >= radix   };
930         let ncc = sqrt1 rp nh in
931         let cc = lsr_mod ncc c in
932         let ghost s0 = pure { mod ncc (power 2 c) } in
933         assert { ncc = power 2 c * cc + s0 };
934         assert { power 2 c * cc <= ncc
935                  by 0 <= s0 };
936         sqrt_norm (uint64'int high) (uint64'int nh) (uint64'int c)
937                   (uint64'int cc) s0 (uint64'int ncc);
938         C.set sp cc;
939         value_tail sp 0;
940         assert { value sp k = cc };
941         C.set rp (high - cc * cc);
942         assert { value rp 1 = high - cc * cc };
943       end;
944       let res = if C.get rp = 0 then 0 else 1 in
945       value_tail rp 0;
946       assert { value rp res = value rp 1 = (pelts rp)[offset rp] };
947       return res
948     end;
949     let ref tn = (n + 1) / 2 in
950     assert { tn = k };
951     let ref rn : int32 = 0 in
952     let adj = to_int32 ((of_int32 n) % 2) in
953     assert { 2 * tn = n + adj };
954     let scratch = salloc (UInt32.(+) (UInt32.(/) (UInt32.of_int32 tn) 2) 1) in
955     if (adj <> 0 || c <> 0)
956     then begin
957       let ref tp = salloc (UInt32.(*) 2 (UInt32.of_int32 tn)) in
958       C.set tp 0;
959       begin ensures { value tp (n+adj)
960                       = power 2 (2 * c) * power radix adj * value np n }
961             ensures { 4 * value tp (n+adj) >= power radix (n+adj) }
962             ensures { max tp = old max tp }
963             ensures { plength tp = old plength tp }
964             ensures { min np = old min np /\ max np = old max np
965                       /\ plength np = old plength np }
966             ensures  { forall j. (pelts np)[j] = old (pelts np)[j] }
967         assert { value tp adj = 0
968                  by value tp 1 = 0
969                  so adj = 0 \/ adj = 1 };
970         let ghost otp = pure { tp } in
971         let tpa = C.incr_split tp adj in
972         label Shift in
973         (if c <> 0
974         then begin
975           value_tail np (n-1);
976           assert { value np n * power 2 (2 * c) < power radix n
977                    by value np n = value np (n-1) + power radix (n-1) * high
978                    so high * power 2 (2 * c) <= radix - power 2 (2 * c)
979                    so power 2 (2 * c) * value np (n-1)
980                       < power 2 (2 * c) * power radix (n-1)
981                    so power radix (n-1) * (high * power 2 (2 * c))
982                       <= power radix (n-1) * (radix - power 2 (2 * c))
983                    so value np n * power 2 (2 * c)
984                       = value np (n-1) * power 2 (2 * c)
985                         + power radix (n-1) * high * power 2 (2 * c)
986                       < power 2 (2 * c) * power radix (n-1)
987                         + power radix (n-1) * (radix - power 2 (2 * c))
988                       = power radix (n-1) * radix = power radix n };
989           label Shift in
990           let ghost h = wmpn_lshift_sep tpa np n (2 * c) in
991           value_sub_frame (pelts np) (pure { pelts np at Shift })
992                           (offset np) (offset np + int32'int n);
993           assert { value np n = value np n at Shift };
994           assert { h = 0
995                    by value np n * power 2 (2 * c) < power radix n
996                    so value tpa n + power radix n * h < power radix n
997                    so 0 <= value tpa n
998                    so h < 1 };
999           assert { 4 * value tpa n >= power radix n
1000                    by value np n >= power radix (n-1) * high
1001                    so value tpa n
1002                       = power 2 (2 * c) * value np n
1003                       >= power 2 (2 * c) * power radix (n-1) * high
1004                    so 4 * power 2 (2 * c) * high >= radix
1005                    so power radix n
1006                       = power radix (n-1) * radix
1007                       <= power radix (n-1) * (4 * power 2 (2 * c) * high)
1008                       <= 4 * value tpa n };
1009         end
1010         else begin
1011           wmpn_copyi tpa np n;
1012           assert { 4 * high >= radix };
1013           assert { 4 * value tpa n >= power radix n
1014                    by value np n >= power radix (n-1) * high
1015                    so power radix n
1016                       = power radix (n-1) * radix
1017                       <= power radix (n-1) * 4 * high
1018                       <= 4 * value np n = 4 * value tpa n };
1019           assert { value tpa n = value np n };
1020         end);
1021         let otpa = pure { tpa } in
1022         join tp tpa;
1023         value_sub_frame (pelts tp) (pelts otp) 0 (int32'int adj);
1024         value_sub_frame (pelts tp) (pelts otpa) (int32'int adj)
1025                                   (int32'int adj + int32'int n);
1026         assert { value_sub (pelts tp) adj (n+adj) = value otpa n };
1027         assert { value tp adj = 0 };
1028         value_concat tp adj (n+adj);
1029         assert { value tp (n+adj) = power radix adj * value otpa n };
1030         assert { 4 * value tp (n+adj) >= power radix (n + adj)
1031                  by power radix (n+adj) = power radix adj * power radix n
1032                  <= power radix adj * (4 * value otpa n)
1033                  = 4 * value tp (n+adj) };
1034       end;
1035       c <- c + (if adj <> 0 then 32 else 0);
1036       assert { 0 <= c <= 63 };
1037       assert { value tp (n+adj) = power 2 (2 * c) * value np n };
1038       (*let mask = lsl 1 c - 1 in*)
1039       value_tail tp (tn + tn - 1);
1040       let ghost h = pure { (pelts tp)[tn + tn - 1] } in
1041       assert { h >= power 2 (Limb.length - 2)
1042                by value tp (n+adj)
1043                   = value tp (tn + tn - 1) + power radix (tn + tn - 1) * h
1044                   < power radix (tn + tn - 1) + power radix (tn + tn - 1) * h
1045                   = power radix (tn + tn - 1) * (h+1)
1046                so power radix (tn + tn) <= 4 * value tp (n+adj)
1047                   < power radix (tn + tn - 1) * 4 * (h+1)
1048                so power radix (tn + tn) = power radix (tn + tn - 1) * radix
1049                so power radix (tn + tn - 1) * radix
1050                   < power radix (tn + tn - 1) * (4 * (h+1))
1051                so radix < 4 * (h+1)
1052                so radix = 4 * power 2 (Limb.length - 2)
1053                so power 2 (Limb.length - 2) < h+1 };
1054       let ghost vn = pure { value np n } in
1055       let ghost vn1 = pure { value tp (n+adj) } in
1056       assert { vn1 = power 2 (2 * c) * vn };
1057       let ref rl = wmpn_dc_sqrtrem sp tp tn scratch in
1058       let ghost vs = pure { value sp tn } in
1059       let ghost vr = pure { value tp tn + power radix tn * rl } in
1060       assert { 0 <= vr
1061                by 0 <= value tp tn
1062                so 0 <= rl
1063                so 0 <= power radix tn };
1064       assert { vn1 = vs * vs + vr };
1065       let ghost vs0 = pure { mod vs (power 2 c) } in
1066       assert { vn1 = (vs - vs0) * (vs - vs0) + 2 * vs0 * vs - vs0 * vs0 + vr };
1067       let s0 = salloc 1 in
1068       value_concat sp 1 tn;
1069       let s00 = (C.get sp) % (lsl 1 c) in
1070       assert { s00 = vs0
1071                by radix = power 2 Limb.length
1072                   = power 2 c * power 2 (Limb.length - c)
1073                so let q = value_sub (pelts sp) (offset sp + 1)
1074                                               (offset sp + tn) in
1075                   vs = value sp tn = (pelts sp)[offset sp] + radix * q
1076                      = power 2 c * (power 2 (Limb.length - c) * q)
1077                        + (pelts sp)[offset sp]
1078                   so mod vs (power 2 c)
1079                      = mod (power 2 c * (power 2 (Limb.length - c) * q)
1080                             + (pelts sp)[offset sp])
1081                            (power 2 c)
1082                      = mod (pelts sp)[offset sp] (power 2 c)
1083                      = s00 };
1084       C.set s0 s00;
1085       assert { value s0 1 = s00 };
1086       let rc = wmpn_addmul_1 tp sp tn (2 * s00) in
1087       assert { value tp tn + power radix tn * (rl + rc) = vr + 2 * vs0 * vs };
1088       assert { rl + rc < radix
1089                by vr <= 2 * vs
1090                so vr + 2 * vs0 * vs <= 2 * vs * (vs0 + 1)
1091                so vs0 < power 2 c <= power 2 63
1092                so 2 * vs * (vs0 + 1) <= 2 * vs * power 2 63 = radix * vs
1093                so vs < power radix tn
1094                so radix * vs < radix * power radix tn
1095                so power radix tn * radix > value tp tn + power radix tn * (rl + rc)
1096                   >= power radix tn * (rl + rc)
1097                so power radix tn * (rl + rc) < power radix tn * radix };
1098       rl <- rl + rc;
1099       assert { value tp tn + power radix tn * rl = vr + 2 * vs0 * vs };
1100       value_concat tp 1 tn;
1101       let ghost otp = pure { tp } in
1102       let ref cc = wmpn_submul_1 tp s0 1 s00 in
1103       value_sub_frame (pelts tp) (pelts otp) (offset tp + 1)
1104                                   (offset tp + int32'int tn);
1105       value_concat tp 1 tn;
1106       assert { value tp tn - radix * cc = value otp tn - s00 * s00 };
1107       assert { value tp tn + power radix tn * rl - radix * cc
1108                = vr + 2 * vs0 * vs - vs0 * vs0 };
1109       begin ensures { value tp tn + power radix tn * rl
1110                       = vr + 2 * vs0 * vs - vs0 * vs0 }
1111         if tn > 1
1112         then begin
1113           label Sub in
1114           value_concat tp 1 tn;
1115           let tp1 = C.incr tp 1 in
1116           let ghost otp = pure { tp } in
1117           let ghost otp1 = pure { tp1 } in
1118           assert { value tp tn = value tp 1 + radix * value tp1 (tn-1) };
1119           cc <- wmpn_sub_1_in_place tp1 (tn - 1) cc;
1120           value_sub_frame (pelts tp) (pelts otp1) (offset tp)
1121                                              (offset tp + 1);
1122           assert { value tp 1 = value tp 1 at Sub };
1123           value_concat tp 1 tn;
1124           assert { value tp tn - power radix tn * cc
1125                    = value otp tn - radix * (cc at Sub)
1126                    by value tp1 (tn - 1) - power radix (tn - 1) * cc
1127                       = value otp1 (tn - 1) - cc at Sub
1128                    so value tp tn = value tp 1 + radix * value tp1 (tn - 1)
1129                    so power radix tn = radix * power radix (tn - 1)
1130                    so value tp tn - power radix tn * cc
1131                       = value tp 1
1132                         + radix * (value tp1 (tn - 1) - power radix (tn-1) * cc)
1133                       = value tp 1 + radix * (value otp1 (tn-1) - (cc at Sub))
1134                       = value tp 1 + radix * value otp1 (tn-1)
1135                         - radix * (cc at Sub)
1136                       = value otp tn - radix * (cc at Sub) };
1137         end
1138         else begin
1139           assert { tn = 1 };
1140         end;
1141         assert { value tp tn + power radix tn * (rl - cc)
1142                    = vr + 2 * vs0 * vs - vs0 * vs0 };
1143         assert { 0 <= rl - cc
1144                  by (vs0 = mod vs (power 2 c) <= vs
1145                      by vs = div vs (power 2 c) * power 2 c + vs0
1146                      so div vs (power 2 c) >= 0
1147                      so power 2 c >= 0
1148                      so div vs (power 2 c) * power 2 c >= 0)
1149                  so vs0 * vs0 <= vs0 * vs
1150                  so 2 * vs0 * vs - vs0 * vs0 >= 0
1151                  so 0 <= vr
1152                  so 0 <= value tp tn + power radix tn * (rl - cc)
1153                  so value tp tn < power radix tn
1154                  so power radix tn * (rl - cc) >= - (power radix tn)
1155                  so rl - cc > - 1 };
1156         rl <- rl - cc
1157       end;
1158       let ghost r = wmpn_rshift sp sp tn c in
1159       let ghost vsq = pure { div vs (power 2 c) } in
1160       assert { vs = vsq * power 2 c + vs0 };
1161       assert { value sp tn * radix + r = vs * power 2 (Limb.length - c) };
1162       assert { mod r (power 2 (Limb.length - c)) = 0
1163                by let q = power 2 (Limb.length - c) in
1164                   let p = power 2 c in
1165                   p * q = radix
1166                so r = vs * q - value sp tn * p * q
1167                     = q * (vs - value sp tn * p)
1168                so mod r q
1169                   = mod (q * (vs - value sp tn * p) + 0) q
1170                   = 0 };
1171       let ghost q = pure { div r (power 2 (Limb.length - c)) } in
1172       assert { r = power 2 (Limb.length - c) * q
1173                by let p = power 2 (Limb.length - c) in
1174                r = p * div r p + mod r p
1175                so div r p = q so mod r p = 0 };
1176       assert { value sp tn * power 2 c + q = vs
1177                by radix = power 2 c * power 2 (Limb.length - c)
1178                so value sp tn * power 2 c * power 2 (Limb.length - c)
1179                     + power 2 (Limb.length - c) * q
1180                   = (value sp tn * power 2 c + q) * power 2 (Limb.length - c)
1181                   = value sp tn * radix + power 2 (Limb.length - c) * q
1182                   = vs * power 2 (Limb.length - c)
1183                so (value sp tn * power 2 c + q) * power 2 (Limb.length - c)
1184                   = vs * power 2 (Limb.length - c)
1185                so power 2 (Limb.length - c) <> 0 };
1186       assert { q = vs0
1187                by vs0 = mod vs (power 2 c)
1188                   = mod (value sp tn * power 2 c + q) (power 2 c)
1189                   = mod q (power 2 c)
1190                so 0 <= q
1191                so q * power 2 (Limb.length - c) = r < radix
1192                   = power 2 c * power 2 (Limb.length - c)
1193                so let p = power 2 (Limb.length - c) in
1194                   q * p < power 2 c * p
1195                so 0 <= q so 0 < p
1196                so q < power 2 c
1197                so div q (power 2 c) = 0
1198                so mod q (power 2 c) = q };
1199       assert { value sp tn * power 2 c = vs - vs0 };
1200       assert { value tp tn + power radix tn * rl
1201                = vr + 2 * vs0 * vs - vs0 * vs0 };
1202       value_sub_update_no_change (pelts tp) (offset tp + int32'int tn)
1203                                  (offset tp) (offset tp + int32'int tn) rl;
1204       label Set in
1205       C.set_ofs tp tn rl;
1206       value_tail tp tn;
1207       assert { value tp (tn + 1) = vr + 2 * vs0 * vs - vs0 * vs0
1208                by value tp (tn + 1) = value tp tn + power radix tn * rl
1209                   = value (tp at Set) tn + power radix tn * rl };
1210       assert { vn1 = value tp (tn + 1) + (vs - vs0) * (vs - vs0) };
1211       assert { power 2 (2 * c) * vn = value tp (tn + 1)
1212                + power 2 (2 * c) * value sp tn * value sp tn
1213                by power 2 (2 * c) = power 2 c * power 2 c
1214                so power 2 (2 * c) * value sp tn * value sp tn
1215                = (vs - vs0) * (vs - vs0) };
1216       let ghost vsp = pure { value sp tn } in
1217       begin ensures { 0 < vn }
1218         value_tail np (n-1);
1219         assert { 0 < vn
1220                  by vn = value np (n-1)
1221                        + power radix (n-1) * (pelts np)[offset np + n - 1]
1222                  so 0 <= value np (n-1)
1223                  so 0 < (pelts np at Start)[offset np + n - 1]
1224                       = (pelts np)[offset np + n - 1]
1225                  so 0 < power radix (n-1)
1226                  so 0 < power radix (n-1) * (pelts np)[offset np + n - 1] };
1227       end;
1228       sqrt_norm vn vn1 (uint64'int c) vsp vs0 vs;
1229       let ref c2 = lsl c 1 in
1230       assert { c2 = 2 * c };
1231       assert { value tp (tn + 1) = power 2 c2 * (vn - vsp * vsp) };
1232       begin ensures { power 2 c2 * (vn - vsp * vsp)
1233                       = value tp tn }
1234             ensures { 0 <= c2 < 64 }
1235             ensures { valid tp tn }
1236             ensures { 0 < tn <= k+1 }
1237         if c2 < 64
1238         then tn <- tn + 1
1239         else begin
1240           value_concat tp 1 (tn + 1);
1241           let tp1 = C.incr tp 1 in
1242           assert { value tp (tn + 1) = value tp 1 + radix * value tp1 tn };
1243           assert { power 2 c2 = radix * power 2 (c2 - 64)
1244                    by radix = power 2 64 so radix * power 2 (c2 - 64) = power 2 c2 };
1245           assert { value tp (tn + 1)
1246                    = power 2 c2 * (vn - vsp * vsp)
1247                    = radix * power 2 (c2 - 64) * (vn - vsp * vsp) };
1248           assert { value tp 1 = 0
1249                    by value tp (tn + 1)
1250                       = radix * power 2 (c2 - 64) * (vn - vsp * vsp)
1251                    so mod (value tp (tn + 1)) radix
1252                       = mod (radix * (power 2 (c2 - 64) * (vn - vsp * vsp)) + 0)
1253                             radix
1254                       = 0
1255                    so mod (value tp (tn + 1)) radix
1256                       = mod (value tp 1) radix
1257                    so 0 <= value tp 1 < radix
1258                    so value tp 1 = mod (value tp 1) radix };
1259           assert { value tp1 tn = power 2 (c2 - 64) * (vn - vsp * vsp)
1260                    by radix * value tp1 tn
1261                       = radix * (power 2 (c2 - 64) * (vn - vsp * vsp)) };
1262           c2 <- c2 - 64;
1263           tp <- tp1
1264         end
1265       end;
1266       begin ensures { value rp rn = vn - vsp * vsp }
1267             ensures { 0 < rn <= k+1 }
1268             ensures { min rp = old min rp /\ max rp = old max rp
1269                       /\ plength rp = old plength rp }
1270         if (not (c2 = 0))
1271         then begin
1272           label Shift in
1273           let ghost b = wmpn_rshift_sep rp tp tn c2 in
1274           value_sub_frame (pelts tp) (pure { pelts tp at Shift })
1275                           (offset tp) (offset tp + int32'int tn);
1276           assert { value tp tn = power 2 c2 * value rp tn
1277                    by radix = power 2 c2 * power 2 (64 - c2)
1278                    so b + radix * value rp tn
1279                       = value tp tn * power 2 (64 - c2)
1280                       = power 2 c2 * (vn - vsp * vsp) * power 2 (64 - c2)
1281                       = radix * (vn - vsp * vsp)
1282                    so mod (radix * (vn - vsp * vsp)) radix = 0
1283                    so 0 = mod (b + radix * value rp tn) radix
1284                       = mod b radix
1285                    so 0 <= b < radix
1286                    so b = mod b radix = 0
1287                    so value rp tn * power 2 c2 * power 2 (64 - c2)
1288                       = radix * value rp tn
1289                       = value tp tn * power 2 (64 - c2) };
1290           assert { value rp tn = vn - vsp * vsp
1291                    by value tp tn = power 2 c2 * value rp tn
1292                    so value tp tn = power 2 c2 * (vn - vsp * vsp) };
1293         end
1294         else wmpn_copyi rp tp tn;
1295         rn <- tn
1296       end
1297     end
1298     else begin
1299       wmpn_copyi rp np n;
1300       assert { (pelts rp)[offset rp + tn + tn - 1] >= power 2 (Limb.length - 2)
1301                by tn + tn = n
1302                so c = 0
1303                so (pelts rp)[offset rp + tn + tn - 1]
1304                   = (pelts rp)[offset rp + (n - 1)]
1305                   = (pelts np)[offset np + (n - 1)] = high
1306                so 4 * high >= radix };
1307       assert { value np n = value rp (tn + tn)
1308                by tn + tn = n };
1309       let h = wmpn_dc_sqrtrem sp rp tn scratch in
1310       value_sub_update_no_change (pelts rp) (offset rp + int32'int tn) (offset rp)
1311                                  (offset rp + int32'int tn) h;
1312       C.set_ofs rp tn h;
1313       value_tail rp tn;
1314       assert { value rp (tn+1) + value sp tn * value sp tn = value np n
1315                by value np n = value sp tn * value sp tn + value rp tn + power radix tn * h
1316                so value rp (tn + 1) = value rp tn + power radix tn * h };
1317       assert { value rp (tn+h) = value rp (tn + 1)
1318                by [@case_split] (h = 0 \/ h = 1) };
1319       rn <- tn + to_int32 h;
1320     end;
1321     let ghost orp = pure { rp } in
1322     let ghost orn = pure { rn } in
1323     assert { value np n = value sp k * value sp k + value orp orn };
1324     assert { 1 <= rn <= n };
1325     while C.get_ofs rp (rn - 1) = 0 do
1326       invariant { value rp rn = value orp orn }
1327       invariant { 1 <= rn <= orn }
1328       variant   { rn }
1329       value_tail rp (rn-1);
1330       assert { value rp (rn - 1) = value rp rn };
1331       rn <- rn - 1;
1332       if rn = 0
1333       then begin
1334         assert { value orp orn = 0
1335                  by 0 = value rp 0 = value rp 1 = value orp orn };
1336         break
1337       end
1338     done;
1339     rn