support sqrt
[fpmath-consensus.git] / impl-mpfr / impl-mpfr.c
blob991d48f615a5e97b53e04937e76a2df4fc8c6d92
1 #include <errno.h>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <string.h>
6 #include <unistd.h>
8 #include <mpfr.h>
10 /* Whether we're looking at 32-bit or 64-bit floats */
11 typedef enum { P_SINGLE, P_DOUBLE } precision;
13 /* What type of arguments we expect, and what we'll give back. */
14 typedef enum {
15 /* */
16 A_UNKNOWN,
17 A__FLT__FLT,
18 A__FLT_FLT_FLT_RND__FLT,
19 A__FLT_RND__FLT,
20 } argtype;
22 /* Types of functions we could call */
23 typedef int (*f__flt__flt)(mpfr_ptr, mpfr_srcptr);
24 typedef int (*f__flt_flt_flt_rnd__flt)(mpfr_ptr, mpfr_srcptr, mpfr_srcptr,
25 mpfr_srcptr, mpfr_rnd_t);
26 typedef int (*f__flt_rnd__flt)(mpfr_ptr, mpfr_srcptr, mpfr_rnd_t);
28 /* Wrapper around a function pointer */
29 typedef struct {
30 /* */
31 precision p;
32 argtype a;
34 union {
35 /* */
36 f__flt__flt flt__flt;
37 f__flt_flt_flt_rnd__flt flt_flt_flt_rnd__flt;
38 f__flt_rnd__flt flt_rnd__flt;
39 } f;
41 } action;
43 void usage(void)
45 fprintf(stderr,
46 "usage: impl-mpfr [-s|-d] -f <function_name> -n <num_inputs>\n");
47 _exit(1);
50 void determine_function(const char *f, action *a)
52 if (!strcmp(f, "zzzzzzzzz")) {
53 a->a = A_UNKNOWN;
54 } else if (!strcmp(f, "id")) {
55 a->a = A__FLT_RND__FLT;
56 a->f.flt_rnd__flt = mpfr_set;
57 } else if (!strcmp(f, "ceil")) {
58 a->a = A__FLT__FLT;
59 a->f.flt__flt = mpfr_ceil;
60 } else if (!strcmp(f, "cos")) {
61 a->a = A__FLT_RND__FLT;
62 a->f.flt_rnd__flt = mpfr_cos;
63 } else if (!strcmp(f, "floor")) {
64 a->a = A__FLT__FLT;
65 a->f.flt__flt = mpfr_floor;
66 } else if (!strcmp(f, "sin")) {
67 a->a = A__FLT_RND__FLT;
68 a->f.flt_rnd__flt = mpfr_sin;
69 } else if (!strcmp(f, "trunc")) {
70 a->a = A__FLT__FLT;
71 a->f.flt__flt = mpfr_trunc;
72 } else if (!strcmp(f, "fma")) {
73 a->a = A__FLT_FLT_FLT_RND__FLT;
74 a->f.flt_flt_flt_rnd__flt = mpfr_fma;
75 } else if (!strcmp(f, "sqrt")) {
76 a->a = A__FLT_RND__FLT;
77 a->f.flt_rnd__flt = mpfr_sqrt;
78 } else {
79 fprintf(stderr, "impl-mpfr: unknown function \"%s\"\n", f);
80 _exit(1);
84 void read_buf(char *b, ssize_t len)
86 ssize_t r;
87 ssize_t total = 0;
89 while (total < len) {
90 r = read(0, (b + total), (len - total));
92 if (!r) {
93 /* EOF */
94 _exit(0);
95 } else if (r == -1) {
96 perror("impl-mpfr: read");
97 _exit(1);
98 } else {
99 total += r;
104 void write_buf(const char *b, ssize_t len)
106 ssize_t r;
107 ssize_t total = 0;
109 while (total < len) {
110 r = write(1, (b + total), (len - total));
112 if (r == -1) {
113 perror("impl-mpfr: write");
114 _exit(1);
115 } else {
116 total += r;
121 size_t input_width(argtype a, precision p)
123 size_t w = (p == P_SINGLE) ? 4 : 8;
125 switch (a) {
126 case A_UNKNOWN:
127 break;
128 case A__FLT__FLT:
130 return 1 * w;
131 case A__FLT_FLT_FLT_RND__FLT:
133 return 3 * w;
134 case A__FLT_RND__FLT:
136 return 1 * w;
139 return (size_t) -1;
142 size_t output_width(argtype a, precision p)
144 size_t w = (p == P_SINGLE) ? 4 : 8;
146 switch (a) {
147 case A_UNKNOWN:
148 break;
149 case A__FLT__FLT:
151 return 1 * w;
152 case A__FLT_FLT_FLT_RND__FLT:
154 return 1 * w;
155 case A__FLT_RND__FLT:
157 return 1 * w;
160 return (size_t) -1;
163 void io_loop(action a, size_t n)
165 char *in_buf = 0;
166 char *out_buf = 0;
167 size_t in_sz = input_width(a.a, a.p);
168 size_t out_sz = output_width(a.a, a.p);
169 mpfr_t x1;
170 mpfr_t x2;
171 mpfr_t x3;
172 mpfr_t y;
174 if ((in_sz * n) / n != in_sz) {
175 fprintf(stderr, "impl-libc: input length overflow\n");
176 _exit(1);
179 if ((out_sz * n) / n != out_sz) {
180 fprintf(stderr, "impl-libc: output length overflow\n");
181 _exit(1);
184 if (!(in_buf = malloc(in_sz * n))) {
185 perror("impl-libc: malloc");
186 _exit(1);
189 if (!(out_buf = malloc(out_sz * n))) {
190 perror("impl-libc: malloc");
191 _exit(1);
194 /* I'm pretty sure 53 precision would be enough */
195 mpfr_init2(x1, 75);
196 mpfr_init2(x2, 75);
197 mpfr_init2(x3, 75);
198 mpfr_init2(y, 75);
200 while (1) {
201 read_buf(in_buf, in_sz * n);
203 switch (a.a) {
204 case A_UNKNOWN:
205 fprintf(stderr, "impl-libc: impossible\n");
206 _exit(1);
207 break;
208 case A__FLT__FLT:
210 switch (a.p) {
211 case P_SINGLE:
213 for (size_t j = 0; j < n; ++j) {
214 float *xf1 = (float *) (in_buf +
215 (in_sz * j));
216 float *yf = (float *) (out_buf +
217 (out_sz * j));
219 mpfr_set_flt(x1, *xf1, MPFR_RNDN);
220 a.f.flt__flt(y, x1);
221 *yf = mpfr_get_flt(y, MPFR_RNDN);
224 break;
225 case P_DOUBLE:
227 for (size_t j = 0; j < n; ++j) {
228 double *xf1 = (double *) (in_buf +
229 (in_sz * j));
230 double *yf = (double *) (out_buf +
231 (out_sz * j));
233 mpfr_set_d(x1, *xf1, MPFR_RNDN);
234 a.f.flt__flt(y, x1);
235 *yf = mpfr_get_d(y, MPFR_RNDN);
238 break;
241 break;
242 case A__FLT_FLT_FLT_RND__FLT:
244 switch (a.p) {
245 case P_SINGLE:
247 for (size_t j = 0; j < n; ++j) {
248 float *xf1 = (float *) (in_buf +
249 (in_sz * j));
250 float *xf2 = (float *) (in_buf +
251 (in_sz * j) +
253 float *xf3 = (float *) (in_buf +
254 (in_sz * j) +
256 float *yf = (float *) (out_buf +
257 (out_sz * j));
259 mpfr_set_flt(x1, *xf1, MPFR_RNDN);
260 mpfr_set_flt(x2, *xf2, MPFR_RNDN);
261 mpfr_set_flt(x3, *xf3, MPFR_RNDN);
262 a.f.flt_flt_flt_rnd__flt(y, x1, x2, x3,
263 MPFR_RNDN);
264 *yf = mpfr_get_flt(y, MPFR_RNDN);
267 break;
268 case P_DOUBLE:
270 for (size_t j = 0; j < n; ++j) {
271 double *xf1 = (double *) (in_buf +
272 (in_sz * j));
273 double *xf2 = (double *) (in_buf +
274 (in_sz * j) +
276 double *xf3 = (double *) (in_buf +
277 (in_sz * j) +
278 16);
279 double *yf = (double *) (out_buf +
280 (out_sz * j));
282 mpfr_set_d(x1, *xf1, MPFR_RNDN);
283 mpfr_set_d(x2, *xf2, MPFR_RNDN);
284 mpfr_set_d(x3, *xf3, MPFR_RNDN);
285 a.f.flt_flt_flt_rnd__flt(y, x1, x2, x3,
286 MPFR_RNDN);
287 *yf = mpfr_get_d(y, MPFR_RNDN);
290 break;
293 break;
294 case A__FLT_RND__FLT:
296 switch (a.p) {
297 case P_SINGLE:
299 for (size_t j = 0; j < n; ++j) {
300 float *xf1 = (float *) (in_buf +
301 (in_sz * j));
302 float *yf = (float *) (out_buf +
303 (out_sz * j));
305 mpfr_set_flt(x1, *xf1, MPFR_RNDN);
306 a.f.flt_rnd__flt(y, x1, MPFR_RNDN);
307 *yf = mpfr_get_flt(y, MPFR_RNDN);
310 break;
311 case P_DOUBLE:
313 for (size_t j = 0; j < n; ++j) {
314 double *xf1 = (double *) (in_buf +
315 (in_sz * j));
316 double *yf = (double *) (out_buf +
317 (out_sz * j));
319 mpfr_set_d(x1, *xf1, MPFR_RNDN);
320 a.f.flt_rnd__flt(y, x1, MPFR_RNDN);
321 *yf = mpfr_get_d(y, MPFR_RNDN);
324 break;
327 break;
330 write_buf(out_buf, out_sz * n);
334 int main(int argc, char **argv)
336 int c = 0;
337 action a = { .p = P_SINGLE };
338 long long n = 0;
340 while ((c = getopt(argc, argv, "sdf:n:")) != -1) {
341 switch (c) {
342 case 's':
343 a.p = P_SINGLE;
344 break;
345 case 'd':
346 a.p = P_DOUBLE;
347 break;
348 case 'f':
349 determine_function(optarg, &a);
350 break;
351 case 'n':
352 errno = 0;
353 n = strtoll(optarg, 0, 0);
355 if (errno) {
356 perror("impl-libc: unparsable");
358 return 1;
361 break;
362 default:
363 usage();
364 break;
368 if (a.a == A_UNKNOWN) {
369 usage();
372 io_loop(a, n);