support log and log1p
[fpmath-consensus.git] / impl-myrddin / impl-myrddin.myr
blob76aefefcb3b7253ef6728adfe05cd856f2749deb
1 use std
3 use math
5 type Fn_flt__flt = struct
6         f32 : (f : flt32 -> flt32)
7         f64 : (f : flt64 -> flt64)
8 ;;
10 type Fn_flt_flt_flt__flt = struct
11         f32 : (f1 : flt32, f2 : flt32, f3 : flt32 -> flt32)
12         f64 : (f1 : flt64, f2 : flt64, f3 : flt64 -> flt64)
15 type fn_desc = struct
16         name : byte[:]
17         f : union
18                 `Flt__flt Fn_flt__flt
19                 `Flt_flt_flt__flt Fn_flt_flt_flt__flt
20         ;;
23 type flt_prec = union
24         `Single
25         `Double
28 var available_fns : fn_desc[:] = [][:]
30 generic id : (a : @a -> @a) = {x; -> x}
32 const main = {args : byte[:][:]
33         available_fns = [
34                 [.name = "id",    .f = `Flt__flt         [ .f32 = id,         .f64 = id]],
35                 [.name = "ceil",  .f = `Flt__flt         [ .f32 = math.ceil,  .f64 = math.ceil]],
36                 [.name = "exp",   .f = `Flt__flt         [ .f32 = math.exp,   .f64 = math.exp]],
37                 [.name = "expm1", .f = `Flt__flt         [ .f32 = math.expm1, .f64 = math.expm1]],
38                 [.name = "floor", .f = `Flt__flt         [ .f32 = math.floor, .f64 = math.floor]],
39                 [.name = "fma",   .f = `Flt_flt_flt__flt [ .f32 = math.fma,   .f64 = math.fma]],
40                 [.name = "log",   .f = `Flt__flt         [ .f32 = math.log,   .f64 = math.log]],
41                 [.name = "log1p", .f = `Flt__flt         [ .f32 = math.log1p, .f64 = math.log1p]],
42                 [.name = "sqrt",  .f = `Flt__flt         [ .f32 = math.sqrt,  .f64 = math.sqrt]],
43                 [.name = "trunc", .f = `Flt__flt         [ .f32 = math.trunc, .f64 = math.trunc]],
44         ][:]
46         var p : flt_prec = `Single
47         var f : fn_desc = available_fns[0]
48         var n : std.size = 0
50         (p, f, n) = read_args(args)
52         io_loop(p, f, n)
55 const read_args = {args : byte[:][:]
56         var p : flt_prec = `Single
57         var n : std.size = 0
58         var fname : byte[:] = ""
59         var fn : fn_desc = available_fns[0]
60         var cmd = std.optparse(args, &[
61                 .argdesc = "",
62                 .opts = [
63                         [.opt = 's', .desc = "use single precision (default)"],
64                         [.opt = 'd', .desc = "use double precision"],
65                         [.opt = 'n', .arg = "N", .desc = "read/write ‘N’ entries at a time"],
66                         [.opt = 'f', .arg = "func", .desc = "use function ‘f’"],
67                 ][:]
68         ])
70         for opt : cmd.opts
71                 match opt
72                 | ('s', _): p = `Single
73                 | ('d', _): p = `Double
74                 | ('n', ns):
75                         match std.intparse(ns)
76                         | `std.Some np: n = np
77                         | `std.None:
78                                 std.put("impl-myrddin: unparsable number “{}”\n", ns)
79                                 std.exit(1)
80                         ;;
81                 | ('f', fs): fname = fs
82                 | _ : std.die("impl-myrddin: impossible\n")
83                 ;;
84         ;;
86         var good_fn : bool = false
87         for f : available_fns
88                 if std.eq(f.name, fname)
89                         fn = f
90                         good_fn = true
91                         break
92                 ;;
93         ;;
95         if !good_fn
96                 std.put("impl-myrddin: unknown function “{}”\n", fname)
97                 std.exit(1)
98         ;;
100         if n <= 0
101                 std.put("impl-myrddin: positive number of entries required\n")
102                 std.exit(1)
103         ;;
105         -> (p, fn, n)
109 const io_loop = {p : flt_prec, fn : fn_desc, n : std.size
110         var input_sz : std.size = 0
111         var output_sz : std.size = 0
112         var in_buf : byte[:] = [][:]
113         var out_buf : byte[:] = [][:]
114         var w = prec_width(p)
116         (input_sz, output_sz) = io_widths(p, fn)
118         if (((input_sz * n) / input_sz) != n) || (((output_sz * n) / output_sz) != n)
119                 std.put("impl-myrddin: overflow in i/o buffer size\n")
120                 std.exit(1)
121         ;;
123         in_buf = std.slalloc(input_sz * n)
124         out_buf = std.slalloc(output_sz * n)
126         while true
127                 match std.readall(0, in_buf)
128                 | `std.Ok _:
129                 | `std.Err e:
130                         std.put("impl-myrddin: std.readall(): {}\n", e)
131                         std.exit(1)
132                 ;;
134                 match (p, fn.f)
135                 | (`Single, `Flt__flt f):
136                         for var j = 0; j < n; ++j
137                                 var ib : byte[:] = in_buf[j * w:(j + 1) * w]
138                                 var ob : byte[:] = out_buf[j * w:(j + 1) * w]
139                                 var x : flt32 = std.flt32frombits(std.getle32(ib))
140                                 std.putle32(ob, std.flt32bits(f.f32(x)))
141                         ;;
142                 | (`Double, `Flt__flt f):
143                         for var j = 0; j < n; ++j
144                                 var ib : byte[:] = in_buf[j * w:(j + 1) * w]
145                                 var ob : byte[:] = out_buf[j * w:(j + 1) * w]
146                                 var x : flt64 = std.flt64frombits(std.getle64(ib))
147                                 std.putle64(ob, std.flt64bits(f.f64(x)))
148                         ;;
149                 | (`Single, `Flt_flt_flt__flt f):
150                         for var j = 0; j < n; ++j
151                                 var ib : byte[:] = in_buf[j * 3 * w:(j + 1) * 3 * w]
152                                 var ob : byte[:] = out_buf[j * w:(j + 1) * w]
153                                 var x1 : flt32 = std.flt32frombits(std.getle32(ib[0: 4]))
154                                 var x2 : flt32 = std.flt32frombits(std.getle32(ib[4: 8]))
155                                 var x3 : flt32 = std.flt32frombits(std.getle32(ib[8:12]))
156                                 std.putle32(ob, std.flt32bits(f.f32(x1, x2, x3)))
157                         ;;
158                 | (`Double, `Flt_flt_flt__flt f):
159                         for var j = 0; j < n; ++j
160                                 var ib : byte[:] = in_buf[j * 3 * w:(j + 1) * 3 * w]
161                                 var ob : byte[:] = out_buf[j * w:(j + 1) * w]
162                                 var x1 : flt64 = std.flt64frombits(std.getle64(ib[ 0: 8]))
163                                 var x2 : flt64 = std.flt64frombits(std.getle64(ib[ 8:16]))
164                                 var x3 : flt64 = std.flt64frombits(std.getle64(ib[16:24]))
165                                 std.putle64(ob, std.flt64bits(f.f64(x1, x2, x3)))
166                         ;;
167                 ;;
169                 match std.writeall(1, out_buf)
170                 | (_, `std.None):
171                 | (_, `std.Some e):
172                         std.put("impl-myrddin: std.writeall(): {}\n", e)
173                         std.exit(1)
174                 ;;
175         ;;
178 const prec_width = {p : flt_prec
179         match p
180         | `Single: -> 4
181         | `Double: -> 8
182         ;;
185 const io_widths = {p : flt_prec, fn : fn_desc
186         var w : std.size = prec_width(p)
188         match fn.f
189         | `Flt__flt _ : -> (w, w)
190         | `Flt_flt_flt__flt _ : -> (3*w, w)
191         ;;