2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017,2019, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 #ifndef GMX_SIMD_IMPL_X86_MIC_SIMD_FLOAT_H
37 #define GMX_SIMD_IMPL_X86_MIC_SIMD_FLOAT_H
44 #include <immintrin.h>
46 #include "gromacs/math/utilities.h"
56 SimdFloat(float f
) : simdInternal_(_mm512_set1_ps(f
)) {}
58 // Internal utility constructor to simplify return statements
59 SimdFloat(__m512 simd
) : simdInternal_(simd
) {}
69 SimdFInt32(std::int32_t i
) : simdInternal_(_mm512_set1_epi32(i
)) {}
71 // Internal utility constructor to simplify return statements
72 SimdFInt32(__m512i simd
) : simdInternal_(simd
) {}
74 __m512i simdInternal_
;
82 SimdFBool(bool b
) : simdInternal_(_mm512_int2mask( b
? 0xFFFF : 0)) {}
84 // Internal utility constructor to simplify return statements
85 SimdFBool(__mmask16 simd
) : simdInternal_(simd
) {}
87 __mmask16 simdInternal_
;
95 SimdFIBool(bool b
) : simdInternal_(_mm512_int2mask( b
? 0xFFFF : 0)) {}
97 // Internal utility constructor to simplify return statements
98 SimdFIBool(__mmask16 simd
) : simdInternal_(simd
) {}
100 __mmask16 simdInternal_
;
103 static inline SimdFloat gmx_simdcall
104 simdLoad(const float *m
, SimdFloatTag
= {})
106 assert(std::size_t(m
) % 64 == 0);
112 static inline void gmx_simdcall
113 store(float *m
, SimdFloat a
)
115 assert(std::size_t(m
) % 64 == 0);
116 _mm512_store_ps(m
, a
.simdInternal_
);
119 static inline SimdFloat gmx_simdcall
120 simdLoadU(const float *m
, SimdFloatTag
= {})
123 _mm512_loadunpackhi_ps(_mm512_loadunpacklo_ps(_mm512_undefined_ps(), m
), m
+16)
127 static inline void gmx_simdcall
128 storeU(float *m
, SimdFloat a
)
130 _mm512_packstorelo_ps(m
, a
.simdInternal_
);
131 _mm512_packstorehi_ps(m
+16, a
.simdInternal_
);
134 static inline SimdFloat gmx_simdcall
142 static inline SimdFInt32 gmx_simdcall
143 simdLoad(const std::int32_t * m
, SimdFInt32Tag
)
145 assert(std::size_t(m
) % 64 == 0);
151 static inline void gmx_simdcall
152 store(std::int32_t * m
, SimdFInt32 a
)
154 assert(std::size_t(m
) % 64 == 0);
155 _mm512_store_epi32(m
, a
.simdInternal_
);
158 static inline SimdFInt32 gmx_simdcall
159 simdLoadU(const std::int32_t *m
, SimdFInt32Tag
)
162 _mm512_loadunpackhi_epi32(_mm512_loadunpacklo_epi32(_mm512_undefined_epi32(), m
), m
+16)
166 static inline void gmx_simdcall
167 storeU(std::int32_t * m
, SimdFInt32 a
)
169 _mm512_packstorelo_epi32(m
, a
.simdInternal_
);
170 _mm512_packstorehi_epi32(m
+16, a
.simdInternal_
);
173 static inline SimdFInt32 gmx_simdcall
177 _mm512_setzero_si512()
183 static inline std::int32_t gmx_simdcall
184 extract(SimdFInt32 a
)
187 _mm512_mask_packstorelo_epi32(&r
, _mm512_mask2int(1<<index
), a
.simdInternal_
);
191 static inline SimdFloat gmx_simdcall
192 operator&(SimdFloat a
, SimdFloat b
)
195 _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(a
.simdInternal_
), _mm512_castps_si512(b
.simdInternal_
)))
199 static inline SimdFloat gmx_simdcall
200 andNot(SimdFloat a
, SimdFloat b
)
203 _mm512_castsi512_ps(_mm512_andnot_epi32(_mm512_castps_si512(a
.simdInternal_
), _mm512_castps_si512(b
.simdInternal_
)))
207 static inline SimdFloat gmx_simdcall
208 operator|(SimdFloat a
, SimdFloat b
)
211 _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(a
.simdInternal_
), _mm512_castps_si512(b
.simdInternal_
)))
215 static inline SimdFloat gmx_simdcall
216 operator^(SimdFloat a
, SimdFloat b
)
219 _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(a
.simdInternal_
), _mm512_castps_si512(b
.simdInternal_
)))
223 static inline SimdFloat gmx_simdcall
224 operator+(SimdFloat a
, SimdFloat b
)
227 _mm512_add_ps(a
.simdInternal_
, b
.simdInternal_
)
231 static inline SimdFloat gmx_simdcall
232 operator-(SimdFloat a
, SimdFloat b
)
235 _mm512_sub_ps(a
.simdInternal_
, b
.simdInternal_
)
239 static inline SimdFloat gmx_simdcall
240 operator-(SimdFloat x
)
243 _mm512_addn_ps(x
.simdInternal_
, _mm512_setzero_ps())
247 static inline SimdFloat gmx_simdcall
248 operator*(SimdFloat a
, SimdFloat b
)
251 _mm512_mul_ps(a
.simdInternal_
, b
.simdInternal_
)
255 static inline SimdFloat gmx_simdcall
256 fma(SimdFloat a
, SimdFloat b
, SimdFloat c
)
259 _mm512_fmadd_ps(a
.simdInternal_
, b
.simdInternal_
, c
.simdInternal_
)
263 static inline SimdFloat gmx_simdcall
264 fms(SimdFloat a
, SimdFloat b
, SimdFloat c
)
267 _mm512_fmsub_ps(a
.simdInternal_
, b
.simdInternal_
, c
.simdInternal_
)
271 static inline SimdFloat gmx_simdcall
272 fnma(SimdFloat a
, SimdFloat b
, SimdFloat c
)
275 _mm512_fnmadd_ps(a
.simdInternal_
, b
.simdInternal_
, c
.simdInternal_
)
279 static inline SimdFloat gmx_simdcall
280 fnms(SimdFloat a
, SimdFloat b
, SimdFloat c
)
283 _mm512_fnmsub_ps(a
.simdInternal_
, b
.simdInternal_
, c
.simdInternal_
)
287 static inline SimdFloat gmx_simdcall
291 _mm512_rsqrt23_ps(x
.simdInternal_
)
295 static inline SimdFloat gmx_simdcall
299 _mm512_rcp23_ps(x
.simdInternal_
)
303 static inline SimdFloat gmx_simdcall
304 maskAdd(SimdFloat a
, SimdFloat b
, SimdFBool m
)
307 _mm512_mask_add_ps(a
.simdInternal_
, m
.simdInternal_
, a
.simdInternal_
, b
.simdInternal_
)
311 static inline SimdFloat gmx_simdcall
312 maskzMul(SimdFloat a
, SimdFloat b
, SimdFBool m
)
315 _mm512_mask_mul_ps(_mm512_setzero_ps(), m
.simdInternal_
, a
.simdInternal_
, b
.simdInternal_
)
319 static inline SimdFloat gmx_simdcall
320 maskzFma(SimdFloat a
, SimdFloat b
, SimdFloat c
, SimdFBool m
)
323 _mm512_mask_mov_ps(_mm512_setzero_ps(), m
.simdInternal_
, _mm512_fmadd_ps(a
.simdInternal_
, b
.simdInternal_
, c
.simdInternal_
))
327 static inline SimdFloat gmx_simdcall
328 maskzRsqrt(SimdFloat x
, SimdFBool m
)
331 _mm512_mask_rsqrt23_ps(_mm512_setzero_ps(), m
.simdInternal_
, x
.simdInternal_
)
335 static inline SimdFloat gmx_simdcall
336 maskzRcp(SimdFloat x
, SimdFBool m
)
339 _mm512_mask_rcp23_ps(_mm512_setzero_ps(), m
.simdInternal_
, x
.simdInternal_
)
343 static inline SimdFloat gmx_simdcall
347 _mm512_castsi512_ps(_mm512_andnot_epi32(_mm512_castps_si512(_mm512_set1_ps(GMX_FLOAT_NEGZERO
)), _mm512_castps_si512(x
.simdInternal_
)))
351 static inline SimdFloat gmx_simdcall
352 max(SimdFloat a
, SimdFloat b
)
355 _mm512_gmax_ps(a
.simdInternal_
, b
.simdInternal_
)
359 static inline SimdFloat gmx_simdcall
360 min(SimdFloat a
, SimdFloat b
)
363 _mm512_gmin_ps(a
.simdInternal_
, b
.simdInternal_
)
367 static inline SimdFloat gmx_simdcall
371 _mm512_round_ps(x
.simdInternal_
, _MM_FROUND_TO_NEAREST_INT
, _MM_EXPADJ_NONE
)
375 static inline SimdFloat gmx_simdcall
379 _mm512_round_ps(x
.simdInternal_
, _MM_FROUND_TO_ZERO
, _MM_EXPADJ_NONE
)
383 static inline SimdFloat gmx_simdcall
384 frexp(SimdFloat value
, SimdFInt32
* exponent
)
386 __m512 rExponent
= _mm512_getexp_ps(value
.simdInternal_
);
387 __m512i iExponent
= _mm512_cvtfxpnt_round_adjustps_epi32(rExponent
, _MM_FROUND_TO_NEAREST_INT
, _MM_EXPADJ_NONE
);
389 exponent
->simdInternal_
= _mm512_add_epi32(iExponent
, _mm512_set1_epi32(1));
392 _mm512_getmant_ps(value
.simdInternal_
, _MM_MANT_NORM_p5_1
, _MM_MANT_SIGN_src
)
396 template <MathOptimization opt
= MathOptimization::Safe
>
397 static inline SimdFloat gmx_simdcall
398 ldexp(SimdFloat value
, SimdFInt32 exponent
)
400 const __m512i exponentBias
= _mm512_set1_epi32(127);
401 __m512i iExponent
= _mm512_add_epi32(exponent
.simdInternal_
, exponentBias
);
403 if (opt
== MathOptimization::Safe
)
405 // Make sure biased argument is not negative
406 iExponent
= _mm512_max_epi32(iExponent
, _mm512_setzero_epi32());
409 iExponent
= _mm512_slli_epi32( iExponent
, 23);
412 _mm512_mul_ps(value
.simdInternal_
, _mm512_castsi512_ps(iExponent
))
416 static inline float gmx_simdcall
419 return _mm512_reduce_add_ps(a
.simdInternal_
);
422 // Picky, picky, picky:
423 // icc-16 complains about "Illegal value of immediate argument to intrinsic"
425 // 1) Ordered-quiet for ==
426 // 2) Unordered-quiet for !=
427 // 3) Ordered-signaling for < and <=
429 static inline SimdFBool gmx_simdcall
430 operator==(SimdFloat a
, SimdFloat b
)
433 _mm512_cmp_ps_mask(a
.simdInternal_
, b
.simdInternal_
, _CMP_EQ_OQ
)
437 static inline SimdFBool gmx_simdcall
438 operator!=(SimdFloat a
, SimdFloat b
)
441 _mm512_cmp_ps_mask(a
.simdInternal_
, b
.simdInternal_
, _CMP_NEQ_UQ
)
445 static inline SimdFBool gmx_simdcall
446 operator<(SimdFloat a
, SimdFloat b
)
449 _mm512_cmp_ps_mask(a
.simdInternal_
, b
.simdInternal_
, _CMP_LT_OS
)
453 static inline SimdFBool gmx_simdcall
454 operator<=(SimdFloat a
, SimdFloat b
)
457 _mm512_cmp_ps_mask(a
.simdInternal_
, b
.simdInternal_
, _CMP_LE_OS
)
461 static inline SimdFBool gmx_simdcall
462 testBits(SimdFloat a
)
465 _mm512_test_epi32_mask( _mm512_castps_si512(a
.simdInternal_
), _mm512_castps_si512(a
.simdInternal_
) )
469 static inline SimdFBool gmx_simdcall
470 operator&&(SimdFBool a
, SimdFBool b
)
473 _mm512_kand(a
.simdInternal_
, b
.simdInternal_
)
477 static inline SimdFBool gmx_simdcall
478 operator||(SimdFBool a
, SimdFBool b
)
481 _mm512_kor(a
.simdInternal_
, b
.simdInternal_
)
485 static inline bool gmx_simdcall
488 return _mm512_mask2int(a
.simdInternal_
) != 0;
491 static inline SimdFloat gmx_simdcall
492 selectByMask(SimdFloat a
, SimdFBool m
)
495 _mm512_mask_mov_ps(_mm512_setzero_ps(), m
.simdInternal_
, a
.simdInternal_
)
499 static inline SimdFloat gmx_simdcall
500 selectByNotMask(SimdFloat a
, SimdFBool m
)
503 _mm512_mask_mov_ps(a
.simdInternal_
, m
.simdInternal_
, _mm512_setzero_ps())
507 static inline SimdFloat gmx_simdcall
508 blend(SimdFloat a
, SimdFloat b
, SimdFBool sel
)
511 _mm512_mask_blend_ps(sel
.simdInternal_
, a
.simdInternal_
, b
.simdInternal_
)
515 static inline SimdFInt32 gmx_simdcall
516 operator&(SimdFInt32 a
, SimdFInt32 b
)
519 _mm512_and_epi32(a
.simdInternal_
, b
.simdInternal_
)
523 static inline SimdFInt32 gmx_simdcall
524 andNot(SimdFInt32 a
, SimdFInt32 b
)
527 _mm512_andnot_epi32(a
.simdInternal_
, b
.simdInternal_
)
531 static inline SimdFInt32 gmx_simdcall
532 operator|(SimdFInt32 a
, SimdFInt32 b
)
535 _mm512_or_epi32(a
.simdInternal_
, b
.simdInternal_
)
539 static inline SimdFInt32 gmx_simdcall
540 operator^(SimdFInt32 a
, SimdFInt32 b
)
543 _mm512_xor_epi32(a
.simdInternal_
, b
.simdInternal_
)
547 static inline SimdFInt32 gmx_simdcall
548 operator+(SimdFInt32 a
, SimdFInt32 b
)
551 _mm512_add_epi32(a
.simdInternal_
, b
.simdInternal_
)
555 static inline SimdFInt32 gmx_simdcall
556 operator-(SimdFInt32 a
, SimdFInt32 b
)
559 _mm512_sub_epi32(a
.simdInternal_
, b
.simdInternal_
)
563 static inline SimdFInt32 gmx_simdcall
564 operator*(SimdFInt32 a
, SimdFInt32 b
)
567 _mm512_mullo_epi32(a
.simdInternal_
, b
.simdInternal_
)
571 static inline SimdFIBool gmx_simdcall
572 operator==(SimdFInt32 a
, SimdFInt32 b
)
575 _mm512_cmp_epi32_mask(a
.simdInternal_
, b
.simdInternal_
, _MM_CMPINT_EQ
)
579 static inline SimdFIBool gmx_simdcall
580 testBits(SimdFInt32 a
)
583 _mm512_test_epi32_mask( a
.simdInternal_
, a
.simdInternal_
)
587 static inline SimdFIBool gmx_simdcall
588 operator<(SimdFInt32 a
, SimdFInt32 b
)
591 _mm512_cmp_epi32_mask(a
.simdInternal_
, b
.simdInternal_
, _MM_CMPINT_LT
)
595 static inline SimdFIBool gmx_simdcall
596 operator&&(SimdFIBool a
, SimdFIBool b
)
599 _mm512_kand(a
.simdInternal_
, b
.simdInternal_
)
603 static inline SimdFIBool gmx_simdcall
604 operator||(SimdFIBool a
, SimdFIBool b
)
607 _mm512_kor(a
.simdInternal_
, b
.simdInternal_
)
611 static inline bool gmx_simdcall
612 anyTrue(SimdFIBool a
)
614 return _mm512_mask2int(a
.simdInternal_
) != 0;
617 static inline SimdFInt32 gmx_simdcall
618 selectByMask(SimdFInt32 a
, SimdFIBool m
)
621 _mm512_mask_mov_epi32(_mm512_setzero_epi32(), m
.simdInternal_
, a
.simdInternal_
)
625 static inline SimdFInt32 gmx_simdcall
626 selectByNotMask(SimdFInt32 a
, SimdFIBool m
)
629 _mm512_mask_mov_epi32(a
.simdInternal_
, m
.simdInternal_
, _mm512_setzero_epi32())
633 static inline SimdFInt32 gmx_simdcall
634 blend(SimdFInt32 a
, SimdFInt32 b
, SimdFIBool sel
)
637 _mm512_mask_blend_epi32(sel
.simdInternal_
, a
.simdInternal_
, b
.simdInternal_
)
641 static inline SimdFInt32 gmx_simdcall
645 _mm512_cvtfxpnt_round_adjustps_epi32(a
.simdInternal_
, _MM_FROUND_TO_NEAREST_INT
, _MM_EXPADJ_NONE
)
649 static inline SimdFInt32 gmx_simdcall
653 _mm512_cvtfxpnt_round_adjustps_epi32(a
.simdInternal_
, _MM_FROUND_TO_ZERO
, _MM_EXPADJ_NONE
)
657 static inline SimdFloat gmx_simdcall
661 _mm512_cvtfxpnt_round_adjustepi32_ps(a
.simdInternal_
, _MM_FROUND_TO_NEAREST_INT
, _MM_EXPADJ_NONE
)
665 static inline SimdFIBool gmx_simdcall
673 static inline SimdFBool gmx_simdcall
674 cvtIB2B(SimdFIBool a
)
682 template <MathOptimization opt
= MathOptimization::Safe
>
683 static inline SimdFloat gmx_simdcall
687 _mm512_exp223_ps(_mm512_cvtfxpnt_round_adjustps_epi32(x
.simdInternal_
, _MM_ROUND_MODE_NEAREST
, _MM_EXPADJ_24
))
691 template <MathOptimization opt
= MathOptimization::Safe
>
692 static inline SimdFloat gmx_simdcall
695 const __m512 argscale
= _mm512_set1_ps(1.44269504088896341F
);
696 const __m512 invargscale
= _mm512_set1_ps(-0.69314718055994528623F
);
698 if (opt
== MathOptimization::Safe
)
700 // Set the limit to gurantee flush to zero
701 const SimdFloat
smallArgLimit(-88.f
);
702 // Since we multiply the argument by 1.44, for the safe version we need to make
703 // sure this doesn't result in overflow
704 x
= max(x
, smallArgLimit
);
707 __m512 xscaled
= _mm512_mul_ps(x
.simdInternal_
, argscale
);
708 __m512 r
= _mm512_exp223_ps(_mm512_cvtfxpnt_round_adjustps_epi32(xscaled
, _MM_ROUND_MODE_NEAREST
, _MM_EXPADJ_24
));
710 // exp2a23_ps provides 23 bits of accuracy, but we ruin some of that with our argument
711 // scaling. To correct this, we find the difference between the scaled argument and
712 // the true one (extended precision arithmetics does not appear to be necessary to
713 // fulfill our accuracy requirements) and then multiply by the exponent of this
714 // correction since exp(a+b)=exp(a)*exp(b).
715 // Note that this only adds two instructions (and maybe some constant loads).
717 // find the difference
718 x
= _mm512_fmadd_ps(invargscale
, xscaled
, x
.simdInternal_
);
719 // x will now be a _very_ small number, so approximate exp(x)=1+x.
720 // We should thus apply the correction as r'=r*(1+x)=r+r*x
721 r
= _mm512_fmadd_ps(r
, x
.simdInternal_
, r
);
727 static inline SimdFloat gmx_simdcall
731 _mm512_mul_ps(_mm512_set1_ps(0.693147180559945286226764F
), _mm512_log2ae23_ps(x
.simdInternal_
))
737 #endif // GMX_SIMD_IMPL_X86_MIC_SIMD_FLOAT_H