Improve accuracy of SIMD exp for small args
[gromacs.git] / src / gromacs / simd / impl_x86_avx_256 / impl_x86_avx_256_simd_double.h
blob5910184acb9772803ccfde0c0bc916f30b8f1f6b
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 #ifndef GMX_SIMD_IMPL_X86_AVX_256_SIMD_DOUBLE_H
37 #define GMX_SIMD_IMPL_X86_AVX_256_SIMD_DOUBLE_H
39 #include "config.h"
41 #include <cassert>
42 #include <cstddef>
43 #include <cstdint>
45 #include <immintrin.h>
47 #include "gromacs/math/utilities.h"
49 #include "impl_x86_avx_256_simd_float.h"
54 namespace gmx
57 class SimdDouble
59 public:
60 SimdDouble() {}
62 SimdDouble(double d) : simdInternal_(_mm256_set1_pd(d)) {}
64 // Internal utility constructor to simplify return statements
65 SimdDouble(__m256d simd) : simdInternal_(simd) {}
67 __m256d simdInternal_;
70 class SimdDInt32
72 public:
73 SimdDInt32() {}
75 SimdDInt32(std::int32_t i) : simdInternal_(_mm_set1_epi32(i)) {}
77 // Internal utility constructor to simplify return statements
78 SimdDInt32(__m128i simd) : simdInternal_(simd) {}
80 __m128i simdInternal_;
83 class SimdDBool
85 public:
86 SimdDBool() {}
88 SimdDBool(bool b) : simdInternal_(_mm256_castsi256_pd(_mm256_set1_epi32( b ? 0xFFFFFFFF : 0))) {}
90 // Internal utility constructor to simplify return statements
91 SimdDBool(__m256d simd) : simdInternal_(simd) {}
93 __m256d simdInternal_;
96 class SimdDIBool
98 public:
99 SimdDIBool() {}
101 SimdDIBool(bool b) : simdInternal_(_mm_set1_epi32( b ? 0xFFFFFFFF : 0)) {}
103 // Internal utility constructor to simplify return statements
104 SimdDIBool(__m128i simd) : simdInternal_(simd) {}
106 __m128i simdInternal_;
110 static inline SimdDouble gmx_simdcall
111 simdLoad(const double *m)
113 assert(std::size_t(m) % 32 == 0);
114 return {
115 _mm256_load_pd(m)
119 static inline void gmx_simdcall
120 store(double *m, SimdDouble a)
122 assert(std::size_t(m) % 32 == 0);
123 _mm256_store_pd(m, a.simdInternal_);
126 static inline SimdDouble gmx_simdcall
127 simdLoadU(const double *m)
129 return {
130 _mm256_loadu_pd(m)
134 static inline void gmx_simdcall
135 storeU(double *m, SimdDouble a)
137 _mm256_storeu_pd(m, a.simdInternal_);
140 static inline SimdDouble gmx_simdcall
141 setZeroD()
143 return {
144 _mm256_setzero_pd()
148 static inline SimdDInt32 gmx_simdcall
149 simdLoadDI(const std::int32_t * m)
151 assert(std::size_t(m) % 16 == 0);
152 return {
153 _mm_load_si128(reinterpret_cast<const __m128i *>(m))
157 static inline void gmx_simdcall
158 store(std::int32_t * m, SimdDInt32 a)
160 assert(std::size_t(m) % 16 == 0);
161 _mm_store_si128(reinterpret_cast<__m128i *>(m), a.simdInternal_);
164 static inline SimdDInt32 gmx_simdcall
165 simdLoadUDI(const std::int32_t *m)
167 return {
168 _mm_loadu_si128(reinterpret_cast<const __m128i *>(m))
172 static inline void gmx_simdcall
173 storeU(std::int32_t * m, SimdDInt32 a)
175 _mm_storeu_si128(reinterpret_cast<__m128i *>(m), a.simdInternal_);
178 static inline SimdDInt32 gmx_simdcall
179 setZeroDI()
181 return {
182 _mm_setzero_si128()
186 template<int index>
187 static inline std::int32_t gmx_simdcall
188 extract(SimdDInt32 a)
190 return _mm_extract_epi32(a.simdInternal_, index);
193 static inline SimdDouble gmx_simdcall
194 operator&(SimdDouble a, SimdDouble b)
196 return {
197 _mm256_and_pd(a.simdInternal_, b.simdInternal_)
201 static inline SimdDouble gmx_simdcall
202 andNot(SimdDouble a, SimdDouble b)
204 return {
205 _mm256_andnot_pd(a.simdInternal_, b.simdInternal_)
209 static inline SimdDouble gmx_simdcall
210 operator|(SimdDouble a, SimdDouble b)
212 return {
213 _mm256_or_pd(a.simdInternal_, b.simdInternal_)
217 static inline SimdDouble gmx_simdcall
218 operator^(SimdDouble a, SimdDouble b)
220 return {
221 _mm256_xor_pd(a.simdInternal_, b.simdInternal_)
225 static inline SimdDouble gmx_simdcall
226 operator+(SimdDouble a, SimdDouble b)
228 return {
229 _mm256_add_pd(a.simdInternal_, b.simdInternal_)
233 static inline SimdDouble gmx_simdcall
234 operator-(SimdDouble a, SimdDouble b)
236 return {
237 _mm256_sub_pd(a.simdInternal_, b.simdInternal_)
241 static inline SimdDouble gmx_simdcall
242 operator-(SimdDouble x)
244 return {
245 _mm256_xor_pd(x.simdInternal_, _mm256_set1_pd(GMX_DOUBLE_NEGZERO))
249 static inline SimdDouble gmx_simdcall
250 operator*(SimdDouble a, SimdDouble b)
252 return {
253 _mm256_mul_pd(a.simdInternal_, b.simdInternal_)
257 // Override for AVX2 and higher
258 #if GMX_SIMD_X86_AVX_256
259 static inline SimdDouble gmx_simdcall
260 fma(SimdDouble a, SimdDouble b, SimdDouble c)
262 return {
263 _mm256_add_pd(_mm256_mul_pd(a.simdInternal_, b.simdInternal_), c.simdInternal_)
267 static inline SimdDouble gmx_simdcall
268 fms(SimdDouble a, SimdDouble b, SimdDouble c)
270 return {
271 _mm256_sub_pd(_mm256_mul_pd(a.simdInternal_, b.simdInternal_), c.simdInternal_)
275 static inline SimdDouble gmx_simdcall
276 fnma(SimdDouble a, SimdDouble b, SimdDouble c)
278 return {
279 _mm256_sub_pd(c.simdInternal_, _mm256_mul_pd(a.simdInternal_, b.simdInternal_))
283 static inline SimdDouble gmx_simdcall
284 fnms(SimdDouble a, SimdDouble b, SimdDouble c)
286 return {
287 _mm256_sub_pd(_mm256_setzero_pd(), _mm256_add_pd(_mm256_mul_pd(a.simdInternal_, b.simdInternal_), c.simdInternal_))
290 #endif
292 static inline SimdDouble gmx_simdcall
293 rsqrt(SimdDouble x)
295 return {
296 _mm256_cvtps_pd(_mm_rsqrt_ps(_mm256_cvtpd_ps(x.simdInternal_)))
300 static inline SimdDouble gmx_simdcall
301 rcp(SimdDouble x)
303 return {
304 _mm256_cvtps_pd(_mm_rcp_ps(_mm256_cvtpd_ps(x.simdInternal_)))
308 static inline SimdDouble gmx_simdcall
309 maskAdd(SimdDouble a, SimdDouble b, SimdDBool m)
311 return {
312 _mm256_add_pd(a.simdInternal_, _mm256_and_pd(b.simdInternal_, m.simdInternal_))
316 static inline SimdDouble gmx_simdcall
317 maskzMul(SimdDouble a, SimdDouble b, SimdDBool m)
319 return {
320 _mm256_and_pd(_mm256_mul_pd(a.simdInternal_, b.simdInternal_), m.simdInternal_)
324 static inline SimdDouble
325 maskzFma(SimdDouble a, SimdDouble b, SimdDouble c, SimdDBool m)
327 return {
328 _mm256_and_pd(_mm256_add_pd(_mm256_mul_pd(a.simdInternal_, b.simdInternal_), c.simdInternal_), m.simdInternal_)
332 static inline SimdDouble
333 maskzRsqrt(SimdDouble x, SimdDBool m)
335 #ifndef NDEBUG
336 x.simdInternal_ = _mm256_blendv_pd(_mm256_set1_pd(1.0), x.simdInternal_, m.simdInternal_);
337 #endif
338 return {
339 _mm256_and_pd(_mm256_cvtps_pd(_mm_rsqrt_ps(_mm256_cvtpd_ps(x.simdInternal_))), m.simdInternal_)
343 static inline SimdDouble
344 maskzRcp(SimdDouble x, SimdDBool m)
346 #ifndef NDEBUG
347 x.simdInternal_ = _mm256_blendv_pd(_mm256_set1_pd(1.0), x.simdInternal_, m.simdInternal_);
348 #endif
349 return {
350 _mm256_and_pd(_mm256_cvtps_pd(_mm_rcp_ps(_mm256_cvtpd_ps(x.simdInternal_))), m.simdInternal_)
354 static inline SimdDouble gmx_simdcall
355 abs(SimdDouble x)
357 return {
358 _mm256_andnot_pd( _mm256_set1_pd(GMX_DOUBLE_NEGZERO), x.simdInternal_ )
362 static inline SimdDouble gmx_simdcall
363 max(SimdDouble a, SimdDouble b)
365 return {
366 _mm256_max_pd(a.simdInternal_, b.simdInternal_)
370 static inline SimdDouble gmx_simdcall
371 min(SimdDouble a, SimdDouble b)
373 return {
374 _mm256_min_pd(a.simdInternal_, b.simdInternal_)
378 static inline SimdDouble gmx_simdcall
379 round(SimdDouble x)
381 return {
382 _mm256_round_pd(x.simdInternal_, _MM_FROUND_NINT)
386 static inline SimdDouble gmx_simdcall
387 trunc(SimdDouble x)
389 return {
390 _mm256_round_pd(x.simdInternal_, _MM_FROUND_TRUNC)
394 // Override for AVX2 and higher
395 #if GMX_SIMD_X86_AVX_256
396 static inline SimdDouble
397 frexp(SimdDouble value, SimdDInt32 * exponent)
399 const __m256d exponentMask = _mm256_castsi256_pd( _mm256_set1_epi64x(0x7FF0000000000000LL));
400 const __m256d mantissaMask = _mm256_castsi256_pd( _mm256_set1_epi64x(0x800FFFFFFFFFFFFFLL));
401 const __m256d half = _mm256_set1_pd(0.5);
402 const __m128i exponentBias = _mm_set1_epi32(1022); // add 1 to make our definition identical to frexp()
403 __m256i iExponent;
404 __m128i iExponentLow, iExponentHigh;
406 iExponent = _mm256_castpd_si256(_mm256_and_pd(value.simdInternal_, exponentMask));
407 iExponentHigh = _mm256_extractf128_si256(iExponent, 0x1);
408 iExponentLow = _mm256_castsi256_si128(iExponent);
409 iExponentLow = _mm_srli_epi64(iExponentLow, 52);
410 iExponentHigh = _mm_srli_epi64(iExponentHigh, 52);
411 iExponentLow = _mm_shuffle_epi32(iExponentLow, _MM_SHUFFLE(1, 1, 2, 0));
412 iExponentHigh = _mm_shuffle_epi32(iExponentHigh, _MM_SHUFFLE(2, 0, 1, 1));
413 iExponentLow = _mm_or_si128(iExponentLow, iExponentHigh);
414 exponent->simdInternal_ = _mm_sub_epi32(iExponentLow, exponentBias);
416 return {
417 _mm256_or_pd(_mm256_and_pd(value.simdInternal_, mantissaMask), half)
421 template <MathOptimization opt = MathOptimization::Safe>
422 static inline SimdDouble
423 ldexp(SimdDouble value, SimdDInt32 exponent)
425 const __m128i exponentBias = _mm_set1_epi32(1023);
426 __m128i iExponentLow, iExponentHigh;
427 __m256d fExponent;
429 iExponentLow = _mm_add_epi32(exponent.simdInternal_, exponentBias);
431 if (opt == MathOptimization::Safe)
433 // Make sure biased argument is not negative
434 iExponentLow = _mm_max_epi32(iExponentLow, _mm_setzero_si128());
437 iExponentHigh = _mm_shuffle_epi32(iExponentLow, _MM_SHUFFLE(3, 3, 2, 2));
438 iExponentLow = _mm_shuffle_epi32(iExponentLow, _MM_SHUFFLE(1, 1, 0, 0));
439 iExponentHigh = _mm_slli_epi64(iExponentHigh, 52);
440 iExponentLow = _mm_slli_epi64(iExponentLow, 52);
441 fExponent = _mm256_castsi256_pd(_mm256_insertf128_si256(_mm256_castsi128_si256(iExponentLow), iExponentHigh, 0x1));
442 return {
443 _mm256_mul_pd(value.simdInternal_, fExponent)
446 #endif
448 static inline double gmx_simdcall
449 reduce(SimdDouble a)
451 __m128d a0, a1;
452 a.simdInternal_ = _mm256_add_pd(a.simdInternal_, _mm256_permute_pd(a.simdInternal_, 0b0101 ));
453 a0 = _mm256_castpd256_pd128(a.simdInternal_);
454 a1 = _mm256_extractf128_pd(a.simdInternal_, 0x1);
455 a0 = _mm_add_sd(a0, a1);
457 return *reinterpret_cast<double *>(&a0);
460 static inline SimdDBool gmx_simdcall
461 operator==(SimdDouble a, SimdDouble b)
463 return {
464 _mm256_cmp_pd(a.simdInternal_, b.simdInternal_, _CMP_EQ_OQ)
468 static inline SimdDBool gmx_simdcall
469 operator!=(SimdDouble a, SimdDouble b)
471 return {
472 _mm256_cmp_pd(a.simdInternal_, b.simdInternal_, _CMP_NEQ_OQ)
476 static inline SimdDBool gmx_simdcall
477 operator<(SimdDouble a, SimdDouble b)
479 return {
480 _mm256_cmp_pd(a.simdInternal_, b.simdInternal_, _CMP_LT_OQ)
484 static inline SimdDBool gmx_simdcall
485 operator<=(SimdDouble a, SimdDouble b)
487 return {
488 _mm256_cmp_pd(a.simdInternal_, b.simdInternal_, _CMP_LE_OQ)
492 // Override for AVX2 and higher
493 #if GMX_SIMD_X86_AVX_256
494 static inline SimdDBool gmx_simdcall
495 testBits(SimdDouble a)
497 // Do an or of the low/high 32 bits of each double (so the data is replicated),
498 // and then use the same algorithm as we use for single precision.
499 __m256 tst = _mm256_castpd_ps(a.simdInternal_);
501 tst = _mm256_or_ps(tst, _mm256_permute_ps(tst, _MM_SHUFFLE(2, 3, 0, 1)));
502 tst = _mm256_cvtepi32_ps(_mm256_castps_si256(tst));
504 return {
505 _mm256_castps_pd(_mm256_cmp_ps(tst, _mm256_setzero_ps(), _CMP_NEQ_OQ))
508 #endif
510 static inline SimdDBool gmx_simdcall
511 operator&&(SimdDBool a, SimdDBool b)
513 return {
514 _mm256_and_pd(a.simdInternal_, b.simdInternal_)
518 static inline SimdDBool gmx_simdcall
519 operator||(SimdDBool a, SimdDBool b)
521 return {
522 _mm256_or_pd(a.simdInternal_, b.simdInternal_)
526 static inline bool gmx_simdcall
527 anyTrue(SimdDBool a) { return _mm256_movemask_pd(a.simdInternal_) != 0; }
529 static inline SimdDouble gmx_simdcall
530 selectByMask(SimdDouble a, SimdDBool mask)
532 return {
533 _mm256_and_pd(a.simdInternal_, mask.simdInternal_)
537 static inline SimdDouble gmx_simdcall
538 selectByNotMask(SimdDouble a, SimdDBool mask)
540 return {
541 _mm256_andnot_pd(mask.simdInternal_, a.simdInternal_)
545 static inline SimdDouble gmx_simdcall
546 blend(SimdDouble a, SimdDouble b, SimdDBool sel)
548 return {
549 _mm256_blendv_pd(a.simdInternal_, b.simdInternal_, sel.simdInternal_)
553 static inline SimdDInt32 gmx_simdcall
554 operator<<(SimdDInt32 a, int n)
556 return {
557 _mm_slli_epi32(a.simdInternal_, n)
561 static inline SimdDInt32 gmx_simdcall
562 operator>>(SimdDInt32 a, int n)
564 return {
565 _mm_srli_epi32(a.simdInternal_, n)
569 static inline SimdDInt32 gmx_simdcall
570 operator&(SimdDInt32 a, SimdDInt32 b)
572 return {
573 _mm_and_si128(a.simdInternal_, b.simdInternal_)
577 static inline SimdDInt32 gmx_simdcall
578 andNot(SimdDInt32 a, SimdDInt32 b)
580 return {
581 _mm_andnot_si128(a.simdInternal_, b.simdInternal_)
585 static inline SimdDInt32 gmx_simdcall
586 operator|(SimdDInt32 a, SimdDInt32 b)
588 return {
589 _mm_or_si128(a.simdInternal_, b.simdInternal_)
593 static inline SimdDInt32 gmx_simdcall
594 operator^(SimdDInt32 a, SimdDInt32 b)
596 return {
597 _mm_xor_si128(a.simdInternal_, b.simdInternal_)
601 static inline SimdDInt32 gmx_simdcall
602 operator+(SimdDInt32 a, SimdDInt32 b)
604 return {
605 _mm_add_epi32(a.simdInternal_, b.simdInternal_)
609 static inline SimdDInt32 gmx_simdcall
610 operator-(SimdDInt32 a, SimdDInt32 b)
612 return {
613 _mm_sub_epi32(a.simdInternal_, b.simdInternal_)
617 static inline SimdDInt32 gmx_simdcall
618 operator*(SimdDInt32 a, SimdDInt32 b)
620 return {
621 _mm_mullo_epi32(a.simdInternal_, b.simdInternal_)
625 static inline SimdDIBool gmx_simdcall
626 operator==(SimdDInt32 a, SimdDInt32 b)
628 return {
629 _mm_cmpeq_epi32(a.simdInternal_, b.simdInternal_)
633 static inline SimdDIBool gmx_simdcall
634 operator<(SimdDInt32 a, SimdDInt32 b)
636 return {
637 _mm_cmplt_epi32(a.simdInternal_, b.simdInternal_)
641 static inline SimdDIBool gmx_simdcall
642 testBits(SimdDInt32 a)
644 __m128i x = a.simdInternal_;
645 __m128i res = _mm_andnot_si128( _mm_cmpeq_epi32(x, _mm_setzero_si128()), _mm_cmpeq_epi32(x, x));
647 return {
652 static inline SimdDIBool gmx_simdcall
653 operator&&(SimdDIBool a, SimdDIBool b)
655 return {
656 _mm_and_si128(a.simdInternal_, b.simdInternal_)
660 static inline SimdDIBool gmx_simdcall
661 operator||(SimdDIBool a, SimdDIBool b)
663 return {
664 _mm_or_si128(a.simdInternal_, b.simdInternal_)
668 static inline bool gmx_simdcall
669 anyTrue(SimdDIBool a) { return _mm_movemask_epi8(_mm_shuffle_epi32(a.simdInternal_, _MM_SHUFFLE(1, 0, 1, 0))) != 0; }
671 static inline SimdDInt32 gmx_simdcall
672 selectByMask(SimdDInt32 a, SimdDIBool mask)
674 return {
675 _mm_and_si128(a.simdInternal_, mask.simdInternal_)
679 static inline SimdDInt32 gmx_simdcall
680 selectByNotMask(SimdDInt32 a, SimdDIBool mask)
682 return {
683 _mm_andnot_si128(mask.simdInternal_, a.simdInternal_)
687 static inline SimdDInt32 gmx_simdcall
688 blend(SimdDInt32 a, SimdDInt32 b, SimdDIBool sel)
690 return {
691 _mm_blendv_epi8(a.simdInternal_, b.simdInternal_, sel.simdInternal_)
695 static inline SimdDInt32 gmx_simdcall
696 cvtR2I(SimdDouble a)
698 return {
699 _mm256_cvtpd_epi32(a.simdInternal_)
703 static inline SimdDInt32 gmx_simdcall
704 cvttR2I(SimdDouble a)
706 return {
707 _mm256_cvttpd_epi32(a.simdInternal_)
711 static inline SimdDouble gmx_simdcall
712 cvtI2R(SimdDInt32 a)
714 return {
715 _mm256_cvtepi32_pd(a.simdInternal_)
719 static inline SimdDIBool gmx_simdcall
720 cvtB2IB(SimdDBool a)
722 __m128i a1 = _mm256_extractf128_si256(_mm256_castpd_si256(a.simdInternal_), 0x1);
723 __m128i a0 = _mm256_castsi256_si128(_mm256_castpd_si256(a.simdInternal_));
724 a0 = _mm_shuffle_epi32(a0, _MM_SHUFFLE(2, 0, 2, 0));
725 a1 = _mm_shuffle_epi32(a1, _MM_SHUFFLE(2, 0, 2, 0));
727 return {
728 _mm_blend_epi16(a0, a1, 0xF0)
732 static inline SimdDBool gmx_simdcall
733 cvtIB2B(SimdDIBool a)
735 __m128d lo = _mm_castsi128_pd(_mm_unpacklo_epi32(a.simdInternal_, a.simdInternal_));
736 __m128d hi = _mm_castsi128_pd(_mm_unpackhi_epi32(a.simdInternal_, a.simdInternal_));
738 return {
739 _mm256_insertf128_pd(_mm256_castpd128_pd256(lo), hi, 0x1)
743 static inline void gmx_simdcall
744 cvtF2DD(SimdFloat f, SimdDouble *d0, SimdDouble *d1)
746 d0->simdInternal_ = _mm256_cvtps_pd(_mm256_castps256_ps128(f.simdInternal_));
747 d1->simdInternal_ = _mm256_cvtps_pd(_mm256_extractf128_ps(f.simdInternal_, 0x1));
750 static inline SimdFloat gmx_simdcall
751 cvtDD2F(SimdDouble d0, SimdDouble d1)
753 __m128 f0 = _mm256_cvtpd_ps(d0.simdInternal_);
754 __m128 f1 = _mm256_cvtpd_ps(d1.simdInternal_);
755 return {
756 _mm256_insertf128_ps(_mm256_castps128_ps256(f0), f1, 0x1)
760 } // namespace gmx
762 #endif // GMX_SIMD_IMPL_X86_AVX_256_SIMD_DOUBLE_H