Improve accuracy of SIMD exp for small args
[gromacs.git] / src / gromacs / simd / impl_x86_sse2 / impl_x86_sse2_simd_float.h
blobc2f3ab8da655d251bd406a44268df6c104f7d086
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
35 #ifndef GMX_SIMD_IMPL_X86_SSE2_SIMD_FLOAT_H
36 #define GMX_SIMD_IMPL_X86_SSE2_SIMD_FLOAT_H
38 #include "config.h"
40 #include <cassert>
41 #include <cstddef>
42 #include <cstdint>
44 #include <emmintrin.h>
46 #include "gromacs/math/utilities.h"
48 namespace gmx
51 class SimdFloat
53 public:
54 SimdFloat() {}
56 SimdFloat(float f) : simdInternal_(_mm_set1_ps(f)) {}
58 // Internal utility constructor to simplify return statements
59 SimdFloat(__m128 simd) : simdInternal_(simd) {}
61 __m128 simdInternal_;
64 class SimdFInt32
66 public:
67 SimdFInt32() {}
69 SimdFInt32(std::int32_t i) : simdInternal_(_mm_set1_epi32(i)) {}
71 // Internal utility constructor to simplify return statements
72 SimdFInt32(__m128i simd) : simdInternal_(simd) {}
74 __m128i simdInternal_;
77 class SimdFBool
79 public:
80 SimdFBool() {}
82 SimdFBool(bool b) : simdInternal_(_mm_castsi128_ps(_mm_set1_epi32( b ? 0xFFFFFFFF : 0))) {}
84 // Internal utility constructor to simplify return statements
85 SimdFBool(__m128 simd) : simdInternal_(simd) {}
87 __m128 simdInternal_;
90 class SimdFIBool
92 public:
93 SimdFIBool() {}
95 SimdFIBool(bool b) : simdInternal_(_mm_set1_epi32( b ? 0xFFFFFFFF : 0)) {}
97 // Internal utility constructor to simplify return statements
98 SimdFIBool(__m128i simd) : simdInternal_(simd) {}
100 __m128i simdInternal_;
103 static inline SimdFloat gmx_simdcall
104 simdLoad(const float *m)
106 assert(std::size_t(m) % 16 == 0);
107 return {
108 _mm_load_ps(m)
112 static inline void gmx_simdcall
113 store(float *m, SimdFloat a)
115 assert(std::size_t(m) % 16 == 0);
116 _mm_store_ps(m, a.simdInternal_);
119 static inline SimdFloat gmx_simdcall
120 simdLoadU(const float *m)
122 return {
123 _mm_loadu_ps(m)
127 static inline void gmx_simdcall
128 storeU(float *m, SimdFloat a) { _mm_storeu_ps(m, a.simdInternal_); }
130 static inline SimdFloat gmx_simdcall
131 setZeroF()
133 return {
134 _mm_setzero_ps()
138 static inline SimdFInt32 gmx_simdcall
139 simdLoadFI(const std::int32_t * m)
141 assert(std::size_t(m) % 16 == 0);
142 return {
143 _mm_load_si128(reinterpret_cast<const __m128i *>(m))
147 static inline void gmx_simdcall
148 store(std::int32_t * m, SimdFInt32 a)
150 assert(std::size_t(m) % 16 == 0);
151 _mm_store_si128(reinterpret_cast<__m128i *>(m), a.simdInternal_);
154 static inline SimdFInt32 gmx_simdcall
155 simdLoadUFI(const std::int32_t *m)
157 return {
158 _mm_loadu_si128(reinterpret_cast<const __m128i *>(m))
162 static inline void gmx_simdcall
163 storeU(std::int32_t * m, SimdFInt32 a)
165 _mm_storeu_si128(reinterpret_cast<__m128i *>(m), a.simdInternal_);
168 static inline SimdFInt32 gmx_simdcall
169 setZeroFI()
171 return {
172 _mm_setzero_si128()
177 // Override for SSE4.1 and higher
178 #if GMX_SIMD_X86_SSE2
179 template<int index>
180 static inline std::int32_t gmx_simdcall
181 extract(SimdFInt32 a)
183 return _mm_cvtsi128_si32( _mm_srli_si128(a.simdInternal_, 4 * index) );
185 #endif
187 static inline SimdFloat gmx_simdcall
188 operator&(SimdFloat a, SimdFloat b)
190 return {
191 _mm_and_ps(a.simdInternal_, b.simdInternal_)
195 static inline SimdFloat gmx_simdcall
196 andNot(SimdFloat a, SimdFloat b)
198 return {
199 _mm_andnot_ps(a.simdInternal_, b.simdInternal_)
203 static inline SimdFloat gmx_simdcall
204 operator|(SimdFloat a, SimdFloat b)
206 return {
207 _mm_or_ps(a.simdInternal_, b.simdInternal_)
211 static inline SimdFloat gmx_simdcall
212 operator^(SimdFloat a, SimdFloat b)
214 return {
215 _mm_xor_ps(a.simdInternal_, b.simdInternal_)
219 static inline SimdFloat gmx_simdcall
220 operator+(SimdFloat a, SimdFloat b)
222 return {
223 _mm_add_ps(a.simdInternal_, b.simdInternal_)
227 static inline SimdFloat gmx_simdcall
228 operator-(SimdFloat a, SimdFloat b)
230 return {
231 _mm_sub_ps(a.simdInternal_, b.simdInternal_)
235 static inline SimdFloat gmx_simdcall
236 operator-(SimdFloat x)
238 return {
239 _mm_xor_ps(x.simdInternal_, _mm_set1_ps(GMX_FLOAT_NEGZERO))
243 static inline SimdFloat gmx_simdcall
244 operator*(SimdFloat a, SimdFloat b)
246 return {
247 _mm_mul_ps(a.simdInternal_, b.simdInternal_)
251 // Override for AVX-128-FMA and higher
252 #if GMX_SIMD_X86_SSE2 || GMX_SIMD_X86_SSE4_1
253 static inline SimdFloat gmx_simdcall
254 fma(SimdFloat a, SimdFloat b, SimdFloat c)
256 return {
257 _mm_add_ps(_mm_mul_ps(a.simdInternal_, b.simdInternal_), c.simdInternal_)
261 static inline SimdFloat gmx_simdcall
262 fms(SimdFloat a, SimdFloat b, SimdFloat c)
264 return {
265 _mm_sub_ps(_mm_mul_ps(a.simdInternal_, b.simdInternal_), c.simdInternal_)
269 static inline SimdFloat gmx_simdcall
270 fnma(SimdFloat a, SimdFloat b, SimdFloat c)
272 return {
273 _mm_sub_ps(c.simdInternal_, _mm_mul_ps(a.simdInternal_, b.simdInternal_))
277 static inline SimdFloat gmx_simdcall
278 fnms(SimdFloat a, SimdFloat b, SimdFloat c)
280 return {
281 _mm_sub_ps(_mm_setzero_ps(), _mm_add_ps(_mm_mul_ps(a.simdInternal_, b.simdInternal_), c.simdInternal_))
284 #endif
286 static inline SimdFloat gmx_simdcall
287 rsqrt(SimdFloat x)
289 return {
290 _mm_rsqrt_ps(x.simdInternal_)
294 static inline SimdFloat gmx_simdcall
295 rcp(SimdFloat x)
297 return {
298 _mm_rcp_ps(x.simdInternal_)
302 static inline SimdFloat gmx_simdcall
303 maskAdd(SimdFloat a, SimdFloat b, SimdFBool m)
305 return {
306 _mm_add_ps(a.simdInternal_, _mm_and_ps(b.simdInternal_, m.simdInternal_))
310 static inline SimdFloat gmx_simdcall
311 maskzMul(SimdFloat a, SimdFloat b, SimdFBool m)
313 return {
314 _mm_and_ps(_mm_mul_ps(a.simdInternal_, b.simdInternal_), m.simdInternal_)
318 static inline SimdFloat gmx_simdcall
319 maskzFma(SimdFloat a, SimdFloat b, SimdFloat c, SimdFBool m)
321 return {
322 _mm_and_ps(_mm_add_ps(_mm_mul_ps(a.simdInternal_, b.simdInternal_), c.simdInternal_), m.simdInternal_)
326 // Override for SSE4.1 and higher
327 #if GMX_SIMD_X86_SSE2
328 static inline SimdFloat gmx_simdcall
329 maskzRsqrt(SimdFloat x, SimdFBool m)
331 #ifndef NDEBUG
332 x.simdInternal_ = _mm_or_ps(_mm_andnot_ps(m.simdInternal_, _mm_set1_ps(1.0f)), _mm_and_ps(m.simdInternal_, x.simdInternal_));
333 #endif
334 return {
335 _mm_and_ps(_mm_rsqrt_ps(x.simdInternal_), m.simdInternal_)
339 static inline SimdFloat gmx_simdcall
340 maskzRcp(SimdFloat x, SimdFBool m)
342 #ifndef NDEBUG
343 x.simdInternal_ = _mm_or_ps(_mm_andnot_ps(m.simdInternal_, _mm_set1_ps(1.0f)), _mm_and_ps(m.simdInternal_, x.simdInternal_));
344 #endif
345 return {
346 _mm_and_ps(_mm_rcp_ps(x.simdInternal_), m.simdInternal_)
349 #endif
351 static inline SimdFloat gmx_simdcall
352 abs(SimdFloat x)
354 return {
355 _mm_andnot_ps( _mm_set1_ps(GMX_FLOAT_NEGZERO), x.simdInternal_ )
359 static inline SimdFloat gmx_simdcall
360 max(SimdFloat a, SimdFloat b)
362 return {
363 _mm_max_ps(a.simdInternal_, b.simdInternal_)
367 static inline SimdFloat gmx_simdcall
368 min(SimdFloat a, SimdFloat b)
370 return {
371 _mm_min_ps(a.simdInternal_, b.simdInternal_)
375 // Override for SSE4.1 and higher
376 #if GMX_SIMD_X86_SSE2
377 static inline SimdFloat gmx_simdcall
378 round(SimdFloat x)
380 return {
381 _mm_cvtepi32_ps( _mm_cvtps_epi32(x.simdInternal_) )
385 static inline SimdFloat gmx_simdcall
386 trunc(SimdFloat x)
388 return {
389 _mm_cvtepi32_ps( _mm_cvttps_epi32(x.simdInternal_) )
393 #endif
395 static inline SimdFloat gmx_simdcall
396 frexp(SimdFloat value, SimdFInt32 * exponent)
398 const __m128 exponentMask = _mm_castsi128_ps(_mm_set1_epi32(0x7F800000));
399 const __m128 mantissaMask = _mm_castsi128_ps(_mm_set1_epi32(0x807FFFFF));
400 const __m128i exponentBias = _mm_set1_epi32(126); // add 1 to make our definition identical to frexp()
401 const __m128 half = _mm_set1_ps(0.5f);
402 __m128i iExponent;
404 iExponent = _mm_castps_si128(_mm_and_ps(value.simdInternal_, exponentMask));
405 iExponent = _mm_sub_epi32(_mm_srli_epi32(iExponent, 23), exponentBias);
406 exponent->simdInternal_ = iExponent;
408 return {
409 _mm_or_ps( _mm_and_ps(value.simdInternal_, mantissaMask), half)
413 // Override for SSE4.1
414 #if GMX_SIMD_X86_SSE2
415 template <MathOptimization opt = MathOptimization::Safe>
416 static inline SimdFloat gmx_simdcall
417 ldexp(SimdFloat value, SimdFInt32 exponent)
419 const __m128i exponentBias = _mm_set1_epi32(127);
420 __m128i iExponent;
422 iExponent = _mm_add_epi32(exponent.simdInternal_, exponentBias);
424 if (opt == MathOptimization::Safe)
426 // Make sure biased argument is not negative
427 iExponent = _mm_and_si128(iExponent, _mm_cmpgt_epi32(iExponent, _mm_setzero_si128()));
430 iExponent = _mm_slli_epi32( iExponent, 23);
432 return {
433 _mm_mul_ps(value.simdInternal_, _mm_castsi128_ps(iExponent))
436 #endif
438 // Override for AVX-128-FMA and higher
439 #if GMX_SIMD_X86_SSE2 || GMX_SIMD_X86_SSE4_1
440 static inline float gmx_simdcall
441 reduce(SimdFloat a)
443 // Shuffle has latency 1/throughput 1, followed by add with latency 3, t-put 1.
444 // This is likely faster than using _mm_hadd_ps, which has latency 5, t-put 2.
445 a.simdInternal_ = _mm_add_ps(a.simdInternal_, _mm_shuffle_ps(a.simdInternal_, a.simdInternal_, _MM_SHUFFLE(1, 0, 3, 2)));
446 a.simdInternal_ = _mm_add_ss(a.simdInternal_, _mm_shuffle_ps(a.simdInternal_, a.simdInternal_, _MM_SHUFFLE(0, 3, 2, 1)));
447 return *reinterpret_cast<float *>(&a);
449 #endif
451 static inline SimdFBool gmx_simdcall
452 operator==(SimdFloat a, SimdFloat b)
454 return {
455 _mm_cmpeq_ps(a.simdInternal_, b.simdInternal_)
459 static inline SimdFBool gmx_simdcall
460 operator!=(SimdFloat a, SimdFloat b)
462 return {
463 _mm_cmpneq_ps(a.simdInternal_, b.simdInternal_)
467 static inline SimdFBool gmx_simdcall
468 operator<(SimdFloat a, SimdFloat b)
470 return {
471 _mm_cmplt_ps(a.simdInternal_, b.simdInternal_)
475 static inline SimdFBool gmx_simdcall
476 operator<=(SimdFloat a, SimdFloat b)
478 return {
479 _mm_cmple_ps(a.simdInternal_, b.simdInternal_)
483 static inline SimdFBool gmx_simdcall
484 testBits(SimdFloat a)
486 __m128i ia = _mm_castps_si128(a.simdInternal_);
487 __m128i res = _mm_andnot_si128( _mm_cmpeq_epi32(ia, _mm_setzero_si128()), _mm_cmpeq_epi32(ia, ia));
489 return {
490 _mm_castsi128_ps(res)
494 static inline SimdFBool gmx_simdcall
495 operator&&(SimdFBool a, SimdFBool b)
497 return {
498 _mm_and_ps(a.simdInternal_, b.simdInternal_)
502 static inline SimdFBool gmx_simdcall
503 operator||(SimdFBool a, SimdFBool b)
505 return {
506 _mm_or_ps(a.simdInternal_, b.simdInternal_)
510 static inline bool gmx_simdcall
511 anyTrue(SimdFBool a) { return _mm_movemask_ps(a.simdInternal_) != 0; }
513 static inline SimdFloat gmx_simdcall
514 selectByMask(SimdFloat a, SimdFBool mask)
516 return {
517 _mm_and_ps(a.simdInternal_, mask.simdInternal_)
521 static inline SimdFloat gmx_simdcall
522 selectByNotMask(SimdFloat a, SimdFBool mask)
524 return {
525 _mm_andnot_ps(mask.simdInternal_, a.simdInternal_)
529 // Override for SSE4.1 and higher
530 #if GMX_SIMD_X86_SSE2
531 static inline SimdFloat gmx_simdcall
532 blend(SimdFloat a, SimdFloat b, SimdFBool sel)
534 return {
535 _mm_or_ps(_mm_andnot_ps(sel.simdInternal_, a.simdInternal_), _mm_and_ps(sel.simdInternal_, b.simdInternal_))
538 #endif
540 static inline SimdFInt32 gmx_simdcall
541 operator<<(SimdFInt32 a, int n)
543 return {
544 _mm_slli_epi32(a.simdInternal_, n)
548 static inline SimdFInt32 gmx_simdcall
549 operator>>(SimdFInt32 a, int n)
551 return {
552 _mm_srli_epi32(a.simdInternal_, n)
556 static inline SimdFInt32 gmx_simdcall
557 operator&(SimdFInt32 a, SimdFInt32 b)
559 return {
560 _mm_and_si128(a.simdInternal_, b.simdInternal_)
564 static inline SimdFInt32 gmx_simdcall
565 andNot(SimdFInt32 a, SimdFInt32 b)
567 return {
568 _mm_andnot_si128(a.simdInternal_, b.simdInternal_)
572 static inline SimdFInt32 gmx_simdcall
573 operator|(SimdFInt32 a, SimdFInt32 b)
575 return {
576 _mm_or_si128(a.simdInternal_, b.simdInternal_)
580 static inline SimdFInt32 gmx_simdcall
581 operator^(SimdFInt32 a, SimdFInt32 b)
583 return {
584 _mm_xor_si128(a.simdInternal_, b.simdInternal_)
588 static inline SimdFInt32 gmx_simdcall
589 operator+(SimdFInt32 a, SimdFInt32 b)
591 return {
592 _mm_add_epi32(a.simdInternal_, b.simdInternal_)
596 static inline SimdFInt32 gmx_simdcall
597 operator-(SimdFInt32 a, SimdFInt32 b)
599 return {
600 _mm_sub_epi32(a.simdInternal_, b.simdInternal_)
604 // Override for SSE4.1 and higher
605 #if GMX_SIMD_X86_SSE2
606 static inline SimdFInt32 gmx_simdcall
607 operator*(SimdFInt32 a, SimdFInt32 b)
609 __m128i a1 = _mm_srli_si128(a.simdInternal_, 4); // - a[3] a[2] a[1]
610 __m128i b1 = _mm_srli_si128(b.simdInternal_, 4); // - b[3] b[2] b[1]
611 __m128i c = _mm_mul_epu32(a.simdInternal_, b.simdInternal_);
612 __m128i c1 = _mm_mul_epu32(a1, b1);
614 c = _mm_shuffle_epi32(c, _MM_SHUFFLE(3, 1, 2, 0)); // - - a[2]*b[2] a[0]*b[0]
615 c1 = _mm_shuffle_epi32(c1, _MM_SHUFFLE(3, 1, 2, 0)); // - - a[3]*b[3] a[1]*b[1]
617 return {
618 _mm_unpacklo_epi32(c, c1)
621 #endif
623 static inline SimdFIBool gmx_simdcall
624 operator==(SimdFInt32 a, SimdFInt32 b)
626 return {
627 _mm_cmpeq_epi32(a.simdInternal_, b.simdInternal_)
631 static inline SimdFIBool gmx_simdcall
632 testBits(SimdFInt32 a)
634 __m128i x = a.simdInternal_;
635 __m128i res = _mm_andnot_si128( _mm_cmpeq_epi32(x, _mm_setzero_si128()), _mm_cmpeq_epi32(x, x));
637 return {
642 static inline SimdFIBool gmx_simdcall
643 operator<(SimdFInt32 a, SimdFInt32 b)
645 return {
646 _mm_cmplt_epi32(a.simdInternal_, b.simdInternal_)
650 static inline SimdFIBool gmx_simdcall
651 operator&&(SimdFIBool a, SimdFIBool b)
653 return {
654 _mm_and_si128(a.simdInternal_, b.simdInternal_)
658 static inline SimdFIBool gmx_simdcall
659 operator||(SimdFIBool a, SimdFIBool b)
661 return {
662 _mm_or_si128(a.simdInternal_, b.simdInternal_)
666 static inline bool gmx_simdcall
667 anyTrue(SimdFIBool a) { return _mm_movemask_epi8(a.simdInternal_) != 0; }
669 static inline SimdFInt32 gmx_simdcall
670 selectByMask(SimdFInt32 a, SimdFIBool mask)
672 return {
673 _mm_and_si128(a.simdInternal_, mask.simdInternal_)
677 static inline SimdFInt32 gmx_simdcall
678 selectByNotMask(SimdFInt32 a, SimdFIBool mask)
680 return {
681 _mm_andnot_si128(mask.simdInternal_, a.simdInternal_)
685 // Override for SSE4.1 and higher
686 #if GMX_SIMD_X86_SSE2
687 static inline SimdFInt32 gmx_simdcall
688 blend(SimdFInt32 a, SimdFInt32 b, SimdFIBool sel)
690 return {
691 _mm_or_si128(_mm_andnot_si128(sel.simdInternal_, a.simdInternal_), _mm_and_si128(sel.simdInternal_, b.simdInternal_))
694 #endif
696 static inline SimdFInt32 gmx_simdcall
697 cvtR2I(SimdFloat a)
699 return {
700 _mm_cvtps_epi32(a.simdInternal_)
704 static inline SimdFInt32 gmx_simdcall
705 cvttR2I(SimdFloat a)
707 return {
708 _mm_cvttps_epi32(a.simdInternal_)
712 static inline SimdFloat gmx_simdcall
713 cvtI2R(SimdFInt32 a)
715 return {
716 _mm_cvtepi32_ps(a.simdInternal_)
720 static inline SimdFIBool gmx_simdcall
721 cvtB2IB(SimdFBool a)
723 return {
724 _mm_castps_si128(a.simdInternal_)
728 static inline SimdFBool gmx_simdcall
729 cvtIB2B(SimdFIBool a)
731 return {
732 _mm_castsi128_ps(a.simdInternal_)
736 } // namespace gmx
738 #endif // GMX_SIMD_IMPL_X86_SSE2_SIMD_FLOAT_H