Improve accuracy of SIMD exp for small args
[gromacs.git] / src / gromacs / simd / impl_arm_neon / impl_arm_neon_simd_float.h
blobeb9367b503f1b337e863aff4c9115fa82be2b04a
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
35 #ifndef GMX_SIMD_IMPL_ARM_NEON_SIMD_FLOAT_H
36 #define GMX_SIMD_IMPL_ARM_NEON_SIMD_FLOAT_H
38 #include "config.h"
40 #include <cassert>
41 #include <cstddef>
42 #include <cstdint>
44 #include <arm_neon.h>
46 #include "gromacs/math/utilities.h"
48 namespace gmx
51 class SimdFloat
53 public:
54 SimdFloat() {}
56 SimdFloat(float f) : simdInternal_(vdupq_n_f32(f)) {}
58 // Internal utility constructor to simplify return statements
59 SimdFloat(float32x4_t simd) : simdInternal_(simd) {}
61 float32x4_t simdInternal_;
64 class SimdFInt32
66 public:
67 SimdFInt32() {}
69 SimdFInt32(std::int32_t i) : simdInternal_(vdupq_n_s32(i)) {}
71 // Internal utility constructor to simplify return statements
72 SimdFInt32(int32x4_t simd) : simdInternal_(simd) {}
74 int32x4_t simdInternal_;
77 class SimdFBool
79 public:
80 SimdFBool() {}
82 SimdFBool(bool b) : simdInternal_(vdupq_n_u32( b ? 0xFFFFFFFF : 0)) {}
84 // Internal utility constructor to simplify return statements
85 SimdFBool(uint32x4_t simd) : simdInternal_(simd) {}
87 uint32x4_t simdInternal_;
90 class SimdFIBool
92 public:
93 SimdFIBool() {}
95 SimdFIBool(bool b) : simdInternal_(vdupq_n_u32( b ? 0xFFFFFFFF : 0)) {}
97 // Internal utility constructor to simplify return statements
98 SimdFIBool(uint32x4_t simd) : simdInternal_(simd) {}
100 uint32x4_t simdInternal_;
103 static inline SimdFloat gmx_simdcall
104 simdLoad(const float *m)
106 assert(std::size_t(m) % 16 == 0);
107 return {
108 vld1q_f32(m)
112 static inline void gmx_simdcall
113 store(float *m, SimdFloat a)
115 assert(std::size_t(m) % 16 == 0);
116 vst1q_f32(m, a.simdInternal_);
119 static inline SimdFloat gmx_simdcall
120 simdLoadU(const float *m)
122 return {
123 vld1q_f32(m)
127 static inline void gmx_simdcall
128 storeU(float *m, SimdFloat a)
130 vst1q_f32(m, a.simdInternal_);
133 static inline SimdFloat gmx_simdcall
134 setZeroF()
136 return {
137 vdupq_n_f32(0.0f)
141 static inline SimdFInt32 gmx_simdcall
142 simdLoadFI(const std::int32_t * m)
144 assert(std::size_t(m) % 16 == 0);
145 return {
146 vld1q_s32(m)
150 static inline void gmx_simdcall
151 store(std::int32_t * m, SimdFInt32 a)
153 assert(std::size_t(m) % 16 == 0);
154 vst1q_s32(m, a.simdInternal_);
157 static inline SimdFInt32 gmx_simdcall
158 simdLoadUFI(const std::int32_t *m)
160 return {
161 vld1q_s32(m)
165 static inline void gmx_simdcall
166 storeU(std::int32_t * m, SimdFInt32 a)
168 vst1q_s32(m, a.simdInternal_);
171 static inline SimdFInt32 gmx_simdcall
172 setZeroFI()
174 return {
175 vdupq_n_s32(0)
179 template<int index> gmx_simdcall
180 static inline std::int32_t
181 extract(SimdFInt32 a)
183 return vgetq_lane_s32(a.simdInternal_, index);
186 static inline SimdFloat gmx_simdcall
187 operator&(SimdFloat a, SimdFloat b)
189 return {
190 vreinterpretq_f32_s32(vandq_s32(vreinterpretq_s32_f32(a.simdInternal_),
191 vreinterpretq_s32_f32(b.simdInternal_)))
195 static inline SimdFloat gmx_simdcall
196 andNot(SimdFloat a, SimdFloat b)
198 return {
199 vreinterpretq_f32_s32(vbicq_s32(vreinterpretq_s32_f32(b.simdInternal_),
200 vreinterpretq_s32_f32(a.simdInternal_)))
204 static inline SimdFloat gmx_simdcall
205 operator|(SimdFloat a, SimdFloat b)
207 return {
208 vreinterpretq_f32_s32(vorrq_s32(vreinterpretq_s32_f32(a.simdInternal_),
209 vreinterpretq_s32_f32(b.simdInternal_)))
213 static inline SimdFloat gmx_simdcall
214 operator^(SimdFloat a, SimdFloat b)
216 return {
217 vreinterpretq_f32_s32(veorq_s32(vreinterpretq_s32_f32(a.simdInternal_),
218 vreinterpretq_s32_f32(b.simdInternal_)))
222 static inline SimdFloat gmx_simdcall
223 operator+(SimdFloat a, SimdFloat b)
225 return {
226 vaddq_f32(a.simdInternal_, b.simdInternal_)
230 static inline SimdFloat gmx_simdcall
231 operator-(SimdFloat a, SimdFloat b)
233 return {
234 vsubq_f32(a.simdInternal_, b.simdInternal_)
238 static inline SimdFloat gmx_simdcall
239 operator-(SimdFloat x)
241 return {
242 vnegq_f32(x.simdInternal_)
246 static inline SimdFloat gmx_simdcall
247 operator*(SimdFloat a, SimdFloat b)
249 return {
250 vmulq_f32(a.simdInternal_, b.simdInternal_)
254 // Override for Neon-Asimd
255 #if GMX_SIMD_ARM_NEON
256 static inline SimdFloat gmx_simdcall
257 fma(SimdFloat a, SimdFloat b, SimdFloat c)
259 return {
260 #ifdef __ARM_FEATURE_FMA
261 vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)
262 #else
263 vmlaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)
264 #endif
268 static inline SimdFloat gmx_simdcall
269 fms(SimdFloat a, SimdFloat b, SimdFloat c)
271 return {
272 #ifdef __ARM_FEATURE_FMA
273 vnegq_f32(vfmsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_))
274 #else
275 vnegq_f32(vmlsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_))
276 #endif
280 static inline SimdFloat gmx_simdcall
281 fnma(SimdFloat a, SimdFloat b, SimdFloat c)
283 return {
284 #ifdef __ARM_FEATURE_FMA
285 vfmsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)
286 #else
287 vmlsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)
288 #endif
292 static inline SimdFloat gmx_simdcall
293 fnms(SimdFloat a, SimdFloat b, SimdFloat c)
295 return {
296 #ifdef __ARM_FEATURE_FMA
297 vnegq_f32(vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_))
298 #else
299 vnegq_f32(vmlaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_))
300 #endif
303 #endif
305 static inline SimdFloat gmx_simdcall
306 rsqrt(SimdFloat x)
308 return {
309 vrsqrteq_f32(x.simdInternal_)
313 static inline SimdFloat gmx_simdcall
314 rsqrtIter(SimdFloat lu, SimdFloat x)
316 return {
317 vmulq_f32(lu.simdInternal_, vrsqrtsq_f32(vmulq_f32(lu.simdInternal_, lu.simdInternal_), x.simdInternal_))
321 static inline SimdFloat gmx_simdcall
322 rcp(SimdFloat x)
324 return {
325 vrecpeq_f32(x.simdInternal_)
329 static inline SimdFloat gmx_simdcall
330 rcpIter(SimdFloat lu, SimdFloat x)
332 return {
333 vmulq_f32(lu.simdInternal_, vrecpsq_f32(lu.simdInternal_, x.simdInternal_))
337 static inline SimdFloat gmx_simdcall
338 maskAdd(SimdFloat a, SimdFloat b, SimdFBool m)
340 b.simdInternal_ = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(b.simdInternal_),
341 m.simdInternal_));
343 return {
344 vaddq_f32(a.simdInternal_, b.simdInternal_)
348 static inline SimdFloat gmx_simdcall
349 maskzMul(SimdFloat a, SimdFloat b, SimdFBool m)
351 SimdFloat tmp = a * b;
353 return {
354 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(tmp.simdInternal_),
355 m.simdInternal_))
359 static inline SimdFloat gmx_simdcall
360 maskzFma(SimdFloat a, SimdFloat b, SimdFloat c, SimdFBool m)
362 #ifdef __ARM_FEATURE_FMA
363 float32x4_t tmp = vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_);
364 #else
365 float32x4_t tmp = vmlaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_);
366 #endif
368 return {
369 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(tmp),
370 m.simdInternal_))
374 static inline SimdFloat gmx_simdcall
375 maskzRsqrt(SimdFloat x, SimdFBool m)
377 // The result will always be correct since we mask the result with m, but
378 // for debug builds we also want to make sure not to generate FP exceptions
379 #ifndef NDEBUG
380 x.simdInternal_ = vbslq_f32(m.simdInternal_, x.simdInternal_, vdupq_n_f32(1.0f));
381 #endif
382 return {
383 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(vrsqrteq_f32(x.simdInternal_)),
384 m.simdInternal_))
388 static inline SimdFloat gmx_simdcall
389 maskzRcp(SimdFloat x, SimdFBool m)
391 // The result will always be correct since we mask the result with m, but
392 // for debug builds we also want to make sure not to generate FP exceptions
393 #ifndef NDEBUG
394 x.simdInternal_ = vbslq_f32(m.simdInternal_, x.simdInternal_, vdupq_n_f32(1.0f));
395 #endif
396 return {
397 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(vrecpeq_f32(x.simdInternal_)),
398 m.simdInternal_))
402 static inline SimdFloat gmx_simdcall
403 abs(SimdFloat x)
405 return {
406 vabsq_f32( x.simdInternal_ )
410 static inline SimdFloat gmx_simdcall
411 max(SimdFloat a, SimdFloat b)
413 return {
414 vmaxq_f32(a.simdInternal_, b.simdInternal_)
418 static inline SimdFloat gmx_simdcall
419 min(SimdFloat a, SimdFloat b)
421 return {
422 vminq_f32(a.simdInternal_, b.simdInternal_)
426 // Round and trunc operations are defined at the end of this file, since they
427 // need to use float-to-integer and integer-to-float conversions.
429 static inline SimdFloat gmx_simdcall
430 frexp(SimdFloat value, SimdFInt32 * exponent)
432 const int32x4_t exponentMask = vdupq_n_s32(0x7F800000);
433 const int32x4_t mantissaMask = vdupq_n_s32(0x807FFFFF);
434 const int32x4_t exponentBias = vdupq_n_s32(126); // add 1 to make our definition identical to frexp()
435 const float32x4_t half = vdupq_n_f32(0.5f);
436 int32x4_t iExponent;
438 iExponent = vandq_s32(vreinterpretq_s32_f32(value.simdInternal_), exponentMask);
439 iExponent = vsubq_s32(vshrq_n_s32(iExponent, 23), exponentBias);
440 exponent->simdInternal_ = iExponent;
442 return {
443 vreinterpretq_f32_s32(vorrq_s32(vandq_s32(vreinterpretq_s32_f32(value.simdInternal_),
444 mantissaMask),
445 vreinterpretq_s32_f32(half)))
449 template <MathOptimization opt = MathOptimization::Safe>
450 static inline SimdFloat gmx_simdcall
451 ldexp(SimdFloat value, SimdFInt32 exponent)
453 const int32x4_t exponentBias = vdupq_n_s32(127);
454 int32x4_t iExponent = vaddq_s32(exponent.simdInternal_, exponentBias);
456 if (opt == MathOptimization::Safe)
458 // Make sure biased argument is not negative
459 iExponent = vmaxq_s32(iExponent, vdupq_n_s32(0));
462 iExponent = vshlq_n_s32( iExponent, 23);
464 return {
465 vmulq_f32(value.simdInternal_, vreinterpretq_f32_s32(iExponent))
469 // Override for Neon-Asimd
470 #if GMX_SIMD_ARM_NEON
471 static inline float gmx_simdcall
472 reduce(SimdFloat a)
474 float32x4_t x = a.simdInternal_;
475 float32x4_t y = vextq_f32(x, x, 2);
477 x = vaddq_f32(x, y);
478 y = vextq_f32(x, x, 1);
479 x = vaddq_f32(x, y);
480 return vgetq_lane_f32(x, 0);
482 #endif
484 static inline SimdFBool gmx_simdcall
485 operator==(SimdFloat a, SimdFloat b)
487 return {
488 vceqq_f32(a.simdInternal_, b.simdInternal_)
492 static inline SimdFBool gmx_simdcall
493 operator!=(SimdFloat a, SimdFloat b)
495 return {
496 vmvnq_u32(vceqq_f32(a.simdInternal_, b.simdInternal_))
500 static inline SimdFBool gmx_simdcall
501 operator<(SimdFloat a, SimdFloat b)
503 return {
504 vcltq_f32(a.simdInternal_, b.simdInternal_)
508 static inline SimdFBool gmx_simdcall
509 operator<=(SimdFloat a, SimdFloat b)
511 return {
512 vcleq_f32(a.simdInternal_, b.simdInternal_)
516 static inline SimdFBool gmx_simdcall
517 testBits(SimdFloat a)
519 uint32x4_t tmp = vreinterpretq_u32_f32(a.simdInternal_);
521 return {
522 vtstq_u32(tmp, tmp)
526 static inline SimdFBool gmx_simdcall
527 operator&&(SimdFBool a, SimdFBool b)
530 return {
531 vandq_u32(a.simdInternal_, b.simdInternal_)
535 static inline SimdFBool gmx_simdcall
536 operator||(SimdFBool a, SimdFBool b)
538 return {
539 vorrq_u32(a.simdInternal_, b.simdInternal_)
543 // Override for Neon-Asimd
544 #if GMX_SIMD_ARM_NEON
545 static inline bool gmx_simdcall
546 anyTrue(SimdFBool a)
548 uint32x4_t x = a.simdInternal_;
549 uint32x4_t y = vextq_u32(x, x, 2);
551 x = vorrq_u32(x, y);
552 y = vextq_u32(x, x, 1);
553 x = vorrq_u32(x, y);
554 return (vgetq_lane_u32(x, 0) != 0);
556 #endif
558 static inline SimdFloat gmx_simdcall
559 selectByMask(SimdFloat a, SimdFBool m)
561 return {
562 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a.simdInternal_),
563 m.simdInternal_))
567 static inline SimdFloat gmx_simdcall
568 selectByNotMask(SimdFloat a, SimdFBool m)
570 return {
571 vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a.simdInternal_),
572 m.simdInternal_))
576 static inline SimdFloat gmx_simdcall
577 blend(SimdFloat a, SimdFloat b, SimdFBool sel)
579 return {
580 vbslq_f32(sel.simdInternal_, b.simdInternal_, a.simdInternal_)
584 static inline SimdFInt32 gmx_simdcall
585 operator<<(SimdFInt32 a, int n)
587 return {
588 vshlq_s32(a.simdInternal_, vdupq_n_s32(n >= 32 ? 32 : n))
592 static inline SimdFInt32 gmx_simdcall
593 operator>>(SimdFInt32 a, int n)
595 return {
596 vshlq_s32(a.simdInternal_, vdupq_n_s32(n >= 32 ? -32 : -n))
600 static inline SimdFInt32 gmx_simdcall
601 operator&(SimdFInt32 a, SimdFInt32 b)
603 return {
604 vandq_s32(a.simdInternal_, b.simdInternal_)
608 static inline SimdFInt32 gmx_simdcall
609 andNot(SimdFInt32 a, SimdFInt32 b)
611 return {
612 vbicq_s32(b.simdInternal_, a.simdInternal_)
616 static inline SimdFInt32 gmx_simdcall
617 operator|(SimdFInt32 a, SimdFInt32 b)
619 return {
620 vorrq_s32(a.simdInternal_, b.simdInternal_)
624 static inline SimdFInt32 gmx_simdcall
625 operator^(SimdFInt32 a, SimdFInt32 b)
627 return {
628 veorq_s32(a.simdInternal_, b.simdInternal_)
632 static inline SimdFInt32 gmx_simdcall
633 operator+(SimdFInt32 a, SimdFInt32 b)
635 return {
636 vaddq_s32(a.simdInternal_, b.simdInternal_)
640 static inline SimdFInt32 gmx_simdcall
641 operator-(SimdFInt32 a, SimdFInt32 b)
643 return {
644 vsubq_s32(a.simdInternal_, b.simdInternal_)
648 static inline SimdFInt32 gmx_simdcall
649 operator*(SimdFInt32 a, SimdFInt32 b)
651 return {
652 vmulq_s32(a.simdInternal_, b.simdInternal_)
656 static inline SimdFIBool gmx_simdcall
657 operator==(SimdFInt32 a, SimdFInt32 b)
659 return {
660 vceqq_s32(a.simdInternal_, b.simdInternal_)
664 static inline SimdFIBool gmx_simdcall
665 testBits(SimdFInt32 a)
667 return {
668 vtstq_s32(a.simdInternal_, a.simdInternal_)
672 static inline SimdFIBool gmx_simdcall
673 operator<(SimdFInt32 a, SimdFInt32 b)
675 return {
676 vcltq_s32(a.simdInternal_, b.simdInternal_)
680 static inline SimdFIBool gmx_simdcall
681 operator&&(SimdFIBool a, SimdFIBool b)
683 return {
684 vandq_u32(a.simdInternal_, b.simdInternal_)
688 static inline SimdFIBool gmx_simdcall
689 operator||(SimdFIBool a, SimdFIBool b)
691 return {
692 vorrq_u32(a.simdInternal_, b.simdInternal_)
696 // Override for Neon-Asimd
697 #if GMX_SIMD_ARM_NEON
698 static inline bool gmx_simdcall
699 anyTrue(SimdFIBool a)
701 uint32x4_t x = a.simdInternal_;
702 uint32x4_t y = vextq_u32(x, x, 2);
704 x = vorrq_u32(x, y);
705 y = vextq_u32(x, x, 1);
706 x = vorrq_u32(x, y);
707 return (vgetq_lane_u32(x, 0) != 0);
709 #endif
711 static inline SimdFInt32 gmx_simdcall
712 selectByMask(SimdFInt32 a, SimdFIBool m)
714 return {
715 vandq_s32(a.simdInternal_, vreinterpretq_s32_u32(m.simdInternal_))
719 static inline SimdFInt32 gmx_simdcall
720 selectByNotMask(SimdFInt32 a, SimdFIBool m)
722 return {
723 vbicq_s32(a.simdInternal_, vreinterpretq_s32_u32(m.simdInternal_))
727 static inline SimdFInt32 gmx_simdcall
728 blend(SimdFInt32 a, SimdFInt32 b, SimdFIBool sel)
730 return {
731 vbslq_s32(sel.simdInternal_, b.simdInternal_, a.simdInternal_)
735 // Override for Neon-Asimd
736 #if GMX_SIMD_ARM_NEON
737 static inline SimdFInt32 gmx_simdcall
738 cvtR2I(SimdFloat a)
740 float32x4_t signBitOfA = vreinterpretq_f32_u32(vandq_u32(vdupq_n_u32(0x80000000), vreinterpretq_u32_f32(a.simdInternal_)));
741 float32x4_t half = vdupq_n_f32(0.5f);
742 float32x4_t corr = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(half), vreinterpretq_u32_f32(signBitOfA)));
744 return {
745 vcvtq_s32_f32(vaddq_f32(a.simdInternal_, corr))
748 #endif
750 static inline SimdFInt32 gmx_simdcall
751 cvttR2I(SimdFloat a)
753 return {
754 vcvtq_s32_f32(a.simdInternal_)
758 static inline SimdFloat gmx_simdcall
759 cvtI2R(SimdFInt32 a)
761 return {
762 vcvtq_f32_s32(a.simdInternal_)
766 static inline SimdFIBool gmx_simdcall
767 cvtB2IB(SimdFBool a)
769 return {
770 a.simdInternal_
774 static inline SimdFBool gmx_simdcall
775 cvtIB2B(SimdFIBool a)
777 return {
778 a.simdInternal_
782 // Override for Neon-Asimd
783 #if GMX_SIMD_ARM_NEON
784 static inline SimdFloat gmx_simdcall
785 round(SimdFloat x)
787 return cvtI2R(cvtR2I(x));
790 static inline SimdFloat gmx_simdcall
791 trunc(SimdFloat x)
793 return cvtI2R(cvttR2I(x));
795 #endif
797 } // namespace gmx
799 #endif // GMX_SIMD_IMPL_ARM_NEON_SIMD_FLOAT_H