Improve accuracy of SIMD exp for small args
[gromacs.git] / src / gromacs / simd / impl_x86_avx_512 / impl_x86_avx_512_simd_double.h
blob1366675c18d53c085ec027a126e24a06d1c450d3
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 #ifndef GMX_SIMD_IMPL_X86_AVX_512_SIMD_DOUBLE_H
37 #define GMX_SIMD_IMPL_X86_AVX_512_SIMD_DOUBLE_H
39 #include "config.h"
41 #include <cassert>
42 #include <cstdint>
44 #include <immintrin.h>
46 #include "gromacs/math/utilities.h"
47 #include "gromacs/utility/basedefinitions.h"
49 #include "impl_x86_avx_512_general.h"
50 #include "impl_x86_avx_512_simd_float.h"
52 namespace gmx
55 class SimdDouble
57 public:
58 SimdDouble() {}
60 SimdDouble(double d) : simdInternal_(_mm512_set1_pd(d)) {}
62 // Internal utility constructor to simplify return statements
63 SimdDouble(__m512d simd) : simdInternal_(simd) {}
65 __m512d simdInternal_;
68 class SimdDInt32
70 public:
71 SimdDInt32() {}
73 SimdDInt32(std::int32_t i) : simdInternal_(_mm256_set1_epi32(i)) {}
75 // Internal utility constructor to simplify return statements
76 SimdDInt32(__m256i simd) : simdInternal_(simd) {}
78 __m256i simdInternal_;
81 class SimdDBool
83 public:
84 SimdDBool() {}
86 // Internal utility constructor to simplify return statements
87 SimdDBool(__mmask8 simd) : simdInternal_(simd) {}
89 __mmask8 simdInternal_;
92 class SimdDIBool
94 public:
95 SimdDIBool() {}
97 // Internal utility constructor to simplify return statements
98 SimdDIBool(__mmask16 simd) : simdInternal_(simd) {}
100 __mmask16 simdInternal_;
103 static inline SimdDouble gmx_simdcall
104 simdLoad(const double *m)
106 assert(std::size_t(m) % 64 == 0);
107 return {
108 _mm512_load_pd(m)
112 static inline void gmx_simdcall
113 store(double *m, SimdDouble a)
115 assert(std::size_t(m) % 64 == 0);
116 _mm512_store_pd(m, a.simdInternal_);
119 static inline SimdDouble gmx_simdcall
120 simdLoadU(const double *m)
122 return {
123 _mm512_loadu_pd(m)
127 static inline void gmx_simdcall
128 storeU(double *m, SimdDouble a)
130 _mm512_storeu_pd(m, a.simdInternal_);
133 static inline SimdDouble gmx_simdcall
134 setZeroD()
136 return {
137 _mm512_setzero_pd()
141 static inline SimdDInt32 gmx_simdcall
142 simdLoadDI(const std::int32_t * m)
144 assert(std::size_t(m) % 32 == 0);
145 return {
146 _mm256_load_si256(reinterpret_cast<const __m256i *>(m))
150 static inline void gmx_simdcall
151 store(std::int32_t * m, SimdDInt32 a)
153 assert(std::size_t(m) % 32 == 0);
154 _mm256_store_si256(reinterpret_cast<__m256i *>(m), a.simdInternal_);
157 static inline SimdDInt32 gmx_simdcall
158 simdLoadUDI(const std::int32_t *m)
160 return {
161 _mm256_loadu_si256(reinterpret_cast<const __m256i *>(m))
165 static inline void gmx_simdcall
166 storeU(std::int32_t * m, SimdDInt32 a)
168 _mm256_storeu_si256(reinterpret_cast<__m256i *>(m), a.simdInternal_);
171 static inline SimdDInt32 gmx_simdcall
172 setZeroDI()
174 return {
175 _mm256_setzero_si256()
179 static inline SimdDouble gmx_simdcall
180 operator&(SimdDouble a, SimdDouble b)
182 return {
183 _mm512_castsi512_pd(_mm512_and_epi32(_mm512_castpd_si512(a.simdInternal_), _mm512_castpd_si512(b.simdInternal_)))
187 static inline SimdDouble gmx_simdcall
188 andNot(SimdDouble a, SimdDouble b)
190 return {
191 _mm512_castsi512_pd(_mm512_andnot_epi32(_mm512_castpd_si512(a.simdInternal_), _mm512_castpd_si512(b.simdInternal_)))
195 static inline SimdDouble gmx_simdcall
196 operator|(SimdDouble a, SimdDouble b)
198 return {
199 _mm512_castsi512_pd(_mm512_or_epi32(_mm512_castpd_si512(a.simdInternal_), _mm512_castpd_si512(b.simdInternal_)))
203 static inline SimdDouble gmx_simdcall
204 operator^(SimdDouble a, SimdDouble b)
206 return {
207 _mm512_castsi512_pd(_mm512_xor_epi32(_mm512_castpd_si512(a.simdInternal_), _mm512_castpd_si512(b.simdInternal_)))
211 static inline SimdDouble gmx_simdcall
212 operator+(SimdDouble a, SimdDouble b)
214 return {
215 _mm512_add_pd(a.simdInternal_, b.simdInternal_)
219 static inline SimdDouble gmx_simdcall
220 operator-(SimdDouble a, SimdDouble b)
222 return {
223 _mm512_sub_pd(a.simdInternal_, b.simdInternal_)
227 static inline SimdDouble gmx_simdcall
228 operator-(SimdDouble x)
230 return {
231 _mm512_castsi512_pd(_mm512_xor_epi32(_mm512_castpd_si512(x.simdInternal_), _mm512_castpd_si512(_mm512_set1_pd(GMX_DOUBLE_NEGZERO))))
235 static inline SimdDouble gmx_simdcall
236 operator*(SimdDouble a, SimdDouble b)
238 return {
239 _mm512_mul_pd(a.simdInternal_, b.simdInternal_)
243 static inline SimdDouble gmx_simdcall
244 fma(SimdDouble a, SimdDouble b, SimdDouble c)
246 return {
247 _mm512_fmadd_pd(a.simdInternal_, b.simdInternal_, c.simdInternal_)
251 static inline SimdDouble gmx_simdcall
252 fms(SimdDouble a, SimdDouble b, SimdDouble c)
254 return {
255 _mm512_fmsub_pd(a.simdInternal_, b.simdInternal_, c.simdInternal_)
259 static inline SimdDouble gmx_simdcall
260 fnma(SimdDouble a, SimdDouble b, SimdDouble c)
262 return {
263 _mm512_fnmadd_pd(a.simdInternal_, b.simdInternal_, c.simdInternal_)
267 static inline SimdDouble gmx_simdcall
268 fnms(SimdDouble a, SimdDouble b, SimdDouble c)
270 return {
271 _mm512_fnmsub_pd(a.simdInternal_, b.simdInternal_, c.simdInternal_)
275 // Override for AVX-512-KNL
276 #if GMX_SIMD_X86_AVX_512
277 static inline SimdDouble gmx_simdcall
278 rsqrt(SimdDouble x)
280 return {
281 _mm512_rsqrt14_pd(x.simdInternal_)
285 static inline SimdDouble gmx_simdcall
286 rcp(SimdDouble x)
288 return {
289 _mm512_rcp14_pd(x.simdInternal_)
292 #endif
294 static inline SimdDouble gmx_simdcall
295 maskAdd(SimdDouble a, SimdDouble b, SimdDBool m)
297 return {
298 _mm512_mask_add_pd(a.simdInternal_, m.simdInternal_, a.simdInternal_, b.simdInternal_)
302 static inline SimdDouble gmx_simdcall
303 maskzMul(SimdDouble a, SimdDouble b, SimdDBool m)
305 return {
306 _mm512_maskz_mul_pd(m.simdInternal_, a.simdInternal_, b.simdInternal_)
310 static inline SimdDouble gmx_simdcall
311 maskzFma(SimdDouble a, SimdDouble b, SimdDouble c, SimdDBool m)
313 return {
314 _mm512_maskz_fmadd_pd(m.simdInternal_, a.simdInternal_, b.simdInternal_, c.simdInternal_)
318 // Override for AVX-512-KNL
319 #if GMX_SIMD_X86_AVX_512
320 static inline SimdDouble gmx_simdcall
321 maskzRsqrt(SimdDouble x, SimdDBool m)
323 return {
324 _mm512_maskz_rsqrt14_pd(m.simdInternal_, x.simdInternal_)
328 static inline SimdDouble gmx_simdcall
329 maskzRcp(SimdDouble x, SimdDBool m)
331 return {
332 _mm512_maskz_rcp14_pd(m.simdInternal_, x.simdInternal_)
335 #endif
337 static inline SimdDouble gmx_simdcall
338 abs(SimdDouble x)
340 return {
341 _mm512_castsi512_pd(_mm512_andnot_epi32(_mm512_castpd_si512(_mm512_set1_pd(GMX_DOUBLE_NEGZERO)), _mm512_castpd_si512(x.simdInternal_)))
345 static inline SimdDouble gmx_simdcall
346 max(SimdDouble a, SimdDouble b)
348 return {
349 _mm512_max_pd(a.simdInternal_, b.simdInternal_)
353 static inline SimdDouble gmx_simdcall
354 min(SimdDouble a, SimdDouble b)
356 return {
357 _mm512_min_pd(a.simdInternal_, b.simdInternal_)
361 static inline SimdDouble gmx_simdcall
362 round(SimdDouble x)
364 return {
365 _mm512_roundscale_pd(x.simdInternal_, 0)
369 static inline SimdDouble gmx_simdcall
370 trunc(SimdDouble x)
372 #if defined(__INTEL_COMPILER) || defined(__ECC)
373 return {
374 _mm512_trunc_pd(x.simdInternal_)
376 #else
377 return {
378 _mm512_cvtepi32_pd(_mm512_cvttpd_epi32(x.simdInternal_))
380 #endif
383 static inline SimdDouble
384 frexp(SimdDouble value, SimdDInt32 * exponent)
386 __m512d rExponent = _mm512_getexp_pd(value.simdInternal_);
387 __m256i iExponent = _mm512_cvtpd_epi32(rExponent);
389 exponent->simdInternal_ = _mm256_add_epi32(iExponent, _mm256_set1_epi32(1));
391 return {
392 _mm512_getmant_pd(value.simdInternal_, _MM_MANT_NORM_p5_1, _MM_MANT_SIGN_src)
396 template <MathOptimization opt = MathOptimization::Safe>
397 static inline SimdDouble
398 ldexp(SimdDouble value, SimdDInt32 exponent)
400 const __m256i exponentBias = _mm256_set1_epi32(1023);
401 __m256i iExponent = _mm256_add_epi32(exponent.simdInternal_, exponentBias);
402 __m512i iExponent512;
404 if (opt == MathOptimization::Safe)
406 // Make sure biased argument is not negative
407 iExponent = _mm256_max_epi32(iExponent, _mm256_setzero_si256());
410 iExponent512 = _mm512_permutexvar_epi32(_mm512_set_epi32(7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0), _mm512_castsi256_si512(iExponent));
411 iExponent512 = _mm512_mask_slli_epi32(_mm512_setzero_epi32(), avx512Int2Mask(0xAAAA), iExponent512, 20);
412 return _mm512_mul_pd(_mm512_castsi512_pd(iExponent512), value.simdInternal_);
415 static inline double gmx_simdcall
416 reduce(SimdDouble a)
418 __m512d x = a.simdInternal_;
419 x = _mm512_add_pd(x, _mm512_shuffle_f64x2(x, x, 0xEE));
420 x = _mm512_add_pd(x, _mm512_shuffle_f64x2(x, x, 0x11));
421 x = _mm512_add_pd(x, _mm512_permute_pd(x, 0x01));
422 return *reinterpret_cast<double *>(&x);
425 static inline SimdDBool gmx_simdcall
426 operator==(SimdDouble a, SimdDouble b)
428 return {
429 _mm512_cmp_pd_mask(a.simdInternal_, b.simdInternal_, _CMP_EQ_OQ)
433 static inline SimdDBool gmx_simdcall
434 operator!=(SimdDouble a, SimdDouble b)
436 return {
437 _mm512_cmp_pd_mask(a.simdInternal_, b.simdInternal_, _CMP_NEQ_OQ)
441 static inline SimdDBool gmx_simdcall
442 operator<(SimdDouble a, SimdDouble b)
444 return {
445 _mm512_cmp_pd_mask(a.simdInternal_, b.simdInternal_, _CMP_LT_OQ)
449 static inline SimdDBool gmx_simdcall
450 operator<=(SimdDouble a, SimdDouble b)
452 return {
453 _mm512_cmp_pd_mask(a.simdInternal_, b.simdInternal_, _CMP_LE_OQ)
457 static inline SimdDBool gmx_simdcall
458 testBits(SimdDouble a)
460 return {
461 _mm512_test_epi64_mask(_mm512_castpd_si512(a.simdInternal_), _mm512_castpd_si512(a.simdInternal_))
465 static inline SimdDBool gmx_simdcall
466 operator&&(SimdDBool a, SimdDBool b)
468 return {
469 static_cast<__mmask8>(_mm512_kand(a.simdInternal_, b.simdInternal_))
473 static inline SimdDBool gmx_simdcall
474 operator||(SimdDBool a, SimdDBool b)
476 return {
477 static_cast<__mmask8>(_mm512_kor(a.simdInternal_, b.simdInternal_))
481 static inline bool gmx_simdcall
482 anyTrue(SimdDBool a)
484 return ( avx512Mask2Int(a.simdInternal_) != 0);
487 static inline SimdDouble gmx_simdcall
488 selectByMask(SimdDouble a, SimdDBool m)
490 return {
491 _mm512_mask_mov_pd(_mm512_setzero_pd(), m.simdInternal_, a.simdInternal_)
495 static inline SimdDouble gmx_simdcall
496 selectByNotMask(SimdDouble a, SimdDBool m)
498 return {
499 _mm512_mask_mov_pd(a.simdInternal_, m.simdInternal_, _mm512_setzero_pd())
503 static inline SimdDouble gmx_simdcall
504 blend(SimdDouble a, SimdDouble b, SimdDBool sel)
506 return {
507 _mm512_mask_blend_pd(sel.simdInternal_, a.simdInternal_, b.simdInternal_)
511 static inline SimdDInt32 gmx_simdcall
512 operator<<(SimdDInt32 a, int n)
514 return {
515 _mm256_slli_epi32(a.simdInternal_, n)
519 static inline SimdDInt32 gmx_simdcall
520 operator>>(SimdDInt32 a, int n)
522 return {
523 _mm256_srli_epi32(a.simdInternal_, n)
527 static inline SimdDInt32 gmx_simdcall
528 operator&(SimdDInt32 a, SimdDInt32 b)
530 return {
531 _mm256_and_si256(a.simdInternal_, b.simdInternal_)
535 static inline SimdDInt32 gmx_simdcall
536 andNot(SimdDInt32 a, SimdDInt32 b)
538 return {
539 _mm256_andnot_si256(a.simdInternal_, b.simdInternal_)
543 static inline SimdDInt32 gmx_simdcall
544 operator|(SimdDInt32 a, SimdDInt32 b)
546 return {
547 _mm256_or_si256(a.simdInternal_, b.simdInternal_)
551 static inline SimdDInt32 gmx_simdcall
552 operator^(SimdDInt32 a, SimdDInt32 b)
554 return {
555 _mm256_xor_si256(a.simdInternal_, b.simdInternal_)
559 static inline SimdDInt32 gmx_simdcall
560 operator+(SimdDInt32 a, SimdDInt32 b)
562 return {
563 _mm256_add_epi32(a.simdInternal_, b.simdInternal_)
567 static inline SimdDInt32 gmx_simdcall
568 operator-(SimdDInt32 a, SimdDInt32 b)
570 return {
571 _mm256_sub_epi32(a.simdInternal_, b.simdInternal_)
575 static inline SimdDInt32 gmx_simdcall
576 operator*(SimdDInt32 a, SimdDInt32 b)
578 return {
579 _mm256_mullo_epi32(a.simdInternal_, b.simdInternal_)
583 static inline SimdDIBool gmx_simdcall
584 operator==(SimdDInt32 a, SimdDInt32 b)
586 return {
587 _mm512_mask_cmp_epi32_mask(avx512Int2Mask(0xFF), _mm512_castsi256_si512(a.simdInternal_), _mm512_castsi256_si512(b.simdInternal_), _MM_CMPINT_EQ)
591 static inline SimdDIBool gmx_simdcall
592 testBits(SimdDInt32 a)
594 return {
595 _mm512_mask_test_epi32_mask(avx512Int2Mask(0xFF), _mm512_castsi256_si512(a.simdInternal_), _mm512_castsi256_si512(a.simdInternal_))
599 static inline SimdDIBool gmx_simdcall
600 operator<(SimdDInt32 a, SimdDInt32 b)
602 return {
603 _mm512_mask_cmp_epi32_mask(avx512Int2Mask(0xFF), _mm512_castsi256_si512(a.simdInternal_), _mm512_castsi256_si512(b.simdInternal_), _MM_CMPINT_LT)
607 static inline SimdDIBool gmx_simdcall
608 operator&&(SimdDIBool a, SimdDIBool b)
610 return {
611 _mm512_kand(a.simdInternal_, b.simdInternal_)
615 static inline SimdDIBool gmx_simdcall
616 operator||(SimdDIBool a, SimdDIBool b)
618 return {
619 _mm512_kor(a.simdInternal_, b.simdInternal_)
623 static inline bool gmx_simdcall
624 anyTrue(SimdDIBool a)
626 return ( avx512Mask2Int(a.simdInternal_) & 0xFF) != 0;
629 static inline SimdDInt32 gmx_simdcall
630 selectByMask(SimdDInt32 a, SimdDIBool m)
632 return {
633 _mm512_castsi512_si256(_mm512_mask_mov_epi32(_mm512_setzero_si512(), m.simdInternal_, _mm512_castsi256_si512(a.simdInternal_)))
637 static inline SimdDInt32 gmx_simdcall
638 selectByNotMask(SimdDInt32 a, SimdDIBool m)
640 return {
641 _mm512_castsi512_si256(_mm512_mask_mov_epi32(_mm512_castsi256_si512(a.simdInternal_), m.simdInternal_, _mm512_setzero_si512()))
645 static inline SimdDInt32 gmx_simdcall
646 blend(SimdDInt32 a, SimdDInt32 b, SimdDIBool sel)
648 return {
649 _mm512_castsi512_si256(_mm512_mask_blend_epi32(sel.simdInternal_, _mm512_castsi256_si512(a.simdInternal_), _mm512_castsi256_si512(b.simdInternal_)))
653 static inline SimdDInt32 gmx_simdcall
654 cvtR2I(SimdDouble a)
656 return {
657 _mm512_cvtpd_epi32(a.simdInternal_)
661 static inline SimdDInt32 gmx_simdcall
662 cvttR2I(SimdDouble a)
664 return {
665 _mm512_cvttpd_epi32(a.simdInternal_)
669 static inline SimdDouble gmx_simdcall
670 cvtI2R(SimdDInt32 a)
672 return {
673 _mm512_cvtepi32_pd(a.simdInternal_)
677 static inline SimdDIBool gmx_simdcall
678 cvtB2IB(SimdDBool a)
680 return {
681 a.simdInternal_
685 static inline SimdDBool gmx_simdcall
686 cvtIB2B(SimdDIBool a)
688 return {
689 static_cast<__mmask8>(a.simdInternal_)
693 static inline void gmx_simdcall
694 cvtF2DD(SimdFloat f, SimdDouble *d0, SimdDouble *d1)
696 d0->simdInternal_ = _mm512_cvtps_pd(_mm512_castps512_ps256(f.simdInternal_));
697 d1->simdInternal_ = _mm512_cvtps_pd(_mm512_castps512_ps256(_mm512_shuffle_f32x4(f.simdInternal_, f.simdInternal_, 0xEE)));
700 static inline SimdFloat gmx_simdcall
701 cvtDD2F(SimdDouble d0, SimdDouble d1)
703 __m512 f0 = _mm512_castps256_ps512(_mm512_cvtpd_ps(d0.simdInternal_));
704 __m512 f1 = _mm512_castps256_ps512(_mm512_cvtpd_ps(d1.simdInternal_));
705 return {
706 _mm512_shuffle_f32x4(f0, f1, 0x44)
710 } // namespace gmx
712 #endif // GMX_SIMD_IMPL_X86_AVX_512_SIMD_DOUBLE_H