Parse negative numbers better
[libyakmo.git] / Q.myr
blob08ad2b9df39b381cdc6bb3632204b2710a07848d
1 use std
2 use math
4 use t
6 use "traits"
7 use "bigint"
8 use "Z"
10 pkg yakmo =
11         type Q = struct
12                 p : std.bigint#
13                 q : std.bigint#
14         ;;
15         impl abs_struct Q#
16         impl disposable Q#
17         impl division_struct Q#
18         impl field_struct Q#
19         impl gcd_struct Q#
20         impl group_struct Q#
21         impl real_struct Q#
22         impl ring_struct Q#
23         impl module_struct Q# -> Z#
24         impl std.equatable Q#
25         impl t.comparable Q#
26         impl t.dupable Q#
28         generic Qfrom : (p : @a, q : @a -> std.result(Q#, byte[:])) :: numeric,integral @a
29         generic QfromZ : (n : @a -> Q#) :: numeric,integral @a
30         generic QfromS : (n : byte[:], d : byte[:] -> std.result(Q#, byte[:]))
31         const QfromFlt : (f : flt64 -> std.result(Q#, byte[:]))
32         const fltfromQ : (q : Q# -> flt64)
34         generic eqQ : (q : Q#, n : @i, d : @i -> bool) :: numeric,integral @i
35         const reduceQ : (q : Q# -> void)
38 const __init__ = {
39         var q : Q#
40         q = q
41         std.fmtinstall(std.typeof(q), qfmt)
44 const qfmt = {sb, ap, opts
45         var q : Q# = std.vanext(ap)
46         std.sbfmt(sb, "{}", q.p)
47         if !std.bigeqi(q.q, 1)
48                 std.sbputs(sb, "/")
49                 std.sbfmt(sb, "{}", q.q)
50         ;;
54    TODO: Every invocation of this should be replaced with "std.free".
55    Typechecking issues that Ori will fix soon.
56  */
57 const freewrap = { a : Q#
58         std.bytefree((a : byte#), sizeof(Q))
61 impl disposable Q# =
62         __dispose__ = {a : Q#
63                 var ap : std.bigint# = (a.p : std.bigint#)
64                 var aq : std.bigint# = (a.q : std.bigint#)
65                 std.bigfree(ap)
66                 std.bigfree(aq)
67                 /* std.free(a) */
68                 freewrap(a)
69         }
72 impl group_struct Q# =
73         gadd_ip = {a : Q#, b : Q#
74                 /* Be careful in case a == b */
75                 var t : std.bigint# = std.bigdup(b.p)
76                 std.bigmul(t, a.q)
78                 std.bigmul(a.p, b.q)
79                 std.bigadd(a.p, t)
80                 std.bigfree(t)
82                 std.bigmul(a.q, b.q)
83                 reduceQ(a)
84         }
85         gadd = {a : Q#, b : Q#
86                 var aa = t.dup(a)
87                 gadd_ip(aa, b)
88                 -> aa
89         }
90         gneg_ip = {a : Q#
91                 a.p.sign *= -1
92         }
93         gneg = {a : Q#
94                 var aa = t.dup(a)
95                 aa.p.sign *= -1
96                 -> aa
97         }
98         gid = {
99                 var p : std.bigint# = std.mkbigint(0)
100                 var q : std.bigint# = std.mkbigint(1)
101                 var ret : Q = [ .p = p, .q = q]
102                 -> (std.mk(ret) : Q#) // TODO: Cast shouldn't be necessary.
103         }
105         eq_gid = {r
106                 -> std.bigiszero(r.p)
107         }
110 impl ring_struct Q# =
111         rid = {
112                 var p : std.bigint# = std.mkbigint(1)
113                 var q : std.bigint# = std.mkbigint(1)
114                 var ret  : Q = [ .p = p, .q = q]
115                 -> (std.mk(ret) : Q#) // TODO: Cast shouldn't be necessary
116         }
117         rmul_ip = {a, b
118                 std.bigmul(a.p, b.p)
119                 std.bigmul(a.q, b.q)
120                 reduceQ(a)
121         }
122         rmul = {a, b
123                 var q = t.dup(a)
124                 rmul_ip(q, b)
125                 -> q
126         }
129 impl division_struct Q# =
130         div_maybe = {a, b
131                 match finv(b)
132                 | `std.Some bi: -> `std.Some rmul(a, bi)
133                 | `std.None: -> `std.None
134                 ;;
135         }
138 impl field_struct Q# =
139         finv = {a
140                 if std.bigiszero(a.p)
141                         -> `std.None
142                 ;;
143                 var q : Q = [ .p = std.bigdup(a.q), .q = std.bigdup(a.p) ]
144                 -> `std.Some (std.mk(q) : Q#) // TODO: Cast shouldn't be necessary
145         }
148 impl abs_struct Q# =
149         abs_ip = {q; abs_ip(q.p)}
150         abs = {q;
151                 var qq = t.dup(q)
152                 abs_ip(qq)
153                 -> qq
154         }
155         cmp_zero = {q;
156                 var b : std.bigint# = (q.p : std.bigint#)
157                 if std.bigiszero(b)
158                         -> `std.Equal
159                 elif b.sign > 0
160                         -> `std.After
161                 else
162                         -> `std.Before
163                 ;;
164         }
167 impl gcd_struct Q# =
168         gcd = {a : Q#, b : Q#
169                 /*
170                    Let Thomae's function f(p/q) = 1/q, for p and q
171                    coprime. Then gcd(a,b) = a f(b/a) when a != 0, and b
172                    if a = 0.
173                  */
174                 var ret
175                 match finv(a)
176                 | `std.None:
177                         ret = t.dup(b)
178                         abs_ip(ret)
179                         -> ret
180                 | `std.Some x:
181                         rmul_ip(x, b)
182                         var old_p = x.p
183                         x.p = std.mkbigint(1)
184                         std.bigfree(old_p)
185                         rmul_ip(x, a)
186                         ret = x
187                 ;;
189                 abs_ip(ret)
190                 -> ret
191         }
194 impl real_struct Q# =
195         compconj_ip = {q;}
196         compconj = {q; -> t.dup(q)}
199 impl module_struct Q# -> Z# =
200         phi_ip = {z : Z#, q : Q#
201                 q.p = std.bigmul(q.p, (z : std.bigint#))
202                 reduceQ(q)
203         }
204         phi = {z, q
205                 var qq = t.dup(q)
206                 phi_ip(z, qq)
207                 -> qq
208         }
211 impl std.equatable Q# =
212         eq = {a, b
213                 var m = std.bigdup(a.p)
214                 var n = std.bigdup(b.p)
215                 std.bigmul(m, b.q)
216                 std.bigmul(n, a.q)
217                 var eq = std.bigeq(m, n)
218                 std.bigfree(m)
219                 std.bigfree(n)
220                 -> eq
221         }
224 impl t.comparable Q# =
225         cmp = {a,b
226                 reduceQ(a)
227                 reduceQ(b)
228                 var aa = t.dup(a)
229                 var bb = t.dup(b)
230                 var m = aa.p
231                 var n = bb.p
232                 std.bigmul(m, bb.q)
233                 std.bigmul(n, aa.q)
234                 auto aa
235                 auto bb
236                 -> std.bigcmp(m, n)
237         }
240 impl t.dupable Q# =
241         dup = {a : Q#
242                 var q : Q = [ .p = std.bigdup(a.p), .q = std.bigdup(a.q) ]
243                 -> (std.mk(q) : Q#) // TODO: Cast shouldn't be necessary
244         }
247 const reduceQ = {q
248         if q.q.sign == -1
249                 q.q.sign = 1
250                 q.p.sign *= -1
251         ;;
253         if std.bigiszero(q.p)
254                 if !std.bigeqi(q.q, 1)
255                         std.bigsteal(q.q, std.mkbigint(1))
256                 ;;
258                 -> void
259         ;;
261         var d : Z# = auto gcd((q.p : Z#), (q.q : Z#))
262         std.bigdiv(q.p, (d : std.bigint#))
263         std.bigdiv(q.q, (d : std.bigint#))
266 generic eqQ = {q, n, d
267         if d == 0
268                 -> false
269         ;;
271         var l : std.bigint# = std.bigdup(q.p)
272         l = std.bigmuli(l, d)
273         var r : std.bigint# = std.bigdup(q.q)
274         r = std.bigmuli(r, n)
276         var ret = std.bigeq(l, r)
277         std.bigfree(l)
278         std.bigfree(r)
280         -> ret
283 generic Qfrom = {p,q
284         var pp : std.bigint# = std.mkbigint(p)
285         var qq : std.bigint# = std.mkbigint(q)
286         if std.bigiszero(qq)
287                 std.bigfree(pp)
288                 std.bigfree(qq)
289                 -> `std.Err(std.fmt("Denominator is zero"))
290         ;;
291         var a = std.mk([ .p = pp, .q = qq ])
292         reduceQ(a)
293         -> `std.Ok a
296 generic QfromZ = {p
297         var pp : std.bigint# = std.mkbigint(p)
298         var qq : std.bigint# = std.mkbigint(1)
299         var a = std.mk([ .p = pp, .q = qq ])
300         reduceQ(a)
301         -> a
304 generic QfromS = {ps, qs
305         match std.bigparse(ps)
306         | `std.Some p:
307                 match std.bigparse(qs)
308                 | `std.Some q:
309                         if std.bigiszero(q)
310                                 -> `std.Err std.fmt("Denominator is zero")
311                         ;;
312                         var qbase : Q = [ .p = p, .q = q ]
313                         -> `std.Ok std.mk(qbase)
314                 | `std.None: -> `std.Err std.fmt("Unparsable denominator")
315                 ;;
316         | `std.None: -> `std.Err std.fmt("Unparsable numerator")
317         ;;
320 const QfromFlt = {f
321         var b = std.flt64bits(f)
322         var e, s
323         if (b >> 52) & 0x7fful == 0x7fful
324                 -> `std.Err std.fmt("Cannot convert infinity or NaN to rational")
325         ;;
327         var isneg : bool = (b >> 63 & 0x1) == 0x1
328         if isneg
329                 b = b & 0x7fffffffffffffff
330                 f = std.flt64frombits(b)
331         ;;
333         if b >= 0x4340000000000000
334                 /*
335                    This is ~2^53, certainly an integer. If we go much
336                    higher, we can't store the rounding in an int64, so
337                    this should be handled specially. We don't need any
338                    continued fractions, since this is certainly an
339                    integer. Since f = s * 2^(e - 52), build it up.
340                  */
341                 (_, e, s) = std.flt64explode(f)
342                 var n : std.bigint# = std.mkbigint(s)
343                 var p : std.bigint# = bigpowtwoi(e - 52)
344                 std.bigmul(n, p)
345                 std.bigfree(p)
347                 if isneg
348                         n.sign = -1
349                 ;;
351                 var ret : Q = [ .p = n, .q = std.mkbigint(1) ]
352                 -> `std.Ok (std.mk(ret) : Q#)
353         ;;
355         /*
356            It will be useful to assume later that 0 < a < 1. We can save
357            this off safely because of the check for large numbers above
358          */
359         var a0f : flt64 = math.floor(f)
360         var a0 : int64 = math.rn(a0f)
361         f = f - a0f
362         (_, e, s) = std.flt64explode(f)
364         if f == 0.0
365                 var ret : Q = [ .p = std.mkbigint(a0), .q = std.mkbigint(1) ]
366                 if isneg
367                         ret.p.sign = -1
368                 ;;
369                 -> `std.Ok (std.mk(ret) : Q#)
370         ;;
372         /*
373            Now f = s * 2^(e - 52) for 0 < s < 2 and e < 0. We think of f
374            as really representing the range (m, M), where
376                f- = floating point number immediately below f
377                f+ = floating point number immediately above f
378                m = (f + f-) / 2
379                M = (f + f+) / 2.
380         
381            We will construct a "best rational approximation" for f as
382            follows: begin to compute, in order, the continued fractions
383            [ a0; a1, a2, ... ] for each of m (coefficients ai) and M
384            (coefficients Ai). One of the following will happen first.
386             - If we obtain a k for which ak != Ak, then the best
387               rational approximation for f has continued fraction
389                 [ a0; a1, a2, ... a{k-1}, min(ak, Ak) + 1 ]
391               We augment this slightly by, when building up the
392               continued fraction, ignoring entries when we can prove
393               that the relative error is below 2^-53. In this case, 
395             - If we obtain a k for which ak or Ak are ambiguous (this
396               will happen at some point, since m and M are rational
397               numbers), we give up and return the exact value of f,
398               which is s/2^(52 - e) or something similar.
400            It is known that, were we to resolve the ambiguity in both
401            directions for both m and M, one of the four combinations of
402            continued fraction pairs would yield the best rational
403            approximation of f. However, that is tedious to calculate and
404            doesn't feel like the "right" algorithm.
405          */
407         /*
408            First, get f-, f+, m, and M. Since we've restricted the range
409            of f, we can just add and subtract bits to f to get f- and
410            f+. Irritatingly, we then have to go to bigints to get the
411            numerators and denominators for m and M.
412          */
413         var fQ : Q# = naiveQfromFlt(f)
414         var ret : Q# = fQ
415         var fminus : Q# = auto naiveQfromFlt(std.flt64frombits(std.flt64bits(f) - 1))
416         var fplus : Q# = auto naiveQfromFlt(std.flt64frombits(std.flt64bits(f) + 1))
417         var half : Q# = auto std.mk([ .p = std.mkbigint(1), .q = std.mkbigint(2) ])
418         var m : Q# = auto yakmo.gadd(fminus, fQ)
419         var M : Q# = auto yakmo.gadd(fplus, fQ)
420         yakmo.rmul_ip(m, half)
421         yakmo.rmul_ip(M, half)
423         /* Now, start computing the continued fractions for m and M */
424         var a = [][:]
425         var A = [][:]
426         while true
427                 /*
428                    Compute ak.
430                      m = p/q = 1/(a + epsilon) with epsilon in [0, 1]
431                          q/p = a + epsilon
433                    We let a = floor(q / p) and epsilon = (q % p) / p.
434                    Then setting m = epsilon allows computing the next
435                    term.
437                    If q % p == 0, then the choice is ambiguous: either
438                    (a, 0) or (a - 1, 1) would work. There should be a
439                    way to figure out what the correct choice is, but I
440                    really don't have time right now, so we bail and
441                    return fQ.
442                  */
443                 var ak, mpnext, Ak, Mpnext
444                 (ak, mpnext) = std.bigdivmod(m.q, m.p)
445                 if std.bigiszero(mpnext)
446                         std.bigfree(ak)
447                         std.bigfree(mpnext)
448                         goto naive
449                 ;;
450                 std.slpush(&a, ak)
451                 std.bigfree(m.q)
452                 m.q = m.p
453                 m.p = mpnext
455                 (Ak, Mpnext) = std.bigdivmod(M.q, M.p)
456                 if std.bigiszero(Mpnext)
457                         std.bigfree(Ak)
458                         std.bigfree(Mpnext)
459                         goto naive
460                 ;;
461                 std.slpush(&A, Ak)
462                 std.bigfree(M.q)
463                 M.q = M.p
464                 M.p = Mpnext
465                 match std.bigcmp(ak, Ak)
466                 | `std.Equal:
467                 | `std.Before:
468                         std.bigaddi(a[a.len - 1], 1)
469                         ret = buildCFrac(a0, a)
470                         goto done
471                 | `std.After:
472                         std.bigaddi(A[A.len - 1], 1)
473                         ret = buildCFrac(a0, A)
474                         goto done
475                 ;;
476         ;;
478 :naive
479         var ipart = std.mkbigint(a0)
480         std.bigmul(ipart, fQ.q)
481         std.bigadd(fQ.p, ipart)
482         reduceQ(fQ)
484 :done
485         if (ret != fQ)
486                 __dispose__(fQ)
487         ;;
489         for var k = 0; k < a.len; ++k
490                 std.bigfree(a[k])
491         ;;
492         std.slfree(a)
494         for var k = 0; k < A.len; ++k
495                 std.bigfree(A[k])
496         ;;
497         std.slfree(A)
499         if isneg
500                 ret.p.sign = -1
501         ;;
503         -> `std.Ok ret
506 const naiveQfromFlt = {f
507         /* We assume 0 < f < 1 */
508         var e, s
509         (_, e, s) = std.flt64explode(f)
511         var ret : Q = [ .p = std.mkbigint(s), .q = bigpowtwoi(52 - e) ]
512         var retp : Q# = std.mk(ret)
514         -> retp
517 const bigpowtwoi = {e
518         var ret : std.bigint# = std.mkbigint(1)
519         std.bigshli(ret, e)
520         -> ret
523 const buildCFrac = {a0 : int64, a : std.bigint#[:]
524         var q     : Q# = std.mk([ .p = std.mkbigint(1), .q = std.mkbigint(1) ])
525         var old_q : Q# = std.mk([ .p = std.mkbigint(2), .q = std.mkbigint(1) ])
527         /*
528            First, figure out the maximum error that a floating point
529            number could detect. If a0 ~ 2^e is not 0, then we can ignore
530            errors up to 2^(e - 53).
532            But if a0 == 0, we guess based on a1. The number we're
533            approximating is not less than 1/(a1 + 1), so we can pull the
534            exponent off that, then subtract 53.
535          */
536         var detectable_exp
537         if a0 != 0
538                 var a0f = (a0 : flt64)
539                 var n, e, s
540                 (n, e, s) = std.flt64explode(a0f)
541                 detectable_exp = 52 - e
542         else
543                 /*
544                    a1 is a bigint, so we have to be sloppy. If
545                    a1.dig.len > 1, pretend a1 = 2^32. If a1.dig.len ==
546                    1, use a1 = a1.dig[0]
547                  */
548                 var a1if
549                 if a[0].dig.len == 1
550                         a1if = 1.0 / (a[0].dig[0] : flt64)
551                 else
552                         a1if = 0.00000000023283064365386962890625
553                 ;;
555                 var n, e, s
556                 (n, e, s) = std.flt64explode(a1if)
557                 detectable_exp = 52 - e
558         ;;
559         if detectable_exp < 0
560                 detectable_exp = 0
561         ;;
563         var two = std.mkbigint(2)
564         std.bigshli(two, detectable_exp)
566         for var j = 1; j <= a.len; ++j
567                 /* Try and build the continued fraction using the first j terms */
568                 __dispose__(old_q)
569                 old_q = q
571                 var ares = a[:j]
572                 var qbody : Q = [ .p = std.mkbigint(1), .q = std.bigdup(ares[ares.len - 1]) ]
573                 q = std.mk(qbody)
575                 for var k = ares.len - 2; k >= 0; --k
576                         /* convert p/q to 1/( ares[k] + p/q ) = q / (ares[k]*q + p) */
577                         var newq = std.bigdup(ares[k])
578                         std.bigmul(newq, q.q)
579                         std.bigadd(newq, q.p)
580                         std.bigfree(q.p)
581                         q.p = std.bigdup(q.q)
582                         q.q = newq
583                 ;;
585                 /*
586                    Was that a detectable improvement? If |q - old_q| <
587                    1/2^detectable_exp, then we should return old_q.
588                  */
589                 var qdiff = auto gneg(q)
590                 gadd_ip(qdiff, old_q)
591                 std.bigmul(qdiff.p, two)
592                 match std.bigcmpabs(qdiff.p, qdiff.q)
593                 | `std.Before:
594                         /* std.swap(&q, &old_q) */
595                         var t = q
596                         q = old_q
597                         old_q = t
598                         break
599                 | _:
600                 ;;
601         ;;
603         std.bigfree(old_q.p)
604         std.bigfree(old_q.q)
606         var ipart : std.bigint# = std.bigdup(q.q)
607         std.bigmuli(ipart, a0)
608         std.bigadd(q.p, ipart)
609         std.bigfree(ipart)
610         std.bigfree(two)
612         reduceQ(q)
614         -> q
617 const fltfromQ = {r
618         var p : std.bigint# = std.bigdup(r.p)
619         var is_neg : bool = false
620         if std.bigiszero(p)
621                 std.bigfree(p)
622                 -> 0.0
623         elif p.sign < 0
624                 is_neg = true
625                 p.sign = 1
626         ;;
628         var q : std.bigint# = std.bigdup(r.q)
630         /*
631            We want to express
633              r = p/q = (1 + (s + eps)/2^52) * 2^e
635            where eps in [0, 1), s in [0, 2^52), and e an integer. This
636            will allow us to construct the significand out of s + eps and
637            u5e e as the exponent. First, calculate e. Luckily, we can
638            calculate this in terms of the shifting offset to make p and
639            q have the same number of digits.
641            There are much fancier ways to get the highest bit of a
642            32-bit number, I know.
643          */
644         var digits_p = 32 * (p.dig.len - 1)
645         for var t = p.dig[p.dig.len - 1]; t != 0; t >>= 1
646                 digits_p++
647         ;;
648         var digits_q = 32 * (q.dig.len - 1)
649         for var t = q.dig[q.dig.len - 1]; t != 0; t >>= 1
650                 digits_q++
651         ;;
653         /* Get p and q to have the same order of magnitude */
654         var e : int64 = digits_p - digits_q
655         if e > 0
656                 std.bigshli(q, e)
657         else
658                 std.bigshli(p, -e)
659         ;;
661         /* Now, are we looking at something like 1.01 / 1.02, or 1.02 / 1.01? */
662         while true
663                 /* I'm pretty sure this only takes one pass. */
664                 match std.bigcmp(p, q)
665                 | `std.Before:
666                         e--
667                         std.bigshli(p, 1)
668                 | _: break
669                 ;;
670         ;;
672         /*
673            Now that we know e, we may construct
675              s + eps = ((r * 2^-e) - 1) * 2^52
676                      = p'/q'
677                      = floor(p'/q') + eps
679            for e' = (p' % q') / q', so floor(p'/q') in [0, 2^52), and eps'
680            in [0, 1). We can then round based on eps'.
681         */
682         std.bigsub(p, q)
683         std.bigshli(p, 52)
685         var sb, eps_p
686         (sb, eps_p) = std.bigdivmod(p, q)
687         var s_m_1 = 0
688         if sb.dig.len > 0
689                 s_m_1 += (sb.dig[0] : uint64)
690                 if sb.dig.len > 1
691                         s_m_1 += ((sb.dig[1] : uint64) << 32)
692                 ;;
693         ;;
695         /*
696            Now, we have to check whether eps_p / q is large enough to
697            deserve rounding s up or not. To do so, compare eps_p*2 to q.
698          */
699         var was_roundup = false
700         std.bigshli(eps_p, 1)
701         match std.bigcmp(eps_p, q)
702         | `std.Before:
703         | `std.After:
704                 s_m_1++
705                 was_roundup = true
706         | `std.Equal:
707                 if s_m_1 & 0x1 == 0x1
708                         s_m_1++
709                         was_roundup = true
710                 ;;
711         ;;
713         /*
714            If we increased s so that it is now 2^52, then
715              (1 + s/2^52) * 2^e = (1 + 0) * 2^(e + 1)
716          */
717         if s_m_1 == 4503599627370496
718                 s_m_1 = 0
719                 e++
720         ;;
721         var s : uint64 = s_m_1 | (1 << 52)
723         /* If e is large enough, clean out all the NaN bits to give Inf */
724         if e > 1023
725                 e = 1024
726                 s = 0
727         ;;
729         /*
730            Finally, we have to handle subnormals. If e is low enough, we
731            should truncate s. However, this may require us to do another
732            rounding check, and this time we have to compensate for a
733            possible previous rounding.
734          */
735         if e < -1075
736                 /* This must round to 0 */
737                 s = 0
738                 e = -1023
739         elif e < -1022
740                 /*
741                    Our s is
743                           bit 52     bit 0
744                           /          /
745                      s = 1xxxxx...xxx
746                           \          \
747                          2^e        2^e-52
749                    The significand of a subnormal number is
751                           bit 52     bit 0
752                           /          /
753                      s = 0xxxxx...xxx
754                           \          \
755                          2^-1022     2^-1074
756                         
757                    We'll shift s to the right by e - (-1022), then set e to -1023 
758                  */
759                 var rshift = ((-1022 - e) : uint64)
760                 var new_s = s >> rshift
761                 e = -1023
762                 var firstcut = s & (1 << (rshift - 1))
763                 var restcut = s & ((1 << (rshift - 1)) - 1)
764                 var lastkept = s & (1 << (rshift))
765                 var roundup = firstcut != 0 && (lastkept != 0 || restcut != 0)
766                 if roundup && !was_roundup
767                         new_s++
768                         if new_s & (1 << 52) != 0
769                                 e++   
770                         ;;
771                 ;;
772                 s = new_s
773         ;;
775         std.bigfree(p)
776         std.bigfree(q)
777         std.bigfree(sb)
778         std.bigfree(eps_p)
779         -> std.flt64assem(is_neg, e, s)