2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 #ifndef GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
37 #define GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H
44 #include <immintrin.h>
46 #include "gromacs/utility/basedefinitions.h"
48 #include "impl_x86_avx_512_general.h"
49 #include "impl_x86_avx_512_simd_double.h"
54 static const int c_simdBestPairAlignmentDouble
= 2;
58 // Multiply function optimized for powers of 2, for which it is done by
59 // shifting. Currently up to 8 is accelerated. Could be accelerated for any
60 // number with a constexpr log2 function.
62 SimdDInt32
fastMultiply(SimdDInt32 x
)
83 static inline void gmx_simdcall
84 gatherLoadBySimdIntTranspose(const double *, SimdDInt32
)
86 //Nothing to do. Termination of recursion.
91 template <int align
, typename
... Targs
>
92 static inline void gmx_simdcall
93 gatherLoadBySimdIntTranspose(const double * base
, SimdDInt32 offset
, SimdDouble
*v
, Targs
... Fargs
)
97 offset
= fastMultiply
<align
>(offset
);
99 v
->simdInternal_
= _mm512_i32gather_pd(offset
.simdInternal_
, base
, sizeof(double));
100 gatherLoadBySimdIntTranspose
<1>(base
+1, offset
, Fargs
...);
103 template <int align
, typename
... Targs
>
104 static inline void gmx_simdcall
105 gatherLoadUBySimdIntTranspose(const double *base
, SimdDInt32 offset
, Targs
... Fargs
)
107 gatherLoadBySimdIntTranspose
<align
>(base
, offset
, Fargs
...);
110 template <int align
, typename
... Targs
>
111 static inline void gmx_simdcall
112 gatherLoadTranspose(const double *base
, const std::int32_t offset
[], Targs
... Fargs
)
114 gatherLoadBySimdIntTranspose
<align
>(base
, simdLoad(offset
, SimdDInt32Tag()), Fargs
...);
117 template <int align
, typename
... Targs
>
118 static inline void gmx_simdcall
119 gatherLoadUTranspose(const double *base
, const std::int32_t offset
[], Targs
... Fargs
)
121 gatherLoadTranspose
<align
>(base
, offset
, Fargs
...);
125 static inline void gmx_simdcall
126 transposeScatterStoreU(double * base
,
127 const std::int32_t offset
[],
132 SimdDInt32 simdoffset
= simdLoad(offset
, SimdDInt32Tag());
135 simdoffset
= fastMultiply
<align
>(simdoffset
);;
137 _mm512_i32scatter_pd(base
, simdoffset
.simdInternal_
, v0
.simdInternal_
, sizeof(double));
138 _mm512_i32scatter_pd(base
+1, simdoffset
.simdInternal_
, v1
.simdInternal_
, sizeof(double));
139 _mm512_i32scatter_pd(base
+2, simdoffset
.simdInternal_
, v2
.simdInternal_
, sizeof(double));
143 static inline void gmx_simdcall
144 transposeScatterIncrU(double * base
,
145 const std::int32_t offset
[],
150 __m512d t
[4], t5
, t6
, t7
, t8
;
151 GMX_ALIGNED(std::int64_t, 8) o
[8];
152 //TODO: should use fastMultiply
153 _mm512_store_epi64(o
, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(_mm256_load_si256((const __m256i
*)(offset
)), _mm256_set1_epi32(align
))));
154 t5
= _mm512_unpacklo_pd(v0
.simdInternal_
, v1
.simdInternal_
);
155 t6
= _mm512_unpackhi_pd(v0
.simdInternal_
, v1
.simdInternal_
);
156 t7
= _mm512_unpacklo_pd(v2
.simdInternal_
, _mm512_setzero_pd());
157 t8
= _mm512_unpackhi_pd(v2
.simdInternal_
, _mm512_setzero_pd());
158 t
[0] = _mm512_mask_permutex_pd(t5
, avx512Int2Mask(0xCC), t7
, 0x4E);
159 t
[1] = _mm512_mask_permutex_pd(t6
, avx512Int2Mask(0xCC), t8
, 0x4E);
160 t
[2] = _mm512_mask_permutex_pd(t7
, avx512Int2Mask(0x33), t5
, 0x4E);
161 t
[3] = _mm512_mask_permutex_pd(t8
, avx512Int2Mask(0x33), t6
, 0x4E);
164 for (int i
= 0; i
< 4; i
++)
166 _mm512_mask_storeu_pd(base
+ o
[0 + i
], avx512Int2Mask(7), _mm512_castpd256_pd512(
167 _mm256_add_pd(_mm256_loadu_pd(base
+ o
[0 + i
]), _mm512_castpd512_pd256(t
[i
]))));
168 _mm512_mask_storeu_pd(base
+ o
[4 + i
], avx512Int2Mask(7), _mm512_castpd256_pd512(
169 _mm256_add_pd(_mm256_loadu_pd(base
+ o
[4 + i
]), _mm512_extractf64x4_pd(t
[i
], 1))));
176 for (int i
= 0; i
< 4; i
++)
178 _mm256_store_pd(base
+ o
[0 + i
],
179 _mm256_add_pd(_mm256_load_pd(base
+ o
[0 + i
]), _mm512_castpd512_pd256(t
[i
])));
180 _mm256_store_pd(base
+ o
[4 + i
],
181 _mm256_add_pd(_mm256_load_pd(base
+ o
[4 + i
]), _mm512_extractf64x4_pd(t
[i
], 1)));
186 for (int i
= 0; i
< 4; i
++)
188 _mm256_storeu_pd(base
+ o
[0 + i
],
189 _mm256_add_pd(_mm256_loadu_pd(base
+ o
[0 + i
]), _mm512_castpd512_pd256(t
[i
])));
190 _mm256_storeu_pd(base
+ o
[4 + i
],
191 _mm256_add_pd(_mm256_loadu_pd(base
+ o
[4 + i
]), _mm512_extractf64x4_pd(t
[i
], 1)));
198 static inline void gmx_simdcall
199 transposeScatterDecrU(double * base
,
200 const std::int32_t offset
[],
205 __m512d t
[4], t5
, t6
, t7
, t8
;
206 GMX_ALIGNED(std::int64_t, 8) o
[8];
207 //TODO: should use fastMultiply
208 _mm512_store_epi64(o
, _mm512_cvtepi32_epi64(_mm256_mullo_epi32(_mm256_load_si256((const __m256i
*)(offset
)), _mm256_set1_epi32(align
))));
209 t5
= _mm512_unpacklo_pd(v0
.simdInternal_
, v1
.simdInternal_
);
210 t6
= _mm512_unpackhi_pd(v0
.simdInternal_
, v1
.simdInternal_
);
211 t7
= _mm512_unpacklo_pd(v2
.simdInternal_
, _mm512_setzero_pd());
212 t8
= _mm512_unpackhi_pd(v2
.simdInternal_
, _mm512_setzero_pd());
213 t
[0] = _mm512_mask_permutex_pd(t5
, avx512Int2Mask(0xCC), t7
, 0x4E);
214 t
[2] = _mm512_mask_permutex_pd(t7
, avx512Int2Mask(0x33), t5
, 0x4E);
215 t
[1] = _mm512_mask_permutex_pd(t6
, avx512Int2Mask(0xCC), t8
, 0x4E);
216 t
[3] = _mm512_mask_permutex_pd(t8
, avx512Int2Mask(0x33), t6
, 0x4E);
219 for (int i
= 0; i
< 4; i
++)
221 _mm512_mask_storeu_pd(base
+ o
[0 + i
], avx512Int2Mask(7), _mm512_castpd256_pd512(
222 _mm256_sub_pd(_mm256_loadu_pd(base
+ o
[0 + i
]), _mm512_castpd512_pd256(t
[i
]))));
223 _mm512_mask_storeu_pd(base
+ o
[4 + i
], avx512Int2Mask(7), _mm512_castpd256_pd512(
224 _mm256_sub_pd(_mm256_loadu_pd(base
+ o
[4 + i
]), _mm512_extractf64x4_pd(t
[i
], 1))));
231 for (int i
= 0; i
< 4; i
++)
233 _mm256_store_pd(base
+ o
[0 + i
],
234 _mm256_sub_pd(_mm256_load_pd(base
+ o
[0 + i
]), _mm512_castpd512_pd256(t
[i
])));
235 _mm256_store_pd(base
+ o
[4 + i
],
236 _mm256_sub_pd(_mm256_load_pd(base
+ o
[4 + i
]), _mm512_extractf64x4_pd(t
[i
], 1)));
241 for (int i
= 0; i
< 4; i
++)
243 _mm256_storeu_pd(base
+ o
[0 + i
],
244 _mm256_sub_pd(_mm256_loadu_pd(base
+ o
[0 + i
]), _mm512_castpd512_pd256(t
[i
])));
245 _mm256_storeu_pd(base
+ o
[4 + i
],
246 _mm256_sub_pd(_mm256_loadu_pd(base
+ o
[4 + i
]), _mm512_extractf64x4_pd(t
[i
], 1)));
252 static inline void gmx_simdcall
253 expandScalarsToTriplets(SimdDouble scalar
,
254 SimdDouble
* triplets0
,
255 SimdDouble
* triplets1
,
256 SimdDouble
* triplets2
)
258 triplets0
->simdInternal_
= _mm512_castsi512_pd(_mm512_permutexvar_epi32(_mm512_set_epi32(5, 4, 5, 4, 3, 2, 3, 2, 3, 2, 1, 0, 1, 0, 1, 0),
259 _mm512_castpd_si512(scalar
.simdInternal_
)));
260 triplets1
->simdInternal_
= _mm512_castsi512_pd(_mm512_permutexvar_epi32(_mm512_set_epi32(11, 10, 9, 8, 9, 8, 9, 8, 7, 6, 7, 6, 7, 6, 5, 4),
261 _mm512_castpd_si512(scalar
.simdInternal_
)));
262 triplets2
->simdInternal_
= _mm512_castsi512_pd(_mm512_permutexvar_epi32(_mm512_set_epi32(15, 14, 15, 14, 15, 14, 13, 12, 13, 12, 13, 12, 11, 10, 11, 10),
263 _mm512_castpd_si512(scalar
.simdInternal_
)));
267 static inline double gmx_simdcall
268 reduceIncr4ReturnSum(double * m
,
277 assert(std::size_t(m
) % 32 == 0);
279 t0
= _mm512_add_pd(v0
.simdInternal_
, _mm512_permute_pd(v0
.simdInternal_
, 0x55));
280 t2
= _mm512_add_pd(v2
.simdInternal_
, _mm512_permute_pd(v2
.simdInternal_
, 0x55));
281 t0
= _mm512_mask_add_pd(t0
, avx512Int2Mask(0xAA), v1
.simdInternal_
, _mm512_permute_pd(v1
.simdInternal_
, 0x55));
282 t2
= _mm512_mask_add_pd(t2
, avx512Int2Mask(0xAA), v3
.simdInternal_
, _mm512_permute_pd(v3
.simdInternal_
, 0x55));
283 t0
= _mm512_add_pd(t0
, _mm512_shuffle_f64x2(t0
, t0
, 0x4E));
284 t0
= _mm512_mask_add_pd(t0
, avx512Int2Mask(0xF0), t2
, _mm512_shuffle_f64x2(t2
, t2
, 0x4E));
285 t0
= _mm512_add_pd(t0
, _mm512_shuffle_f64x2(t0
, t0
, 0xB1));
286 t0
= _mm512_mask_shuffle_f64x2(t0
, avx512Int2Mask(0x0C), t0
, t0
, 0xEE);
288 t3
= _mm512_castpd512_pd256(t0
);
289 t4
= _mm256_load_pd(m
);
290 t4
= _mm256_add_pd(t4
, t3
);
291 _mm256_store_pd(m
, t4
);
293 t3
= _mm256_add_pd(t3
, _mm256_permutex_pd(t3
, 0x4E));
294 t3
= _mm256_add_pd(t3
, _mm256_permutex_pd(t3
, 0xB1));
296 return _mm_cvtsd_f64(_mm256_castpd256_pd128(t3
));
299 static inline SimdDouble gmx_simdcall
300 loadDualHsimd(const double * m0
,
303 assert(std::size_t(m0
) % 32 == 0);
304 assert(std::size_t(m1
) % 32 == 0);
307 _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_load_pd(m0
)),
308 _mm256_load_pd(m1
), 1)
312 static inline SimdDouble gmx_simdcall
313 loadDuplicateHsimd(const double * m
)
315 assert(std::size_t(m
) % 32 == 0);
318 _mm512_broadcast_f64x4(_mm256_load_pd(m
))
322 static inline SimdDouble gmx_simdcall
323 loadU1DualHsimd(const double * m
)
326 _mm512_insertf64x4(_mm512_broadcastsd_pd(_mm_load_sd(m
)),
327 _mm256_broadcastsd_pd(_mm_load_sd(m
+1)), 1)
332 static inline void gmx_simdcall
333 storeDualHsimd(double * m0
,
337 assert(std::size_t(m0
) % 32 == 0);
338 assert(std::size_t(m1
) % 32 == 0);
340 _mm256_store_pd(m0
, _mm512_castpd512_pd256(a
.simdInternal_
));
341 _mm256_store_pd(m1
, _mm512_extractf64x4_pd(a
.simdInternal_
, 1));
344 static inline void gmx_simdcall
345 incrDualHsimd(double * m0
,
349 assert(std::size_t(m0
) % 32 == 0);
350 assert(std::size_t(m1
) % 32 == 0);
355 x
= _mm256_load_pd(m0
);
356 x
= _mm256_add_pd(x
, _mm512_castpd512_pd256(a
.simdInternal_
));
357 _mm256_store_pd(m0
, x
);
360 x
= _mm256_load_pd(m1
);
361 x
= _mm256_add_pd(x
, _mm512_extractf64x4_pd(a
.simdInternal_
, 1));
362 _mm256_store_pd(m1
, x
);
365 static inline void gmx_simdcall
366 decrHsimd(double * m
,
371 assert(std::size_t(m
) % 32 == 0);
373 a
.simdInternal_
= _mm512_add_pd(a
.simdInternal_
, _mm512_shuffle_f64x2(a
.simdInternal_
, a
.simdInternal_
, 0xEE));
374 t
= _mm256_load_pd(m
);
375 t
= _mm256_sub_pd(t
, _mm512_castpd512_pd256(a
.simdInternal_
));
376 _mm256_store_pd(m
, t
);
381 static inline void gmx_simdcall
382 gatherLoadTransposeHsimd(const double * base0
,
383 const double * base1
,
384 const std::int32_t offset
[],
392 assert(std::size_t(offset
) % 16 == 0);
393 assert(std::size_t(base0
) % 16 == 0);
394 assert(std::size_t(base1
) % 16 == 0);
396 idx0
= _mm_load_si128(reinterpret_cast<const __m128i
*>(offset
));
398 static_assert(align
== 2 || align
== 4, "If more are needed use fastMultiply");
399 idx0
= _mm_slli_epi32(idx0
, align
== 2 ? 1 : 2);
401 idx1
= _mm_add_epi32(idx0
, _mm_set1_epi32(1));
403 idx
= _mm256_inserti128_si256(_mm256_castsi128_si256(idx0
), idx1
, 1);
405 tmp1
= _mm512_i32gather_pd(idx
, base0
, sizeof(double)); //TODO: Might be faster to use invidual loads
406 tmp2
= _mm512_i32gather_pd(idx
, base1
, sizeof(double));
408 v0
->simdInternal_
= _mm512_shuffle_f64x2(tmp1
, tmp2
, 0x44 );
409 v1
->simdInternal_
= _mm512_shuffle_f64x2(tmp1
, tmp2
, 0xEE );
412 static inline double gmx_simdcall
413 reduceIncr4ReturnSumHsimd(double * m
,
420 assert(std::size_t(m
) % 32 == 0);
422 t0
= _mm512_add_pd(v0
.simdInternal_
, _mm512_permutex_pd(v0
.simdInternal_
, 0x4E));
423 t0
= _mm512_mask_add_pd(t0
, avx512Int2Mask(0xCC), v1
.simdInternal_
, _mm512_permutex_pd(v1
.simdInternal_
, 0x4E));
424 t0
= _mm512_add_pd(t0
, _mm512_permutex_pd(t0
, 0xB1));
425 t0
= _mm512_mask_shuffle_f64x2(t0
, avx512Int2Mask(0xAA), t0
, t0
, 0xEE);
427 t2
= _mm512_castpd512_pd256(t0
);
428 t3
= _mm256_load_pd(m
);
429 t3
= _mm256_add_pd(t3
, t2
);
430 _mm256_store_pd(m
, t3
);
432 t2
= _mm256_add_pd(t2
, _mm256_permutex_pd(t2
, 0x4E));
433 t2
= _mm256_add_pd(t2
, _mm256_permutex_pd(t2
, 0xB1));
435 return _mm_cvtsd_f64(_mm256_castpd256_pd128(t2
));
438 static inline SimdDouble gmx_simdcall
439 loadU4NOffset(const double *m
, int offset
)
442 _mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_loadu_pd(m
)),
443 _mm256_loadu_pd(m
+offset
), 1)
449 #endif // GMX_SIMD_IMPL_X86_AVX_512_UTIL_DOUBLE_H