Improve accuracy of SIMD exp for small args
[gromacs.git] / src / gromacs / simd / impl_arm_neon_asimd / impl_arm_neon_asimd_simd_double.h
blob64dda47c14cb65172a3f77d5f61bcf1dba11c650
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 #ifndef GMX_SIMD_IMPL_ARM_NEON_ASIMD_SIMD_DOUBLE_H
37 #define GMX_SIMD_IMPL_ARM_NEON_ASIMD_SIMD_DOUBLE_H
39 #include "config.h"
41 #include <cassert>
43 #include <arm_neon.h>
45 #include "gromacs/math/utilities.h"
47 #include "impl_arm_neon_asimd_simd_float.h"
49 namespace gmx
52 class SimdDouble
54 public:
55 SimdDouble() {}
57 SimdDouble(double d) : simdInternal_(vdupq_n_f64(d)) {}
59 // Internal utility constructor to simplify return statements
60 SimdDouble(float64x2_t simd) : simdInternal_(simd) {}
62 float64x2_t simdInternal_;
65 class SimdDInt32
67 public:
68 SimdDInt32() {}
70 SimdDInt32(std::int32_t i) : simdInternal_(vdup_n_s32(i)) {}
72 // Internal utility constructor to simplify return statements
73 SimdDInt32(int32x2_t simd) : simdInternal_(simd) {}
75 int32x2_t simdInternal_;
78 class SimdDBool
80 public:
81 SimdDBool() {}
83 SimdDBool(bool b) : simdInternal_(vdupq_n_u64( b ? 0xFFFFFFFFFFFFFFFF : 0)) {}
85 // Internal utility constructor to simplify return statements
86 SimdDBool(uint64x2_t simd) : simdInternal_(simd) {}
88 uint64x2_t simdInternal_;
91 class SimdDIBool
93 public:
94 SimdDIBool() {}
96 SimdDIBool(bool b) : simdInternal_(vdup_n_u32( b ? 0xFFFFFFFF : 0)) {}
98 // Internal utility constructor to simplify return statements
99 SimdDIBool(uint32x2_t simd) : simdInternal_(simd) {}
101 uint32x2_t simdInternal_;
104 static inline SimdDouble gmx_simdcall
105 simdLoad(const double *m)
107 assert(std::size_t(m) % 16 == 0);
108 return {
109 vld1q_f64(m)
113 static inline void gmx_simdcall
114 store(double *m, SimdDouble a)
116 assert(std::size_t(m) % 16 == 0);
117 vst1q_f64(m, a.simdInternal_);
120 static inline SimdDouble gmx_simdcall
121 simdLoadU(const double *m)
123 return {
124 vld1q_f64(m)
128 static inline void gmx_simdcall
129 storeU(double *m, SimdDouble a)
131 vst1q_f64(m, a.simdInternal_);
134 static inline SimdDouble gmx_simdcall
135 setZeroD()
137 return {
138 vdupq_n_f64(0.0)
142 static inline SimdDInt32 gmx_simdcall
143 simdLoadDI(const std::int32_t * m)
145 assert(std::size_t(m) % 8 == 0);
146 return {
147 vld1_s32(m)
151 static inline void gmx_simdcall
152 store(std::int32_t * m, SimdDInt32 a)
154 assert(std::size_t(m) % 8 == 0);
155 vst1_s32(m, a.simdInternal_);
158 static inline SimdDInt32 gmx_simdcall
159 simdLoadUDI(const std::int32_t *m)
161 return {
162 vld1_s32(m)
166 static inline void gmx_simdcall
167 storeU(std::int32_t * m, SimdDInt32 a)
169 vst1_s32(m, a.simdInternal_);
172 static inline SimdDInt32 gmx_simdcall
173 setZeroDI()
175 return {
176 vdup_n_s32(0)
180 template<int index> gmx_simdcall
181 static inline std::int32_t
182 extract(SimdDInt32 a)
184 return vget_lane_s32(a.simdInternal_, index);
187 static inline SimdDouble gmx_simdcall
188 operator&(SimdDouble a, SimdDouble b)
190 return {
191 float64x2_t(vandq_s64(int64x2_t(a.simdInternal_), int64x2_t(b.simdInternal_)))
195 static inline SimdDouble gmx_simdcall
196 andNot(SimdDouble a, SimdDouble b)
198 return {
199 float64x2_t(vbicq_s64(int64x2_t(b.simdInternal_), int64x2_t(a.simdInternal_)))
203 static inline SimdDouble gmx_simdcall
204 operator|(SimdDouble a, SimdDouble b)
206 return {
207 float64x2_t(vorrq_s64(int64x2_t(a.simdInternal_), int64x2_t(b.simdInternal_)))
211 static inline SimdDouble gmx_simdcall
212 operator^(SimdDouble a, SimdDouble b)
214 return {
215 float64x2_t(veorq_s64(int64x2_t(a.simdInternal_), int64x2_t(b.simdInternal_)))
219 static inline SimdDouble gmx_simdcall
220 operator+(SimdDouble a, SimdDouble b)
222 return {
223 vaddq_f64(a.simdInternal_, b.simdInternal_)
227 static inline SimdDouble gmx_simdcall
228 operator-(SimdDouble a, SimdDouble b)
230 return {
231 vsubq_f64(a.simdInternal_, b.simdInternal_)
235 static inline SimdDouble gmx_simdcall
236 operator-(SimdDouble x)
238 return {
239 vnegq_f64(x.simdInternal_)
243 static inline SimdDouble gmx_simdcall
244 operator*(SimdDouble a, SimdDouble b)
246 return {
247 vmulq_f64(a.simdInternal_, b.simdInternal_)
251 static inline SimdDouble gmx_simdcall
252 fma(SimdDouble a, SimdDouble b, SimdDouble c)
254 return {
255 vfmaq_f64(c.simdInternal_, b.simdInternal_, a.simdInternal_)
259 static inline SimdDouble gmx_simdcall
260 fms(SimdDouble a, SimdDouble b, SimdDouble c)
262 return {
263 vnegq_f64(vfmsq_f64(c.simdInternal_, b.simdInternal_, a.simdInternal_))
267 static inline SimdDouble gmx_simdcall
268 fnma(SimdDouble a, SimdDouble b, SimdDouble c)
270 return {
271 vfmsq_f64(c.simdInternal_, b.simdInternal_, a.simdInternal_)
275 static inline SimdDouble gmx_simdcall
276 fnms(SimdDouble a, SimdDouble b, SimdDouble c)
278 return {
279 vnegq_f64(vfmaq_f64(c.simdInternal_, b.simdInternal_, a.simdInternal_))
283 static inline SimdDouble gmx_simdcall
284 rsqrt(SimdDouble x)
286 return {
287 vrsqrteq_f64(x.simdInternal_)
291 static inline SimdDouble gmx_simdcall
292 rsqrtIter(SimdDouble lu, SimdDouble x)
294 return {
295 vmulq_f64(lu.simdInternal_, vrsqrtsq_f64(vmulq_f64(lu.simdInternal_, lu.simdInternal_), x.simdInternal_))
299 static inline SimdDouble gmx_simdcall
300 rcp(SimdDouble x)
302 return {
303 vrecpeq_f64(x.simdInternal_)
307 static inline SimdDouble gmx_simdcall
308 rcpIter(SimdDouble lu, SimdDouble x)
310 return {
311 vmulq_f64(lu.simdInternal_, vrecpsq_f64(lu.simdInternal_, x.simdInternal_))
315 static inline SimdDouble gmx_simdcall
316 maskAdd(SimdDouble a, SimdDouble b, SimdDBool m)
318 float64x2_t addend = float64x2_t(vandq_u64(uint64x2_t(b.simdInternal_), m.simdInternal_));
320 return {
321 vaddq_f64(a.simdInternal_, addend)
325 static inline SimdDouble gmx_simdcall
326 maskzMul(SimdDouble a, SimdDouble b, SimdDBool m)
328 float64x2_t prod = vmulq_f64(a.simdInternal_, b.simdInternal_);
329 return {
330 float64x2_t(vandq_u64(uint64x2_t(prod), m.simdInternal_))
334 static inline SimdDouble gmx_simdcall
335 maskzFma(SimdDouble a, SimdDouble b, SimdDouble c, SimdDBool m)
337 float64x2_t prod = vfmaq_f64(c.simdInternal_, b.simdInternal_, a.simdInternal_);
339 return {
340 float64x2_t(vandq_u64(uint64x2_t(prod), m.simdInternal_))
344 static inline SimdDouble gmx_simdcall
345 maskzRsqrt(SimdDouble x, SimdDBool m)
347 // The result will always be correct since we mask the result with m, but
348 // for debug builds we also want to make sure not to generate FP exceptions
349 #ifndef NDEBUG
350 x.simdInternal_ = vbslq_f64(m.simdInternal_, x.simdInternal_, vdupq_n_f64(1.0));
351 #endif
352 return {
353 float64x2_t(vandq_u64(uint64x2_t(vrsqrteq_f64(x.simdInternal_)), m.simdInternal_))
357 static inline SimdDouble gmx_simdcall
358 maskzRcp(SimdDouble x, SimdDBool m)
360 // The result will always be correct since we mask the result with m, but
361 // for debug builds we also want to make sure not to generate FP exceptions
362 #ifndef NDEBUG
363 x.simdInternal_ = vbslq_f64(m.simdInternal_, x.simdInternal_, vdupq_n_f64(1.0));
364 #endif
365 return {
366 float64x2_t(vandq_u64(uint64x2_t(vrecpeq_f64(x.simdInternal_)), m.simdInternal_))
370 static inline SimdDouble gmx_simdcall
371 abs(SimdDouble x)
373 return {
374 vabsq_f64( x.simdInternal_ )
378 static inline SimdDouble gmx_simdcall
379 max(SimdDouble a, SimdDouble b)
381 return {
382 vmaxq_f64(a.simdInternal_, b.simdInternal_)
386 static inline SimdDouble gmx_simdcall
387 min(SimdDouble a, SimdDouble b)
389 return {
390 vminq_f64(a.simdInternal_, b.simdInternal_)
394 static inline SimdDouble gmx_simdcall
395 round(SimdDouble x)
397 return {
398 vrndnq_f64(x.simdInternal_)
402 static inline SimdDouble gmx_simdcall
403 trunc(SimdDouble x)
405 return {
406 vrndq_f64( x.simdInternal_ )
410 static inline SimdDouble
411 frexp(SimdDouble value, SimdDInt32 * exponent)
413 const float64x2_t exponentMask = float64x2_t( vdupq_n_s64(0x7FF0000000000000LL) );
414 const float64x2_t mantissaMask = float64x2_t( vdupq_n_s64(0x800FFFFFFFFFFFFFLL) );
416 const int64x2_t exponentBias = vdupq_n_s64(1022); // add 1 to make our definition identical to frexp()
417 const float64x2_t half = vdupq_n_f64(0.5);
418 int64x2_t iExponent;
420 iExponent = vandq_s64( int64x2_t(value.simdInternal_), int64x2_t(exponentMask) );
421 iExponent = vsubq_s64(vshrq_n_s64(iExponent, 52), exponentBias);
422 exponent->simdInternal_ = vmovn_s64(iExponent);
424 return {
425 float64x2_t(vorrq_s64(vandq_s64(int64x2_t(value.simdInternal_), int64x2_t(mantissaMask)), int64x2_t(half)))
429 template <MathOptimization opt = MathOptimization::Safe>
430 static inline SimdDouble
431 ldexp(SimdDouble value, SimdDInt32 exponent)
433 const int32x2_t exponentBias = vdup_n_s32(1023);
434 int32x2_t iExponent = vadd_s32(exponent.simdInternal_, exponentBias);
435 int64x2_t iExponent64;
437 if (opt == MathOptimization::Safe)
439 // Make sure biased argument is not negative
440 iExponent = vmax_s32(iExponent, vdup_n_s32(0));
443 iExponent64 = vmovl_s32(iExponent);
444 iExponent64 = vshlq_n_s64(iExponent64, 52);
446 return {
447 vmulq_f64(value.simdInternal_, float64x2_t(iExponent64))
451 static inline double gmx_simdcall
452 reduce(SimdDouble a)
454 float64x2_t b = vpaddq_f64(a.simdInternal_, a.simdInternal_);
455 return vgetq_lane_f64(b, 0);
458 static inline SimdDBool gmx_simdcall
459 operator==(SimdDouble a, SimdDouble b)
461 return {
462 vceqq_f64(a.simdInternal_, b.simdInternal_)
466 static inline SimdDBool gmx_simdcall
467 operator!=(SimdDouble a, SimdDouble b)
469 return {
470 vreinterpretq_u64_u32(vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(a.simdInternal_, b.simdInternal_))))
474 static inline SimdDBool gmx_simdcall
475 operator<(SimdDouble a, SimdDouble b)
477 return {
478 vcltq_f64(a.simdInternal_, b.simdInternal_)
482 static inline SimdDBool gmx_simdcall
483 operator<=(SimdDouble a, SimdDouble b)
485 return {
486 vcleq_f64(a.simdInternal_, b.simdInternal_)
490 static inline SimdDBool gmx_simdcall
491 testBits(SimdDouble a)
493 return {
494 vtstq_s64( int64x2_t(a.simdInternal_), int64x2_t(a.simdInternal_) )
498 static inline SimdDBool gmx_simdcall
499 operator&&(SimdDBool a, SimdDBool b)
501 return {
502 vandq_u64(a.simdInternal_, b.simdInternal_)
506 static inline SimdDBool gmx_simdcall
507 operator||(SimdDBool a, SimdDBool b)
509 return {
510 vorrq_u64(a.simdInternal_, b.simdInternal_)
514 static inline bool gmx_simdcall
515 anyTrue(SimdDBool a)
517 return (vmaxvq_u32((uint32x4_t)(a.simdInternal_)) != 0);
520 static inline SimdDouble gmx_simdcall
521 selectByMask(SimdDouble a, SimdDBool m)
523 return {
524 float64x2_t(vandq_u64(uint64x2_t(a.simdInternal_), m.simdInternal_))
528 static inline SimdDouble gmx_simdcall
529 selectByNotMask(SimdDouble a, SimdDBool m)
531 return {
532 float64x2_t(vbicq_u64(uint64x2_t(a.simdInternal_), m.simdInternal_))
536 static inline SimdDouble gmx_simdcall
537 blend(SimdDouble a, SimdDouble b, SimdDBool sel)
539 return {
540 vbslq_f64(sel.simdInternal_, b.simdInternal_, a.simdInternal_)
544 static inline SimdDInt32 gmx_simdcall
545 operator<<(SimdDInt32 a, int n)
547 return {
548 vshl_s32(a.simdInternal_, vdup_n_s32(n >= 32 ? 32 : n))
552 static inline SimdDInt32 gmx_simdcall
553 operator>>(SimdDInt32 a, int n)
555 return {
556 vshl_s32(a.simdInternal_, vdup_n_s32(n >= 32 ? -32 : -n))
560 static inline SimdDInt32 gmx_simdcall
561 operator&(SimdDInt32 a, SimdDInt32 b)
563 return {
564 vand_s32(a.simdInternal_, b.simdInternal_)
568 static inline SimdDInt32 gmx_simdcall
569 andNot(SimdDInt32 a, SimdDInt32 b)
571 return {
572 vbic_s32(b.simdInternal_, a.simdInternal_)
576 static inline SimdDInt32 gmx_simdcall
577 operator|(SimdDInt32 a, SimdDInt32 b)
579 return {
580 vorr_s32(a.simdInternal_, b.simdInternal_)
584 static inline SimdDInt32 gmx_simdcall
585 operator^(SimdDInt32 a, SimdDInt32 b)
587 return {
588 veor_s32(a.simdInternal_, b.simdInternal_)
592 static inline SimdDInt32 gmx_simdcall
593 operator+(SimdDInt32 a, SimdDInt32 b)
595 return {
596 vadd_s32(a.simdInternal_, b.simdInternal_)
600 static inline SimdDInt32 gmx_simdcall
601 operator-(SimdDInt32 a, SimdDInt32 b)
603 return {
604 vsub_s32(a.simdInternal_, b.simdInternal_)
608 static inline SimdDInt32 gmx_simdcall
609 operator*(SimdDInt32 a, SimdDInt32 b)
611 return {
612 vmul_s32(a.simdInternal_, b.simdInternal_)
616 static inline SimdDIBool gmx_simdcall
617 operator==(SimdDInt32 a, SimdDInt32 b)
619 return {
620 vceq_s32(a.simdInternal_, b.simdInternal_)
624 static inline SimdDIBool gmx_simdcall
625 testBits(SimdDInt32 a)
627 return {
628 vtst_s32( a.simdInternal_, a.simdInternal_)
632 static inline SimdDIBool gmx_simdcall
633 operator<(SimdDInt32 a, SimdDInt32 b)
635 return {
636 vclt_s32(a.simdInternal_, b.simdInternal_)
640 static inline SimdDIBool gmx_simdcall
641 operator&&(SimdDIBool a, SimdDIBool b)
643 return {
644 vand_u32(a.simdInternal_, b.simdInternal_)
648 static inline SimdDIBool gmx_simdcall
649 operator||(SimdDIBool a, SimdDIBool b)
651 return {
652 vorr_u32(a.simdInternal_, b.simdInternal_)
656 static inline bool gmx_simdcall
657 anyTrue(SimdDIBool a)
659 return (vmaxv_u32(a.simdInternal_) != 0);
662 static inline SimdDInt32 gmx_simdcall
663 selectByMask(SimdDInt32 a, SimdDIBool m)
665 return {
666 vand_s32(a.simdInternal_, vreinterpret_s32_u32(m.simdInternal_))
670 static inline SimdDInt32 gmx_simdcall
671 selectByNotMask(SimdDInt32 a, SimdDIBool m)
673 return {
674 vbic_s32(a.simdInternal_, vreinterpret_s32_u32(m.simdInternal_))
678 static inline SimdDInt32 gmx_simdcall
679 blend(SimdDInt32 a, SimdDInt32 b, SimdDIBool sel)
681 return {
682 vbsl_s32(sel.simdInternal_, b.simdInternal_, a.simdInternal_)
686 static inline SimdDInt32 gmx_simdcall
687 cvtR2I(SimdDouble a)
689 return {
690 vmovn_s64(vcvtnq_s64_f64(a.simdInternal_))
694 static inline SimdDInt32 gmx_simdcall
695 cvttR2I(SimdDouble a)
697 return {
698 vmovn_s64(vcvtq_s64_f64(a.simdInternal_))
702 static inline SimdDouble gmx_simdcall
703 cvtI2R(SimdDInt32 a)
705 return {
706 vcvtq_f64_s64(vmovl_s32(a.simdInternal_))
710 static inline SimdDIBool gmx_simdcall
711 cvtB2IB(SimdDBool a)
713 return {
714 vqmovn_u64(a.simdInternal_)
718 static inline SimdDBool gmx_simdcall
719 cvtIB2B(SimdDIBool a)
721 return {
722 vorrq_u64(vmovl_u32(a.simdInternal_), vshlq_n_u64(vmovl_u32(a.simdInternal_), 32))
726 static inline void gmx_simdcall
727 cvtF2DD(SimdFloat f, SimdDouble *d0, SimdDouble *d1)
729 d0->simdInternal_ = vcvt_f64_f32(vget_low_f32(f.simdInternal_));
730 d1->simdInternal_ = vcvt_high_f64_f32(f.simdInternal_);
733 static inline SimdFloat gmx_simdcall
734 cvtDD2F(SimdDouble d0, SimdDouble d1)
736 return {
737 vcvt_high_f32_f64(vcvt_f32_f64(d0.simdInternal_), d1.simdInternal_)
741 } // namespace gmx
743 #endif // GMX_SIMD_IMPL_ARM_NEON_ASIMD_SIMD_DOUBLE_H