4 #include "gemmology_fwd.h"
10 #include <xsimd/xsimd.hpp>
17 // Arch specific implementation of various elementary operations
24 std::tuple
<xsimd::batch
<int8_t, Arch
>, xsimd::batch
<int8_t, Arch
>>
25 interleave(xsimd::batch
<int8_t, Arch
> first
, xsimd::batch
<int8_t, Arch
> second
,
26 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
27 return {_mm512_unpacklo_epi8(first
, second
),
28 _mm512_unpackhi_epi8(first
, second
)};
32 std::tuple
<xsimd::batch
<int16_t, Arch
>, xsimd::batch
<int16_t, Arch
>>
33 interleave(xsimd::batch
<int16_t, Arch
> first
,
34 xsimd::batch
<int16_t, Arch
> second
,
35 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
36 return {_mm512_unpacklo_epi16(first
, second
),
37 _mm512_unpackhi_epi16(first
, second
)};
41 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
42 interleave(xsimd::batch
<int32_t, Arch
> first
,
43 xsimd::batch
<int32_t, Arch
> second
,
44 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
45 return {_mm512_unpacklo_epi32(first
, second
),
46 _mm512_unpackhi_epi32(first
, second
)};
50 std::tuple
<xsimd::batch
<int64_t, Arch
>, xsimd::batch
<int64_t, Arch
>>
51 interleave(xsimd::batch
<int64_t, Arch
> first
,
52 xsimd::batch
<int64_t, Arch
> second
,
53 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
54 return {_mm512_unpacklo_epi64(first
, second
),
55 _mm512_unpackhi_epi64(first
, second
)};
59 xsimd::batch
<int8_t, Arch
>
60 deinterleave(xsimd::batch
<int16_t, Arch
> first
,
61 xsimd::batch
<int16_t, Arch
> second
,
62 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
63 return _mm512_packs_epi16(first
, second
);
67 xsimd::batch
<int16_t, Arch
>
68 deinterleave(xsimd::batch
<int32_t, Arch
> first
,
69 xsimd::batch
<int32_t, Arch
> second
,
70 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
71 return _mm512_packs_epi32(first
, second
);
75 inline xsimd::batch
<int32_t, Arch
>
76 madd(xsimd::batch
<int16_t, Arch
> x
, xsimd::batch
<int16_t, Arch
> y
,
77 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
78 return _mm512_madd_epi16(x
, y
);
82 inline xsimd::batch
<int16_t, Arch
>
83 madd(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
84 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
85 return _mm512_maddubs_epi16(x
, y
);
89 inline xsimd::batch
<int16_t, Arch
>
90 madd(xsimd::batch
<int8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
91 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
92 return _mm512_madd_epi16(x
, y
);
96 inline xsimd::batch
<int32_t, xsimd::avx2
>
97 PermuteSummer(xsimd::batch
<int32_t, Arch
> pack0123
,
98 xsimd::batch
<int32_t, Arch
> pack4567
,
99 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
100 // Form [0th 128-bit register of pack0123, 0st 128-bit register of pack4567,
101 // 2nd 128-bit register of pack0123, 2nd 128-bit register of pack4567]
103 _mm512_mask_permutex_epi64(pack0123
, 0xcc, pack4567
, (0 << 4) | (1 << 6));
104 // Form [1st 128-bit register of pack0123, 1st 128-bit register of pack4567,
105 // 3rd 128-bit register of pack0123, 3rd 128-bit register of pack4567]
107 _mm512_mask_permutex_epi64(pack4567
, 0x33, pack0123
, 2 | (3 << 2));
108 __m512i added
= _mm512_add_epi32(mix0
, mix1
);
109 // Now we have 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7.
110 // Fold register over itself.
111 return _mm256_add_epi32(_mm512_castsi512_si256(added
),
112 _mm512_extracti64x4_epi64(added
, 1));
117 template <class Arch
>
118 std::tuple
<xsimd::batch
<int8_t, Arch
>, xsimd::batch
<int8_t, Arch
>>
119 interleave(xsimd::batch
<int8_t, Arch
> first
, xsimd::batch
<int8_t, Arch
> second
,
120 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
121 return {_mm256_unpacklo_epi8(first
, second
),
122 _mm256_unpackhi_epi8(first
, second
)};
125 template <class Arch
>
126 std::tuple
<xsimd::batch
<int16_t, Arch
>, xsimd::batch
<int16_t, Arch
>>
127 interleave(xsimd::batch
<int16_t, Arch
> first
,
128 xsimd::batch
<int16_t, Arch
> second
,
129 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
130 return {_mm256_unpacklo_epi16(first
, second
),
131 _mm256_unpackhi_epi16(first
, second
)};
134 template <class Arch
>
135 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
136 interleave(xsimd::batch
<int32_t, Arch
> first
,
137 xsimd::batch
<int32_t, Arch
> second
,
138 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
139 return {_mm256_unpacklo_epi32(first
, second
),
140 _mm256_unpackhi_epi32(first
, second
)};
143 template <class Arch
>
144 std::tuple
<xsimd::batch
<int64_t, Arch
>, xsimd::batch
<int64_t, Arch
>>
145 interleave(xsimd::batch
<int64_t, Arch
> first
,
146 xsimd::batch
<int64_t, Arch
> second
,
147 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
148 return {_mm256_unpacklo_epi64(first
, second
),
149 _mm256_unpackhi_epi64(first
, second
)};
152 template <class Arch
>
153 xsimd::batch
<int8_t, Arch
>
154 deinterleave(xsimd::batch
<int16_t, Arch
> first
,
155 xsimd::batch
<int16_t, Arch
> second
,
156 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
157 return _mm256_packs_epi16(first
, second
);
160 template <class Arch
>
161 xsimd::batch
<int16_t, Arch
>
162 deinterleave(xsimd::batch
<int32_t, Arch
> first
,
163 xsimd::batch
<int32_t, Arch
> second
,
164 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
165 return _mm256_packs_epi32(first
, second
);
168 template <class Arch
>
169 inline xsimd::batch
<int32_t, Arch
>
170 madd(xsimd::batch
<int16_t, Arch
> x
, xsimd::batch
<int16_t, Arch
> y
,
171 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
172 return _mm256_madd_epi16(x
, y
);
175 template <class Arch
>
176 inline xsimd::batch
<int16_t, Arch
>
177 madd(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
178 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
179 return _mm256_maddubs_epi16(x
, y
);
182 template <class Arch
>
183 inline xsimd::batch
<int16_t, Arch
>
184 madd(xsimd::batch
<int8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
185 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
186 return _mm256_maddubs_epi16(xsimd::abs(x
), _mm256_sign_epi8(y
, x
));
189 template <class Arch
>
190 inline xsimd::batch
<int32_t, Arch
>
191 PermuteSummer(xsimd::batch
<int32_t, Arch
> pack0123
,
192 xsimd::batch
<int32_t, Arch
> pack4567
,
193 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
194 // This instruction generates 1s 2s 3s 4s 5f 6f 7f 8f
195 __m256i rev
= _mm256_permute2f128_si256(pack0123
, pack4567
, 0x21);
196 // This instruction generates 1f 2f 3f 4f 5s 6s 7s 8s
197 __m256i blended
= _mm256_blend_epi32(pack0123
, pack4567
, 0xf0);
198 return _mm256_add_epi32(rev
, blended
);
201 template <class Arch
>
202 inline xsimd::batch
<int32_t, Arch
> Pack0123(xsimd::batch
<int32_t, Arch
> sum0
,
203 xsimd::batch
<int32_t, Arch
> sum1
,
204 xsimd::batch
<int32_t, Arch
> sum2
,
205 xsimd::batch
<int32_t, Arch
> sum3
,
206 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
207 auto pack01
= _mm256_hadd_epi32(sum0
, sum1
);
208 auto pack23
= _mm256_hadd_epi32(sum2
, sum3
);
209 return _mm256_hadd_epi32(pack01
, pack23
);
214 template <class Arch
>
215 inline xsimd::batch
<int32_t, Arch
>
216 maddw(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
217 xsimd::batch
<int32_t, Arch
> z
,
218 xsimd::kernel::requires_arch
<xsimd::avxvnni
>) {
219 return _mm256_dpbusd_avx_epi32(z
, x
, y
);
223 #ifdef __AVX512VNNI__
225 template <class Arch
>
226 inline xsimd::batch
<int32_t, Arch
>
227 maddw(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
228 xsimd::batch
<int32_t, Arch
> z
,
229 xsimd::kernel::requires_arch
<xsimd::avx512vnni
<xsimd::avx512bw
>>) {
230 return _mm512_dpbusd_epi32(z
, x
, y
);
233 template <class Arch
>
234 inline xsimd::batch
<int32_t, Arch
>
235 maddw(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
236 xsimd::batch
<int32_t, Arch
> z
,
237 xsimd::kernel::requires_arch
<xsimd::avx512vnni
<xsimd::avx512vbmi
>>) {
238 return _mm512_dpbusd_epi32(z
, x
, y
);
246 template <class Arch
>
247 inline xsimd::batch
<int16_t, Arch
>
248 madd(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
249 xsimd::kernel::requires_arch
<xsimd::ssse3
>) {
250 return _mm_maddubs_epi16(x
, y
);
253 template <class Arch
>
254 inline xsimd::batch
<int16_t, Arch
>
255 madd(xsimd::batch
<int8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
256 xsimd::kernel::requires_arch
<xsimd::ssse3
>) {
257 return _mm_maddubs_epi16(xsimd::abs(x
), _mm_sign_epi8(y
, x
));
260 template <class Arch
>
261 inline xsimd::batch
<int32_t, Arch
> Pack0123(xsimd::batch
<int32_t, Arch
> sum0
,
262 xsimd::batch
<int32_t, Arch
> sum1
,
263 xsimd::batch
<int32_t, Arch
> sum2
,
264 xsimd::batch
<int32_t, Arch
> sum3
,
265 xsimd::kernel::requires_arch
<xsimd::ssse3
>) {
266 auto pack01
= _mm_hadd_epi32(sum0
, sum1
);
267 auto pack23
= _mm_hadd_epi32(sum2
, sum3
);
268 return _mm_hadd_epi32(pack01
, pack23
);
273 template <class Arch
>
274 std::tuple
<xsimd::batch
<int8_t, Arch
>, xsimd::batch
<int8_t, Arch
>>
275 interleave(xsimd::batch
<int8_t, Arch
> first
, xsimd::batch
<int8_t, Arch
> second
,
276 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
277 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
280 template <class Arch
>
281 std::tuple
<xsimd::batch
<int16_t, Arch
>, xsimd::batch
<int16_t, Arch
>>
282 interleave(xsimd::batch
<int16_t, Arch
> first
,
283 xsimd::batch
<int16_t, Arch
> second
,
284 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
285 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
288 template <class Arch
>
289 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
290 interleave(xsimd::batch
<int32_t, Arch
> first
,
291 xsimd::batch
<int32_t, Arch
> second
,
292 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
293 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
296 template <class Arch
>
297 std::tuple
<xsimd::batch
<int64_t, Arch
>, xsimd::batch
<int64_t, Arch
>>
298 interleave(xsimd::batch
<int64_t, Arch
> first
,
299 xsimd::batch
<int64_t, Arch
> second
,
300 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
301 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
304 template <class Arch
>
305 xsimd::batch
<int8_t, Arch
>
306 deinterleave(xsimd::batch
<int16_t, Arch
> first
,
307 xsimd::batch
<int16_t, Arch
> second
,
308 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
309 return _mm_packs_epi16(first
, second
);
312 template <class Arch
>
313 xsimd::batch
<int16_t, Arch
>
314 deinterleave(xsimd::batch
<int32_t, Arch
> first
,
315 xsimd::batch
<int32_t, Arch
> second
,
316 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
317 return _mm_packs_epi32(first
, second
);
320 template <class Arch
>
321 inline xsimd::batch
<int32_t, Arch
>
322 madd(xsimd::batch
<int16_t, Arch
> x
, xsimd::batch
<int16_t, Arch
> y
,
323 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
324 return _mm_madd_epi16(x
, y
);
327 template <class Arch
>
328 inline xsimd::batch
<int16_t, Arch
>
329 madd(xsimd::batch
<uint8_t, Arch
> a
, xsimd::batch
<int8_t, Arch
> b
,
330 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
332 // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2
333 // a = 0x00 0x01 0xFE 0x04 ...
334 // b = 0x00 0x02 0x80 0x84 ...
336 // To extend signed 8-bit value, MSB has to be set to 0xFF
337 __m128i sign_mask_b
= _mm_cmplt_epi8(b
, _mm_setzero_si128());
339 // sign_mask_b = 0x00 0x00 0xFF 0xFF ...
341 // Unpack positives with 0x00, negatives with 0xFF
342 __m128i a_epi16_l
= _mm_unpacklo_epi8(a
, _mm_setzero_si128());
343 __m128i a_epi16_h
= _mm_unpackhi_epi8(a
, _mm_setzero_si128());
344 __m128i b_epi16_l
= _mm_unpacklo_epi8(b
, sign_mask_b
);
345 __m128i b_epi16_h
= _mm_unpackhi_epi8(b
, sign_mask_b
);
347 // Here - valid 16-bit signed integers corresponding to the 8-bit input
348 // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ...
350 // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts
351 __m128i madd_epi32_l
= _mm_madd_epi16(a_epi16_l
, b_epi16_l
);
352 __m128i madd_epi32_h
= _mm_madd_epi16(a_epi16_h
, b_epi16_h
);
354 // Now go back from 32-bit values to 16-bit values & signed saturate
355 return _mm_packs_epi32(madd_epi32_l
, madd_epi32_h
);
358 template <class Arch
>
359 inline xsimd::batch
<int16_t, Arch
>
360 madd(xsimd::batch
<int8_t, Arch
> a
, xsimd::batch
<int8_t, Arch
> b
,
361 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
363 // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2
364 // a = 0x00 0x01 0xFE 0x04 ...
365 // b = 0x00 0x02 0x80 0x84 ...
367 // To extend signed 8-bit value, MSB has to be set to 0xFF
368 __m128i sign_mask_a
= _mm_cmplt_epi8(a
, _mm_setzero_si128());
369 __m128i sign_mask_b
= _mm_cmplt_epi8(b
, _mm_setzero_si128());
371 // sign_mask_a = 0x00 0x00 0xFF 0x00 ...
372 // sign_mask_b = 0x00 0x00 0xFF 0xFF ...
374 // Unpack positives with 0x00, negatives with 0xFF
375 __m128i a_epi16_l
= _mm_unpacklo_epi8(a
, sign_mask_a
);
376 __m128i a_epi16_h
= _mm_unpackhi_epi8(a
, sign_mask_a
);
377 __m128i b_epi16_l
= _mm_unpacklo_epi8(b
, sign_mask_b
);
378 __m128i b_epi16_h
= _mm_unpackhi_epi8(b
, sign_mask_b
);
380 // Here - valid 16-bit signed integers corresponding to the 8-bit input
381 // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ...
383 // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts
384 __m128i madd_epi32_l
= _mm_madd_epi16(a_epi16_l
, b_epi16_l
);
385 __m128i madd_epi32_h
= _mm_madd_epi16(a_epi16_h
, b_epi16_h
);
387 // Now go back from 32-bit values to 16-bit values & signed saturate
388 return _mm_packs_epi32(madd_epi32_l
, madd_epi32_h
);
391 template <class Arch
>
392 inline std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
393 PermuteSummer(xsimd::batch
<int32_t, Arch
> pack0123
,
394 xsimd::batch
<int32_t, Arch
> pack4567
,
395 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
396 return {pack0123
, pack4567
};
402 template <class Arch
>
403 std::tuple
<xsimd::batch
<int8_t, Arch
>, xsimd::batch
<int8_t, Arch
>>
404 interleave(xsimd::batch
<int8_t, Arch
> first
, xsimd::batch
<int8_t, Arch
> second
,
405 xsimd::kernel::requires_arch
<xsimd::neon
>) {
406 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
409 template <class Arch
>
410 std::tuple
<xsimd::batch
<int16_t, Arch
>, xsimd::batch
<int16_t, Arch
>>
411 interleave(xsimd::batch
<int16_t, Arch
> first
,
412 xsimd::batch
<int16_t, Arch
> second
,
413 xsimd::kernel::requires_arch
<xsimd::neon
>) {
414 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
417 template <class Arch
>
418 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
419 interleave(xsimd::batch
<int32_t, Arch
> first
,
420 xsimd::batch
<int32_t, Arch
> second
,
421 xsimd::kernel::requires_arch
<xsimd::neon
>) {
422 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
425 template <class Arch
>
426 std::tuple
<xsimd::batch
<int64_t, Arch
>, xsimd::batch
<int64_t, Arch
>>
427 interleave(xsimd::batch
<int64_t, Arch
> first
,
428 xsimd::batch
<int64_t, Arch
> second
,
429 xsimd::kernel::requires_arch
<xsimd::neon
>) {
430 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
433 template <class Arch
>
434 xsimd::batch
<int8_t, Arch
>
435 deinterleave(xsimd::batch
<int16_t, Arch
> first
,
436 xsimd::batch
<int16_t, Arch
> second
,
437 xsimd::kernel::requires_arch
<xsimd::neon
>) {
439 return vcombine_s8(vqmovn_s16(first
), vqmovn_s16(second
));
442 template <class Arch
>
443 xsimd::batch
<int16_t, Arch
>
444 deinterleave(xsimd::batch
<int32_t, Arch
> first
,
445 xsimd::batch
<int32_t, Arch
> second
,
446 xsimd::kernel::requires_arch
<xsimd::neon
>) {
447 return vcombine_s16(vqmovn_s32(first
), vqmovn_s32(second
));
450 template <class Arch
>
451 inline xsimd::batch
<int32_t, Arch
>
452 madd(xsimd::batch
<int16_t, Arch
> x
, xsimd::batch
<int16_t, Arch
> y
,
453 xsimd::kernel::requires_arch
<xsimd::neon
>) {
455 int32x4_t low
= vmull_s16(vget_low_s16(x
), vget_low_s16(y
));
456 int32x4_t high
= vmull_s16(vget_high_s16(x
), vget_high_s16(y
));
458 int32x2_t low_sum
= vpadd_s32(vget_low_s32(low
), vget_high_s32(low
));
459 int32x2_t high_sum
= vpadd_s32(vget_low_s32(high
), vget_high_s32(high
));
461 return vcombine_s32(low_sum
, high_sum
);
464 template <class Arch
>
465 inline xsimd::batch
<int16_t, Arch
>
466 madd(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
467 xsimd::kernel::requires_arch
<xsimd::neon
>) {
469 // This would be much simpler if x86 would choose to zero extend OR sign
470 // extend, not both. This could probably be optimized better.
474 vreinterpretq_s16_u16(vshrq_n_u16(vreinterpretq_u16_u8(x
), 8));
475 int16x8_t x_even
= vreinterpretq_s16_u16(
476 vbicq_u16(vreinterpretq_u16_u8(x
), vdupq_n_u16(0xff00)));
478 // Sign extend by shifting left then shifting right.
479 int16x8_t y_even
= vshrq_n_s16(vshlq_n_s16(vreinterpretq_s16_s8(y
), 8), 8);
480 int16x8_t y_odd
= vshrq_n_s16(vreinterpretq_s16_s8(y
), 8);
483 int16x8_t prod1
= vmulq_s16(x_even
, y_even
);
484 int16x8_t prod2
= vmulq_s16(x_odd
, y_odd
);
487 return vqaddq_s16(prod1
, prod2
);
490 template <class Arch
>
491 inline xsimd::batch
<int16_t, Arch
>
492 madd(xsimd::batch
<int8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
493 xsimd::kernel::requires_arch
<xsimd::neon
>) {
494 int16x8_t low
= vmull_s8(vget_low_s8(x
), vget_low_s8(y
));
495 int16x8_t high
= vmull_s8(vget_high_s8(x
), vget_high_s8(y
));
497 int16x4_t low_sum
= vpadd_s16(vget_low_s16(low
), vget_high_s16(low
));
498 int16x4_t high_sum
= vpadd_s16(vget_low_s16(high
), vget_high_s16(high
));
500 return vcombine_s16(low_sum
, high_sum
);
503 template <class Arch
>
504 inline std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
505 PermuteSummer(xsimd::batch
<int32_t, Arch
> pack0123
,
506 xsimd::batch
<int32_t, Arch
> pack4567
,
507 xsimd::kernel::requires_arch
<xsimd::neon
>) {
508 return {pack0123
, pack4567
};
513 template <class Arch
>
514 std::tuple
<xsimd::batch
<int8_t, Arch
>, xsimd::batch
<int8_t, Arch
>>
515 interleave(xsimd::batch
<int8_t, Arch
> first
, xsimd::batch
<int8_t, Arch
> second
,
516 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
517 return {vzip1q_s8(first
, second
), vzip2q_s8(first
, second
)};
520 template <class Arch
>
521 std::tuple
<xsimd::batch
<int16_t, Arch
>, xsimd::batch
<int16_t, Arch
>>
522 interleave(xsimd::batch
<int16_t, Arch
> first
,
523 xsimd::batch
<int16_t, Arch
> second
,
524 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
525 return {vzip1q_s16(first
, second
), vzip2q_s16(first
, second
)};
528 template <class Arch
>
529 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
530 interleave(xsimd::batch
<int32_t, Arch
> first
,
531 xsimd::batch
<int32_t, Arch
> second
,
532 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
533 return {vzip1q_s32(first
, second
), vzip2q_s32(first
, second
)};
536 template <class Arch
>
537 std::tuple
<xsimd::batch
<int64_t, Arch
>, xsimd::batch
<int64_t, Arch
>>
538 interleave(xsimd::batch
<int64_t, Arch
> first
,
539 xsimd::batch
<int64_t, Arch
> second
,
540 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
541 return {vzip1q_s64(first
, second
), vzip2q_s64(first
, second
)};
544 template <class Arch
>
545 xsimd::batch
<int8_t, Arch
>
546 deinterleave(xsimd::batch
<int16_t, Arch
> first
,
547 xsimd::batch
<int16_t, Arch
> second
,
548 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
550 return vqmovn_high_s16(vqmovn_s16(first
), second
);
553 template <class Arch
>
554 xsimd::batch
<int16_t, Arch
>
555 deinterleave(xsimd::batch
<int32_t, Arch
> first
,
556 xsimd::batch
<int32_t, Arch
> second
,
557 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
558 return vqmovn_high_s32(vqmovn_s32(first
), second
);
561 #ifdef __ARM_FEATURE_MATMUL_INT8
562 template <class Arch
>
563 inline xsimd::batch
<int32_t, Arch
>
564 maddw(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
565 xsimd::batch
<int32_t, Arch
> z
,
566 xsimd::kernel::requires_arch
<xsimd::i8mm
<xsimd::neon64
>>) {
567 return vusdotq_s32(z
, x
, y
);
571 template <class Arch
>
572 inline xsimd::batch
<int32_t, Arch
>
573 maddw(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
574 xsimd::batch
<int32_t, Arch
> z
,
575 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
576 int16x8_t tl
= vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(x
))),
577 vmovl_s8(vget_low_s8(y
)));
578 int16x8_t th
= vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x
))),
579 vmovl_s8(vget_high_s8(y
)));
580 return vpadalq_s16(vpadalq_s16(z
, tl
), th
);
583 template <class Arch
>
584 inline xsimd::batch
<int32_t, Arch
> Pack0123(xsimd::batch
<int32_t, Arch
> sum0
,
585 xsimd::batch
<int32_t, Arch
> sum1
,
586 xsimd::batch
<int32_t, Arch
> sum2
,
587 xsimd::batch
<int32_t, Arch
> sum3
,
588 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
589 auto pack01
= vpaddq_s32(sum0
, sum1
);
590 auto pack23
= vpaddq_s32(sum2
, sum3
);
591 return vpaddq_s32(pack01
, pack23
);
596 template <class Arch
>
597 inline xsimd::batch
<int32_t, Arch
>
598 maddw(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
599 xsimd::batch
<int32_t, Arch
> z
,
600 xsimd::kernel::requires_arch
<xsimd::generic
>) {
601 return z
+ madd(xsimd::batch
<int16_t, Arch
>(1), madd(x
, y
, Arch
{}), Arch
{});
604 } // namespace kernel
607 // Generic dispatcher for interleave, deinterleave madd and PermuteSummer
610 template <class T
, class Arch
>
611 std::tuple
<xsimd::batch
<T
, Arch
>, xsimd::batch
<T
, Arch
>>
612 interleave(xsimd::batch
<T
, Arch
> first
, xsimd::batch
<T
, Arch
> second
) {
613 return kernel::interleave(first
, second
, Arch
{});
616 template <class Arch
>
617 xsimd::batch
<int8_t, Arch
> deinterleave(xsimd::batch
<int16_t, Arch
> first
,
618 xsimd::batch
<int16_t, Arch
> second
) {
619 return kernel::deinterleave(first
, second
, Arch
{});
621 template <class Arch
>
622 xsimd::batch
<int16_t, Arch
> deinterleave(xsimd::batch
<int32_t, Arch
> first
,
623 xsimd::batch
<int32_t, Arch
> second
) {
624 return kernel::deinterleave(first
, second
, Arch
{});
627 template <class Arch
>
628 inline xsimd::batch
<int32_t, Arch
> madd(xsimd::batch
<int16_t, Arch
> x
,
629 xsimd::batch
<int16_t, Arch
> y
) {
630 return kernel::madd(x
, y
, Arch
{});
632 template <class Arch
>
633 inline xsimd::batch
<int16_t, Arch
> madd(xsimd::batch
<int8_t, Arch
> x
,
634 xsimd::batch
<int8_t, Arch
> y
) {
635 return kernel::madd(x
, y
, Arch
{});
637 template <class Arch
>
638 inline xsimd::batch
<int16_t, Arch
> madd(xsimd::batch
<uint8_t, Arch
> x
,
639 xsimd::batch
<int8_t, Arch
> y
) {
640 return kernel::madd(x
, y
, Arch
{});
642 template <class Arch
>
643 inline xsimd::batch
<int32_t, Arch
> maddw(xsimd::batch
<uint8_t, Arch
> x
,
644 xsimd::batch
<int8_t, Arch
> y
,
645 xsimd::batch
<int32_t, Arch
> z
647 return kernel::maddw(x
, y
, z
, Arch
{});
649 template <class Arch
>
650 inline xsimd::batch
<int32_t, Arch
> maddw(xsimd::batch
<uint8_t, Arch
> x
,
651 xsimd::batch
<int8_t, Arch
> y
653 return maddw(x
, y
, xsimd::batch
<int32_t, Arch
>((int32_t)0));
656 template <class Arch
>
657 inline auto PermuteSummer(xsimd::batch
<int32_t, Arch
> pack0123
,
658 xsimd::batch
<int32_t, Arch
> pack4567
)
659 -> decltype(kernel::PermuteSummer(pack0123
, pack4567
, Arch
{})) {
660 return kernel::PermuteSummer(pack0123
, pack4567
, Arch
{});
666 template <class Arch
>
667 inline xsimd::batch
<int32_t, Arch
> Pack0123(xsimd::batch
<int32_t, Arch
> sum0
,
668 xsimd::batch
<int32_t, Arch
> sum1
,
669 xsimd::batch
<int32_t, Arch
> sum2
,
670 xsimd::batch
<int32_t, Arch
> sum3
,
671 xsimd::kernel::requires_arch
<xsimd::generic
>) {
673 std::tie(sum0
, sum1
) = interleave(sum0
, sum1
, Arch
{});
674 auto pack01
= sum0
+ sum1
;
675 std::tie(sum2
, sum3
) = interleave(sum2
, sum3
, Arch
{});
676 auto pack23
= sum2
+ sum3
;
678 auto packed
= interleave(xsimd::bitwise_cast
<int64_t>(pack01
),
679 xsimd::bitwise_cast
<int64_t>(pack23
),
681 return xsimd::bitwise_cast
<int32_t>(std::get
<0>(packed
)) +
682 xsimd::bitwise_cast
<int32_t>(std::get
<1>(packed
));
686 template <class Arch
>
687 inline xsimd::batch
<int32_t, Arch
> Pack0123(xsimd::batch
<int32_t, Arch
> sum0
,
688 xsimd::batch
<int32_t, Arch
> sum1
,
689 xsimd::batch
<int32_t, Arch
> sum2
,
690 xsimd::batch
<int32_t, Arch
> sum3
) {
691 return kernel::Pack0123(sum0
, sum1
, sum2
, sum3
, Arch
{});
694 template <class Arch
>
695 static inline xsimd::batch
<int32_t, Arch
>
696 quantize(xsimd::batch
<float, Arch
> input
,
697 xsimd::batch
<float, Arch
> quant_mult
) {
698 return xsimd::nearbyint_as_int(input
* quant_mult
);
701 template <class Arch
>
702 inline xsimd::batch
<int32_t, Arch
>
703 QuantizerGrab(const float *input
, xsimd::batch
<float, Arch
> quant_mult_reg
) {
704 return quantize(xsimd::batch
<float, Arch
>::load_unaligned(input
),
709 inline __m512
Concat(const __m256 first
, const __m256 second
) {
710 // INTGEMM_AVX512DQ but that goes with INTGEMM_AVX512BW anyway.
711 return _mm512_insertf32x8(_mm512_castps256_ps512(first
), second
, 1);
714 // Like QuantizerGrab, but allows 32-byte halves (i.e. 8 columns) to be
715 // controlled independently.
716 /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set
717 * INTGEMM_AVX512BW */
718 inline __m512i
QuantizerGrabHalves(const float *input0
, const float *input1
,
719 const __m512 quant_mult_reg
) {
720 __m512 appended
= Concat(_mm256_loadu_ps(input0
), _mm256_loadu_ps(input1
));
721 appended
= _mm512_mul_ps(appended
, quant_mult_reg
);
722 return _mm512_cvtps_epi32(appended
);
725 template <class Arch
>
726 inline xsimd::batch
<int32_t, Arch
>
727 QuantizerGrabHalves(const float *input0
, const float *input1
,
728 xsimd::batch
<float, Arch
> quant_mult_reg
);
731 /* Read 8 floats at a time from input0, input1, input2, and input3. Quantize
732 * them to 8-bit by multiplying with quant_mult_reg then rounding. Concatenate
733 * the result into one register and return it.
735 class QuantizeTile8
{
736 template <class Arch
> struct Tiler
{
737 static constexpr uint32_t get(std::size_t i
, std::size_t n
) {
738 size_t factor
= xsimd::batch
<float, Arch
>::size
/ 4;
739 return (i
% factor
) * 4 + i
/ factor
;
744 template <class Arch
>
745 static inline xsimd::batch
<int8_t, Arch
>
746 Consecutive(xsimd::batch
<float, Arch
> quant_mult
, const float *input
) {
747 return Tile(quant_mult
, input
+ 0 * xsimd::batch
<float, Arch
>::size
,
748 input
+ 1 * xsimd::batch
<float, Arch
>::size
,
749 input
+ 2 * xsimd::batch
<float, Arch
>::size
,
750 input
+ 3 * xsimd::batch
<float, Arch
>::size
);
753 template <class Arch
>
754 static inline xsimd::batch
<uint8_t, Arch
>
755 ConsecutiveU(xsimd::batch
<float, Arch
> quant_mult
, const float *input
) {
756 return TileU(quant_mult
, input
+ 0 * xsimd::batch
<float, Arch
>::size
,
757 input
+ 1 * xsimd::batch
<float, Arch
>::size
,
758 input
+ 2 * xsimd::batch
<float, Arch
>::size
,
759 input
+ 3 * xsimd::batch
<float, Arch
>::size
);
762 template <class Arch
>
763 static inline xsimd::batch
<int8_t, Arch
>
764 ConsecutiveWithWrapping(xsimd::batch
<float, Arch
> quant_mult
,
765 const float *input
, size_t cols_left
, size_t cols
,
767 using batchf32
= xsimd::batch
<float, Arch
>;
768 const float *inputs
[4];
769 for (size_t i
= 0; i
< std::size(inputs
); ++i
) {
770 while (cols_left
< batchf32::size
) {
771 input
+= cols
* (row_step
- 1);
775 input
+= batchf32::size
;
776 cols_left
-= batchf32::size
;
778 return Tile(quant_mult
, inputs
[0], inputs
[1], inputs
[2], inputs
[3]);
781 template <class Arch
>
782 static inline xsimd::batch
<int8_t, Arch
>
783 ForReshape(xsimd::batch
<float, Arch
> quant_mult
, const float *input
,
785 using batchf32
= xsimd::batch
<float, Arch
>;
786 using batch8
= xsimd::batch
<int8_t, Arch
>;
787 using batch16
= xsimd::batch
<int16_t, Arch
>;
788 using batch32
= xsimd::batch
<int32_t, Arch
>;
789 using ubatch32
= xsimd::batch
<uint32_t, Arch
>;
791 // Put higher rows in the second half of the register. These will jumble
792 // around in the same way then conveniently land in the right place.
793 if constexpr (batchf32::size
== 16) {
794 const batch8
neg127(-127);
795 // In reverse order: grabbing the first 32-bit values from each 128-bit
796 // register, then the second 32-bit values, etc. Grab 4 registers at a
797 // time in 32-bit format.
799 QuantizerGrabHalves(input
+ 0 * cols
, input
+ 2 * cols
, quant_mult
);
801 QuantizerGrabHalves(input
+ 16 * cols
, input
+ 18 * cols
, quant_mult
);
803 QuantizerGrabHalves(input
+ 32 * cols
, input
+ 34 * cols
, quant_mult
);
805 QuantizerGrabHalves(input
+ 48 * cols
, input
+ 50 * cols
, quant_mult
);
807 // Pack 32-bit to 16-bit.
808 batch16 packed0
= deinterleave(g0
, g1
);
809 batch16 packed1
= deinterleave(g2
, g3
);
810 // Pack 16-bit to 8-bit.
811 batch8 packed
= deinterleave(packed0
, packed1
);
813 packed
= xsimd::max(packed
, neg127
);
815 return xsimd::bitwise_cast
<int8_t>(
816 xsimd::swizzle(xsimd::bitwise_cast
<int32_t>(packed
),
817 xsimd::make_batch_constant
<ubatch32
, Tiler
<Arch
>>()));
818 } else if constexpr (batchf32::size
== 8)
819 return Tile(quant_mult
, input
, input
+ 2 * cols
, input
+ 16 * cols
,
821 else if constexpr (batchf32::size
== 4)
823 return Tile(quant_mult
, input
, input
+ 4, input
+ 2 * cols
,
824 input
+ 2 * cols
+ 4);
829 template <class Arch
>
830 static inline xsimd::batch
<int8_t, Arch
>
831 Tile(xsimd::batch
<float, Arch
> quant_mult
, const float *input0
,
832 const float *input1
, const float *input2
, const float *input3
) {
833 using batch8
= xsimd::batch
<int8_t, Arch
>;
834 using batch16
= xsimd::batch
<int16_t, Arch
>;
835 using batch32
= xsimd::batch
<int32_t, Arch
>;
836 using ubatch32
= xsimd::batch
<uint32_t, Arch
>;
838 const batch8
neg127(-127);
839 // Grab 4 registers at a time in 32-bit format.
840 batch32 g0
= QuantizerGrab(input0
, quant_mult
);
841 batch32 g1
= QuantizerGrab(input1
, quant_mult
);
842 batch32 g2
= QuantizerGrab(input2
, quant_mult
);
843 batch32 g3
= QuantizerGrab(input3
, quant_mult
);
844 // Pack 32-bit to 16-bit.
845 batch16 packed0
= deinterleave(g0
, g1
);
846 batch16 packed1
= deinterleave(g2
, g3
);
847 // Pack 16-bit to 8-bit.
848 batch8 packed
= deinterleave(packed0
, packed1
);
850 packed
= xsimd::max(packed
, neg127
);
852 if constexpr (batch32::size
== 4)
854 // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14
855 // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7
856 // Technically this could be removed so long as the rows are bigger than 16
857 // and the values are only used for GEMM.
858 return xsimd::bitwise_cast
<int8_t>(
859 xsimd::swizzle(xsimd::bitwise_cast
<int32_t>(packed
),
860 xsimd::make_batch_constant
<ubatch32
, Tiler
<Arch
>>()));
864 // A version that produces uint8_ts
865 template <class Arch
>
866 static inline xsimd::batch
<uint8_t, Arch
>
867 TileU(xsimd::batch
<float, Arch
> quant_mult
, const float *input0
,
868 const float *input1
, const float *input2
, const float *input3
) {
869 using batch8
= xsimd::batch
<int8_t, Arch
>;
870 using batch16
= xsimd::batch
<int16_t, Arch
>;
871 using batch32
= xsimd::batch
<int32_t, Arch
>;
872 using ubatch32
= xsimd::batch
<uint32_t, Arch
>;
874 const batch8 neg127
= -127;
875 const batch8 pos127
= +127;
876 // Grab 4 registers at a time in 32-bit format.
877 batch32 g0
= QuantizerGrab(input0
, quant_mult
);
878 batch32 g1
= QuantizerGrab(input1
, quant_mult
);
879 batch32 g2
= QuantizerGrab(input2
, quant_mult
);
880 batch32 g3
= QuantizerGrab(input3
, quant_mult
);
881 // Pack 32-bit to 16-bit.
882 batch16 packed0
= deinterleave(g0
, g1
);
883 batch16 packed1
= deinterleave(g2
, g3
);
884 // Pack 16-bit to 8-bit.
885 batch8 packed
= deinterleave(packed0
, packed1
);
887 packed
= xsimd::max(packed
, neg127
); // Could be removed if we use +128
888 packed
= packed
+ pos127
;
889 if (batch32::size
== 4)
890 return xsimd::bitwise_cast
<uint8_t>(packed
);
891 // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14
892 // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7
893 // Technically this could be removed so long as the rows are bigger than 16
894 // and the values are only used for GEMM.
895 return xsimd::bitwise_cast
<uint8_t>(
896 xsimd::swizzle(xsimd::bitwise_cast
<int32_t>(packed
),
897 xsimd::make_batch_constant
<ubatch32
, Tiler
<Arch
>>()));
901 template <class Arch
>
902 inline void Transpose16InLane(
903 xsimd::batch
<int8_t, Arch
> &r0
, xsimd::batch
<int8_t, Arch
> &r1
,
904 xsimd::batch
<int8_t, Arch
> &r2
, xsimd::batch
<int8_t, Arch
> &r3
,
905 xsimd::batch
<int8_t, Arch
> &r4
, xsimd::batch
<int8_t, Arch
> &r5
,
906 xsimd::batch
<int8_t, Arch
> &r6
, xsimd::batch
<int8_t, Arch
> &r7
) {
907 /* r0: columns 0 1 2 3 4 5 6 7 from row 0
908 r1: columns 0 1 2 3 4 5 6 7 from row 1*/
909 auto r0_16
= xsimd::bitwise_cast
<int16_t>(r0
);
910 auto r1_16
= xsimd::bitwise_cast
<int16_t>(r1
);
911 auto r2_16
= xsimd::bitwise_cast
<int16_t>(r2
);
912 auto r3_16
= xsimd::bitwise_cast
<int16_t>(r3
);
913 auto r4_16
= xsimd::bitwise_cast
<int16_t>(r4
);
914 auto r5_16
= xsimd::bitwise_cast
<int16_t>(r5
);
915 auto r6_16
= xsimd::bitwise_cast
<int16_t>(r6
);
916 auto r7_16
= xsimd::bitwise_cast
<int16_t>(r7
);
918 std::tie(r0_16
, r1_16
) = interleave(r0_16
, r1_16
);
919 std::tie(r2_16
, r3_16
) = interleave(r2_16
, r3_16
);
920 std::tie(r4_16
, r5_16
) = interleave(r4_16
, r5_16
);
921 std::tie(r6_16
, r7_16
) = interleave(r6_16
, r7_16
);
922 /* r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1
923 r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1
924 r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3
925 r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3
926 r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5
927 r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5
928 r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7
929 r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7*/
930 auto r0_32
= xsimd::bitwise_cast
<int32_t>(r0_16
);
931 auto r2_32
= xsimd::bitwise_cast
<int32_t>(r2_16
);
932 auto r1_32
= xsimd::bitwise_cast
<int32_t>(r1_16
);
933 auto r3_32
= xsimd::bitwise_cast
<int32_t>(r3_16
);
934 auto r4_32
= xsimd::bitwise_cast
<int32_t>(r4_16
);
935 auto r6_32
= xsimd::bitwise_cast
<int32_t>(r6_16
);
936 auto r5_32
= xsimd::bitwise_cast
<int32_t>(r5_16
);
937 auto r7_32
= xsimd::bitwise_cast
<int32_t>(r7_16
);
939 std::tie(r0_32
, r2_32
) = interleave(r0_32
, r2_32
);
940 std::tie(r1_32
, r3_32
) = interleave(r1_32
, r3_32
);
941 std::tie(r4_32
, r6_32
) = interleave(r4_32
, r6_32
);
942 std::tie(r5_32
, r7_32
) = interleave(r5_32
, r7_32
);
943 /* r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3
944 r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3
945 r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3
946 r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3
947 r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7
948 r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7
949 r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7
950 r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7*/
952 auto r0_64
= xsimd::bitwise_cast
<int64_t>(r0_32
);
953 auto r2_64
= xsimd::bitwise_cast
<int64_t>(r2_32
);
954 auto r1_64
= xsimd::bitwise_cast
<int64_t>(r1_32
);
955 auto r3_64
= xsimd::bitwise_cast
<int64_t>(r3_32
);
956 auto r4_64
= xsimd::bitwise_cast
<int64_t>(r4_32
);
957 auto r6_64
= xsimd::bitwise_cast
<int64_t>(r6_32
);
958 auto r5_64
= xsimd::bitwise_cast
<int64_t>(r5_32
);
959 auto r7_64
= xsimd::bitwise_cast
<int64_t>(r7_32
);
961 std::tie(r0_64
, r4_64
) = interleave(r0_64
, r4_64
);
962 std::tie(r1_64
, r5_64
) = interleave(r1_64
, r5_64
);
963 std::tie(r2_64
, r6_64
) = interleave(r2_64
, r6_64
);
964 std::tie(r3_64
, r7_64
) = interleave(r3_64
, r7_64
);
966 r0
= xsimd::bitwise_cast
<int8_t>(r0_64
);
967 r1
= xsimd::bitwise_cast
<int8_t>(r1_64
);
968 r2
= xsimd::bitwise_cast
<int8_t>(r2_64
);
969 r3
= xsimd::bitwise_cast
<int8_t>(r3_64
);
970 r4
= xsimd::bitwise_cast
<int8_t>(r4_64
);
971 r5
= xsimd::bitwise_cast
<int8_t>(r5_64
);
972 r6
= xsimd::bitwise_cast
<int8_t>(r6_64
);
973 r7
= xsimd::bitwise_cast
<int8_t>(r7_64
);
974 /* r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7
975 r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7
976 r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7
977 r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7
978 r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7
979 r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7*/
980 /* Empirically gcc is able to remove these movs and just rename the outputs of
986 template <class Arch
, typename IntegerTy
>
987 void SelectColumnsOfB(const xsimd::batch
<int8_t, Arch
> *input
,
988 xsimd::batch
<int8_t, Arch
> *output
,
989 size_t rows_bytes
/* number of bytes in a row */,
990 const IntegerTy
*cols_begin
, const IntegerTy
*cols_end
) {
991 using batch8
= xsimd::batch
<int8_t, Arch
>;
992 /* Do columns for multiples of 8.*/
993 size_t register_rows
= rows_bytes
/ batch8::size
;
994 const batch8
*starts
[8];
995 for (; cols_begin
!= cols_end
; cols_begin
+= 8) {
996 for (size_t k
= 0; k
< 8; ++k
) {
998 input
+ (cols_begin
[k
] & 7) + (cols_begin
[k
] & ~7) * register_rows
;
1000 for (size_t r
= 0; r
< register_rows
; ++r
) {
1001 for (size_t k
= 0; k
< 8; ++k
) {
1002 *(output
++) = *starts
[k
];
1011 namespace callbacks
{
1012 template <class Arch
>
1013 xsimd::batch
<float, Arch
> Unquantize::operator()(xsimd::batch
<int32_t, Arch
> total
, size_t, size_t,
1015 return xsimd::batch_cast
<float>(total
) * unquant_mult
;
1018 template <class Arch
>
1019 std::tuple
<xsimd::batch
<float, Arch
>, xsimd::batch
<float, Arch
>> Unquantize::operator()(
1020 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>> total
,
1021 size_t, size_t, size_t) {
1022 return std::make_tuple(
1023 xsimd::batch_cast
<float>(std::get
<0>(total
)) * unquant_mult
,
1024 xsimd::batch_cast
<float>(std::get
<1>(total
)) * unquant_mult
);
1027 template <class Arch
>
1028 xsimd::batch
<float, Arch
> AddBias::operator()(xsimd::batch
<float, Arch
> total
, size_t,
1029 size_t col_idx
, size_t) {
1030 return total
+ xsimd::batch
<float, Arch
>::load_aligned(bias_addr
+ col_idx
);
1033 template <class Arch
>
1034 std::tuple
<xsimd::batch
<float, Arch
>, xsimd::batch
<float, Arch
>>
1035 AddBias::operator()(
1036 std::tuple
<xsimd::batch
<float, Arch
>, xsimd::batch
<float, Arch
>> total
,
1037 size_t, size_t col_idx
, size_t) {
1038 return std::make_tuple(
1039 std::get
<0>(total
) + xsimd::batch
<float, Arch
>::load_aligned(bias_addr
+ col_idx
+ 0),
1040 std::get
<1>(total
) +
1041 xsimd::batch
<float, Arch
>::load_aligned(bias_addr
+ col_idx
+
1042 xsimd::batch
<float, Arch
>::size
));
1045 template <class Arch
>
1046 void Write::operator()(xsimd::batch
<float, Arch
> result
, size_t row_idx
,
1047 size_t col_idx
, size_t col_size
) {
1048 result
.store_aligned(output_addr
+ row_idx
* col_size
+ col_idx
);
1051 template <class Arch
>
1052 void Write::operator()(xsimd::batch
<int32_t, Arch
> result
, size_t row_idx
,
1053 size_t col_idx
, size_t col_size
) {
1054 xsimd::bitwise_cast
<float>(result
).store_aligned(
1055 output_addr
+ row_idx
* col_size
+ col_idx
);
1058 template <class Arch
>
1059 void Write::operator()(
1060 std::tuple
<xsimd::batch
<float, Arch
>, xsimd::batch
<float, Arch
>> result
,
1061 size_t row_idx
, size_t col_idx
, size_t col_size
) {
1062 std::get
<0>(result
).store_aligned(output_addr
+ row_idx
* col_size
+ col_idx
+
1064 std::get
<1>(result
).store_aligned(output_addr
+ row_idx
* col_size
+ col_idx
+
1065 xsimd::batch
<float, Arch
>::size
);
1068 template <class Arch
>
1069 void Write::operator()(
1070 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>> result
,
1071 size_t row_idx
, size_t col_idx
, size_t col_size
) {
1072 xsimd::bitwise_cast
<float>(std::get
<0>(result
))
1073 .store_aligned(output_addr
+ row_idx
* col_size
+ col_idx
+ 0);
1074 xsimd::bitwise_cast
<float>(std::get
<1>(result
))
1075 .store_aligned(output_addr
+ row_idx
* col_size
+ col_idx
+
1076 xsimd::batch
<int32_t, Arch
>::size
);
1080 void UnquantizeAndWrite::operator()(T
const &total
, size_t row_idx
,
1081 size_t col_idx
, size_t col_size
) {
1082 auto unquantized
= unquantize(total
, row_idx
, col_idx
, col_size
);
1083 write(unquantized
, row_idx
, col_idx
, col_size
);
1087 void UnquantizeAndAddBiasAndWrite::operator()(T
const &total
, size_t row_idx
,
1088 size_t col_idx
, size_t col_size
) {
1089 auto unquantized
= unquantize(total
, row_idx
, col_idx
, col_size
);
1090 auto bias_added
= add_bias(unquantized
, row_idx
, col_idx
, col_size
);
1091 write(bias_added
, row_idx
, col_idx
, col_size
);
1093 } // namespace callbacks
1095 template <class Arch
>
1096 void Engine
<Arch
>::QuantizeU(const float *input
, uint8_t *output
,
1097 float quant_mult
, size_t size
) {
1098 using batch8
= xsimd::batch
<int8_t, Arch
>;
1100 xsimd::batch
<float, Arch
> q(quant_mult
);
1101 const float *end
= input
+ size
;
1102 for (; input
!= end
; input
+= batch8::size
, output
+= batch8::size
) {
1103 auto tile
= QuantizeTile8::ConsecutiveU(q
, input
);
1104 tile
.store_aligned(output
);
1108 template <class Arch
>
1109 void Engine
<Arch
>::Quantize(const float *const input
, int8_t *const output
,
1110 float quant_mult
, size_t size
) {
1111 using batch8
= xsimd::batch
<int8_t, Arch
>;
1113 const std::size_t kBatch
= batch8::size
;
1114 const std::size_t fast_end
= size
& ~(kBatch
- 1);
1116 xsimd::batch
<float, Arch
> q(quant_mult
);
1117 for (std::size_t i
= 0; i
< fast_end
; i
+= kBatch
) {
1118 auto tile
= QuantizeTile8::Consecutive(q
, input
+ i
);
1119 tile
.store_aligned(output
+ i
);
1122 std::size_t overhang
= size
& (kBatch
- 1);
1125 /* Each does size(xsimd::batch<int8_t, Arch>) / 32 == kBatch / 4 floats at a
1126 * time. If we're allowed to read one of them, then we can read the whole
1129 const float *inputs
[4];
1131 for (i
= 0; i
< (overhang
+ (kBatch
/ 4) - 1) / (kBatch
/ 4); ++i
) {
1132 inputs
[i
] = &input
[fast_end
+ i
* (kBatch
/ 4)];
1134 /* These will be clipped off. */
1135 for (; i
< 4; ++i
) {
1136 inputs
[i
] = &input
[fast_end
];
1139 QuantizeTile8::Tile(q
, inputs
[0], inputs
[1], inputs
[2], inputs
[3]);
1140 std::memcpy(output
+ (size
& ~(kBatch
- 1)), &result
, overhang
);
1143 template <class Arch
>
1144 template <typename IntegerTy
>
1145 void Engine
<Arch
>::SelectColumnsB(const int8_t *input
, int8_t *output
,
1146 size_t rows
, const IntegerTy
*cols_begin
,
1147 const IntegerTy
*cols_end
) {
1148 using batch8
= xsimd::batch
<int8_t, Arch
>;
1149 SelectColumnsOfB(reinterpret_cast<const batch8
*>(input
),
1150 reinterpret_cast<batch8
*>(output
), rows
, cols_begin
,
1154 template <class Arch
>
1155 void Engine
<Arch
>::PrepareBTransposed(const float *input
, int8_t *output
,
1156 float quant_mult
, size_t cols
,
1158 using batch8
= xsimd::batch
<int8_t, Arch
>;
1159 const size_t RegisterElemsInt
= batch8::size
;
1160 const size_t kColStride
= 8;
1162 xsimd::batch
<float, Arch
> q(quant_mult
);
1163 auto *output_it
= reinterpret_cast<batch8
*>(output
);
1167 for (size_t ri
= 0; ri
< 8; ++ri
)
1168 *output_it
++ = QuantizeTile8::ConsecutiveWithWrapping(
1169 q
, input
+ (r
+ ri
) * cols
+ c
, cols
- c
, cols
, 8);
1170 c
+= RegisterElemsInt
;
1178 template <class Arch
>
1179 void Engine
<Arch
>::PrepareBQuantizedTransposed(const int8_t *input
,
1180 int8_t *output
, size_t cols
,
1182 using batch8
= xsimd::batch
<int8_t, Arch
>;
1183 const size_t RegisterElems
= batch8::size
;
1184 const size_t kColStride
= 8;
1186 auto *output_it
= reinterpret_cast<batch8
*>(output
);
1187 for (size_t r
= 0; r
< rows
; r
+= kColStride
)
1188 for (size_t c
= 0; c
< cols
; c
+= RegisterElems
)
1189 for (size_t ri
= 0; ri
< 8; ++ri
)
1191 *reinterpret_cast<const batch8
*>(input
+ (r
+ ri
) * cols
+ c
);
1194 template <class Arch
>
1195 void Engine
<Arch
>::PrepareB(const float *input
, int8_t *output_shadow
,
1196 float quant_mult
, size_t rows
, size_t cols
) {
1197 using batch8
= xsimd::batch
<int8_t, Arch
>;
1199 xsimd::batch
<float, Arch
> q(quant_mult
);
1200 /* Currently all multipliers have a stride of 8 columns.*/
1201 const size_t kColStride
= 8;
1202 auto *output
= reinterpret_cast<batch8
*>(output_shadow
);
1203 for (size_t c
= 0; c
< cols
; c
+= kColStride
) {
1204 for (size_t r
= 0; r
< rows
; r
+= sizeof(*output
), output
+= 8) {
1206 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 0) + c
, cols
);
1208 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 1) + c
, cols
);
1210 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 4) + c
, cols
);
1212 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 5) + c
, cols
);
1214 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 8) + c
, cols
);
1216 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 9) + c
, cols
);
1218 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 12) + c
, cols
);
1220 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 13) + c
, cols
);
1221 std::tie(output
[0], output
[1]) =
1222 interleave(xsimd::bitwise_cast
<int8_t>(output
[0]),
1223 xsimd::bitwise_cast
<int8_t>(output
[1]));
1224 std::tie(output
[2], output
[3]) =
1225 interleave(xsimd::bitwise_cast
<int8_t>(output
[2]),
1226 xsimd::bitwise_cast
<int8_t>(output
[3]));
1227 std::tie(output
[4], output
[5]) =
1228 interleave(xsimd::bitwise_cast
<int8_t>(output
[4]),
1229 xsimd::bitwise_cast
<int8_t>(output
[5]));
1230 std::tie(output
[6], output
[7]) =
1231 interleave(xsimd::bitwise_cast
<int8_t>(output
[6]),
1232 xsimd::bitwise_cast
<int8_t>(output
[7]));
1233 Transpose16InLane(output
[0], output
[1], output
[2], output
[3], output
[4],
1234 output
[5], output
[6], output
[7]);
1239 template <class Arch
>
1240 void Engine
<Arch
>::PrepareA(const float *input
, int8_t *output
,
1241 float quant_mult
, size_t rows
, size_t cols
) {
1242 Quantize(input
, output
, quant_mult
, rows
* cols
);
1245 template <class Arch
>
1246 void Engine
<Arch
>::Shift::PrepareA(const float *input
, uint8_t *output
,
1247 float quant_mult
, size_t rows
, size_t cols
) {
1248 QuantizeU(input
, output
, quant_mult
, rows
* cols
);
1251 template <class Arch
>
1252 template <class Callback
>
1253 void Engine
<Arch
>::Shift::Multiply(const uint8_t *A
, const int8_t *B
,
1254 size_t A_rows
, size_t width
, size_t B_cols
,
1255 Callback callback
) {
1257 using batch8
= xsimd::batch
<int8_t, Arch
>;
1258 using ubatch8
= xsimd::batch
<uint8_t, Arch
>;
1259 using batch32
= xsimd::batch
<int32_t, Arch
>;
1261 const size_t simd_width
= width
/ batch8::size
;
1262 for (size_t B0_colidx
= 0; B0_colidx
< B_cols
; B0_colidx
+= 8) {
1263 const auto *B0_col
=
1264 reinterpret_cast<const batch8
*>(B
) + simd_width
* B0_colidx
;
1265 /* Process one row of A at a time. Doesn't seem to be faster to do multiple
1266 * rows of A at once.*/
1267 for (size_t A_rowidx
= 0; A_rowidx
< A_rows
; ++A_rowidx
) {
1269 reinterpret_cast<const ubatch8
*>(A
+ A_rowidx
* width
);
1270 /* These will be packed 16-bit integers containing sums for each row of B
1271 multiplied by the row of A. Iterate over shared (inner) dimension.*/
1272 /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is
1275 ubatch8 a
= *(A_row
+ k
);
1276 batch32 isum0
= maddw(a
, *(B0_col
+ k
* 8));
1277 batch32 isum1
= maddw(a
, *(B0_col
+ k
* 8 + 1));
1278 batch32 isum2
= maddw(a
, *(B0_col
+ k
* 8 + 2));
1279 batch32 isum3
= maddw(a
, *(B0_col
+ k
* 8 + 3));
1280 batch32 isum4
= maddw(a
, *(B0_col
+ k
* 8 + 4));
1281 batch32 isum5
= maddw(a
, *(B0_col
+ k
* 8 + 5));
1282 batch32 isum6
= maddw(a
, *(B0_col
+ k
* 8 + 6));
1283 batch32 isum7
= maddw(a
, *(B0_col
+ k
* 8 + 7));
1284 for (k
= 1; k
< simd_width
; ++k
) {
1286 /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/
1287 /* Upcast to 32-bit and horizontally add.*/
1288 isum0
= maddw(a
, *(B0_col
+ k
* 8 + 0), isum0
);
1289 isum1
= maddw(a
, *(B0_col
+ k
* 8 + 1), isum1
);
1290 isum2
= maddw(a
, *(B0_col
+ k
* 8 + 2), isum2
);
1291 isum3
= maddw(a
, *(B0_col
+ k
* 8 + 3), isum3
);
1292 isum4
= maddw(a
, *(B0_col
+ k
* 8 + 4), isum4
);
1293 isum5
= maddw(a
, *(B0_col
+ k
* 8 + 5), isum5
);
1294 isum6
= maddw(a
, *(B0_col
+ k
* 8 + 6), isum6
);
1295 isum7
= maddw(a
, *(B0_col
+ k
* 8 + 7), isum7
);
1297 /* Reduce sums within 128-bit lanes.*/
1298 auto pack0123
= Pack0123(isum0
, isum1
, isum2
, isum3
);
1299 auto pack4567
= Pack0123(isum4
, isum5
, isum6
, isum7
);
1300 /*The specific implementation may need to reduce further.*/
1301 auto total
= PermuteSummer(pack0123
, pack4567
);
1302 callback(total
, A_rowidx
, B0_colidx
, B_cols
);
1307 template <class Arch
>
1308 template <class Callback
>
1309 void Engine
<Arch
>::Shift::PrepareBias(const int8_t *B
, size_t width
,
1310 size_t B_cols
, Callback C
) {
1311 using batch8
= xsimd::batch
<int8_t, Arch
>;
1312 const size_t simd_width
= width
/ batch8::size
;
1313 xsimd::batch
<uint8_t, Arch
> a(1);
1314 for (size_t j
= 0; j
< B_cols
; j
+= 8) {
1315 /*Process one row of A at a time. Doesn't seem to be faster to do multiple
1316 * rows of A at once.*/
1317 const int8_t *B_j
= B
+ j
* width
;
1319 /* Rather than initializing as zeros and adding, just initialize the
1321 /* These will be packed 16-bit integers containing sums for each column of
1322 * B multiplied by the row of A.*/
1323 /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is
1325 auto isum0
= maddw(a
, batch8::load_aligned(&B_j
[0 * batch8::size
]));
1326 auto isum1
= maddw(a
, batch8::load_aligned(&B_j
[1 * batch8::size
]));
1327 auto isum2
= maddw(a
, batch8::load_aligned(&B_j
[2 * batch8::size
]));
1328 auto isum3
= maddw(a
, batch8::load_aligned(&B_j
[3 * batch8::size
]));
1329 auto isum4
= maddw(a
, batch8::load_aligned(&B_j
[4 * batch8::size
]));
1330 auto isum5
= maddw(a
, batch8::load_aligned(&B_j
[5 * batch8::size
]));
1331 auto isum6
= maddw(a
, batch8::load_aligned(&B_j
[6 * batch8::size
]));
1332 auto isum7
= maddw(a
, batch8::load_aligned(&B_j
[7 * batch8::size
]));
1334 B_j
+= 8 * batch8::size
;
1336 for (size_t k
= 1; k
< simd_width
; ++k
, B_j
+= 8 * batch8::size
) {
1337 isum0
= maddw(a
, batch8::load_aligned(&B_j
[0 * batch8::size
]), isum0
);
1338 isum1
= maddw(a
, batch8::load_aligned(&B_j
[1 * batch8::size
]), isum1
);
1339 isum2
= maddw(a
, batch8::load_aligned(&B_j
[2 * batch8::size
]), isum2
);
1340 isum3
= maddw(a
, batch8::load_aligned(&B_j
[3 * batch8::size
]), isum3
);
1341 isum4
= maddw(a
, batch8::load_aligned(&B_j
[4 * batch8::size
]), isum4
);
1342 isum5
= maddw(a
, batch8::load_aligned(&B_j
[5 * batch8::size
]), isum5
);
1343 isum6
= maddw(a
, batch8::load_aligned(&B_j
[6 * batch8::size
]), isum6
);
1344 isum7
= maddw(a
, batch8::load_aligned(&B_j
[7 * batch8::size
]), isum7
);
1347 auto pack0123
= Pack0123(isum0
, isum1
, isum2
, isum3
);
1348 auto pack4567
= Pack0123(isum4
, isum5
, isum6
, isum7
);
1350 auto total
= PermuteSummer(pack0123
, pack4567
);
1351 C(total
, 0, j
, B_cols
);
1355 } // namespace gemmology