support log and log1p
[fpmath-consensus.git] / impl-mpfr / impl-mpfr.c
blob4a60bb27304fd3002ae3815d237636b384a41269
1 #include <errno.h>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <string.h>
6 #include <unistd.h>
8 #include <mpfr.h>
10 /* Whether we're looking at 32-bit or 64-bit floats */
11 typedef enum { P_SINGLE, P_DOUBLE } precision;
13 /* What type of arguments we expect, and what we'll give back. */
14 typedef enum {
15 /* */
16 A_UNKNOWN,
17 A__FLT__FLT,
18 A__FLT_FLT_FLT_RND__FLT,
19 A__FLT_RND__FLT,
20 } argtype;
22 /* Types of functions we could call */
23 typedef int (*f__flt__flt)(mpfr_ptr, mpfr_srcptr);
24 typedef int (*f__flt_flt_flt_rnd__flt)(mpfr_ptr, mpfr_srcptr, mpfr_srcptr,
25 mpfr_srcptr, mpfr_rnd_t);
26 typedef int (*f__flt_rnd__flt)(mpfr_ptr, mpfr_srcptr, mpfr_rnd_t);
28 /* Wrapper around a function pointer */
29 typedef struct {
30 /* */
31 precision p;
32 argtype a;
34 union {
35 /* */
36 f__flt__flt flt__flt;
37 f__flt_flt_flt_rnd__flt flt_flt_flt_rnd__flt;
38 f__flt_rnd__flt flt_rnd__flt;
39 } f;
41 } action;
43 void usage(void)
45 fprintf(stderr,
46 "usage: impl-mpfr [-s|-d] -f <function_name> -n <num_inputs>\n");
47 _exit(1);
50 void determine_function(const char *f, action *a)
52 if (!strcmp(f, "zzzzzzzzz")) {
53 a->a = A_UNKNOWN;
54 } else if (!strcmp(f, "id")) {
55 a->a = A__FLT_RND__FLT;
56 a->f.flt_rnd__flt = mpfr_set;
57 } else if (!strcmp(f, "ceil")) {
58 a->a = A__FLT__FLT;
59 a->f.flt__flt = mpfr_ceil;
60 } else if (!strcmp(f, "cos")) {
61 a->a = A__FLT_RND__FLT;
62 a->f.flt_rnd__flt = mpfr_cos;
63 } else if (!strcmp(f, "floor")) {
64 a->a = A__FLT__FLT;
65 a->f.flt__flt = mpfr_floor;
66 } else if (!strcmp(f, "exp")) {
67 a->a = A__FLT_RND__FLT;
68 a->f.flt_rnd__flt = mpfr_exp;
69 } else if (!strcmp(f, "expm1")) {
70 a->a = A__FLT_RND__FLT;
71 a->f.flt_rnd__flt = mpfr_expm1;
72 } else if (!strcmp(f, "fma")) {
73 a->a = A__FLT_FLT_FLT_RND__FLT;
74 a->f.flt_flt_flt_rnd__flt = mpfr_fma;
75 } else if (!strcmp(f, "log")) {
76 a->a = A__FLT_RND__FLT;
77 a->f.flt_rnd__flt = mpfr_log;
78 } else if (!strcmp(f, "log1p")) {
79 a->a = A__FLT_RND__FLT;
80 a->f.flt_rnd__flt = mpfr_log1p;
81 } else if (!strcmp(f, "sin")) {
82 a->a = A__FLT_RND__FLT;
83 a->f.flt_rnd__flt = mpfr_sin;
84 } else if (!strcmp(f, "sqrt")) {
85 a->a = A__FLT_RND__FLT;
86 a->f.flt_rnd__flt = mpfr_sqrt;
87 } else if (!strcmp(f, "trunc")) {
88 a->a = A__FLT__FLT;
89 a->f.flt__flt = mpfr_trunc;
90 } else {
91 fprintf(stderr, "impl-mpfr: unknown function \"%s\"\n", f);
92 _exit(1);
96 void read_buf(char *b, ssize_t len)
98 ssize_t r;
99 ssize_t total = 0;
101 while (total < len) {
102 r = read(0, (b + total), (len - total));
104 if (!r) {
105 /* EOF */
106 _exit(0);
107 } else if (r == -1) {
108 perror("impl-mpfr: read");
109 _exit(1);
110 } else {
111 total += r;
116 void write_buf(const char *b, ssize_t len)
118 ssize_t r;
119 ssize_t total = 0;
121 while (total < len) {
122 r = write(1, (b + total), (len - total));
124 if (r == -1) {
125 perror("impl-mpfr: write");
126 _exit(1);
127 } else {
128 total += r;
133 size_t input_width(argtype a, precision p)
135 size_t w = (p == P_SINGLE) ? 4 : 8;
137 switch (a) {
138 case A_UNKNOWN:
139 break;
140 case A__FLT__FLT:
142 return 1 * w;
143 case A__FLT_FLT_FLT_RND__FLT:
145 return 3 * w;
146 case A__FLT_RND__FLT:
148 return 1 * w;
151 return (size_t) -1;
154 size_t output_width(argtype a, precision p)
156 size_t w = (p == P_SINGLE) ? 4 : 8;
158 switch (a) {
159 case A_UNKNOWN:
160 break;
161 case A__FLT__FLT:
163 return 1 * w;
164 case A__FLT_FLT_FLT_RND__FLT:
166 return 1 * w;
167 case A__FLT_RND__FLT:
169 return 1 * w;
172 return (size_t) -1;
175 void io_loop(action a, size_t n)
177 char *in_buf = 0;
178 char *out_buf = 0;
179 size_t in_sz = input_width(a.a, a.p);
180 size_t out_sz = output_width(a.a, a.p);
181 mpfr_t x1;
182 mpfr_t x2;
183 mpfr_t x3;
184 mpfr_t y;
186 if ((in_sz * n) / n != in_sz) {
187 fprintf(stderr, "impl-libc: input length overflow\n");
188 _exit(1);
191 if ((out_sz * n) / n != out_sz) {
192 fprintf(stderr, "impl-libc: output length overflow\n");
193 _exit(1);
196 if (!(in_buf = malloc(in_sz * n))) {
197 perror("impl-libc: malloc");
198 _exit(1);
201 if (!(out_buf = malloc(out_sz * n))) {
202 perror("impl-libc: malloc");
203 _exit(1);
206 /* I'm pretty sure 53 precision would be enough */
207 mpfr_init2(x1, 75);
208 mpfr_init2(x2, 75);
209 mpfr_init2(x3, 75);
210 mpfr_init2(y, 75);
212 while (1) {
213 read_buf(in_buf, in_sz * n);
215 switch (a.a) {
216 case A_UNKNOWN:
217 fprintf(stderr, "impl-libc: impossible\n");
218 _exit(1);
219 break;
220 case A__FLT__FLT:
222 switch (a.p) {
223 case P_SINGLE:
225 for (size_t j = 0; j < n; ++j) {
226 float *xf1 = (float *) (in_buf +
227 (in_sz * j));
228 float *yf = (float *) (out_buf +
229 (out_sz * j));
231 mpfr_set_flt(x1, *xf1, MPFR_RNDN);
232 a.f.flt__flt(y, x1);
233 *yf = mpfr_get_flt(y, MPFR_RNDN);
236 break;
237 case P_DOUBLE:
239 for (size_t j = 0; j < n; ++j) {
240 double *xf1 = (double *) (in_buf +
241 (in_sz * j));
242 double *yf = (double *) (out_buf +
243 (out_sz * j));
245 mpfr_set_d(x1, *xf1, MPFR_RNDN);
246 a.f.flt__flt(y, x1);
247 *yf = mpfr_get_d(y, MPFR_RNDN);
250 break;
253 break;
254 case A__FLT_FLT_FLT_RND__FLT:
256 switch (a.p) {
257 case P_SINGLE:
259 for (size_t j = 0; j < n; ++j) {
260 float *xf1 = (float *) (in_buf +
261 (in_sz * j));
262 float *xf2 = (float *) (in_buf +
263 (in_sz * j) +
265 float *xf3 = (float *) (in_buf +
266 (in_sz * j) +
268 float *yf = (float *) (out_buf +
269 (out_sz * j));
271 mpfr_set_flt(x1, *xf1, MPFR_RNDN);
272 mpfr_set_flt(x2, *xf2, MPFR_RNDN);
273 mpfr_set_flt(x3, *xf3, MPFR_RNDN);
274 a.f.flt_flt_flt_rnd__flt(y, x1, x2, x3,
275 MPFR_RNDN);
276 *yf = mpfr_get_flt(y, MPFR_RNDN);
279 break;
280 case P_DOUBLE:
282 for (size_t j = 0; j < n; ++j) {
283 double *xf1 = (double *) (in_buf +
284 (in_sz * j));
285 double *xf2 = (double *) (in_buf +
286 (in_sz * j) +
288 double *xf3 = (double *) (in_buf +
289 (in_sz * j) +
290 16);
291 double *yf = (double *) (out_buf +
292 (out_sz * j));
294 mpfr_set_d(x1, *xf1, MPFR_RNDN);
295 mpfr_set_d(x2, *xf2, MPFR_RNDN);
296 mpfr_set_d(x3, *xf3, MPFR_RNDN);
297 a.f.flt_flt_flt_rnd__flt(y, x1, x2, x3,
298 MPFR_RNDN);
299 *yf = mpfr_get_d(y, MPFR_RNDN);
302 break;
305 break;
306 case A__FLT_RND__FLT:
308 switch (a.p) {
309 case P_SINGLE:
311 for (size_t j = 0; j < n; ++j) {
312 float *xf1 = (float *) (in_buf +
313 (in_sz * j));
314 float *yf = (float *) (out_buf +
315 (out_sz * j));
317 mpfr_set_flt(x1, *xf1, MPFR_RNDN);
318 a.f.flt_rnd__flt(y, x1, MPFR_RNDN);
319 *yf = mpfr_get_flt(y, MPFR_RNDN);
322 break;
323 case P_DOUBLE:
325 for (size_t j = 0; j < n; ++j) {
326 double *xf1 = (double *) (in_buf +
327 (in_sz * j));
328 double *yf = (double *) (out_buf +
329 (out_sz * j));
331 mpfr_set_d(x1, *xf1, MPFR_RNDN);
332 a.f.flt_rnd__flt(y, x1, MPFR_RNDN);
333 *yf = mpfr_get_d(y, MPFR_RNDN);
336 break;
339 break;
342 write_buf(out_buf, out_sz * n);
346 int main(int argc, char **argv)
348 int c = 0;
349 action a = { .p = P_SINGLE };
350 long long n = 0;
352 while ((c = getopt(argc, argv, "sdf:n:")) != -1) {
353 switch (c) {
354 case 's':
355 a.p = P_SINGLE;
356 break;
357 case 'd':
358 a.p = P_DOUBLE;
359 break;
360 case 'f':
361 determine_function(optarg, &a);
362 break;
363 case 'n':
364 errno = 0;
365 n = strtoll(optarg, 0, 0);
367 if (errno) {
368 perror("impl-libc: unparsable");
370 return 1;
373 break;
374 default:
375 usage();
376 break;
380 if (a.a == A_UNKNOWN) {
381 usage();
384 io_loop(a, n);