support atan and atan2
[fpmath-consensus.git] / checker / checker.myr
blobb368e2872bb183bed9d91bc1506fc4a7a486773b
1 use std
3 use bio
4 use iter
5 use fileutil
6 use sys
8 type impl_prog = struct
9         name : byte[:]
10         pid : std.pid
11         stdin : std.fd
12         stdout : std.fd
13         alive : bool
15         output_bits : byte[:]
18 var rng : std.rng#
20 /* Flt is ``whatever precision we're testing''. */
21 type fp_type = union
22         `Flt
25 type fn_desc = struct
26         name : byte[:]
27         inputs : fp_type[:]
28         outputs : fp_type[:]
31 type flt_prec = union
32         `Single
33         `Double
36 type exactness = union
37         `Exact
38         `Inexact uint
41 /* (name, number-of-flt-args, constant-extra-bytes) */
42 var available_fns : fn_desc[:] = [][:]
44 const nop = {;}
46 const main = {args : byte[:][:]
47         available_fns = [
48                 [.name = "id",    .inputs = [`Flt][:], .outputs = [`Flt][:]],
49                 [.name = "atan",  .inputs = [`Flt][:], .outputs = [`Flt][:]],
50                 [.name = "atan2", .inputs = [`Flt, `Flt][:], .outputs = [`Flt][:]],
51                 [.name = "ceil",  .inputs = [`Flt][:], .outputs = [`Flt][:]],
52                 [.name = "cos",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
53                 [.name = "cot",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
54                 [.name = "exp",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
55                 [.name = "expm1", .inputs = [`Flt][:], .outputs = [`Flt][:]],
56                 [.name = "floor", .inputs = [`Flt][:], .outputs = [`Flt][:]],
57                 [.name = "fma",   .inputs = [`Flt, `Flt, `Flt][:], .outputs = [`Flt][:]],
58                 [.name = "log",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
59                 [.name = "log1p", .inputs = [`Flt][:], .outputs = [`Flt][:]],
60                 [.name = "powr",  .inputs = [`Flt, `Flt][:], .outputs = [`Flt][:]],
61                 [.name = "sin",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
62                 [.name = "sincos",.inputs = [`Flt][:], .outputs = [`Flt, `Flt][:]],
63                 [.name = "sqrt",  .inputs = [`Flt][:], .outputs = [`Flt][:]],
64                 [.name = "tan",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
65                 [.name = "trunc", .inputs = [`Flt][:], .outputs = [`Flt][:]],
66         ][:]
68         var old
69         var sa = [
70                 .handler = (nop : byte#),
71                 .flags = sys.Saresethand,
72         ]
73         sys.sigaction(sys.Sigpipe, &sa, &old)
75         var fn : fn_desc = available_fns[0]
76         var next_bits_fn : (b : byte[:], n : std.size -> bool) = next_rand
78         var precision : flt_prec = `Single
79         var exactness : exactness = `Exact
80         var impls : impl_prog#[:]
81         rng = std.mksrng((std.now() : uint32))
83         (precision, exactness, fn, next_bits_fn) = read_args(args)
85         var input_sz = args_width(fn.inputs, precision)
86         var num_inputs = (1 << 18)
87         var buf_sz = num_inputs * input_sz
89         impls = start_impls([prec_arg(precision), "-f", fn.name, "-n", std.fmt("{}", num_inputs)][:])
91         io_loop(impls, precision, exactness, num_inputs, fn, next_bits_fn)
92         std.put("\n")
95 const prec_width = {p
96         match p
97         | `Single: -> 4
98         | `Double: -> 8
99         ;;
102 const prec_arg = {p
103         match p
104         | `Single: -> "-s"
105         | `Double: -> "-d"
106         ;;
109 const args_width = {ts : fp_type[:], p : flt_prec
110         var w : std.size = 0
112         for t : ts
113                 match t
114                 | `Flt: w += prec_width(p)
115                 ;;
116         ;;
118         -> w
121 const read_args = {args : byte[:][:]
122         var exactness : exactness = `Exact
123         var precision : flt_prec = `Single
124         var fn_name : byte[:] = "UNSPECIFIED"
125         var fn : fn_desc = available_fns[0]
126         var next_bits_fn : (b : byte[:], n : std.size -> bool) = next_rand
128         var cmd = std.optparse(args, &[
129                 .argdesc = "",
130                 .opts = [
131                         [.opt = 's', .desc = "use single precision (default)"],
132                         [.opt = 'd', .desc = "use double precision"],
133                         [.opt = 'l', .desc = "list available functions"],
134                         [.opt = 'f', .arg = "f", .desc = "test function ‘f’"],
135                         [.opt = 'r', .arg = "s", .desc = "choose inputs randomly, with seed ‘s’"],
136                         [.opt = 'e', .desc = "exhaust input space"],
137                         [.opt = 'i', .arg = "i", .desc = "allow inexact matches (by ‘i’, bitwise)"],
138                 ][:]
139         ])
141         for opt : cmd.opts
142                 match opt
143                 | ('s', _):
144                         precision = `Single
145                 | ('d', _):
146                         precision = `Double
147                 | ('i', arg):
148                         match std.intparse(arg)
149                         | `std.Some n:
150                                 if n >= 0
151                                         exactness = `Inexact (n : uint)
152                                 else
153                                         std.put("unacceptable inexactness “{}”\n", arg)
154                                 ;;
155                         | `std.None:
156                                 std.put("cannot parse inexactness “{}”\n", arg)
157                                 std.exit(1)
158                         ;;
159                 | ('l', _):
160                         list_functions()
161                         std.exit(0)
162                 | ('f', arg): fn_name = arg
163                 | ('r', arg):
164                         next_bits_fn = next_rand
165                         match std.intparse(arg)
166                         | `std.Some n:
167                                 rng = std.mksrng(n)
168                         | `std.None:
169                                 std.put("cannot parse seed “{}”\n", arg)
170                                 std.exit(1)
171                         ;;
172                 | ('e', _): next_bits_fn = next_exhaust
173                 | _: std.die("impossible\n")
174                 ;;
175         ;;
177         var good_fn : bool = false
178         for f : available_fns
179                 if std.eq(f.name, fn_name)
180                         fn = f
181                         good_fn = true
182                         break
183                 ;;
184         ;;
186         if !good_fn
187                 std.put("unknown function “{}”\n", fn_name)
188                 std.exit(1)
189         ;;
191         -> (precision, exactness, fn, next_bits_fn)
194 const io_loop = {impls : impl_prog#[:], p : flt_prec, x : exactness, num_inputs : std.size, fn : fn_desc, next_bits_fn : (b : byte[:], n : std.size -> bool)
195         var input_sz = args_width(fn.inputs, p)
196         var output_sz = args_width(fn.outputs, p)
197         var draw_line : bool = false
198         var n = 0
199         var bits : byte[:] = std.slalloc(num_inputs * input_sz)
200         var last_bits : byte[:] = std.slalloc(num_inputs * input_sz)
201         std.slfill(bits, 0)
202         std.slfill(last_bits, 0)
203         for i : impls
204                 i.output_bits = std.slalloc(num_inputs * output_sz)
205                 std.slfill(i.output_bits, 0)
206         ;;
208         /* Now, loop perhaps infinitely with the comparisons */
209 :again
210         draw_line = false
212         /* Send question */
213         for i : impls
214                 match std.writeall(i.stdin, bits)
215                 | `std.Ok _:
216                 | `std.Err (_, e):
217                         std.put("CRASH: {w=20} [{w=6}] failed to receive data\n", i.name, i.pid)
218                         i.alive = false
219                         draw_line = true
220                 ;;
221         ;;
223         /* Reap zombies */
224         for var j = 0; j < impls.len; ++j
225                 if impls[j].alive
226                         continue
227                 ;;
228                 impls[j] = impls[impls.len - 1]
229                 std.slgrow(&impls, impls.len - 1)
230         ;;
232         /* Gather consensus on last time's answers */
233         consensus(last_bits, num_inputs, fn, p, x, impls)
234         if n % 100 == 0
235                 if n % 8000 == 0
236                         std.put("\x1b[1G\x1b[0K")
237                 ;;
238                 std.put(".")
239         ;;
240         n++
242         std.slcp(last_bits, bits)
244         /* Receive new answers */
245         for i : impls
246                 if !i.alive
247                         continue
248                 ;;
250                 match std.readall(i.stdout, i.output_bits)
251                 | `std.Ok _:
252                 | `std.Err e:
253                         std.put("CRASH: {w=20} [{w=6}] failed to send data\n", i.name, i.pid)
254                         i.alive = false
255                         draw_line = true
256                 |_:
257                 ;;
258         ;;
260         if impls.len < 2
261                 std.put("Less than 2 implementations left. Consensus impossible.\n")
262                 std.put("----------\n")
263                 std.exit(1)
264         ;;
266         if draw_line
267                 std.put("----------\n")
268         ;;
271         /* Onward */
272         if next_bits_fn(bits, num_inputs)
273                 goto again
274         ;;
277 const next_rand = { b : byte[:], n : std.size
278         std.rngrandbytes(rng, b)
279         -> true
282 const next_exhaust = { b : byte[:], n : std.size
283         var one_arg : byte[:] = std.slalloc(b.len / n)
284         std.slcp(one_arg, b[b.len - one_arg.len:])
285         var finished : bool = false
287         /* n is the number of total argument groups */
288         for var i = 0; i < n; ++i
289                 /* Increment this particular argument */
290                 var j = one_arg.len - 1
291                 while j >= 0
292                         one_arg[j]++
293                         if (one_arg[j] != 0)
294                                 break
295                         ;;
296                         j--
297                 ;;
299                 finished = finished || j < 0
300                 var z = one_arg.len * i
301                 std.slcp(b[z:z + one_arg.len], one_arg)
302         ;;
304         -> !finished
307 const list_functions = {
308         std.put("Available functions:\n")
309         std.put("--------------------\n")
310         for f : available_fns
311                 std.put("  {}\n", f)
312         ;;
315 const start_impls = {opts : byte[:][:]
316         var cmd : byte[:][:] = std.slalloc(opts.len + 1)
317         var nice_name : byte[:] = [][:]
318         var started_impls : impl_prog#[:] = [][:]
319         var survived_impls : impl_prog#[:] = [][:]
321         for var j = 0; j < opts.len; ++j
322                 cmd[j + 1] = opts[j]
323         ;;
325         /* Start everything */
326         for f  : fileutil.bywalk(".")
327                 match std.strrfind(f, "/")
328                 | `std.Some j: nice_name = std.sldup(f[j+1:])
329                 | `std.None: nice_name = std.sldup(f)
330                 ;;
332                 if nice_name.len < 5 || !std.eq(nice_name[:5], "impl-") || \
333                         std.eq(nice_name[nice_name.len - 2:nice_name.len - 1], ".")
334                         std.slfree(nice_name)
335                         continue
336                 ;;
338                 cmd[0] = f
339                 match std.spork(cmd)
340                 | `std.Ok (p, fi, fo) :
341                         std.slpush(&started_impls, std.mk([
342                                 .name = nice_name,
343                                 .pid = p,
344                                 .stdin = fi,
345                                 .stdout = fo,
346                                 .alive = true,
347                         ]))
348                 | `std.Err e: std.slfree(nice_name)
349                 ;;
350         ;;
352         /* Give them a bit of time to die */
353         std.usleep(500_000)
355         /* Reap the zombies */
356         var z
357         var l
358         var WNOHANG = 1 /* HACK */
359         while ((z = sys.waitpid(-1, &l, WNOHANG)) > 0)
360                 match sys.waitstatus(l)
361                 | `sys.Waitexit _:
362                 | `sys.Waitsig _:
363                 | `sys.Waitfail _:
364                 | `sys.Waitstop _: continue
365                 ;;
366                 for i : started_impls
367                         if i.pid == (z : std.pid)
368                                 i.alive = false
369                         ;;
370                 ;;
371         ;;
373         /* What remains? */
374         for i : started_impls
375                 if !i.alive
376                         continue
377                 ;;
379                 std.slpush(&survived_impls, i)
380         ;;
382         match survived_impls.len
383         | 0:
384                 std.put("No implementations found. Try running from fpmath-consensus root dir.\n")
385                 std.exit(1)
386         | 1:
387                 std.put("Only one implementation found. Comparisons will be impossible.\n")
388                 std.exit(1)
389         | _:
390         ;;
392         std.put("Executing:\n")
393         std.put("----------\n")
394         for i : survived_impls
395                 std.put("  [{w=6}] {w=20}", i.pid, i.name)
396                 for o : opts
397                         std.put(" {}", o)
398                 ;;
399                 std.put("\n")
400         ;;
401         std.put("----------\n")
403         std.slfree(started_impls)
405         -> survived_impls
408 const consensus = { input : byte[:], num_inputs : std.size, fn : fn_desc, p : flt_prec, x : exactness, impls : impl_prog#[:]
409         var all_agree : bool = true
410         var flt_sz : std.size = prec_width(p)
411         var inputs_sz = args_width(fn.inputs, p)
412         var outputs_sz = args_width(fn.outputs, p)
414         for var z = 0; z < num_inputs; ++z
415                 all_agree = true
417                 /* The input is possibly multiple entries */
418                 var i_start = z * inputs_sz
419                 var i_end = (z + 1) * inputs_sz
421                 /* The output might also be strange */
422                 var a_start = z * outputs_sz
423                 var a_end = (z + 1) * outputs_sz
425                 for var j = 0; j + 1 < impls.len; ++j
426                         match x
427                         | `Exact:
428                                 if std.sleq(impls[j].output_bits[a_start:a_end], impls[j+1].output_bits[a_start:a_end])
429                                         continue
430                                 ;;
432                                 /* The memory patterns don't agree. But perhaps this is due to NaN? */
433                                 if detailed_eq(impls[j].output_bits[a_start:a_end], impls[j+1].output_bits[a_start:a_end], fn.outputs, p, 0)
434                                         continue
435                                 ;;
436                         | `Inexact i:
437                                 var good : bool = true
439                                 for var k = j + 1; k < impls.len; ++k
440                                         good = good && detailed_eq(impls[j].output_bits[a_start:a_end], impls[k].output_bits[a_start:a_end], fn.outputs, p, i)
441                                 ;;
443                                 if good
444                                         continue
445                                 ;;
446                         ;;
448                         all_agree = false
449                         break
450                 ;;
452                 if all_agree
453                         continue
454                 ;;
456                 
457                 std.put("For input: ")
458                 extract_and_describe(input[i_start:i_end], fn.inputs, p)
459                 std.put("\n")
460                 for i : impls
461                         std.put("  [{w=6}] {w=20}: ", i.pid, i.name)
462                         extract_and_describe(i.output_bits[a_start:a_end], fn.outputs, p)
463                 ;;
464                 std.put("----------\n")
465         ;;
468 const extract_and_describe = {bits : byte[:], ts : fp_type[:], p : flt_prec
469         var w : std.size = prec_width(p)
471         if ts.len > 1
472                 std.put("\n")
473         ;;
475         for t : ts
476                 if ts.len > 1
477                         std.put("    ")
478                 ;;
480                 match t
481                 | `Flt:
482                         match p
483                         | `Single:
484                                 var u = std.getle32(bits[:w])
485                                 std.put("0x{w=8,p=0,x} ({})\n", u, std.flt32frombits(u))
486                         | `Double:
487                                 var u = std.getle64(bits[:w])
488                                 std.put("0x{w=16,p=0,x} ({})\n", u, std.flt64frombits(u))
489                         ;;
491                         bits = bits[w:]
492                 ;;
493         ;;
496 const detailed_eq = {a : byte[:], b : byte[:], ts : fp_type[:], p : flt_prec, i : uint
497         var w : std.size = prec_width(p)
499         for t : ts
500                 match t
501                 | `Flt:
502                         match p
503                         | `Single:
504                                 var u1 = std.getle32(a[:w])
505                                 var u2 = std.getle32(b[:w])
506                                 var f1 = std.flt32frombits(u1)
507                                 var f2 = std.flt32frombits(u2)
508                                 if !((f1 == f2) || (std.isnan(f1) && std.isnan(f2)))
509                                         if i == 0 || ((u1 - u2 > (i : uint32)) && (u2 - u1 > (i : uint32)))
510                                                 -> false
511                                         ;;
512                                 ;;
513                         | `Double:
514                                 var u1 = std.getle64(a[:w])
515                                 var u2 = std.getle64(b[:w])
516                                 var f1 = std.flt64frombits(u1)
517                                 var f2 = std.flt64frombits(u2)
518                                 if !((f1 == f2) || (std.isnan(f1) && std.isnan(f2)))
519                                         if i == 0 || ((u1 - u2 > (i : uint64)) && (u2 - u1 > (i : uint64)))
520                                                 -> false
521                                         ;;
522                                 ;;
523                         ;;
525                         a = a[w:]
526                         b = b[w:]
527                 ;;
528         ;;
530         -> true