support rootn
[fpmath-consensus.git] / impl-myrddin / impl-myrddin.myr
blobfcbc4f1d1402fc3cbd4c2f3cb789da6e4268c2d3
1 use std
3 use math
5 type Fn_flt__flt = struct
6         f32 : (x : flt32 -> flt32)
7         f64 : (x : flt64 -> flt64)
8 ;;
10 type Fn_flt_flt__flt = struct
11         f32 : (x : flt32, y : flt32 -> flt32)
12         f64 : (x : flt64, y : flt64 -> flt64)
15 type Fn_flt_int16__flt = struct
16         f32 : (x : flt32, y : int16 -> flt32)
17         f64 : (x : flt64, y : int16 -> flt64)
20 type Fn_flt_uint16__flt = struct
21         f32 : (x : flt32, y : uint16 -> flt32)
22         f64 : (x : flt64, y : uint16 -> flt64)
25 type Fn_flt_flt_flt__flt = struct
26         f32 : (x : flt32, y : flt32, z : flt32 -> flt32)
27         f64 : (x : flt64, y : flt64, z : flt64 -> flt64)
30 type fn_desc = struct
31         name : byte[:]
32         f : union
33                 `Flt__flt Fn_flt__flt
34                 `Flt_int16__flt Fn_flt_int16__flt
35                 `Flt_uint16__flt Fn_flt_uint16__flt
36                 `Flt_flt__flt Fn_flt_flt__flt
37                 `Flt_flt_flt__flt Fn_flt_flt_flt__flt
38         ;;
41 type flt_prec = union
42         `Single
43         `Double
46 var available_fns : fn_desc[:] = [][:]
48 generic id : (a : @a -> @a) = {x; -> x}
49 const pownwr32 = {x : flt32, y : int16; -> math.pown(x, (y : int32))}
50 const pownwr64 = {x : flt64, y : int16; -> math.pown(x, (y : int64))}
52 const rootnwr32 = {x : flt32, y : uint16; -> math.rootn(x, (y : uint32))}
53 const rootnwr64 = {x : flt64, y : uint16; -> math.rootn(x, (y : uint64))}
55 const main = {args : byte[:][:]
56         available_fns = [
57                 [.name = "id",    .f = `Flt__flt         [ .f32 = id,         .f64 = id]],
58                 [.name = "atan",  .f = `Flt__flt         [ .f32 = math.atan,  .f64 = math.atan]],
59                 [.name = "atan2", .f = `Flt_flt__flt     [ .f32 = math.atan2, .f64 = math.atan2]],
60                 [.name = "ceil",  .f = `Flt__flt         [ .f32 = math.ceil,  .f64 = math.ceil]],
61                 [.name = "cos",   .f = `Flt__flt         [ .f32 = math.cos,   .f64 = math.cos]],
62                 [.name = "cot",   .f = `Flt__flt         [ .f32 = math.cot,   .f64 = math.cot]],
63                 [.name = "exp",   .f = `Flt__flt         [ .f32 = math.exp,   .f64 = math.exp]],
64                 [.name = "expm1", .f = `Flt__flt         [ .f32 = math.expm1, .f64 = math.expm1]],
65                 [.name = "floor", .f = `Flt__flt         [ .f32 = math.floor, .f64 = math.floor]],
66                 [.name = "fma",   .f = `Flt_flt_flt__flt [ .f32 = math.fma,   .f64 = math.fma]],
67                 [.name = "log",   .f = `Flt__flt         [ .f32 = math.log,   .f64 = math.log]],
68                 [.name = "log1p", .f = `Flt__flt         [ .f32 = math.log1p, .f64 = math.log1p]],
69                 [.name = "pown",  .f = `Flt_int16__flt   [ .f32 = pownwr32,   .f64 = pownwr64]],
70                 [.name = "powr",  .f = `Flt_flt__flt     [ .f32 = math.powr,  .f64 = math.powr]],
71                 [.name = "rootn", .f = `Flt_uint16__flt  [ .f32 = rootnwr32,  .f64 = rootnwr64]],
72                 [.name = "sqrt",  .f = `Flt__flt         [ .f32 = math.sqrt,  .f64 = math.sqrt]],
73                 [.name = "sin",   .f = `Flt__flt         [ .f32 = math.sin,   .f64 = math.sin]],
74                 [.name = "tan",   .f = `Flt__flt         [ .f32 = math.tan,   .f64 = math.tan]],
75                 [.name = "trunc", .f = `Flt__flt         [ .f32 = math.trunc, .f64 = math.trunc]],
76         ][:]
78         var p : flt_prec = `Single
79         var f : fn_desc = available_fns[0]
80         var n : std.size = 0
82         (p, f, n) = read_args(args)
84         io_loop(p, f, n)
87 const read_args = {args : byte[:][:]
88         var p : flt_prec = `Single
89         var n : std.size = 0
90         var fname : byte[:] = ""
91         var fn : fn_desc = available_fns[0]
92         var cmd = std.optparse(args, &[
93                 .argdesc = "",
94                 .opts = [
95                         [.opt = 's', .desc = "use single precision (default)"],
96                         [.opt = 'd', .desc = "use double precision"],
97                         [.opt = 'n', .arg = "N", .desc = "read/write ‘N’ entries at a time"],
98                         [.opt = 'f', .arg = "func", .desc = "use function ‘f’"],
99                 ][:]
100         ])
102         for opt : cmd.opts
103                 match opt
104                 | ('s', _): p = `Single
105                 | ('d', _): p = `Double
106                 | ('n', ns):
107                         match std.intparse(ns)
108                         | `std.Some np: n = np
109                         | `std.None:
110                                 std.fput(2, "impl-myrddin: unparsable number “{}”\n", ns)
111                                 std.exit(1)
112                         ;;
113                 | ('f', fs): fname = fs
114                 | _ : std.die("impl-myrddin: impossible\n")
115                 ;;
116         ;;
118         var good_fn : bool = false
119         for f : available_fns
120                 if std.eq(f.name, fname)
121                         fn = f
122                         good_fn = true
123                         break
124                 ;;
125         ;;
127         if !good_fn
128                 std.fput(2, "impl-myrddin: unknown function “{}”\n", fname)
129                 std.exit(1)
130         ;;
132         if n <= 0
133                 std.fput(2, "impl-myrddin: positive number of entries required\n")
134                 std.exit(1)
135         ;;
137         -> (p, fn, n)
141 const io_loop = {p : flt_prec, fn : fn_desc, n : std.size
142         var input_sz : std.size = 0
143         var output_sz : std.size = 0
144         var in_buf : byte[:] = [][:]
145         var out_buf : byte[:] = [][:]
146         var w = prec_width(p)
148         (input_sz, output_sz) = io_widths(p, fn)
150         if (((input_sz * n) / input_sz) != n) || (((output_sz * n) / output_sz) != n)
151                 std.fput(2, "impl-myrddin: overflow in i/o buffer size\n")
152                 std.exit(1)
153         ;;
155         in_buf = std.slalloc(input_sz * n)
156         out_buf = std.slalloc(output_sz * n)
158         while true
159                 match std.readall(0, in_buf)
160                 | `std.Ok _:
161                 | `std.Err e:
162                         std.fput(2, "impl-myrddin: std.readall(): {}\n", e)
163                         std.exit(1)
164                 ;;
166                 for var j = 0; j < n; ++j
167                         var ib : byte[:] = in_buf[j * input_sz:(j + 1) * input_sz]
168                         var ob : byte[:] = out_buf[j * output_sz:(j + 1) * output_sz]
169                         match (p, fn.f)
170                         | (`Single, `Flt__flt f):
171                                 var x : flt32 = std.flt32frombits(std.getle32(ib))
172                                 std.putle32(ob, std.flt32bits(f.f32(x)))
173                         | (`Double, `Flt__flt f):
174                                 var x : flt64 = std.flt64frombits(std.getle64(ib))
175                                 std.putle64(ob, std.flt64bits(f.f64(x)))
176                         | (`Single, `Flt_flt__flt f):
177                                 var x1 : flt32 = std.flt32frombits(std.getle32(ib[0: 4]))
178                                 var x2 : flt32 = std.flt32frombits(std.getle32(ib[4: 8]))
179                                 std.putle32(ob, std.flt32bits(f.f32(x1, x2)))
180                         | (`Double, `Flt_flt__flt f):
181                                 var x1 : flt64 = std.flt64frombits(std.getle64(ib[ 0: 8]))
182                                 var x2 : flt64 = std.flt64frombits(std.getle64(ib[ 8:16]))
183                                 std.putle64(ob, std.flt64bits(f.f64(x1, x2)))
184                         | (`Single, `Flt_int16__flt f):
185                                 var x1 : flt32 = std.flt32frombits(std.getle32(ib[0: 4]))
186                                 var x2 : int16 = std.getle16(ib[4: 6])
187                                 std.putle32(ob, std.flt32bits(f.f32(x1, x2)))
188                         | (`Double, `Flt_int16__flt f):
189                                 var x1 : flt64 = std.flt64frombits(std.getle64(ib[ 0: 8]))
190                                 var x2 : int16 = std.getle16(ib[ 8:10])
191                                 std.putle64(ob, std.flt64bits(f.f64(x1, x2)))
192                         | (`Single, `Flt_uint16__flt f):
193                                 var x1 : flt32 = std.flt32frombits(std.getle32(ib[0: 4]))
194                                 var x2 : uint16 = std.getle16(ib[4: 6])
195                                 std.putle32(ob, std.flt32bits(f.f32(x1, x2)))
196                         | (`Double, `Flt_uint16__flt f):
197                                 var x1 : flt64 = std.flt64frombits(std.getle64(ib[ 0: 8]))
198                                 var x2 : uint16 = std.getle16(ib[ 8:10])
199                                 std.putle64(ob, std.flt64bits(f.f64(x1, x2)))
200                         | (`Single, `Flt_flt_flt__flt f):
201                                 var x1 : flt32 = std.flt32frombits(std.getle32(ib[0: 4]))
202                                 var x2 : flt32 = std.flt32frombits(std.getle32(ib[4: 8]))
203                                 var x3 : flt32 = std.flt32frombits(std.getle32(ib[8:12]))
204                                 std.putle32(ob, std.flt32bits(f.f32(x1, x2, x3)))
205                         | (`Double, `Flt_flt_flt__flt f):
206                                 var x1 : flt64 = std.flt64frombits(std.getle64(ib[ 0: 8]))
207                                 var x2 : flt64 = std.flt64frombits(std.getle64(ib[ 8:16]))
208                                 var x3 : flt64 = std.flt64frombits(std.getle64(ib[16:24]))
209                                 std.putle64(ob, std.flt64bits(f.f64(x1, x2, x3)))
210                         ;;
211                 ;;
213                 match std.writeall(1, out_buf)
214                 | `std.Ok _:
215                 | `std.Err (_, e):
216                         std.fput(2, "impl-myrddin: std.writeall(): {}\n", e)
217                         std.exit(1)
218                 ;;
219         ;;
222 const prec_width = {p : flt_prec
223         match p
224         | `Single: -> 4
225         | `Double: -> 8
226         ;;
229 const io_widths = {p : flt_prec, fn : fn_desc
230         var w : std.size = prec_width(p)
232         match fn.f
233         | `Flt__flt _ : -> (w, w)
234         | `Flt_flt__flt _ : -> (2*w, w)
235         | `Flt_int16__flt _ : -> (w + 2, w)
236         | `Flt_uint16__flt _ : -> (w + 2, w)
237         | `Flt_flt_flt__flt _ : -> (3*w, w)
238         ;;