support log and log1p
[fpmath-consensus.git] / impl-libc / impl-libc.c
blob345ddfa2771dd8afe40b6318135bcef27ada3ee7
1 #include <errno.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include <unistd.h>
9 /* Whether we're looking at 32-bit or 64-bit floats */
10 typedef enum { P_SINGLE, P_DOUBLE } precision;
12 /* What type of arguments we expect, and what we'll give back. */
13 typedef enum {
14 /* */
15 A_UNKNOWN,
16 A__FLT__FLT,
17 A__FLT_FLT_FLT__FLT,
18 } argtype;
20 /* Types of functions we could call */
21 typedef float (*f__f32__f32)(float);
22 typedef float (*f__f32_f32_f32__f32)(float, float, float);
23 typedef double (*f__f64__f64)(double);
24 typedef double (*f__f64_f64_f64__f64)(double, double, double);
26 /* Wrapper around a function pointer */
27 typedef struct {
28 /* */
29 precision p;
30 argtype a;
32 union {
33 /* */
34 f__f32__f32 f32__f32;
35 f__f32_f32_f32__f32 f32_f32_f32__f32;
36 } f32;
38 union {
39 /* */
40 f__f64__f64 f64__f64;
41 f__f64_f64_f64__f64 f64_f64_f64__f64;
42 } f64;
44 } action;
46 void usage(void)
48 fprintf(stderr,
49 "usage: impl-libc [-s|-d] -f <function_name> -n <num_inputs>\n");
50 _exit(1);
53 float idf(float f)
55 return f;
58 double idd(double d)
60 return d;
63 void determine_function(const char *f, action *a)
65 if (!strcmp(f, "zzzzzz")) {
66 a->a = A_UNKNOWN;
67 } else if (!strcmp(f, "id")) {
68 a->a = A__FLT__FLT;
69 a->f32.f32__f32 = idf;
70 a->f64.f64__f64 = idd;
71 } else if (!strcmp(f, "ceil")) {
72 a->a = A__FLT__FLT;
73 a->f32.f32__f32 = ceilf;
74 a->f64.f64__f64 = ceil;
75 } else if (!strcmp(f, "cos")) {
76 a->a = A__FLT__FLT;
77 a->f32.f32__f32 = cosf;
78 a->f64.f64__f64 = cos;
79 } else if (!strcmp(f, "floor")) {
80 a->a = A__FLT__FLT;
81 a->f32.f32__f32 = floorf;
82 a->f64.f64__f64 = floor;
83 } else if (!strcmp(f, "fma")) {
84 a->a = A__FLT_FLT_FLT__FLT;
85 a->f32.f32_f32_f32__f32 = fmaf;
86 a->f64.f64_f64_f64__f64 = fma;
87 } else if (!strcmp(f, "exp")) {
88 a->a = A__FLT__FLT;
89 a->f32.f32__f32 = expf;
90 a->f64.f64__f64 = exp;
91 } else if (!strcmp(f, "expm1")) {
92 a->a = A__FLT__FLT;
93 a->f32.f32__f32 = expm1f;
94 a->f64.f64__f64 = expm1;
95 } else if (!strcmp(f, "log")) {
96 a->a = A__FLT__FLT;
97 a->f32.f32__f32 = logf;
98 a->f64.f64__f64 = log;
99 } else if (!strcmp(f, "log1p")) {
100 a->a = A__FLT__FLT;
101 a->f32.f32__f32 = log1pf;
102 a->f64.f64__f64 = log1p;
103 } else if (!strcmp(f, "sin")) {
104 a->a = A__FLT__FLT;
105 a->f32.f32__f32 = sinf;
106 a->f64.f64__f64 = sin;
107 } else if (!strcmp(f, "sqrt")) {
108 a->a = A__FLT__FLT;
109 a->f32.f32__f32 = sqrtf;
110 a->f64.f64__f64 = sqrt;
111 } else if (!strcmp(f, "trunc")) {
112 a->a = A__FLT__FLT;
113 a->f32.f32__f32 = truncf;
114 a->f64.f64__f64 = trunc;
115 } else {
116 fprintf(stderr, "impl-libc: unknown function \"%s\"\n", f);
117 _exit(1);
121 void read_buf(char *b, ssize_t len)
123 ssize_t r;
124 ssize_t total = 0;
126 while (total < len) {
127 r = read(0, (b + total), (len - total));
129 if (!r) {
130 /* EOF */
131 _exit(0);
132 } else if (r == -1) {
133 perror("impl-libc: read");
134 _exit(1);
135 } else {
136 total += r;
141 void write_buf(const char *b, ssize_t len)
143 ssize_t r;
144 ssize_t total = 0;
146 while (total < len) {
147 r = write(1, (b + total), (len - total));
149 if (r == -1) {
150 perror("impl-libc: write");
151 _exit(1);
152 } else {
153 total += r;
158 size_t input_width(argtype a, precision p)
160 size_t w = (p == P_SINGLE) ? 4 : 8;
162 switch (a) {
163 case A_UNKNOWN:
164 break;
165 case A__FLT__FLT:
167 return 1 * w;
168 case A__FLT_FLT_FLT__FLT:
170 return 3 * w;
173 return (size_t) -1;
176 size_t output_width(argtype a, precision p)
178 size_t w = (p == P_SINGLE) ? 4 : 8;
180 switch (a) {
181 case A_UNKNOWN:
182 break;
183 case A__FLT__FLT:
185 return 1 * w;
186 case A__FLT_FLT_FLT__FLT:
188 return 1 * w;
191 return (size_t) -1;
194 void io_loop(action a, size_t n)
196 char *in_buf = 0;
197 char *out_buf = 0;
198 size_t in_sz = input_width(a.a, a.p);
199 size_t out_sz = output_width(a.a, a.p);
201 if ((in_sz * n) / n != in_sz) {
202 fprintf(stderr, "impl-libc: input length overflow\n");
203 _exit(1);
206 if ((out_sz * n) / n != out_sz) {
207 fprintf(stderr, "impl-libc: output length overflow\n");
208 _exit(1);
211 if (!(in_buf = malloc(in_sz * n))) {
212 perror("impl-libc: malloc");
213 _exit(1);
216 if (!(out_buf = malloc(out_sz * n))) {
217 perror("impl-libc: malloc");
218 _exit(1);
221 while (1) {
222 read_buf(in_buf, in_sz * n);
224 switch (a.a) {
225 case A_UNKNOWN:
226 fprintf(stderr, "impl-libc: impossible\n");
227 _exit(1);
228 break;
229 case A__FLT__FLT:
231 switch (a.p) {
232 case P_SINGLE:
234 for (size_t j = 0; j < n; ++j) {
235 float *x = (float *) (in_buf + (in_sz *
236 j));
237 float *y = (float *) (out_buf +
238 (out_sz * j));
240 *y = a.f32.f32__f32(*x);
243 break;
244 case P_DOUBLE:
246 for (size_t j = 0; j < n; ++j) {
247 double *x = (double *) (in_buf +
248 (in_sz * j));
249 double *y = (double *) (out_buf +
250 (out_sz * j));
252 *y = a.f64.f64__f64(*x);
255 break;
258 break;
259 case A__FLT_FLT_FLT__FLT:
261 switch (a.p) {
262 case P_SINGLE:
264 for (size_t j = 0; j < n; ++j) {
265 float *x1 = (float *) (in_buf + (in_sz *
266 j));
267 float *x2 = (float *) (in_buf + (in_sz *
269 4));
270 float *x3 = (float *) (in_buf + (in_sz *
272 8));
273 float *y = (float *) (out_buf +
274 (out_sz * j));
276 *y = a.f32.f32_f32_f32__f32(*x1, *x2,
277 *x3);
280 break;
281 case P_DOUBLE:
283 for (size_t j = 0; j < n; ++j) {
284 double *x1 = (double *) (in_buf +
285 (in_sz * j));
286 double *x2 = (double *) (in_buf +
287 (in_sz * j) +
289 double *x3 = (double *) (in_buf +
290 (in_sz * j) +
291 16);
292 double *y = (double *) (out_buf +
293 (out_sz * j));
295 *y = a.f64.f64_f64_f64__f64(*x1, *x2,
296 *x3);
299 break;
302 break;
305 write_buf(out_buf, out_sz * n);
309 int main(int argc, char **argv)
311 int c = 0;
312 action a = { .p = P_SINGLE };
313 long long n = 0;
315 while ((c = getopt(argc, argv, "sdf:n:")) != -1) {
316 switch (c) {
317 case 's':
318 a.p = P_SINGLE;
319 break;
320 case 'd':
321 a.p = P_DOUBLE;
322 break;
323 case 'f':
324 determine_function(optarg, &a);
325 break;
326 case 'n':
327 errno = 0;
328 n = strtoll(optarg, 0, 0);
330 if (errno) {
331 perror("impl-libc: unparsable");
333 return 1;
336 break;
337 default:
338 usage();
339 break;
343 if (a.a == A_UNKNOWN) {
344 usage();
347 io_loop(a, n);