Improve accuracy of SIMD exp for small args
[gromacs.git] / src / gromacs / simd / impl_x86_mic / impl_x86_mic_simd_float.h
blob9a054eea231bc3db504e1a24275aeb167bc5423a
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 #ifndef GMX_SIMD_IMPL_X86_MIC_SIMD_FLOAT_H
37 #define GMX_SIMD_IMPL_X86_MIC_SIMD_FLOAT_H
39 #include "config.h"
41 #include <cassert>
42 #include <cstdint>
44 #include <immintrin.h>
46 #include "gromacs/math/utilities.h"
48 namespace gmx
51 class SimdFloat
53 public:
54 SimdFloat() {}
56 SimdFloat(float f) : simdInternal_(_mm512_set1_ps(f)) {}
58 // Internal utility constructor to simplify return statements
59 SimdFloat(__m512 simd) : simdInternal_(simd) {}
61 __m512 simdInternal_;
64 class SimdFInt32
66 public:
67 SimdFInt32() {}
69 SimdFInt32(std::int32_t i) : simdInternal_(_mm512_set1_epi32(i)) {}
71 // Internal utility constructor to simplify return statements
72 SimdFInt32(__m512i simd) : simdInternal_(simd) {}
74 __m512i simdInternal_;
77 class SimdFBool
79 public:
80 SimdFBool() {}
82 SimdFBool(bool b) : simdInternal_(_mm512_int2mask( b ? 0xFFFF : 0)) {}
84 // Internal utility constructor to simplify return statements
85 SimdFBool(__mmask16 simd) : simdInternal_(simd) {}
87 __mmask16 simdInternal_;
90 class SimdFIBool
92 public:
93 SimdFIBool() {}
95 SimdFIBool(bool b) : simdInternal_(_mm512_int2mask( b ? 0xFFFF : 0)) {}
97 // Internal utility constructor to simplify return statements
98 SimdFIBool(__mmask16 simd) : simdInternal_(simd) {}
100 __mmask16 simdInternal_;
103 static inline SimdFloat gmx_simdcall
104 simdLoad(const float *m)
106 assert(std::size_t(m) % 64 == 0);
107 return {
108 _mm512_load_ps(m)
112 static inline void gmx_simdcall
113 store(float *m, SimdFloat a)
115 assert(std::size_t(m) % 64 == 0);
116 _mm512_store_ps(m, a.simdInternal_);
119 static inline SimdFloat gmx_simdcall
120 simdLoadU(const float *m)
122 return {
123 _mm512_loadunpackhi_ps(_mm512_loadunpacklo_ps(_mm512_undefined_ps(), m), m+16)
127 static inline void gmx_simdcall
128 storeU(float *m, SimdFloat a)
130 _mm512_packstorelo_ps(m, a.simdInternal_);
131 _mm512_packstorehi_ps(m+16, a.simdInternal_);
134 static inline SimdFloat gmx_simdcall
135 setZeroF()
137 return {
138 _mm512_setzero_ps()
142 static inline SimdFInt32 gmx_simdcall
143 simdLoadFI(const std::int32_t * m)
145 assert(std::size_t(m) % 64 == 0);
146 return {
147 _mm512_load_epi32(m)
151 static inline void gmx_simdcall
152 store(std::int32_t * m, SimdFInt32 a)
154 assert(std::size_t(m) % 64 == 0);
155 _mm512_store_epi32(m, a.simdInternal_);
158 static inline SimdFInt32 gmx_simdcall
159 simdLoadUFI(const std::int32_t *m)
161 return {
162 _mm512_loadunpackhi_epi32(_mm512_loadunpacklo_epi32(_mm512_undefined_epi32(), m), m+16)
166 static inline void gmx_simdcall
167 storeU(std::int32_t * m, SimdFInt32 a)
169 _mm512_packstorelo_epi32(m, a.simdInternal_);
170 _mm512_packstorehi_epi32(m+16, a.simdInternal_);
173 static inline SimdFInt32 gmx_simdcall
174 setZeroFI()
176 return {
177 _mm512_setzero_si512()
182 template<int index>
183 static inline std::int32_t gmx_simdcall
184 extract(SimdFInt32 a)
186 int r;
187 _mm512_mask_packstorelo_epi32(&r, _mm512_mask2int(1<<index), a.simdInternal_);
188 return r;
191 static inline SimdFloat gmx_simdcall
192 operator&(SimdFloat a, SimdFloat b)
194 return {
195 _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(a.simdInternal_), _mm512_castps_si512(b.simdInternal_)))
199 static inline SimdFloat gmx_simdcall
200 andNot(SimdFloat a, SimdFloat b)
202 return {
203 _mm512_castsi512_ps(_mm512_andnot_epi32(_mm512_castps_si512(a.simdInternal_), _mm512_castps_si512(b.simdInternal_)))
207 static inline SimdFloat gmx_simdcall
208 operator|(SimdFloat a, SimdFloat b)
210 return {
211 _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(a.simdInternal_), _mm512_castps_si512(b.simdInternal_)))
215 static inline SimdFloat gmx_simdcall
216 operator^(SimdFloat a, SimdFloat b)
218 return {
219 _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(a.simdInternal_), _mm512_castps_si512(b.simdInternal_)))
223 static inline SimdFloat gmx_simdcall
224 operator+(SimdFloat a, SimdFloat b)
226 return {
227 _mm512_add_ps(a.simdInternal_, b.simdInternal_)
231 static inline SimdFloat gmx_simdcall
232 operator-(SimdFloat a, SimdFloat b)
234 return {
235 _mm512_sub_ps(a.simdInternal_, b.simdInternal_)
239 static inline SimdFloat gmx_simdcall
240 operator-(SimdFloat x)
242 return {
243 _mm512_addn_ps(x.simdInternal_, _mm512_setzero_ps())
247 static inline SimdFloat gmx_simdcall
248 operator*(SimdFloat a, SimdFloat b)
250 return {
251 _mm512_mul_ps(a.simdInternal_, b.simdInternal_)
255 static inline SimdFloat gmx_simdcall
256 fma(SimdFloat a, SimdFloat b, SimdFloat c)
258 return {
259 _mm512_fmadd_ps(a.simdInternal_, b.simdInternal_, c.simdInternal_)
263 static inline SimdFloat gmx_simdcall
264 fms(SimdFloat a, SimdFloat b, SimdFloat c)
266 return {
267 _mm512_fmsub_ps(a.simdInternal_, b.simdInternal_, c.simdInternal_)
271 static inline SimdFloat gmx_simdcall
272 fnma(SimdFloat a, SimdFloat b, SimdFloat c)
274 return {
275 _mm512_fnmadd_ps(a.simdInternal_, b.simdInternal_, c.simdInternal_)
279 static inline SimdFloat gmx_simdcall
280 fnms(SimdFloat a, SimdFloat b, SimdFloat c)
282 return {
283 _mm512_fnmsub_ps(a.simdInternal_, b.simdInternal_, c.simdInternal_)
287 static inline SimdFloat gmx_simdcall
288 rsqrt(SimdFloat x)
290 return {
291 _mm512_rsqrt23_ps(x.simdInternal_)
295 static inline SimdFloat gmx_simdcall
296 rcp(SimdFloat x)
298 return {
299 _mm512_rcp23_ps(x.simdInternal_)
303 static inline SimdFloat gmx_simdcall
304 maskAdd(SimdFloat a, SimdFloat b, SimdFBool m)
306 return {
307 _mm512_mask_add_ps(a.simdInternal_, m.simdInternal_, a.simdInternal_, b.simdInternal_)
311 static inline SimdFloat gmx_simdcall
312 maskzMul(SimdFloat a, SimdFloat b, SimdFBool m)
314 return {
315 _mm512_mask_mul_ps(_mm512_setzero_ps(), m.simdInternal_, a.simdInternal_, b.simdInternal_)
319 static inline SimdFloat gmx_simdcall
320 maskzFma(SimdFloat a, SimdFloat b, SimdFloat c, SimdFBool m)
322 return {
323 _mm512_mask_mov_ps(_mm512_setzero_ps(), m.simdInternal_, _mm512_fmadd_ps(a.simdInternal_, b.simdInternal_, c.simdInternal_))
327 static inline SimdFloat gmx_simdcall
328 maskzRsqrt(SimdFloat x, SimdFBool m)
330 return {
331 _mm512_mask_rsqrt23_ps(_mm512_setzero_ps(), m.simdInternal_, x.simdInternal_)
335 static inline SimdFloat gmx_simdcall
336 maskzRcp(SimdFloat x, SimdFBool m)
338 return {
339 _mm512_mask_rcp23_ps(_mm512_setzero_ps(), m.simdInternal_, x.simdInternal_)
343 static inline SimdFloat gmx_simdcall
344 abs(SimdFloat x)
346 return {
347 _mm512_castsi512_ps(_mm512_andnot_epi32(_mm512_castps_si512(_mm512_set1_ps(GMX_FLOAT_NEGZERO)), _mm512_castps_si512(x.simdInternal_)))
351 static inline SimdFloat gmx_simdcall
352 max(SimdFloat a, SimdFloat b)
354 return {
355 _mm512_gmax_ps(a.simdInternal_, b.simdInternal_)
359 static inline SimdFloat gmx_simdcall
360 min(SimdFloat a, SimdFloat b)
362 return {
363 _mm512_gmin_ps(a.simdInternal_, b.simdInternal_)
367 static inline SimdFloat gmx_simdcall
368 round(SimdFloat x)
370 return {
371 _mm512_round_ps(x.simdInternal_, _MM_FROUND_TO_NEAREST_INT, _MM_EXPADJ_NONE)
375 static inline SimdFloat gmx_simdcall
376 trunc(SimdFloat x)
378 return {
379 _mm512_round_ps(x.simdInternal_, _MM_FROUND_TO_ZERO, _MM_EXPADJ_NONE)
383 static inline SimdFloat gmx_simdcall
384 frexp(SimdFloat value, SimdFInt32 * exponent)
386 __m512 rExponent = _mm512_getexp_ps(value.simdInternal_);
387 __m512i iExponent = _mm512_cvtfxpnt_round_adjustps_epi32(rExponent, _MM_FROUND_TO_NEAREST_INT, _MM_EXPADJ_NONE);
389 exponent->simdInternal_ = _mm512_add_epi32(iExponent, _mm512_set1_epi32(1));
391 return {
392 _mm512_getmant_ps(value.simdInternal_, _MM_MANT_NORM_p5_1, _MM_MANT_SIGN_src)
396 template <MathOptimization opt = MathOptimization::Safe>
397 static inline SimdFloat gmx_simdcall
398 ldexp(SimdFloat value, SimdFInt32 exponent)
400 const __m512i exponentBias = _mm512_set1_epi32(127);
401 __m512i iExponent = _mm512_add_epi32(exponent.simdInternal_, exponentBias);
403 if (opt == MathOptimization::Safe)
405 // Make sure biased argument is not negative
406 iExponent = _mm512_max_epi32(iExponent, _mm512_setzero_epi32());
409 iExponent = _mm512_slli_epi32( iExponent, 23);
411 return {
412 _mm512_mul_ps(value.simdInternal_, _mm512_castsi512_ps(iExponent))
416 static inline float gmx_simdcall
417 reduce(SimdFloat a)
419 return _mm512_reduce_add_ps(a.simdInternal_);
422 // Picky, picky, picky:
423 // icc-16 complains about "Illegal value of immediate argument to intrinsic"
424 // unless we use
425 // 1) Ordered-quiet for ==
426 // 2) Unordered-quiet for !=
427 // 3) Ordered-signaling for < and <=
429 static inline SimdFBool gmx_simdcall
430 operator==(SimdFloat a, SimdFloat b)
432 return {
433 _mm512_cmp_ps_mask(a.simdInternal_, b.simdInternal_, _CMP_EQ_OQ)
437 static inline SimdFBool gmx_simdcall
438 operator!=(SimdFloat a, SimdFloat b)
440 return {
441 _mm512_cmp_ps_mask(a.simdInternal_, b.simdInternal_, _CMP_NEQ_UQ)
445 static inline SimdFBool gmx_simdcall
446 operator<(SimdFloat a, SimdFloat b)
448 return {
449 _mm512_cmp_ps_mask(a.simdInternal_, b.simdInternal_, _CMP_LT_OS)
453 static inline SimdFBool gmx_simdcall
454 operator<=(SimdFloat a, SimdFloat b)
456 return {
457 _mm512_cmp_ps_mask(a.simdInternal_, b.simdInternal_, _CMP_LE_OS)
461 static inline SimdFBool gmx_simdcall
462 testBits(SimdFloat a)
464 return {
465 _mm512_test_epi32_mask( _mm512_castps_si512(a.simdInternal_), _mm512_castps_si512(a.simdInternal_) )
469 static inline SimdFBool gmx_simdcall
470 operator&&(SimdFBool a, SimdFBool b)
472 return {
473 _mm512_kand(a.simdInternal_, b.simdInternal_)
477 static inline SimdFBool gmx_simdcall
478 operator||(SimdFBool a, SimdFBool b)
480 return {
481 _mm512_kor(a.simdInternal_, b.simdInternal_)
485 static inline bool gmx_simdcall
486 anyTrue(SimdFBool a)
488 return _mm512_mask2int(a.simdInternal_) != 0;
491 static inline SimdFloat gmx_simdcall
492 selectByMask(SimdFloat a, SimdFBool m)
494 return {
495 _mm512_mask_mov_ps(_mm512_setzero_ps(), m.simdInternal_, a.simdInternal_)
499 static inline SimdFloat gmx_simdcall
500 selectByNotMask(SimdFloat a, SimdFBool m)
502 return {
503 _mm512_mask_mov_ps(a.simdInternal_, m.simdInternal_, _mm512_setzero_ps())
507 static inline SimdFloat gmx_simdcall
508 blend(SimdFloat a, SimdFloat b, SimdFBool sel)
510 return {
511 _mm512_mask_blend_ps(sel.simdInternal_, a.simdInternal_, b.simdInternal_)
515 static inline SimdFInt32 gmx_simdcall
516 operator<<(SimdFInt32 a, int n)
518 return {
519 _mm512_slli_epi32(a.simdInternal_, n)
523 static inline SimdFInt32 gmx_simdcall
524 operator>>(SimdFInt32 a, int n)
526 return {
527 _mm512_srli_epi32(a.simdInternal_, n)
531 static inline SimdFInt32 gmx_simdcall
532 operator&(SimdFInt32 a, SimdFInt32 b)
534 return {
535 _mm512_and_epi32(a.simdInternal_, b.simdInternal_)
539 static inline SimdFInt32 gmx_simdcall
540 andNot(SimdFInt32 a, SimdFInt32 b)
542 return {
543 _mm512_andnot_epi32(a.simdInternal_, b.simdInternal_)
547 static inline SimdFInt32 gmx_simdcall
548 operator|(SimdFInt32 a, SimdFInt32 b)
550 return {
551 _mm512_or_epi32(a.simdInternal_, b.simdInternal_)
555 static inline SimdFInt32 gmx_simdcall
556 operator^(SimdFInt32 a, SimdFInt32 b)
558 return {
559 _mm512_xor_epi32(a.simdInternal_, b.simdInternal_)
563 static inline SimdFInt32 gmx_simdcall
564 operator+(SimdFInt32 a, SimdFInt32 b)
566 return {
567 _mm512_add_epi32(a.simdInternal_, b.simdInternal_)
571 static inline SimdFInt32 gmx_simdcall
572 operator-(SimdFInt32 a, SimdFInt32 b)
574 return {
575 _mm512_sub_epi32(a.simdInternal_, b.simdInternal_)
579 static inline SimdFInt32 gmx_simdcall
580 operator*(SimdFInt32 a, SimdFInt32 b)
582 return {
583 _mm512_mullo_epi32(a.simdInternal_, b.simdInternal_)
587 static inline SimdFIBool gmx_simdcall
588 operator==(SimdFInt32 a, SimdFInt32 b)
590 return {
591 _mm512_cmp_epi32_mask(a.simdInternal_, b.simdInternal_, _MM_CMPINT_EQ)
595 static inline SimdFIBool gmx_simdcall
596 testBits(SimdFInt32 a)
598 return {
599 _mm512_test_epi32_mask( a.simdInternal_, a.simdInternal_ )
603 static inline SimdFIBool gmx_simdcall
604 operator<(SimdFInt32 a, SimdFInt32 b)
606 return {
607 _mm512_cmp_epi32_mask(a.simdInternal_, b.simdInternal_, _MM_CMPINT_LT)
611 static inline SimdFIBool gmx_simdcall
612 operator&&(SimdFIBool a, SimdFIBool b)
614 return {
615 _mm512_kand(a.simdInternal_, b.simdInternal_)
619 static inline SimdFIBool gmx_simdcall
620 operator||(SimdFIBool a, SimdFIBool b)
622 return {
623 _mm512_kor(a.simdInternal_, b.simdInternal_)
627 static inline bool gmx_simdcall
628 anyTrue(SimdFIBool a)
630 return _mm512_mask2int(a.simdInternal_) != 0;
633 static inline SimdFInt32 gmx_simdcall
634 selectByMask(SimdFInt32 a, SimdFIBool m)
636 return {
637 _mm512_mask_mov_epi32(_mm512_setzero_epi32(), m.simdInternal_, a.simdInternal_)
641 static inline SimdFInt32 gmx_simdcall
642 selectByNotMask(SimdFInt32 a, SimdFIBool m)
644 return {
645 _mm512_mask_mov_epi32(a.simdInternal_, m.simdInternal_, _mm512_setzero_epi32())
649 static inline SimdFInt32 gmx_simdcall
650 blend(SimdFInt32 a, SimdFInt32 b, SimdFIBool sel)
652 return {
653 _mm512_mask_blend_epi32(sel.simdInternal_, a.simdInternal_, b.simdInternal_)
657 static inline SimdFInt32 gmx_simdcall
658 cvtR2I(SimdFloat a)
660 return {
661 _mm512_cvtfxpnt_round_adjustps_epi32(a.simdInternal_, _MM_FROUND_TO_NEAREST_INT, _MM_EXPADJ_NONE)
665 static inline SimdFInt32 gmx_simdcall
666 cvttR2I(SimdFloat a)
668 return {
669 _mm512_cvtfxpnt_round_adjustps_epi32(a.simdInternal_, _MM_FROUND_TO_ZERO, _MM_EXPADJ_NONE)
673 static inline SimdFloat gmx_simdcall
674 cvtI2R(SimdFInt32 a)
676 return {
677 _mm512_cvtfxpnt_round_adjustepi32_ps(a.simdInternal_, _MM_FROUND_TO_NEAREST_INT, _MM_EXPADJ_NONE)
681 static inline SimdFIBool gmx_simdcall
682 cvtB2IB(SimdFBool a)
684 return {
685 a.simdInternal_
689 static inline SimdFBool gmx_simdcall
690 cvtIB2B(SimdFIBool a)
692 return {
693 a.simdInternal_
698 template <MathOptimization opt = MathOptimization::Safe>
699 static inline SimdFloat gmx_simdcall
700 exp2(SimdFloat x)
702 return {
703 _mm512_exp223_ps(_mm512_cvtfxpnt_round_adjustps_epi32(x.simdInternal_, _MM_ROUND_MODE_NEAREST, _MM_EXPADJ_24))
707 template <MathOptimization opt = MathOptimization::Safe>
708 static inline SimdFloat gmx_simdcall
709 exp(SimdFloat x)
711 const __m512 argscale = _mm512_set1_ps(1.44269504088896341f);
712 const __m512 invargscale = _mm512_set1_ps(-0.69314718055994528623f);
714 if (opt == MathOptimization::Safe)
716 // Set the limit to gurantee flush to zero
717 const SimdFloat smallArgLimit(-88.f);
718 // Since we multiply the argument by 1.44, for the safe version we need to make
719 // sure this doesn't result in overflow
720 x = max(x, smallArgLimit);
723 __m512 xscaled = _mm512_mul_ps(x.simdInternal_, argscale);
724 __m512 r = _mm512_exp223_ps(_mm512_cvtfxpnt_round_adjustps_epi32(xscaled, _MM_ROUND_MODE_NEAREST, _MM_EXPADJ_24));
726 // exp2a23_ps provides 23 bits of accuracy, but we ruin some of that with our argument
727 // scaling. To correct this, we find the difference between the scaled argument and
728 // the true one (extended precision arithmetics does not appear to be necessary to
729 // fulfill our accuracy requirements) and then multiply by the exponent of this
730 // correction since exp(a+b)=exp(a)*exp(b).
731 // Note that this only adds two instructions (and maybe some constant loads).
733 // find the difference
734 x = _mm512_fmadd_ps(invargscale, xscaled, x.simdInternal_);
735 // x will now be a _very_ small number, so approximate exp(x)=1+x.
736 // We should thus apply the correction as r'=r*(1+x)=r+r*x
737 r = _mm512_fmadd_ps(r, x.simdInternal_, r);
738 return {
743 static inline SimdFloat gmx_simdcall
744 log(SimdFloat x)
746 return {
747 _mm512_mul_ps(_mm512_set1_ps(0.693147180559945286226764f), _mm512_log2ae23_ps(x.simdInternal_))
751 } // namespace gmx
753 #endif // GMX_SIMD_IMPL_X86_MIC_SIMD_FLOAT_H