Make sure frexp() returns correct for argument 0.0
[gromacs.git] / src / gromacs / simd / impl_x86_mic / impl_x86_mic_simd_float.h
blobf1ab6da335184aa394e3d6ff9eda2f1198730f6c
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017,2019,2020, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 #ifndef GMX_SIMD_IMPL_X86_MIC_SIMD_FLOAT_H
37 #define GMX_SIMD_IMPL_X86_MIC_SIMD_FLOAT_H
39 #include "config.h"
41 #include <cassert>
42 #include <cstdint>
44 #include <immintrin.h>
46 #include "gromacs/math/utilities.h"
48 namespace gmx
51 class SimdFloat
53 public:
54 SimdFloat() {}
56 SimdFloat(float f) : simdInternal_(_mm512_set1_ps(f)) {}
58 // Internal utility constructor to simplify return statements
59 SimdFloat(__m512 simd) : simdInternal_(simd) {}
61 __m512 simdInternal_;
64 class SimdFInt32
66 public:
67 SimdFInt32() {}
69 SimdFInt32(std::int32_t i) : simdInternal_(_mm512_set1_epi32(i)) {}
71 // Internal utility constructor to simplify return statements
72 SimdFInt32(__m512i simd) : simdInternal_(simd) {}
74 __m512i simdInternal_;
77 class SimdFBool
79 public:
80 SimdFBool() {}
82 SimdFBool(bool b) : simdInternal_(_mm512_int2mask(b ? 0xFFFF : 0)) {}
84 // Internal utility constructor to simplify return statements
85 SimdFBool(__mmask16 simd) : simdInternal_(simd) {}
87 __mmask16 simdInternal_;
90 class SimdFIBool
92 public:
93 SimdFIBool() {}
95 SimdFIBool(bool b) : simdInternal_(_mm512_int2mask(b ? 0xFFFF : 0)) {}
97 // Internal utility constructor to simplify return statements
98 SimdFIBool(__mmask16 simd) : simdInternal_(simd) {}
100 __mmask16 simdInternal_;
103 static inline SimdFloat gmx_simdcall simdLoad(const float* m, SimdFloatTag = {})
105 assert(std::size_t(m) % 64 == 0);
106 return { _mm512_load_ps(m) };
109 static inline void gmx_simdcall store(float* m, SimdFloat a)
111 assert(std::size_t(m) % 64 == 0);
112 _mm512_store_ps(m, a.simdInternal_);
115 static inline SimdFloat gmx_simdcall simdLoadU(const float* m, SimdFloatTag = {})
117 return { _mm512_loadunpackhi_ps(_mm512_loadunpacklo_ps(_mm512_undefined_ps(), m), m + 16) };
120 static inline void gmx_simdcall storeU(float* m, SimdFloat a)
122 _mm512_packstorelo_ps(m, a.simdInternal_);
123 _mm512_packstorehi_ps(m + 16, a.simdInternal_);
126 static inline SimdFloat gmx_simdcall setZeroF()
128 return { _mm512_setzero_ps() };
131 static inline SimdFInt32 gmx_simdcall simdLoad(const std::int32_t* m, SimdFInt32Tag)
133 assert(std::size_t(m) % 64 == 0);
134 return { _mm512_load_epi32(m) };
137 static inline void gmx_simdcall store(std::int32_t* m, SimdFInt32 a)
139 assert(std::size_t(m) % 64 == 0);
140 _mm512_store_epi32(m, a.simdInternal_);
143 static inline SimdFInt32 gmx_simdcall simdLoadU(const std::int32_t* m, SimdFInt32Tag)
145 return { _mm512_loadunpackhi_epi32(_mm512_loadunpacklo_epi32(_mm512_undefined_epi32(), m), m + 16) };
148 static inline void gmx_simdcall storeU(std::int32_t* m, SimdFInt32 a)
150 _mm512_packstorelo_epi32(m, a.simdInternal_);
151 _mm512_packstorehi_epi32(m + 16, a.simdInternal_);
154 static inline SimdFInt32 gmx_simdcall setZeroFI()
156 return { _mm512_setzero_si512() };
160 template<int index>
161 static inline std::int32_t gmx_simdcall extract(SimdFInt32 a)
163 int r;
164 _mm512_mask_packstorelo_epi32(&r, _mm512_mask2int(1 << index), a.simdInternal_);
165 return r;
168 static inline SimdFloat gmx_simdcall operator&(SimdFloat a, SimdFloat b)
170 return { _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(a.simdInternal_),
171 _mm512_castps_si512(b.simdInternal_))) };
174 static inline SimdFloat gmx_simdcall andNot(SimdFloat a, SimdFloat b)
176 return { _mm512_castsi512_ps(_mm512_andnot_epi32(_mm512_castps_si512(a.simdInternal_),
177 _mm512_castps_si512(b.simdInternal_))) };
180 static inline SimdFloat gmx_simdcall operator|(SimdFloat a, SimdFloat b)
182 return { _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(a.simdInternal_),
183 _mm512_castps_si512(b.simdInternal_))) };
186 static inline SimdFloat gmx_simdcall operator^(SimdFloat a, SimdFloat b)
188 return { _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(a.simdInternal_),
189 _mm512_castps_si512(b.simdInternal_))) };
192 static inline SimdFloat gmx_simdcall operator+(SimdFloat a, SimdFloat b)
194 return { _mm512_add_ps(a.simdInternal_, b.simdInternal_) };
197 static inline SimdFloat gmx_simdcall operator-(SimdFloat a, SimdFloat b)
199 return { _mm512_sub_ps(a.simdInternal_, b.simdInternal_) };
202 static inline SimdFloat gmx_simdcall operator-(SimdFloat x)
204 return { _mm512_addn_ps(x.simdInternal_, _mm512_setzero_ps()) };
207 static inline SimdFloat gmx_simdcall operator*(SimdFloat a, SimdFloat b)
209 return { _mm512_mul_ps(a.simdInternal_, b.simdInternal_) };
212 static inline SimdFloat gmx_simdcall fma(SimdFloat a, SimdFloat b, SimdFloat c)
214 return { _mm512_fmadd_ps(a.simdInternal_, b.simdInternal_, c.simdInternal_) };
217 static inline SimdFloat gmx_simdcall fms(SimdFloat a, SimdFloat b, SimdFloat c)
219 return { _mm512_fmsub_ps(a.simdInternal_, b.simdInternal_, c.simdInternal_) };
222 static inline SimdFloat gmx_simdcall fnma(SimdFloat a, SimdFloat b, SimdFloat c)
224 return { _mm512_fnmadd_ps(a.simdInternal_, b.simdInternal_, c.simdInternal_) };
227 static inline SimdFloat gmx_simdcall fnms(SimdFloat a, SimdFloat b, SimdFloat c)
229 return { _mm512_fnmsub_ps(a.simdInternal_, b.simdInternal_, c.simdInternal_) };
232 static inline SimdFloat gmx_simdcall rsqrt(SimdFloat x)
234 return { _mm512_rsqrt23_ps(x.simdInternal_) };
237 static inline SimdFloat gmx_simdcall rcp(SimdFloat x)
239 return { _mm512_rcp23_ps(x.simdInternal_) };
242 static inline SimdFloat gmx_simdcall maskAdd(SimdFloat a, SimdFloat b, SimdFBool m)
244 return { _mm512_mask_add_ps(a.simdInternal_, m.simdInternal_, a.simdInternal_, b.simdInternal_) };
247 static inline SimdFloat gmx_simdcall maskzMul(SimdFloat a, SimdFloat b, SimdFBool m)
249 return { _mm512_mask_mul_ps(_mm512_setzero_ps(), m.simdInternal_, a.simdInternal_, b.simdInternal_) };
252 static inline SimdFloat gmx_simdcall maskzFma(SimdFloat a, SimdFloat b, SimdFloat c, SimdFBool m)
254 return { _mm512_mask_mov_ps(_mm512_setzero_ps(), m.simdInternal_,
255 _mm512_fmadd_ps(a.simdInternal_, b.simdInternal_, c.simdInternal_)) };
258 static inline SimdFloat gmx_simdcall maskzRsqrt(SimdFloat x, SimdFBool m)
260 return { _mm512_mask_rsqrt23_ps(_mm512_setzero_ps(), m.simdInternal_, x.simdInternal_) };
263 static inline SimdFloat gmx_simdcall maskzRcp(SimdFloat x, SimdFBool m)
265 return { _mm512_mask_rcp23_ps(_mm512_setzero_ps(), m.simdInternal_, x.simdInternal_) };
268 static inline SimdFloat gmx_simdcall abs(SimdFloat x)
270 return { _mm512_castsi512_ps(_mm512_andnot_epi32(_mm512_castps_si512(_mm512_set1_ps(GMX_FLOAT_NEGZERO)),
271 _mm512_castps_si512(x.simdInternal_))) };
274 static inline SimdFloat gmx_simdcall max(SimdFloat a, SimdFloat b)
276 return { _mm512_gmax_ps(a.simdInternal_, b.simdInternal_) };
279 static inline SimdFloat gmx_simdcall min(SimdFloat a, SimdFloat b)
281 return { _mm512_gmin_ps(a.simdInternal_, b.simdInternal_) };
284 static inline SimdFloat gmx_simdcall round(SimdFloat x)
286 return { _mm512_round_ps(x.simdInternal_, _MM_FROUND_TO_NEAREST_INT, _MM_EXPADJ_NONE) };
289 static inline SimdFloat gmx_simdcall trunc(SimdFloat x)
291 return { _mm512_round_ps(x.simdInternal_, _MM_FROUND_TO_ZERO, _MM_EXPADJ_NONE) };
294 template<MathOptimization opt = MathOptimization::Safe>
295 static inline SimdFloat gmx_simdcall frexp(SimdFloat value, SimdFInt32* exponent)
297 __m512 rExponent;
298 __m512i iExponent;
299 __m512 result;
301 if (opt == MathOptimization::Safe)
303 // For the safe branch, we use the masked operations to only assign results if the
304 // input value was nonzero, and otherwise set exponent to 0, and the fraction to the input (+-0).
305 __mmask16 valueIsNonZero =
306 _mm512_cmp_ps_mask(_mm512_setzero_ps(), value.simdInternal_, _CMP_NEQ_OQ);
307 rExponent = _mm512_mask_getexp_ps(_mm512_setzero_ps(), valueIsNonZero, value.simdInternal_);
308 iExponent = _mm512_cvtfxpnt_round_adjustps_epi32(rExponent, _MM_FROUND_TO_NEAREST_INT,
309 _MM_EXPADJ_NONE);
310 iExponent = _mm512_mask_add_epi32(iExponent, valueIsNonZero, iExponent, _mm512_set1_epi32(1));
312 // Set result to input value when the latter is +-0
313 result = _mm512_mask_getmant_ps(value.simdInternal_, valueIsNonZero, value.simdInternal_,
314 _MM_MANT_NORM_p5_1, _MM_MANT_SIGN_src);
316 else
318 rExponent = _mm512_getexp_ps(value.simdInternal_);
319 iExponent = _mm512_cvtfxpnt_round_adjustps_epi32(rExponent, _MM_FROUND_TO_NEAREST_INT,
320 _MM_EXPADJ_NONE);
321 iExponent = _mm512_add_epi32(iExponent, _mm512_set1_epi32(1));
322 result = _mm512_getmant_ps(value.simdInternal_, _MM_MANT_NORM_p5_1, _MM_MANT_SIGN_src);
325 exponent->simdInternal_ = iExponent;
327 return { result };
330 template<MathOptimization opt = MathOptimization::Safe>
331 static inline SimdFloat gmx_simdcall ldexp(SimdFloat value, SimdFInt32 exponent)
333 const __m512i exponentBias = _mm512_set1_epi32(127);
334 __m512i iExponent = _mm512_add_epi32(exponent.simdInternal_, exponentBias);
336 if (opt == MathOptimization::Safe)
338 // Make sure biased argument is not negative
339 iExponent = _mm512_max_epi32(iExponent, _mm512_setzero_epi32());
342 iExponent = _mm512_slli_epi32(iExponent, 23);
344 return { _mm512_mul_ps(value.simdInternal_, _mm512_castsi512_ps(iExponent)) };
347 static inline float gmx_simdcall reduce(SimdFloat a)
349 return _mm512_reduce_add_ps(a.simdInternal_);
352 // Picky, picky, picky:
353 // icc-16 complains about "Illegal value of immediate argument to intrinsic"
354 // unless we use
355 // 1) Ordered-quiet for ==
356 // 2) Unordered-quiet for !=
357 // 3) Ordered-signaling for < and <=
359 static inline SimdFBool gmx_simdcall operator==(SimdFloat a, SimdFloat b)
361 return { _mm512_cmp_ps_mask(a.simdInternal_, b.simdInternal_, _CMP_EQ_OQ) };
364 static inline SimdFBool gmx_simdcall operator!=(SimdFloat a, SimdFloat b)
366 return { _mm512_cmp_ps_mask(a.simdInternal_, b.simdInternal_, _CMP_NEQ_UQ) };
369 static inline SimdFBool gmx_simdcall operator<(SimdFloat a, SimdFloat b)
371 return { _mm512_cmp_ps_mask(a.simdInternal_, b.simdInternal_, _CMP_LT_OS) };
374 static inline SimdFBool gmx_simdcall operator<=(SimdFloat a, SimdFloat b)
376 return { _mm512_cmp_ps_mask(a.simdInternal_, b.simdInternal_, _CMP_LE_OS) };
379 static inline SimdFBool gmx_simdcall testBits(SimdFloat a)
381 return { _mm512_test_epi32_mask(_mm512_castps_si512(a.simdInternal_),
382 _mm512_castps_si512(a.simdInternal_)) };
385 static inline SimdFBool gmx_simdcall operator&&(SimdFBool a, SimdFBool b)
387 return { _mm512_kand(a.simdInternal_, b.simdInternal_) };
390 static inline SimdFBool gmx_simdcall operator||(SimdFBool a, SimdFBool b)
392 return { _mm512_kor(a.simdInternal_, b.simdInternal_) };
395 static inline bool gmx_simdcall anyTrue(SimdFBool a)
397 return _mm512_mask2int(a.simdInternal_) != 0;
400 static inline SimdFloat gmx_simdcall selectByMask(SimdFloat a, SimdFBool m)
402 return { _mm512_mask_mov_ps(_mm512_setzero_ps(), m.simdInternal_, a.simdInternal_) };
405 static inline SimdFloat gmx_simdcall selectByNotMask(SimdFloat a, SimdFBool m)
407 return { _mm512_mask_mov_ps(a.simdInternal_, m.simdInternal_, _mm512_setzero_ps()) };
410 static inline SimdFloat gmx_simdcall blend(SimdFloat a, SimdFloat b, SimdFBool sel)
412 return { _mm512_mask_blend_ps(sel.simdInternal_, a.simdInternal_, b.simdInternal_) };
415 static inline SimdFInt32 gmx_simdcall operator&(SimdFInt32 a, SimdFInt32 b)
417 return { _mm512_and_epi32(a.simdInternal_, b.simdInternal_) };
420 static inline SimdFInt32 gmx_simdcall andNot(SimdFInt32 a, SimdFInt32 b)
422 return { _mm512_andnot_epi32(a.simdInternal_, b.simdInternal_) };
425 static inline SimdFInt32 gmx_simdcall operator|(SimdFInt32 a, SimdFInt32 b)
427 return { _mm512_or_epi32(a.simdInternal_, b.simdInternal_) };
430 static inline SimdFInt32 gmx_simdcall operator^(SimdFInt32 a, SimdFInt32 b)
432 return { _mm512_xor_epi32(a.simdInternal_, b.simdInternal_) };
435 static inline SimdFInt32 gmx_simdcall operator+(SimdFInt32 a, SimdFInt32 b)
437 return { _mm512_add_epi32(a.simdInternal_, b.simdInternal_) };
440 static inline SimdFInt32 gmx_simdcall operator-(SimdFInt32 a, SimdFInt32 b)
442 return { _mm512_sub_epi32(a.simdInternal_, b.simdInternal_) };
445 static inline SimdFInt32 gmx_simdcall operator*(SimdFInt32 a, SimdFInt32 b)
447 return { _mm512_mullo_epi32(a.simdInternal_, b.simdInternal_) };
450 static inline SimdFIBool gmx_simdcall operator==(SimdFInt32 a, SimdFInt32 b)
452 return { _mm512_cmp_epi32_mask(a.simdInternal_, b.simdInternal_, _MM_CMPINT_EQ) };
455 static inline SimdFIBool gmx_simdcall testBits(SimdFInt32 a)
457 return { _mm512_test_epi32_mask(a.simdInternal_, a.simdInternal_) };
460 static inline SimdFIBool gmx_simdcall operator<(SimdFInt32 a, SimdFInt32 b)
462 return { _mm512_cmp_epi32_mask(a.simdInternal_, b.simdInternal_, _MM_CMPINT_LT) };
465 static inline SimdFIBool gmx_simdcall operator&&(SimdFIBool a, SimdFIBool b)
467 return { _mm512_kand(a.simdInternal_, b.simdInternal_) };
470 static inline SimdFIBool gmx_simdcall operator||(SimdFIBool a, SimdFIBool b)
472 return { _mm512_kor(a.simdInternal_, b.simdInternal_) };
475 static inline bool gmx_simdcall anyTrue(SimdFIBool a)
477 return _mm512_mask2int(a.simdInternal_) != 0;
480 static inline SimdFInt32 gmx_simdcall selectByMask(SimdFInt32 a, SimdFIBool m)
482 return { _mm512_mask_mov_epi32(_mm512_setzero_epi32(), m.simdInternal_, a.simdInternal_) };
485 static inline SimdFInt32 gmx_simdcall selectByNotMask(SimdFInt32 a, SimdFIBool m)
487 return { _mm512_mask_mov_epi32(a.simdInternal_, m.simdInternal_, _mm512_setzero_epi32()) };
490 static inline SimdFInt32 gmx_simdcall blend(SimdFInt32 a, SimdFInt32 b, SimdFIBool sel)
492 return { _mm512_mask_blend_epi32(sel.simdInternal_, a.simdInternal_, b.simdInternal_) };
495 static inline SimdFInt32 gmx_simdcall cvtR2I(SimdFloat a)
497 return { _mm512_cvtfxpnt_round_adjustps_epi32(a.simdInternal_, _MM_FROUND_TO_NEAREST_INT,
498 _MM_EXPADJ_NONE) };
501 static inline SimdFInt32 gmx_simdcall cvttR2I(SimdFloat a)
503 return { _mm512_cvtfxpnt_round_adjustps_epi32(a.simdInternal_, _MM_FROUND_TO_ZERO, _MM_EXPADJ_NONE) };
506 static inline SimdFloat gmx_simdcall cvtI2R(SimdFInt32 a)
508 return { _mm512_cvtfxpnt_round_adjustepi32_ps(a.simdInternal_, _MM_FROUND_TO_NEAREST_INT,
509 _MM_EXPADJ_NONE) };
512 static inline SimdFIBool gmx_simdcall cvtB2IB(SimdFBool a)
514 return { a.simdInternal_ };
517 static inline SimdFBool gmx_simdcall cvtIB2B(SimdFIBool a)
519 return { a.simdInternal_ };
523 template<MathOptimization opt = MathOptimization::Safe>
524 static inline SimdFloat gmx_simdcall exp2(SimdFloat x)
526 return { _mm512_exp223_ps(_mm512_cvtfxpnt_round_adjustps_epi32(
527 x.simdInternal_, _MM_ROUND_MODE_NEAREST, _MM_EXPADJ_24)) };
530 template<MathOptimization opt = MathOptimization::Safe>
531 static inline SimdFloat gmx_simdcall exp(SimdFloat x)
533 const __m512 argscale = _mm512_set1_ps(1.44269504088896341F);
534 const __m512 invargscale = _mm512_set1_ps(-0.69314718055994528623F);
536 if (opt == MathOptimization::Safe)
538 // Set the limit to gurantee flush to zero
539 const SimdFloat smallArgLimit(-88.f);
540 // Since we multiply the argument by 1.44, for the safe version we need to make
541 // sure this doesn't result in overflow
542 x = max(x, smallArgLimit);
545 __m512 xscaled = _mm512_mul_ps(x.simdInternal_, argscale);
546 __m512 r = _mm512_exp223_ps(
547 _mm512_cvtfxpnt_round_adjustps_epi32(xscaled, _MM_ROUND_MODE_NEAREST, _MM_EXPADJ_24));
549 // exp2a23_ps provides 23 bits of accuracy, but we ruin some of that with our argument
550 // scaling. To correct this, we find the difference between the scaled argument and
551 // the true one (extended precision arithmetics does not appear to be necessary to
552 // fulfill our accuracy requirements) and then multiply by the exponent of this
553 // correction since exp(a+b)=exp(a)*exp(b).
554 // Note that this only adds two instructions (and maybe some constant loads).
556 // find the difference
557 x = _mm512_fmadd_ps(invargscale, xscaled, x.simdInternal_);
558 // x will now be a _very_ small number, so approximate exp(x)=1+x.
559 // We should thus apply the correction as r'=r*(1+x)=r+r*x
560 r = _mm512_fmadd_ps(r, x.simdInternal_, r);
561 return { r };
564 static inline SimdFloat gmx_simdcall log(SimdFloat x)
566 return { _mm512_mul_ps(_mm512_set1_ps(0.693147180559945286226764F),
567 _mm512_log2ae23_ps(x.simdInternal_)) };
570 } // namespace gmx
572 #endif // GMX_SIMD_IMPL_X86_MIC_SIMD_FLOAT_H