Make sure frexp() returns correct for argument 0.0
[gromacs.git] / src / gromacs / simd / impl_arm_neon / impl_arm_neon_simd_float.h
blob888a4df3fd52fcbcaaf415d822761c09a12c4881
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017,2019,2020, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
35 #ifndef GMX_SIMD_IMPL_ARM_NEON_SIMD_FLOAT_H
36 #define GMX_SIMD_IMPL_ARM_NEON_SIMD_FLOAT_H
38 #include "config.h"
40 #include <cassert>
41 #include <cstddef>
42 #include <cstdint>
44 #include <arm_neon.h>
46 #include "gromacs/math/utilities.h"
48 namespace gmx
51 class SimdFloat
53 public:
54 SimdFloat() {}
56 SimdFloat(float f) : simdInternal_(vdupq_n_f32(f)) {}
58 // Internal utility constructor to simplify return statements
59 SimdFloat(float32x4_t simd) : simdInternal_(simd) {}
61 float32x4_t simdInternal_;
64 class SimdFInt32
66 public:
67 SimdFInt32() {}
69 SimdFInt32(std::int32_t i) : simdInternal_(vdupq_n_s32(i)) {}
71 // Internal utility constructor to simplify return statements
72 SimdFInt32(int32x4_t simd) : simdInternal_(simd) {}
74 int32x4_t simdInternal_;
77 class SimdFBool
79 public:
80 SimdFBool() {}
82 SimdFBool(bool b) : simdInternal_(vdupq_n_u32(b ? 0xFFFFFFFF : 0)) {}
84 // Internal utility constructor to simplify return statements
85 SimdFBool(uint32x4_t simd) : simdInternal_(simd) {}
87 uint32x4_t simdInternal_;
90 class SimdFIBool
92 public:
93 SimdFIBool() {}
95 SimdFIBool(bool b) : simdInternal_(vdupq_n_u32(b ? 0xFFFFFFFF : 0)) {}
97 // Internal utility constructor to simplify return statements
98 SimdFIBool(uint32x4_t simd) : simdInternal_(simd) {}
100 uint32x4_t simdInternal_;
103 static inline SimdFloat gmx_simdcall simdLoad(const float* m, SimdFloatTag = {})
105 assert(std::size_t(m) % 16 == 0);
106 return { vld1q_f32(m) };
109 static inline void gmx_simdcall store(float* m, SimdFloat a)
111 assert(std::size_t(m) % 16 == 0);
112 vst1q_f32(m, a.simdInternal_);
115 static inline SimdFloat gmx_simdcall simdLoadU(const float* m, SimdFloatTag = {})
117 return { vld1q_f32(m) };
120 static inline void gmx_simdcall storeU(float* m, SimdFloat a)
122 vst1q_f32(m, a.simdInternal_);
125 static inline SimdFloat gmx_simdcall setZeroF()
127 return { vdupq_n_f32(0.0F) };
130 static inline SimdFInt32 gmx_simdcall simdLoad(const std::int32_t* m, SimdFInt32Tag)
132 assert(std::size_t(m) % 16 == 0);
133 return { vld1q_s32(m) };
136 static inline void gmx_simdcall store(std::int32_t* m, SimdFInt32 a)
138 assert(std::size_t(m) % 16 == 0);
139 vst1q_s32(m, a.simdInternal_);
142 static inline SimdFInt32 gmx_simdcall simdLoadU(const std::int32_t* m, SimdFInt32Tag)
144 return { vld1q_s32(m) };
147 static inline void gmx_simdcall storeU(std::int32_t* m, SimdFInt32 a)
149 vst1q_s32(m, a.simdInternal_);
152 static inline SimdFInt32 gmx_simdcall setZeroFI()
154 return { vdupq_n_s32(0) };
157 template<int index>
158 gmx_simdcall static inline std::int32_t extract(SimdFInt32 a)
160 return vgetq_lane_s32(a.simdInternal_, index);
163 static inline SimdFloat gmx_simdcall operator&(SimdFloat a, SimdFloat b)
165 return { vreinterpretq_f32_s32(vandq_s32(vreinterpretq_s32_f32(a.simdInternal_),
166 vreinterpretq_s32_f32(b.simdInternal_))) };
169 static inline SimdFloat gmx_simdcall andNot(SimdFloat a, SimdFloat b)
171 return { vreinterpretq_f32_s32(vbicq_s32(vreinterpretq_s32_f32(b.simdInternal_),
172 vreinterpretq_s32_f32(a.simdInternal_))) };
175 static inline SimdFloat gmx_simdcall operator|(SimdFloat a, SimdFloat b)
177 return { vreinterpretq_f32_s32(vorrq_s32(vreinterpretq_s32_f32(a.simdInternal_),
178 vreinterpretq_s32_f32(b.simdInternal_))) };
181 static inline SimdFloat gmx_simdcall operator^(SimdFloat a, SimdFloat b)
183 return { vreinterpretq_f32_s32(veorq_s32(vreinterpretq_s32_f32(a.simdInternal_),
184 vreinterpretq_s32_f32(b.simdInternal_))) };
187 static inline SimdFloat gmx_simdcall operator+(SimdFloat a, SimdFloat b)
189 return { vaddq_f32(a.simdInternal_, b.simdInternal_) };
192 static inline SimdFloat gmx_simdcall operator-(SimdFloat a, SimdFloat b)
194 return { vsubq_f32(a.simdInternal_, b.simdInternal_) };
197 static inline SimdFloat gmx_simdcall operator-(SimdFloat x)
199 return { vnegq_f32(x.simdInternal_) };
202 static inline SimdFloat gmx_simdcall operator*(SimdFloat a, SimdFloat b)
204 return { vmulq_f32(a.simdInternal_, b.simdInternal_) };
207 // Override for Neon-Asimd
208 #if GMX_SIMD_ARM_NEON
209 static inline SimdFloat gmx_simdcall fma(SimdFloat a, SimdFloat b, SimdFloat c)
211 return {
212 # ifdef __ARM_FEATURE_FMA
213 vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)
214 # else
215 vmlaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)
216 # endif
220 static inline SimdFloat gmx_simdcall fms(SimdFloat a, SimdFloat b, SimdFloat c)
222 return {
223 # ifdef __ARM_FEATURE_FMA
224 vnegq_f32(vfmsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_))
225 # else
226 vnegq_f32(vmlsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_))
227 # endif
231 static inline SimdFloat gmx_simdcall fnma(SimdFloat a, SimdFloat b, SimdFloat c)
233 return {
234 # ifdef __ARM_FEATURE_FMA
235 vfmsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)
236 # else
237 vmlsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)
238 # endif
242 static inline SimdFloat gmx_simdcall fnms(SimdFloat a, SimdFloat b, SimdFloat c)
244 return {
245 # ifdef __ARM_FEATURE_FMA
246 vnegq_f32(vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_))
247 # else
248 vnegq_f32(vmlaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_))
249 # endif
252 #endif
254 static inline SimdFloat gmx_simdcall rsqrt(SimdFloat x)
256 return { vrsqrteq_f32(x.simdInternal_) };
259 // The SIMD implementation seems to overflow when we square lu for
260 // values close to FLOAT_MAX, so we fall back on the version in
261 // simd_math.h, which is probably slightly slower.
262 #if GMX_SIMD_HAVE_NATIVE_RSQRT_ITER_FLOAT
263 static inline SimdFloat gmx_simdcall rsqrtIter(SimdFloat lu, SimdFloat x)
265 return { vmulq_f32(lu.simdInternal_,
266 vrsqrtsq_f32(vmulq_f32(lu.simdInternal_, lu.simdInternal_), x.simdInternal_)) };
268 #endif
270 static inline SimdFloat gmx_simdcall rcp(SimdFloat x)
272 return { vrecpeq_f32(x.simdInternal_) };
275 static inline SimdFloat gmx_simdcall rcpIter(SimdFloat lu, SimdFloat x)
277 return { vmulq_f32(lu.simdInternal_, vrecpsq_f32(lu.simdInternal_, x.simdInternal_)) };
280 static inline SimdFloat gmx_simdcall maskAdd(SimdFloat a, SimdFloat b, SimdFBool m)
282 b.simdInternal_ =
283 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(b.simdInternal_), m.simdInternal_));
285 return { vaddq_f32(a.simdInternal_, b.simdInternal_) };
288 static inline SimdFloat gmx_simdcall maskzMul(SimdFloat a, SimdFloat b, SimdFBool m)
290 SimdFloat tmp = a * b;
292 return { vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(tmp.simdInternal_), m.simdInternal_)) };
295 static inline SimdFloat gmx_simdcall maskzFma(SimdFloat a, SimdFloat b, SimdFloat c, SimdFBool m)
297 #ifdef __ARM_FEATURE_FMA
298 float32x4_t tmp = vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_);
299 #else
300 float32x4_t tmp = vmlaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_);
301 #endif
303 return { vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(tmp), m.simdInternal_)) };
306 static inline SimdFloat gmx_simdcall maskzRsqrt(SimdFloat x, SimdFBool m)
308 // The result will always be correct since we mask the result with m, but
309 // for debug builds we also want to make sure not to generate FP exceptions
310 #ifndef NDEBUG
311 x.simdInternal_ = vbslq_f32(m.simdInternal_, x.simdInternal_, vdupq_n_f32(1.0F));
312 #endif
313 return { vreinterpretq_f32_u32(
314 vandq_u32(vreinterpretq_u32_f32(vrsqrteq_f32(x.simdInternal_)), m.simdInternal_)) };
317 static inline SimdFloat gmx_simdcall maskzRcp(SimdFloat x, SimdFBool m)
319 // The result will always be correct since we mask the result with m, but
320 // for debug builds we also want to make sure not to generate FP exceptions
321 #ifndef NDEBUG
322 x.simdInternal_ = vbslq_f32(m.simdInternal_, x.simdInternal_, vdupq_n_f32(1.0F));
323 #endif
324 return { vreinterpretq_f32_u32(
325 vandq_u32(vreinterpretq_u32_f32(vrecpeq_f32(x.simdInternal_)), m.simdInternal_)) };
328 static inline SimdFloat gmx_simdcall abs(SimdFloat x)
330 return { vabsq_f32(x.simdInternal_) };
333 static inline SimdFloat gmx_simdcall max(SimdFloat a, SimdFloat b)
335 return { vmaxq_f32(a.simdInternal_, b.simdInternal_) };
338 static inline SimdFloat gmx_simdcall min(SimdFloat a, SimdFloat b)
340 return { vminq_f32(a.simdInternal_, b.simdInternal_) };
343 // Round and trunc operations are defined at the end of this file, since they
344 // need to use float-to-integer and integer-to-float conversions.
346 template<MathOptimization opt = MathOptimization::Safe>
347 static inline SimdFloat gmx_simdcall frexp(SimdFloat value, SimdFInt32* exponent)
349 const int32x4_t exponentMask = vdupq_n_s32(0x7F800000);
350 const int32x4_t mantissaMask = vdupq_n_s32(0x807FFFFF);
351 const int32x4_t exponentBias = vdupq_n_s32(126); // add 1 to make our definition identical to frexp()
352 const float32x4_t half = vdupq_n_f32(0.5F);
353 int32x4_t iExponent;
355 iExponent = vandq_s32(vreinterpretq_s32_f32(value.simdInternal_), exponentMask);
356 iExponent = vsubq_s32(vshrq_n_s32(iExponent, 23), exponentBias);
357 exponent->simdInternal_ = iExponent;
359 return { vreinterpretq_f32_s32(vorrq_s32(vandq_s32(vreinterpretq_s32_f32(value.simdInternal_), mantissaMask),
360 vreinterpretq_s32_f32(half))) };
363 template<MathOptimization opt = MathOptimization::Safe>
364 static inline SimdFloat gmx_simdcall ldexp(SimdFloat value, SimdFInt32 exponent)
366 const int32x4_t exponentBias = vdupq_n_s32(127);
367 int32x4_t iExponent = vaddq_s32(exponent.simdInternal_, exponentBias);
369 if (opt == MathOptimization::Safe)
371 // Make sure biased argument is not negative
372 iExponent = vmaxq_s32(iExponent, vdupq_n_s32(0));
375 iExponent = vshlq_n_s32(iExponent, 23);
377 return { vmulq_f32(value.simdInternal_, vreinterpretq_f32_s32(iExponent)) };
380 // Override for Neon-Asimd
381 #if GMX_SIMD_ARM_NEON
382 static inline float gmx_simdcall reduce(SimdFloat a)
384 float32x4_t x = a.simdInternal_;
385 float32x4_t y = vextq_f32(x, x, 2);
387 x = vaddq_f32(x, y);
388 y = vextq_f32(x, x, 1);
389 x = vaddq_f32(x, y);
390 return vgetq_lane_f32(x, 0);
392 #endif
394 static inline SimdFBool gmx_simdcall operator==(SimdFloat a, SimdFloat b)
396 return { vceqq_f32(a.simdInternal_, b.simdInternal_) };
399 static inline SimdFBool gmx_simdcall operator!=(SimdFloat a, SimdFloat b)
401 return { vmvnq_u32(vceqq_f32(a.simdInternal_, b.simdInternal_)) };
404 static inline SimdFBool gmx_simdcall operator<(SimdFloat a, SimdFloat b)
406 return { vcltq_f32(a.simdInternal_, b.simdInternal_) };
409 static inline SimdFBool gmx_simdcall operator<=(SimdFloat a, SimdFloat b)
411 return { vcleq_f32(a.simdInternal_, b.simdInternal_) };
414 static inline SimdFBool gmx_simdcall testBits(SimdFloat a)
416 uint32x4_t tmp = vreinterpretq_u32_f32(a.simdInternal_);
418 return { vtstq_u32(tmp, tmp) };
421 static inline SimdFBool gmx_simdcall operator&&(SimdFBool a, SimdFBool b)
424 return { vandq_u32(a.simdInternal_, b.simdInternal_) };
427 static inline SimdFBool gmx_simdcall operator||(SimdFBool a, SimdFBool b)
429 return { vorrq_u32(a.simdInternal_, b.simdInternal_) };
432 // Override for Neon-Asimd
433 #if GMX_SIMD_ARM_NEON
434 static inline bool gmx_simdcall anyTrue(SimdFBool a)
436 uint32x4_t x = a.simdInternal_;
437 uint32x4_t y = vextq_u32(x, x, 2);
439 x = vorrq_u32(x, y);
440 y = vextq_u32(x, x, 1);
441 x = vorrq_u32(x, y);
442 return (vgetq_lane_u32(x, 0) != 0);
444 #endif
446 static inline SimdFloat gmx_simdcall selectByMask(SimdFloat a, SimdFBool m)
448 return { vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a.simdInternal_), m.simdInternal_)) };
451 static inline SimdFloat gmx_simdcall selectByNotMask(SimdFloat a, SimdFBool m)
453 return { vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a.simdInternal_), m.simdInternal_)) };
456 static inline SimdFloat gmx_simdcall blend(SimdFloat a, SimdFloat b, SimdFBool sel)
458 return { vbslq_f32(sel.simdInternal_, b.simdInternal_, a.simdInternal_) };
461 static inline SimdFInt32 gmx_simdcall operator&(SimdFInt32 a, SimdFInt32 b)
463 return { vandq_s32(a.simdInternal_, b.simdInternal_) };
466 static inline SimdFInt32 gmx_simdcall andNot(SimdFInt32 a, SimdFInt32 b)
468 return { vbicq_s32(b.simdInternal_, a.simdInternal_) };
471 static inline SimdFInt32 gmx_simdcall operator|(SimdFInt32 a, SimdFInt32 b)
473 return { vorrq_s32(a.simdInternal_, b.simdInternal_) };
476 static inline SimdFInt32 gmx_simdcall operator^(SimdFInt32 a, SimdFInt32 b)
478 return { veorq_s32(a.simdInternal_, b.simdInternal_) };
481 static inline SimdFInt32 gmx_simdcall operator+(SimdFInt32 a, SimdFInt32 b)
483 return { vaddq_s32(a.simdInternal_, b.simdInternal_) };
486 static inline SimdFInt32 gmx_simdcall operator-(SimdFInt32 a, SimdFInt32 b)
488 return { vsubq_s32(a.simdInternal_, b.simdInternal_) };
491 static inline SimdFInt32 gmx_simdcall operator*(SimdFInt32 a, SimdFInt32 b)
493 return { vmulq_s32(a.simdInternal_, b.simdInternal_) };
496 static inline SimdFIBool gmx_simdcall operator==(SimdFInt32 a, SimdFInt32 b)
498 return { vceqq_s32(a.simdInternal_, b.simdInternal_) };
501 static inline SimdFIBool gmx_simdcall testBits(SimdFInt32 a)
503 return { vtstq_s32(a.simdInternal_, a.simdInternal_) };
506 static inline SimdFIBool gmx_simdcall operator<(SimdFInt32 a, SimdFInt32 b)
508 return { vcltq_s32(a.simdInternal_, b.simdInternal_) };
511 static inline SimdFIBool gmx_simdcall operator&&(SimdFIBool a, SimdFIBool b)
513 return { vandq_u32(a.simdInternal_, b.simdInternal_) };
516 static inline SimdFIBool gmx_simdcall operator||(SimdFIBool a, SimdFIBool b)
518 return { vorrq_u32(a.simdInternal_, b.simdInternal_) };
521 // Override for Neon-Asimd
522 #if GMX_SIMD_ARM_NEON
523 static inline bool gmx_simdcall anyTrue(SimdFIBool a)
525 uint32x4_t x = a.simdInternal_;
526 uint32x4_t y = vextq_u32(x, x, 2);
528 x = vorrq_u32(x, y);
529 y = vextq_u32(x, x, 1);
530 x = vorrq_u32(x, y);
531 return (vgetq_lane_u32(x, 0) != 0);
533 #endif
535 static inline SimdFInt32 gmx_simdcall selectByMask(SimdFInt32 a, SimdFIBool m)
537 return { vandq_s32(a.simdInternal_, vreinterpretq_s32_u32(m.simdInternal_)) };
540 static inline SimdFInt32 gmx_simdcall selectByNotMask(SimdFInt32 a, SimdFIBool m)
542 return { vbicq_s32(a.simdInternal_, vreinterpretq_s32_u32(m.simdInternal_)) };
545 static inline SimdFInt32 gmx_simdcall blend(SimdFInt32 a, SimdFInt32 b, SimdFIBool sel)
547 return { vbslq_s32(sel.simdInternal_, b.simdInternal_, a.simdInternal_) };
550 // Override for Neon-Asimd
551 #if GMX_SIMD_ARM_NEON
552 static inline SimdFInt32 gmx_simdcall cvtR2I(SimdFloat a)
554 float32x4_t signBitOfA = vreinterpretq_f32_u32(
555 vandq_u32(vdupq_n_u32(0x80000000), vreinterpretq_u32_f32(a.simdInternal_)));
556 float32x4_t half = vdupq_n_f32(0.5F);
557 float32x4_t corr = vreinterpretq_f32_u32(
558 vorrq_u32(vreinterpretq_u32_f32(half), vreinterpretq_u32_f32(signBitOfA)));
560 return { vcvtq_s32_f32(vaddq_f32(a.simdInternal_, corr)) };
562 #endif
564 static inline SimdFInt32 gmx_simdcall cvttR2I(SimdFloat a)
566 return { vcvtq_s32_f32(a.simdInternal_) };
569 static inline SimdFloat gmx_simdcall cvtI2R(SimdFInt32 a)
571 return { vcvtq_f32_s32(a.simdInternal_) };
574 static inline SimdFIBool gmx_simdcall cvtB2IB(SimdFBool a)
576 return { a.simdInternal_ };
579 static inline SimdFBool gmx_simdcall cvtIB2B(SimdFIBool a)
581 return { a.simdInternal_ };
584 // Override for Neon-Asimd
585 #if GMX_SIMD_ARM_NEON
586 static inline SimdFloat gmx_simdcall round(SimdFloat x)
588 return cvtI2R(cvtR2I(x));
591 static inline SimdFloat gmx_simdcall trunc(SimdFloat x)
593 return cvtI2R(cvttR2I(x));
595 #endif
597 } // namespace gmx
599 #endif // GMX_SIMD_IMPL_ARM_NEON_SIMD_FLOAT_H