PME-gather: 4xN SIMD
[gromacs/AngularHB.git] / src / gromacs / simd / impl_x86_avx_512 / impl_x86_avx_512_util_double.h
blobd70854ab5bea20b225315feb4a9c0c998f287742
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
37 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
39 #include "config.h"
41 #include <cassert>
42 #include <cstdint>
44 #include <immintrin.h>
46 #include "gromacs/utility/basedefinitions.h"
48 #include "impl_x86_avx_512_general.h"
49 #include "impl_x86_avx_512_simd_double.h"
51 namespace gmx
54 static const int c_simdBestPairAlignmentDouble = 2;
56 namespace
58 // Multiply function optimized for powers of 2, for which it is done by
59 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
60 // number with a constexpr log2 function.
61 template<int n>
62 SimdDInt32 fastMultiply(SimdDInt32 x)
64 if (n == 2)
66 return x << 1;
68 else if (n == 4)
70 return x << 2;
72 else if (n == 8)
74 return x << 3;
76 else
78 return x * n;
82 template<int align>
83 static inline void gmx_simdcall
84 gatherLoadBySimdIntTranspose(const double *, SimdDInt32)
86 //Nothing to do. Termination of recursion.
91 template <int align, typename ... Targs>
92 static inline void gmx_simdcall
93 gatherLoadBySimdIntTranspose(const double * base, SimdDInt32 offset, SimdDouble *v, Targs... Fargs)
95 if (align > 1)
97 offset = fastMultiply<align>(offset);
99 v->simdInternal_ = _mm512_i32gather_pd(offset.simdInternal_, base, sizeof(double));
100 gatherLoadBySimdIntTranspose<1>(base+1, offset, Fargs ...);
103 template <int align, typename ... Targs>
104 static inline void gmx_simdcall
105 gatherLoadUBySimdIntTranspose(const double *base, SimdDInt32 offset, Targs... Fargs)
107 gatherLoadBySimdIntTranspose<align>(base, offset, Fargs ...);
110 template <int align, typename ... Targs>
111 static inline void gmx_simdcall
112 gatherLoadTranspose(const double *base, const std::int32_t offset[], Targs... Fargs)
114 gatherLoadBySimdIntTranspose<align>(base, simdLoad(offset, SimdDInt32Tag()), Fargs ...);
117 template <int align, typename ... Targs>
118 static inline void gmx_simdcall
119 gatherLoadUTranspose(const double *base, const std::int32_t offset[], Targs... Fargs)
121 gatherLoadTranspose<align>(base, offset, Fargs ...);
124 template <int align>
125 static inline void gmx_simdcall
126 transposeScatterStoreU(double * base,
127 const std::int32_t offset[],
128 SimdDouble v0,
129 SimdDouble v1,
130 SimdDouble v2)
132 SimdDInt32 simdoffset = simdLoad(offset, SimdDInt32Tag());
133 if (align > 1)
135 simdoffset = fastMultiply<align>(simdoffset);;
137 _mm512_i32scatter_pd(base, simdoffset.simdInternal_, v0.simdInternal_, sizeof(double));
138 _mm512_i32scatter_pd(base+1, simdoffset.simdInternal_, v1.simdInternal_, sizeof(double));
139 _mm512_i32scatter_pd(base+2, simdoffset.simdInternal_, v2.simdInternal_, sizeof(double));
142 template <int align>
143 static inline void gmx_simdcall
144 transposeScatterIncrU(double * base,
145 const std::int32_t offset[],
146 SimdDouble v0,
147 SimdDouble v1,
148 SimdDouble v2)
150 __m512d t[4], t5, t6, t7, t8;
151 GMX_ALIGNED(std::int64_t, 8) o[8];
152 //TODO: should use fastMultiply
153 _mm512_store_epi64(o, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(_mm256_load_si256((const __m256i*)(offset )), _mm256_set1_epi32(align))));
154 t5 = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
155 t6 = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
156 t7 = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
157 t8 = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
158 t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
159 t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
160 t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
161 t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
162 if (align < 4)
164 for (int i = 0; i < 4; i++)
166 _mm512_mask_storeu_pd(base + o[0 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
167 _mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i]))));
168 _mm512_mask_storeu_pd(base + o[4 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
169 _mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1))));
172 else
174 if (align % 4 == 0)
176 for (int i = 0; i < 4; i++)
178 _mm256_store_pd(base + o[0 + i],
179 _mm256_add_pd(_mm256_load_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
180 _mm256_store_pd(base + o[4 + i],
181 _mm256_add_pd(_mm256_load_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
184 else
186 for (int i = 0; i < 4; i++)
188 _mm256_storeu_pd(base + o[0 + i],
189 _mm256_add_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
190 _mm256_storeu_pd(base + o[4 + i],
191 _mm256_add_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
197 template <int align>
198 static inline void gmx_simdcall
199 transposeScatterDecrU(double * base,
200 const std::int32_t offset[],
201 SimdDouble v0,
202 SimdDouble v1,
203 SimdDouble v2)
205 __m512d t[4], t5, t6, t7, t8;
206 GMX_ALIGNED(std::int64_t, 8) o[8];
207 //TODO: should use fastMultiply
208 _mm512_store_epi64(o, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(_mm256_load_si256((const __m256i*)(offset )), _mm256_set1_epi32(align))));
209 t5 = _mm512_unpacklo_pd(v0.simdInternal_, v1.simdInternal_);
210 t6 = _mm512_unpackhi_pd(v0.simdInternal_, v1.simdInternal_);
211 t7 = _mm512_unpacklo_pd(v2.simdInternal_, _mm512_setzero_pd());
212 t8 = _mm512_unpackhi_pd(v2.simdInternal_, _mm512_setzero_pd());
213 t[0] = _mm512_mask_permutex_pd(t5, avx512Int2Mask(0xCC), t7, 0x4E);
214 t[2] = _mm512_mask_permutex_pd(t7, avx512Int2Mask(0x33), t5, 0x4E);
215 t[1] = _mm512_mask_permutex_pd(t6, avx512Int2Mask(0xCC), t8, 0x4E);
216 t[3] = _mm512_mask_permutex_pd(t8, avx512Int2Mask(0x33), t6, 0x4E);
217 if (align < 4)
219 for (int i = 0; i < 4; i++)
221 _mm512_mask_storeu_pd(base + o[0 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
222 _mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i]))));
223 _mm512_mask_storeu_pd(base + o[4 + i], avx512Int2Mask(7), _mm512_castpd256_pd512(
224 _mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1))));
227 else
229 if (align % 4 == 0)
231 for (int i = 0; i < 4; i++)
233 _mm256_store_pd(base + o[0 + i],
234 _mm256_sub_pd(_mm256_load_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
235 _mm256_store_pd(base + o[4 + i],
236 _mm256_sub_pd(_mm256_load_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
239 else
241 for (int i = 0; i < 4; i++)
243 _mm256_storeu_pd(base + o[0 + i],
244 _mm256_sub_pd(_mm256_loadu_pd(base + o[0 + i]), _mm512_castpd512_pd256(t[i])));
245 _mm256_storeu_pd(base + o[4 + i],
246 _mm256_sub_pd(_mm256_loadu_pd(base + o[4 + i]), _mm512_extractf64x4_pd(t[i], 1)));
252 static inline void gmx_simdcall
253 expandScalarsToTriplets(SimdDouble scalar,
254 SimdDouble * triplets0,
255 SimdDouble * triplets1,
256 SimdDouble * triplets2)
258 triplets0->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(_mm512_set_epi32(5, 4, 5, 4, 3, 2, 3, 2, 3, 2, 1, 0, 1, 0, 1, 0),
259 _mm512_castpd_si512(scalar.simdInternal_)));
260 triplets1->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(_mm512_set_epi32(11, 10, 9, 8, 9, 8, 9, 8, 7, 6, 7, 6, 7, 6, 5, 4),
261 _mm512_castpd_si512(scalar.simdInternal_)));
262 triplets2->simdInternal_ = _mm512_castsi512_pd(_mm512_permutexvar_epi32(_mm512_set_epi32(15, 14, 15, 14, 15, 14, 13, 12, 13, 12, 13, 12, 11, 10, 11, 10),
263 _mm512_castpd_si512(scalar.simdInternal_)));
267 static inline double gmx_simdcall
268 reduceIncr4ReturnSum(double * m,
269 SimdDouble v0,
270 SimdDouble v1,
271 SimdDouble v2,
272 SimdDouble v3)
274 __m512d t0, t2;
275 __m256d t3, t4;
277 assert(std::size_t(m) % 32 == 0);
279 t0 = _mm512_add_pd(v0.simdInternal_, _mm512_permute_pd(v0.simdInternal_, 0x55));
280 t2 = _mm512_add_pd(v2.simdInternal_, _mm512_permute_pd(v2.simdInternal_, 0x55));
281 t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xAA), v1.simdInternal_, _mm512_permute_pd(v1.simdInternal_, 0x55));
282 t2 = _mm512_mask_add_pd(t2, avx512Int2Mask(0xAA), v3.simdInternal_, _mm512_permute_pd(v3.simdInternal_, 0x55));
283 t0 = _mm512_add_pd(t0, _mm512_shuffle_f64x2(t0, t0, 0x4E));
284 t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xF0), t2, _mm512_shuffle_f64x2(t2, t2, 0x4E));
285 t0 = _mm512_add_pd(t0, _mm512_shuffle_f64x2(t0, t0, 0xB1));
286 t0 = _mm512_mask_shuffle_f64x2(t0, avx512Int2Mask(0x0C), t0, t0, 0xEE);
288 t3 = _mm512_castpd512_pd256(t0);
289 t4 = _mm256_load_pd(m);
290 t4 = _mm256_add_pd(t4, t3);
291 _mm256_store_pd(m, t4);
293 t3 = _mm256_add_pd(t3, _mm256_permutex_pd(t3, 0x4E));
294 t3 = _mm256_add_pd(t3, _mm256_permutex_pd(t3, 0xB1));
296 return _mm_cvtsd_f64(_mm256_castpd256_pd128(t3));
299 static inline SimdDouble gmx_simdcall
300 loadDualHsimd(const double * m0,
301 const double * m1)
303 assert(std::size_t(m0) % 32 == 0);
304 assert(std::size_t(m1) % 32 == 0);
306 return {
307 _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_load_pd(m0)),
308 _mm256_load_pd(m1), 1)
312 static inline SimdDouble gmx_simdcall
313 loadDuplicateHsimd(const double * m)
315 assert(std::size_t(m) % 32 == 0);
317 return {
318 _mm512_broadcast_f64x4(_mm256_load_pd(m))
322 static inline SimdDouble gmx_simdcall
323 loadU1DualHsimd(const double * m)
325 return {
326 _mm512_insertf64x4(_mm512_broadcastsd_pd(_mm_load_sd(m)),
327 _mm256_broadcastsd_pd(_mm_load_sd(m+1)), 1)
332 static inline void gmx_simdcall
333 storeDualHsimd(double * m0,
334 double * m1,
335 SimdDouble a)
337 assert(std::size_t(m0) % 32 == 0);
338 assert(std::size_t(m1) % 32 == 0);
340 _mm256_store_pd(m0, _mm512_castpd512_pd256(a.simdInternal_));
341 _mm256_store_pd(m1, _mm512_extractf64x4_pd(a.simdInternal_, 1));
344 static inline void gmx_simdcall
345 incrDualHsimd(double * m0,
346 double * m1,
347 SimdDouble a)
349 assert(std::size_t(m0) % 32 == 0);
350 assert(std::size_t(m1) % 32 == 0);
352 __m256d x;
354 // Lower half
355 x = _mm256_load_pd(m0);
356 x = _mm256_add_pd(x, _mm512_castpd512_pd256(a.simdInternal_));
357 _mm256_store_pd(m0, x);
359 // Upper half
360 x = _mm256_load_pd(m1);
361 x = _mm256_add_pd(x, _mm512_extractf64x4_pd(a.simdInternal_, 1));
362 _mm256_store_pd(m1, x);
365 static inline void gmx_simdcall
366 decrHsimd(double * m,
367 SimdDouble a)
369 __m256d t;
371 assert(std::size_t(m) % 32 == 0);
373 a.simdInternal_ = _mm512_add_pd(a.simdInternal_, _mm512_shuffle_f64x2(a.simdInternal_, a.simdInternal_, 0xEE));
374 t = _mm256_load_pd(m);
375 t = _mm256_sub_pd(t, _mm512_castpd512_pd256(a.simdInternal_));
376 _mm256_store_pd(m, t);
380 template <int align>
381 static inline void gmx_simdcall
382 gatherLoadTransposeHsimd(const double * base0,
383 const double * base1,
384 const std::int32_t offset[],
385 SimdDouble * v0,
386 SimdDouble * v1)
388 __m128i idx0, idx1;
389 __m256i idx;
390 __m512d tmp1, tmp2;
392 assert(std::size_t(offset) % 16 == 0);
393 assert(std::size_t(base0) % 16 == 0);
394 assert(std::size_t(base1) % 16 == 0);
396 idx0 = _mm_load_si128(reinterpret_cast<const __m128i*>(offset));
398 static_assert(align == 2 || align == 4, "If more are needed use fastMultiply");
399 idx0 = _mm_slli_epi32(idx0, align == 2 ? 1 : 2);
401 idx1 = _mm_add_epi32(idx0, _mm_set1_epi32(1));
403 idx = _mm256_inserti128_si256(_mm256_castsi128_si256(idx0), idx1, 1);
405 tmp1 = _mm512_i32gather_pd(idx, base0, sizeof(double)); //TODO: Might be faster to use invidual loads
406 tmp2 = _mm512_i32gather_pd(idx, base1, sizeof(double));
408 v0->simdInternal_ = _mm512_shuffle_f64x2(tmp1, tmp2, 0x44 );
409 v1->simdInternal_ = _mm512_shuffle_f64x2(tmp1, tmp2, 0xEE );
412 static inline double gmx_simdcall
413 reduceIncr4ReturnSumHsimd(double * m,
414 SimdDouble v0,
415 SimdDouble v1)
417 __m512d t0;
418 __m256d t2, t3;
420 assert(std::size_t(m) % 32 == 0);
422 t0 = _mm512_add_pd(v0.simdInternal_, _mm512_permutex_pd(v0.simdInternal_, 0x4E));
423 t0 = _mm512_mask_add_pd(t0, avx512Int2Mask(0xCC), v1.simdInternal_, _mm512_permutex_pd(v1.simdInternal_, 0x4E));
424 t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xB1));
425 t0 = _mm512_mask_shuffle_f64x2(t0, avx512Int2Mask(0xAA), t0, t0, 0xEE);
427 t2 = _mm512_castpd512_pd256(t0);
428 t3 = _mm256_load_pd(m);
429 t3 = _mm256_add_pd(t3, t2);
430 _mm256_store_pd(m, t3);
432 t2 = _mm256_add_pd(t2, _mm256_permutex_pd(t2, 0x4E));
433 t2 = _mm256_add_pd(t2, _mm256_permutex_pd(t2, 0xB1));
435 return _mm_cvtsd_f64(_mm256_castpd256_pd128(t2));
438 static inline SimdDouble gmx_simdcall
439 loadU4NOffset(const double *m, int offset)
441 return {
442 _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_loadu_pd(m)),
443 _mm256_loadu_pd(m+offset), 1)
447 } // namespace gmx
449 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H