2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
35 #ifndef GMX_SIMD_IMPL_ARM_NEON_SIMD_FLOAT_H
36 #define GMX_SIMD_IMPL_ARM_NEON_SIMD_FLOAT_H
46 #include "gromacs/math/utilities.h"
56 SimdFloat(float f
) : simdInternal_(vdupq_n_f32(f
)) {}
58 // Internal utility constructor to simplify return statements
59 SimdFloat(float32x4_t simd
) : simdInternal_(simd
) {}
61 float32x4_t simdInternal_
;
69 SimdFInt32(std::int32_t i
) : simdInternal_(vdupq_n_s32(i
)) {}
71 // Internal utility constructor to simplify return statements
72 SimdFInt32(int32x4_t simd
) : simdInternal_(simd
) {}
74 int32x4_t simdInternal_
;
82 SimdFBool(bool b
) : simdInternal_(vdupq_n_u32( b
? 0xFFFFFFFF : 0)) {}
84 // Internal utility constructor to simplify return statements
85 SimdFBool(uint32x4_t simd
) : simdInternal_(simd
) {}
87 uint32x4_t simdInternal_
;
95 SimdFIBool(bool b
) : simdInternal_(vdupq_n_u32( b
? 0xFFFFFFFF : 0)) {}
97 // Internal utility constructor to simplify return statements
98 SimdFIBool(uint32x4_t simd
) : simdInternal_(simd
) {}
100 uint32x4_t simdInternal_
;
103 static inline SimdFloat gmx_simdcall
104 simdLoad(const float *m
)
106 assert(std::size_t(m
) % 16 == 0);
112 static inline void gmx_simdcall
113 store(float *m
, SimdFloat a
)
115 assert(std::size_t(m
) % 16 == 0);
116 vst1q_f32(m
, a
.simdInternal_
);
119 static inline SimdFloat gmx_simdcall
120 simdLoadU(const float *m
)
127 static inline void gmx_simdcall
128 storeU(float *m
, SimdFloat a
)
130 vst1q_f32(m
, a
.simdInternal_
);
133 static inline SimdFloat gmx_simdcall
141 static inline SimdFInt32 gmx_simdcall
142 simdLoadFI(const std::int32_t * m
)
144 assert(std::size_t(m
) % 16 == 0);
150 static inline void gmx_simdcall
151 store(std::int32_t * m
, SimdFInt32 a
)
153 assert(std::size_t(m
) % 16 == 0);
154 vst1q_s32(m
, a
.simdInternal_
);
157 static inline SimdFInt32 gmx_simdcall
158 simdLoadUFI(const std::int32_t *m
)
165 static inline void gmx_simdcall
166 storeU(std::int32_t * m
, SimdFInt32 a
)
168 vst1q_s32(m
, a
.simdInternal_
);
171 static inline SimdFInt32 gmx_simdcall
179 template<int index
> gmx_simdcall
180 static inline std::int32_t
181 extract(SimdFInt32 a
)
183 return vgetq_lane_s32(a
.simdInternal_
, index
);
186 static inline SimdFloat gmx_simdcall
187 operator&(SimdFloat a
, SimdFloat b
)
190 vreinterpretq_f32_s32(vandq_s32(vreinterpretq_s32_f32(a
.simdInternal_
),
191 vreinterpretq_s32_f32(b
.simdInternal_
)))
195 static inline SimdFloat gmx_simdcall
196 andNot(SimdFloat a
, SimdFloat b
)
199 vreinterpretq_f32_s32(vbicq_s32(vreinterpretq_s32_f32(b
.simdInternal_
),
200 vreinterpretq_s32_f32(a
.simdInternal_
)))
204 static inline SimdFloat gmx_simdcall
205 operator|(SimdFloat a
, SimdFloat b
)
208 vreinterpretq_f32_s32(vorrq_s32(vreinterpretq_s32_f32(a
.simdInternal_
),
209 vreinterpretq_s32_f32(b
.simdInternal_
)))
213 static inline SimdFloat gmx_simdcall
214 operator^(SimdFloat a
, SimdFloat b
)
217 vreinterpretq_f32_s32(veorq_s32(vreinterpretq_s32_f32(a
.simdInternal_
),
218 vreinterpretq_s32_f32(b
.simdInternal_
)))
222 static inline SimdFloat gmx_simdcall
223 operator+(SimdFloat a
, SimdFloat b
)
226 vaddq_f32(a
.simdInternal_
, b
.simdInternal_
)
230 static inline SimdFloat gmx_simdcall
231 operator-(SimdFloat a
, SimdFloat b
)
234 vsubq_f32(a
.simdInternal_
, b
.simdInternal_
)
238 static inline SimdFloat gmx_simdcall
239 operator-(SimdFloat x
)
242 vnegq_f32(x
.simdInternal_
)
246 static inline SimdFloat gmx_simdcall
247 operator*(SimdFloat a
, SimdFloat b
)
250 vmulq_f32(a
.simdInternal_
, b
.simdInternal_
)
254 // Override for Neon-Asimd
255 #if GMX_SIMD_ARM_NEON
256 static inline SimdFloat gmx_simdcall
257 fma(SimdFloat a
, SimdFloat b
, SimdFloat c
)
260 #ifdef __ARM_FEATURE_FMA
261 vfmaq_f32(c
.simdInternal_
, b
.simdInternal_
, a
.simdInternal_
)
263 vmlaq_f32(c
.simdInternal_
, b
.simdInternal_
, a
.simdInternal_
)
268 static inline SimdFloat gmx_simdcall
269 fms(SimdFloat a
, SimdFloat b
, SimdFloat c
)
272 #ifdef __ARM_FEATURE_FMA
273 vnegq_f32(vfmsq_f32(c
.simdInternal_
, b
.simdInternal_
, a
.simdInternal_
))
275 vnegq_f32(vmlsq_f32(c
.simdInternal_
, b
.simdInternal_
, a
.simdInternal_
))
280 static inline SimdFloat gmx_simdcall
281 fnma(SimdFloat a
, SimdFloat b
, SimdFloat c
)
284 #ifdef __ARM_FEATURE_FMA
285 vfmsq_f32(c
.simdInternal_
, b
.simdInternal_
, a
.simdInternal_
)
287 vmlsq_f32(c
.simdInternal_
, b
.simdInternal_
, a
.simdInternal_
)
292 static inline SimdFloat gmx_simdcall
293 fnms(SimdFloat a
, SimdFloat b
, SimdFloat c
)
296 #ifdef __ARM_FEATURE_FMA
297 vnegq_f32(vfmaq_f32(c
.simdInternal_
, b
.simdInternal_
, a
.simdInternal_
))
299 vnegq_f32(vmlaq_f32(c
.simdInternal_
, b
.simdInternal_
, a
.simdInternal_
))
305 static inline SimdFloat gmx_simdcall
309 vrsqrteq_f32(x
.simdInternal_
)
313 static inline SimdFloat gmx_simdcall
314 rsqrtIter(SimdFloat lu
, SimdFloat x
)
317 vmulq_f32(lu
.simdInternal_
, vrsqrtsq_f32(vmulq_f32(lu
.simdInternal_
, lu
.simdInternal_
), x
.simdInternal_
))
321 static inline SimdFloat gmx_simdcall
325 vrecpeq_f32(x
.simdInternal_
)
329 static inline SimdFloat gmx_simdcall
330 rcpIter(SimdFloat lu
, SimdFloat x
)
333 vmulq_f32(lu
.simdInternal_
, vrecpsq_f32(lu
.simdInternal_
, x
.simdInternal_
))
337 static inline SimdFloat gmx_simdcall
338 maskAdd(SimdFloat a
, SimdFloat b
, SimdFBool m
)
340 b
.simdInternal_
= vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(b
.simdInternal_
),
344 vaddq_f32(a
.simdInternal_
, b
.simdInternal_
)
348 static inline SimdFloat gmx_simdcall
349 maskzMul(SimdFloat a
, SimdFloat b
, SimdFBool m
)
351 SimdFloat tmp
= a
* b
;
354 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(tmp
.simdInternal_
),
359 static inline SimdFloat gmx_simdcall
360 maskzFma(SimdFloat a
, SimdFloat b
, SimdFloat c
, SimdFBool m
)
362 #ifdef __ARM_FEATURE_FMA
363 float32x4_t tmp
= vfmaq_f32(c
.simdInternal_
, b
.simdInternal_
, a
.simdInternal_
);
365 float32x4_t tmp
= vmlaq_f32(c
.simdInternal_
, b
.simdInternal_
, a
.simdInternal_
);
369 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(tmp
),
374 static inline SimdFloat gmx_simdcall
375 maskzRsqrt(SimdFloat x
, SimdFBool m
)
377 // The result will always be correct since we mask the result with m, but
378 // for debug builds we also want to make sure not to generate FP exceptions
380 x
.simdInternal_
= vbslq_f32(m
.simdInternal_
, x
.simdInternal_
, vdupq_n_f32(1.0f
));
383 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(vrsqrteq_f32(x
.simdInternal_
)),
388 static inline SimdFloat gmx_simdcall
389 maskzRcp(SimdFloat x
, SimdFBool m
)
391 // The result will always be correct since we mask the result with m, but
392 // for debug builds we also want to make sure not to generate FP exceptions
394 x
.simdInternal_
= vbslq_f32(m
.simdInternal_
, x
.simdInternal_
, vdupq_n_f32(1.0f
));
397 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(vrecpeq_f32(x
.simdInternal_
)),
402 static inline SimdFloat gmx_simdcall
406 vabsq_f32( x
.simdInternal_
)
410 static inline SimdFloat gmx_simdcall
411 max(SimdFloat a
, SimdFloat b
)
414 vmaxq_f32(a
.simdInternal_
, b
.simdInternal_
)
418 static inline SimdFloat gmx_simdcall
419 min(SimdFloat a
, SimdFloat b
)
422 vminq_f32(a
.simdInternal_
, b
.simdInternal_
)
426 // Round and trunc operations are defined at the end of this file, since they
427 // need to use float-to-integer and integer-to-float conversions.
429 static inline SimdFloat gmx_simdcall
430 frexp(SimdFloat value
, SimdFInt32
* exponent
)
432 const int32x4_t exponentMask
= vdupq_n_s32(0x7F800000);
433 const int32x4_t mantissaMask
= vdupq_n_s32(0x807FFFFF);
434 const int32x4_t exponentBias
= vdupq_n_s32(126); // add 1 to make our definition identical to frexp()
435 const float32x4_t half
= vdupq_n_f32(0.5f
);
438 iExponent
= vandq_s32(vreinterpretq_s32_f32(value
.simdInternal_
), exponentMask
);
439 iExponent
= vsubq_s32(vshrq_n_s32(iExponent
, 23), exponentBias
);
440 exponent
->simdInternal_
= iExponent
;
443 vreinterpretq_f32_s32(vorrq_s32(vandq_s32(vreinterpretq_s32_f32(value
.simdInternal_
),
445 vreinterpretq_s32_f32(half
)))
449 template <MathOptimization opt
= MathOptimization::Safe
>
450 static inline SimdFloat gmx_simdcall
451 ldexp(SimdFloat value
, SimdFInt32 exponent
)
453 const int32x4_t exponentBias
= vdupq_n_s32(127);
454 int32x4_t iExponent
= vaddq_s32(exponent
.simdInternal_
, exponentBias
);
456 if (opt
== MathOptimization::Safe
)
458 // Make sure biased argument is not negative
459 iExponent
= vmaxq_s32(iExponent
, vdupq_n_s32(0));
462 iExponent
= vshlq_n_s32( iExponent
, 23);
465 vmulq_f32(value
.simdInternal_
, vreinterpretq_f32_s32(iExponent
))
469 // Override for Neon-Asimd
470 #if GMX_SIMD_ARM_NEON
471 static inline float gmx_simdcall
474 float32x4_t x
= a
.simdInternal_
;
475 float32x4_t y
= vextq_f32(x
, x
, 2);
478 y
= vextq_f32(x
, x
, 1);
480 return vgetq_lane_f32(x
, 0);
484 static inline SimdFBool gmx_simdcall
485 operator==(SimdFloat a
, SimdFloat b
)
488 vceqq_f32(a
.simdInternal_
, b
.simdInternal_
)
492 static inline SimdFBool gmx_simdcall
493 operator!=(SimdFloat a
, SimdFloat b
)
496 vmvnq_u32(vceqq_f32(a
.simdInternal_
, b
.simdInternal_
))
500 static inline SimdFBool gmx_simdcall
501 operator<(SimdFloat a
, SimdFloat b
)
504 vcltq_f32(a
.simdInternal_
, b
.simdInternal_
)
508 static inline SimdFBool gmx_simdcall
509 operator<=(SimdFloat a
, SimdFloat b
)
512 vcleq_f32(a
.simdInternal_
, b
.simdInternal_
)
516 static inline SimdFBool gmx_simdcall
517 testBits(SimdFloat a
)
519 uint32x4_t tmp
= vreinterpretq_u32_f32(a
.simdInternal_
);
526 static inline SimdFBool gmx_simdcall
527 operator&&(SimdFBool a
, SimdFBool b
)
531 vandq_u32(a
.simdInternal_
, b
.simdInternal_
)
535 static inline SimdFBool gmx_simdcall
536 operator||(SimdFBool a
, SimdFBool b
)
539 vorrq_u32(a
.simdInternal_
, b
.simdInternal_
)
543 // Override for Neon-Asimd
544 #if GMX_SIMD_ARM_NEON
545 static inline bool gmx_simdcall
548 uint32x4_t x
= a
.simdInternal_
;
549 uint32x4_t y
= vextq_u32(x
, x
, 2);
552 y
= vextq_u32(x
, x
, 1);
554 return (vgetq_lane_u32(x
, 0) != 0);
558 static inline SimdFloat gmx_simdcall
559 selectByMask(SimdFloat a
, SimdFBool m
)
562 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a
.simdInternal_
),
567 static inline SimdFloat gmx_simdcall
568 selectByNotMask(SimdFloat a
, SimdFBool m
)
571 vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a
.simdInternal_
),
576 static inline SimdFloat gmx_simdcall
577 blend(SimdFloat a
, SimdFloat b
, SimdFBool sel
)
580 vbslq_f32(sel
.simdInternal_
, b
.simdInternal_
, a
.simdInternal_
)
584 static inline SimdFInt32 gmx_simdcall
585 operator<<(SimdFInt32 a
, int n
)
588 vshlq_s32(a
.simdInternal_
, vdupq_n_s32(n
>= 32 ? 32 : n
))
592 static inline SimdFInt32 gmx_simdcall
593 operator>>(SimdFInt32 a
, int n
)
596 vshlq_s32(a
.simdInternal_
, vdupq_n_s32(n
>= 32 ? -32 : -n
))
600 static inline SimdFInt32 gmx_simdcall
601 operator&(SimdFInt32 a
, SimdFInt32 b
)
604 vandq_s32(a
.simdInternal_
, b
.simdInternal_
)
608 static inline SimdFInt32 gmx_simdcall
609 andNot(SimdFInt32 a
, SimdFInt32 b
)
612 vbicq_s32(b
.simdInternal_
, a
.simdInternal_
)
616 static inline SimdFInt32 gmx_simdcall
617 operator|(SimdFInt32 a
, SimdFInt32 b
)
620 vorrq_s32(a
.simdInternal_
, b
.simdInternal_
)
624 static inline SimdFInt32 gmx_simdcall
625 operator^(SimdFInt32 a
, SimdFInt32 b
)
628 veorq_s32(a
.simdInternal_
, b
.simdInternal_
)
632 static inline SimdFInt32 gmx_simdcall
633 operator+(SimdFInt32 a
, SimdFInt32 b
)
636 vaddq_s32(a
.simdInternal_
, b
.simdInternal_
)
640 static inline SimdFInt32 gmx_simdcall
641 operator-(SimdFInt32 a
, SimdFInt32 b
)
644 vsubq_s32(a
.simdInternal_
, b
.simdInternal_
)
648 static inline SimdFInt32 gmx_simdcall
649 operator*(SimdFInt32 a
, SimdFInt32 b
)
652 vmulq_s32(a
.simdInternal_
, b
.simdInternal_
)
656 static inline SimdFIBool gmx_simdcall
657 operator==(SimdFInt32 a
, SimdFInt32 b
)
660 vceqq_s32(a
.simdInternal_
, b
.simdInternal_
)
664 static inline SimdFIBool gmx_simdcall
665 testBits(SimdFInt32 a
)
668 vtstq_s32(a
.simdInternal_
, a
.simdInternal_
)
672 static inline SimdFIBool gmx_simdcall
673 operator<(SimdFInt32 a
, SimdFInt32 b
)
676 vcltq_s32(a
.simdInternal_
, b
.simdInternal_
)
680 static inline SimdFIBool gmx_simdcall
681 operator&&(SimdFIBool a
, SimdFIBool b
)
684 vandq_u32(a
.simdInternal_
, b
.simdInternal_
)
688 static inline SimdFIBool gmx_simdcall
689 operator||(SimdFIBool a
, SimdFIBool b
)
692 vorrq_u32(a
.simdInternal_
, b
.simdInternal_
)
696 // Override for Neon-Asimd
697 #if GMX_SIMD_ARM_NEON
698 static inline bool gmx_simdcall
699 anyTrue(SimdFIBool a
)
701 uint32x4_t x
= a
.simdInternal_
;
702 uint32x4_t y
= vextq_u32(x
, x
, 2);
705 y
= vextq_u32(x
, x
, 1);
707 return (vgetq_lane_u32(x
, 0) != 0);
711 static inline SimdFInt32 gmx_simdcall
712 selectByMask(SimdFInt32 a
, SimdFIBool m
)
715 vandq_s32(a
.simdInternal_
, vreinterpretq_s32_u32(m
.simdInternal_
))
719 static inline SimdFInt32 gmx_simdcall
720 selectByNotMask(SimdFInt32 a
, SimdFIBool m
)
723 vbicq_s32(a
.simdInternal_
, vreinterpretq_s32_u32(m
.simdInternal_
))
727 static inline SimdFInt32 gmx_simdcall
728 blend(SimdFInt32 a
, SimdFInt32 b
, SimdFIBool sel
)
731 vbslq_s32(sel
.simdInternal_
, b
.simdInternal_
, a
.simdInternal_
)
735 // Override for Neon-Asimd
736 #if GMX_SIMD_ARM_NEON
737 static inline SimdFInt32 gmx_simdcall
740 float32x4_t signBitOfA
= vreinterpretq_f32_u32(vandq_u32(vdupq_n_u32(0x80000000), vreinterpretq_u32_f32(a
.simdInternal_
)));
741 float32x4_t half
= vdupq_n_f32(0.5f
);
742 float32x4_t corr
= vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(half
), vreinterpretq_u32_f32(signBitOfA
)));
745 vcvtq_s32_f32(vaddq_f32(a
.simdInternal_
, corr
))
750 static inline SimdFInt32 gmx_simdcall
754 vcvtq_s32_f32(a
.simdInternal_
)
758 static inline SimdFloat gmx_simdcall
762 vcvtq_f32_s32(a
.simdInternal_
)
766 static inline SimdFIBool gmx_simdcall
774 static inline SimdFBool gmx_simdcall
775 cvtIB2B(SimdFIBool a
)
782 // Override for Neon-Asimd
783 #if GMX_SIMD_ARM_NEON
784 static inline SimdFloat gmx_simdcall
787 return cvtI2R(cvtR2I(x
));
790 static inline SimdFloat gmx_simdcall
793 return cvtI2R(cvttR2I(x
));
799 #endif // GMX_SIMD_IMPL_ARM_NEON_SIMD_FLOAT_H