4 #include "gemmology_fwd.h"
10 #include <xsimd/xsimd.hpp>
17 // Arch specific implementation of various elementary operations
24 std::tuple
<xsimd::batch
<int8_t, Arch
>, xsimd::batch
<int8_t, Arch
>>
25 interleave(xsimd::batch
<int8_t, Arch
> first
, xsimd::batch
<int8_t, Arch
> second
,
26 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
27 return {_mm512_unpacklo_epi8(first
, second
),
28 _mm512_unpackhi_epi8(first
, second
)};
32 std::tuple
<xsimd::batch
<int16_t, Arch
>, xsimd::batch
<int16_t, Arch
>>
33 interleave(xsimd::batch
<int16_t, Arch
> first
,
34 xsimd::batch
<int16_t, Arch
> second
,
35 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
36 return {_mm512_unpacklo_epi16(first
, second
),
37 _mm512_unpackhi_epi16(first
, second
)};
41 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
42 interleave(xsimd::batch
<int32_t, Arch
> first
,
43 xsimd::batch
<int32_t, Arch
> second
,
44 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
45 return {_mm512_unpacklo_epi32(first
, second
),
46 _mm512_unpackhi_epi32(first
, second
)};
50 std::tuple
<xsimd::batch
<int64_t, Arch
>, xsimd::batch
<int64_t, Arch
>>
51 interleave(xsimd::batch
<int64_t, Arch
> first
,
52 xsimd::batch
<int64_t, Arch
> second
,
53 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
54 return {_mm512_unpacklo_epi64(first
, second
),
55 _mm512_unpackhi_epi64(first
, second
)};
59 xsimd::batch
<int8_t, Arch
>
60 deinterleave(xsimd::batch
<int16_t, Arch
> first
,
61 xsimd::batch
<int16_t, Arch
> second
,
62 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
63 return _mm512_packs_epi16(first
, second
);
67 xsimd::batch
<int16_t, Arch
>
68 deinterleave(xsimd::batch
<int32_t, Arch
> first
,
69 xsimd::batch
<int32_t, Arch
> second
,
70 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
71 return _mm512_packs_epi32(first
, second
);
75 inline xsimd::batch
<int32_t, Arch
>
76 madd(xsimd::batch
<int16_t, Arch
> x
, xsimd::batch
<int16_t, Arch
> y
,
77 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
78 return _mm512_madd_epi16(x
, y
);
82 inline xsimd::batch
<int16_t, Arch
>
83 madd(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
84 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
85 return _mm512_maddubs_epi16(x
, y
);
89 inline xsimd::batch
<int16_t, Arch
>
90 madd(xsimd::batch
<int8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
91 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
92 return _mm512_madd_epi16(x
, y
);
96 inline xsimd::batch
<int32_t, xsimd::avx2
>
97 PermuteSummer(xsimd::batch
<int32_t, Arch
> pack0123
,
98 xsimd::batch
<int32_t, Arch
> pack4567
,
99 xsimd::kernel::requires_arch
<xsimd::avx512bw
>) {
100 // Form [0th 128-bit register of pack0123, 0st 128-bit register of pack4567,
101 // 2nd 128-bit register of pack0123, 2nd 128-bit register of pack4567]
103 _mm512_mask_permutex_epi64(pack0123
, 0xcc, pack4567
, (0 << 4) | (1 << 6));
104 // Form [1st 128-bit register of pack0123, 1st 128-bit register of pack4567,
105 // 3rd 128-bit register of pack0123, 3rd 128-bit register of pack4567]
107 _mm512_mask_permutex_epi64(pack4567
, 0x33, pack0123
, 2 | (3 << 2));
108 __m512i added
= _mm512_add_epi32(mix0
, mix1
);
109 // Now we have 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7.
110 // Fold register over itself.
111 return _mm256_add_epi32(_mm512_castsi512_si256(added
),
112 _mm512_extracti64x4_epi64(added
, 1));
117 template <class Arch
>
118 std::tuple
<xsimd::batch
<int8_t, Arch
>, xsimd::batch
<int8_t, Arch
>>
119 interleave(xsimd::batch
<int8_t, Arch
> first
, xsimd::batch
<int8_t, Arch
> second
,
120 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
121 return {_mm256_unpacklo_epi8(first
, second
),
122 _mm256_unpackhi_epi8(first
, second
)};
125 template <class Arch
>
126 std::tuple
<xsimd::batch
<int16_t, Arch
>, xsimd::batch
<int16_t, Arch
>>
127 interleave(xsimd::batch
<int16_t, Arch
> first
,
128 xsimd::batch
<int16_t, Arch
> second
,
129 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
130 return {_mm256_unpacklo_epi16(first
, second
),
131 _mm256_unpackhi_epi16(first
, second
)};
134 template <class Arch
>
135 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
136 interleave(xsimd::batch
<int32_t, Arch
> first
,
137 xsimd::batch
<int32_t, Arch
> second
,
138 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
139 return {_mm256_unpacklo_epi32(first
, second
),
140 _mm256_unpackhi_epi32(first
, second
)};
143 template <class Arch
>
144 std::tuple
<xsimd::batch
<int64_t, Arch
>, xsimd::batch
<int64_t, Arch
>>
145 interleave(xsimd::batch
<int64_t, Arch
> first
,
146 xsimd::batch
<int64_t, Arch
> second
,
147 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
148 return {_mm256_unpacklo_epi64(first
, second
),
149 _mm256_unpackhi_epi64(first
, second
)};
152 template <class Arch
>
153 xsimd::batch
<int8_t, Arch
>
154 deinterleave(xsimd::batch
<int16_t, Arch
> first
,
155 xsimd::batch
<int16_t, Arch
> second
,
156 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
157 return _mm256_packs_epi16(first
, second
);
160 template <class Arch
>
161 xsimd::batch
<int16_t, Arch
>
162 deinterleave(xsimd::batch
<int32_t, Arch
> first
,
163 xsimd::batch
<int32_t, Arch
> second
,
164 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
165 return _mm256_packs_epi32(first
, second
);
168 template <class Arch
>
169 inline xsimd::batch
<int32_t, Arch
>
170 madd(xsimd::batch
<int16_t, Arch
> x
, xsimd::batch
<int16_t, Arch
> y
,
171 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
172 return _mm256_madd_epi16(x
, y
);
175 template <class Arch
>
176 inline xsimd::batch
<int16_t, Arch
>
177 madd(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
178 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
179 return _mm256_maddubs_epi16(x
, y
);
182 template <class Arch
>
183 inline xsimd::batch
<int16_t, Arch
>
184 madd(xsimd::batch
<int8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
185 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
186 return _mm256_maddubs_epi16(xsimd::abs(x
), _mm256_sign_epi8(y
, x
));
189 template <class Arch
>
190 inline xsimd::batch
<int32_t, Arch
>
191 PermuteSummer(xsimd::batch
<int32_t, Arch
> pack0123
,
192 xsimd::batch
<int32_t, Arch
> pack4567
,
193 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
194 // This instruction generates 1s 2s 3s 4s 5f 6f 7f 8f
195 __m256i rev
= _mm256_permute2f128_si256(pack0123
, pack4567
, 0x21);
196 // This instruction generates 1f 2f 3f 4f 5s 6s 7s 8s
197 __m256i blended
= _mm256_blend_epi32(pack0123
, pack4567
, 0xf0);
198 return _mm256_add_epi32(rev
, blended
);
201 template <class Arch
>
202 inline xsimd::batch
<int32_t, Arch
> Pack0123(xsimd::batch
<int32_t, Arch
> sum0
,
203 xsimd::batch
<int32_t, Arch
> sum1
,
204 xsimd::batch
<int32_t, Arch
> sum2
,
205 xsimd::batch
<int32_t, Arch
> sum3
,
206 xsimd::kernel::requires_arch
<xsimd::avx2
>) {
207 auto pack01
= _mm256_hadd_epi32(sum0
, sum1
);
208 auto pack23
= _mm256_hadd_epi32(sum2
, sum3
);
209 return _mm256_hadd_epi32(pack01
, pack23
);
214 template <class Arch
>
215 inline xsimd::batch
<int32_t, Arch
>
216 maddw(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
217 xsimd::batch
<int32_t, Arch
> z
,
218 xsimd::kernel::requires_arch
<xsimd::avxvnni
>) {
219 return _mm256_dpbusd_avx_epi32(z
, x
, y
);
223 #ifdef __AVX512VNNI__
225 template <class Arch
>
226 inline xsimd::batch
<int32_t, Arch
>
227 maddw(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
228 xsimd::batch
<int32_t, Arch
> z
,
229 xsimd::kernel::requires_arch
<xsimd::avx512vnni
<xsimd::avx512bw
>>) {
230 return _mm512_dpbusd_epi32(z
, x
, y
);
233 template <class Arch
>
234 inline xsimd::batch
<int32_t, Arch
>
235 maddw(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
236 xsimd::batch
<int32_t, Arch
> z
,
237 xsimd::kernel::requires_arch
<xsimd::avx512vnni
<xsimd::avx512vbmi
>>) {
238 return _mm512_dpbusd_epi32(z
, x
, y
);
246 template <class Arch
>
247 inline xsimd::batch
<int16_t, Arch
>
248 madd(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
249 xsimd::kernel::requires_arch
<xsimd::ssse3
>) {
250 return _mm_maddubs_epi16(x
, y
);
253 template <class Arch
>
254 inline xsimd::batch
<int16_t, Arch
>
255 madd(xsimd::batch
<int8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
256 xsimd::kernel::requires_arch
<xsimd::ssse3
>) {
257 return _mm_maddubs_epi16(xsimd::abs(x
), _mm_sign_epi8(y
, x
));
260 template <class Arch
>
261 inline xsimd::batch
<int32_t, Arch
> Pack0123(xsimd::batch
<int32_t, Arch
> sum0
,
262 xsimd::batch
<int32_t, Arch
> sum1
,
263 xsimd::batch
<int32_t, Arch
> sum2
,
264 xsimd::batch
<int32_t, Arch
> sum3
,
265 xsimd::kernel::requires_arch
<xsimd::ssse3
>) {
266 auto pack01
= _mm_hadd_epi32(sum0
, sum1
);
267 auto pack23
= _mm_hadd_epi32(sum2
, sum3
);
268 return _mm_hadd_epi32(pack01
, pack23
);
273 template <class Arch
>
274 std::tuple
<xsimd::batch
<int8_t, Arch
>, xsimd::batch
<int8_t, Arch
>>
275 interleave(xsimd::batch
<int8_t, Arch
> first
, xsimd::batch
<int8_t, Arch
> second
,
276 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
277 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
280 template <class Arch
>
281 std::tuple
<xsimd::batch
<int16_t, Arch
>, xsimd::batch
<int16_t, Arch
>>
282 interleave(xsimd::batch
<int16_t, Arch
> first
,
283 xsimd::batch
<int16_t, Arch
> second
,
284 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
285 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
288 template <class Arch
>
289 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
290 interleave(xsimd::batch
<int32_t, Arch
> first
,
291 xsimd::batch
<int32_t, Arch
> second
,
292 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
293 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
296 template <class Arch
>
297 std::tuple
<xsimd::batch
<int64_t, Arch
>, xsimd::batch
<int64_t, Arch
>>
298 interleave(xsimd::batch
<int64_t, Arch
> first
,
299 xsimd::batch
<int64_t, Arch
> second
,
300 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
301 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
304 template <class Arch
>
305 xsimd::batch
<int8_t, Arch
>
306 deinterleave(xsimd::batch
<int16_t, Arch
> first
,
307 xsimd::batch
<int16_t, Arch
> second
,
308 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
309 return _mm_packs_epi16(first
, second
);
312 template <class Arch
>
313 xsimd::batch
<int16_t, Arch
>
314 deinterleave(xsimd::batch
<int32_t, Arch
> first
,
315 xsimd::batch
<int32_t, Arch
> second
,
316 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
317 return _mm_packs_epi32(first
, second
);
320 template <class Arch
>
321 inline xsimd::batch
<int32_t, Arch
>
322 madd(xsimd::batch
<int16_t, Arch
> x
, xsimd::batch
<int16_t, Arch
> y
,
323 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
324 return _mm_madd_epi16(x
, y
);
327 template <class Arch
>
328 inline xsimd::batch
<int16_t, Arch
>
329 madd(xsimd::batch
<uint8_t, Arch
> a
, xsimd::batch
<int8_t, Arch
> b
,
330 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
332 // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2
333 // a = 0x00 0x01 0xFE 0x04 ...
334 // b = 0x00 0x02 0x80 0x84 ...
336 // To extend signed 8-bit value, MSB has to be set to 0xFF
337 __m128i sign_mask_b
= _mm_cmplt_epi8(b
, _mm_setzero_si128());
339 // sign_mask_b = 0x00 0x00 0xFF 0xFF ...
341 // Unpack positives with 0x00, negatives with 0xFF
342 __m128i a_epi16_l
= _mm_unpacklo_epi8(a
, _mm_setzero_si128());
343 __m128i a_epi16_h
= _mm_unpackhi_epi8(a
, _mm_setzero_si128());
344 __m128i b_epi16_l
= _mm_unpacklo_epi8(b
, sign_mask_b
);
345 __m128i b_epi16_h
= _mm_unpackhi_epi8(b
, sign_mask_b
);
347 // Here - valid 16-bit signed integers corresponding to the 8-bit input
348 // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ...
350 // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts
351 __m128i madd_epi32_l
= _mm_madd_epi16(a_epi16_l
, b_epi16_l
);
352 __m128i madd_epi32_h
= _mm_madd_epi16(a_epi16_h
, b_epi16_h
);
354 // Now go back from 32-bit values to 16-bit values & signed saturate
355 return _mm_packs_epi32(madd_epi32_l
, madd_epi32_h
);
358 template <class Arch
>
359 inline xsimd::batch
<int16_t, Arch
>
360 madd(xsimd::batch
<int8_t, Arch
> a
, xsimd::batch
<int8_t, Arch
> b
,
361 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
363 // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2
364 // a = 0x00 0x01 0xFE 0x04 ...
365 // b = 0x00 0x02 0x80 0x84 ...
367 // To extend signed 8-bit value, MSB has to be set to 0xFF
368 __m128i sign_mask_a
= _mm_cmplt_epi8(a
, _mm_setzero_si128());
369 __m128i sign_mask_b
= _mm_cmplt_epi8(b
, _mm_setzero_si128());
371 // sign_mask_a = 0x00 0x00 0xFF 0x00 ...
372 // sign_mask_b = 0x00 0x00 0xFF 0xFF ...
374 // Unpack positives with 0x00, negatives with 0xFF
375 __m128i a_epi16_l
= _mm_unpacklo_epi8(a
, sign_mask_a
);
376 __m128i a_epi16_h
= _mm_unpackhi_epi8(a
, sign_mask_a
);
377 __m128i b_epi16_l
= _mm_unpacklo_epi8(b
, sign_mask_b
);
378 __m128i b_epi16_h
= _mm_unpackhi_epi8(b
, sign_mask_b
);
380 // Here - valid 16-bit signed integers corresponding to the 8-bit input
381 // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ...
383 // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts
384 __m128i madd_epi32_l
= _mm_madd_epi16(a_epi16_l
, b_epi16_l
);
385 __m128i madd_epi32_h
= _mm_madd_epi16(a_epi16_h
, b_epi16_h
);
387 // Now go back from 32-bit values to 16-bit values & signed saturate
388 return _mm_packs_epi32(madd_epi32_l
, madd_epi32_h
);
391 template <class Arch
>
392 inline std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
393 PermuteSummer(xsimd::batch
<int32_t, Arch
> pack0123
,
394 xsimd::batch
<int32_t, Arch
> pack4567
,
395 xsimd::kernel::requires_arch
<xsimd::sse2
>) {
396 return {pack0123
, pack4567
};
402 template <class Arch
>
403 std::tuple
<xsimd::batch
<int8_t, Arch
>, xsimd::batch
<int8_t, Arch
>>
404 interleave(xsimd::batch
<int8_t, Arch
> first
, xsimd::batch
<int8_t, Arch
> second
,
405 xsimd::kernel::requires_arch
<xsimd::neon
>) {
406 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
409 template <class Arch
>
410 std::tuple
<xsimd::batch
<int16_t, Arch
>, xsimd::batch
<int16_t, Arch
>>
411 interleave(xsimd::batch
<int16_t, Arch
> first
,
412 xsimd::batch
<int16_t, Arch
> second
,
413 xsimd::kernel::requires_arch
<xsimd::neon
>) {
414 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
417 template <class Arch
>
418 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
419 interleave(xsimd::batch
<int32_t, Arch
> first
,
420 xsimd::batch
<int32_t, Arch
> second
,
421 xsimd::kernel::requires_arch
<xsimd::neon
>) {
422 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
425 template <class Arch
>
426 std::tuple
<xsimd::batch
<int64_t, Arch
>, xsimd::batch
<int64_t, Arch
>>
427 interleave(xsimd::batch
<int64_t, Arch
> first
,
428 xsimd::batch
<int64_t, Arch
> second
,
429 xsimd::kernel::requires_arch
<xsimd::neon
>) {
430 return {xsimd::zip_lo(first
, second
), xsimd::zip_hi(first
, second
)};
433 template <class Arch
>
434 xsimd::batch
<int8_t, Arch
>
435 deinterleave(xsimd::batch
<int16_t, Arch
> first
,
436 xsimd::batch
<int16_t, Arch
> second
,
437 xsimd::kernel::requires_arch
<xsimd::neon
>) {
439 return vcombine_s8(vqmovn_s16(first
), vqmovn_s16(second
));
442 template <class Arch
>
443 xsimd::batch
<int16_t, Arch
>
444 deinterleave(xsimd::batch
<int32_t, Arch
> first
,
445 xsimd::batch
<int32_t, Arch
> second
,
446 xsimd::kernel::requires_arch
<xsimd::neon
>) {
447 return vcombine_s16(vqmovn_s32(first
), vqmovn_s32(second
));
450 template <class Arch
>
451 inline xsimd::batch
<int32_t, Arch
>
452 madd(xsimd::batch
<int16_t, Arch
> x
, xsimd::batch
<int16_t, Arch
> y
,
453 xsimd::kernel::requires_arch
<xsimd::neon
>) {
455 int32x4_t low
= vmull_s16(vget_low_s16(x
), vget_low_s16(y
));
456 int32x4_t high
= vmull_s16(vget_high_s16(x
), vget_high_s16(y
));
458 int32x2_t low_sum
= vpadd_s32(vget_low_s32(low
), vget_high_s32(low
));
459 int32x2_t high_sum
= vpadd_s32(vget_low_s32(high
), vget_high_s32(high
));
461 return vcombine_s32(low_sum
, high_sum
);
464 template <class Arch
>
465 inline xsimd::batch
<int16_t, Arch
>
466 madd(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
467 xsimd::kernel::requires_arch
<xsimd::neon
>) {
469 // This would be much simpler if x86 would choose to zero extend OR sign
470 // extend, not both. This could probably be optimized better.
474 vreinterpretq_s16_u16(vshrq_n_u16(vreinterpretq_u16_u8(x
), 8));
475 int16x8_t x_even
= vreinterpretq_s16_u16(
476 vbicq_u16(vreinterpretq_u16_u8(x
), vdupq_n_u16(0xff00)));
478 // Sign extend by shifting left then shifting right.
479 int16x8_t y_even
= vshrq_n_s16(vshlq_n_s16(vreinterpretq_s16_s8(y
), 8), 8);
480 int16x8_t y_odd
= vshrq_n_s16(vreinterpretq_s16_s8(y
), 8);
483 int16x8_t prod1
= vmulq_s16(x_even
, y_even
);
484 int16x8_t prod2
= vmulq_s16(x_odd
, y_odd
);
487 return vqaddq_s16(prod1
, prod2
);
490 template <class Arch
>
491 inline xsimd::batch
<int16_t, Arch
>
492 madd(xsimd::batch
<int8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
493 xsimd::kernel::requires_arch
<xsimd::neon
>) {
494 int16x8_t low
= vmull_s8(vget_low_s8(x
), vget_low_s8(y
));
495 int16x8_t high
= vmull_s8(vget_high_s8(x
), vget_high_s8(y
));
497 int16x4_t low_sum
= vpadd_s16(vget_low_s16(low
), vget_high_s16(low
));
498 int16x4_t high_sum
= vpadd_s16(vget_low_s16(high
), vget_high_s16(high
));
500 return vcombine_s16(low_sum
, high_sum
);
503 template <class Arch
>
504 inline std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
505 PermuteSummer(xsimd::batch
<int32_t, Arch
> pack0123
,
506 xsimd::batch
<int32_t, Arch
> pack4567
,
507 xsimd::kernel::requires_arch
<xsimd::neon
>) {
508 return {pack0123
, pack4567
};
513 template <class Arch
>
514 std::tuple
<xsimd::batch
<int8_t, Arch
>, xsimd::batch
<int8_t, Arch
>>
515 interleave(xsimd::batch
<int8_t, Arch
> first
, xsimd::batch
<int8_t, Arch
> second
,
516 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
517 return {vzip1q_s8(first
, second
), vzip2q_s8(first
, second
)};
520 template <class Arch
>
521 std::tuple
<xsimd::batch
<int16_t, Arch
>, xsimd::batch
<int16_t, Arch
>>
522 interleave(xsimd::batch
<int16_t, Arch
> first
,
523 xsimd::batch
<int16_t, Arch
> second
,
524 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
525 return {vzip1q_s16(first
, second
), vzip2q_s16(first
, second
)};
528 template <class Arch
>
529 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
530 interleave(xsimd::batch
<int32_t, Arch
> first
,
531 xsimd::batch
<int32_t, Arch
> second
,
532 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
533 return {vzip1q_s32(first
, second
), vzip2q_s32(first
, second
)};
536 template <class Arch
>
537 std::tuple
<xsimd::batch
<int64_t, Arch
>, xsimd::batch
<int64_t, Arch
>>
538 interleave(xsimd::batch
<int64_t, Arch
> first
,
539 xsimd::batch
<int64_t, Arch
> second
,
540 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
541 return {vzip1q_s64(first
, second
), vzip2q_s64(first
, second
)};
544 template <class Arch
>
545 xsimd::batch
<int8_t, Arch
>
546 deinterleave(xsimd::batch
<int16_t, Arch
> first
,
547 xsimd::batch
<int16_t, Arch
> second
,
548 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
550 return vqmovn_high_s16(vqmovn_s16(first
), second
);
553 template <class Arch
>
554 xsimd::batch
<int16_t, Arch
>
555 deinterleave(xsimd::batch
<int32_t, Arch
> first
,
556 xsimd::batch
<int32_t, Arch
> second
,
557 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
558 return vqmovn_high_s32(vqmovn_s32(first
), second
);
561 #ifdef __ARM_FEATURE_MATMUL_INT8
562 template <class Arch
>
563 inline xsimd::batch
<int32_t, Arch
>
564 maddw(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
565 xsimd::batch
<int32_t, Arch
> z
,
566 xsimd::kernel::requires_arch
<xsimd::i8mm
<xsimd::neon64
>>) {
567 return vusdotq_s32(z
, x
, y
);
571 template <class Arch
>
572 inline xsimd::batch
<int32_t, Arch
>
573 maddw(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
574 xsimd::batch
<int32_t, Arch
> z
,
575 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
576 int16x8_t tl
= vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(x
))),
577 vmovl_s8(vget_low_s8(y
)));
578 int16x8_t th
= vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x
))),
579 vmovl_s8(vget_high_s8(y
)));
580 return vpadalq_s16(vpadalq_s16(z
, tl
), th
);
583 template <class Arch
>
584 inline xsimd::batch
<int32_t, Arch
> Pack0123(xsimd::batch
<int32_t, Arch
> sum0
,
585 xsimd::batch
<int32_t, Arch
> sum1
,
586 xsimd::batch
<int32_t, Arch
> sum2
,
587 xsimd::batch
<int32_t, Arch
> sum3
,
588 xsimd::kernel::requires_arch
<xsimd::neon64
>) {
589 auto pack01
= vpaddq_s32(sum0
, sum1
);
590 auto pack23
= vpaddq_s32(sum2
, sum3
);
591 return vpaddq_s32(pack01
, pack23
);
596 template <class Arch
>
597 inline xsimd::batch
<int32_t, Arch
>
598 maddw(xsimd::batch
<uint8_t, Arch
> x
, xsimd::batch
<int8_t, Arch
> y
,
599 xsimd::batch
<int32_t, Arch
> z
,
600 xsimd::kernel::requires_arch
<xsimd::generic
>) {
601 return z
+ madd(xsimd::batch
<int16_t, Arch
>(1), madd(x
, y
, Arch
{}), Arch
{});
604 } // namespace kernel
607 // Generic dispatcher for interleave, deinterleave madd and PermuteSummer
610 template <class T
, class Arch
>
611 std::tuple
<xsimd::batch
<T
, Arch
>, xsimd::batch
<T
, Arch
>>
612 interleave(xsimd::batch
<T
, Arch
> first
, xsimd::batch
<T
, Arch
> second
) {
613 return kernel::interleave(first
, second
, Arch
{});
616 template <class Arch
>
617 xsimd::batch
<int8_t, Arch
> deinterleave(xsimd::batch
<int16_t, Arch
> first
,
618 xsimd::batch
<int16_t, Arch
> second
) {
619 return kernel::deinterleave(first
, second
, Arch
{});
621 template <class Arch
>
622 xsimd::batch
<int16_t, Arch
> deinterleave(xsimd::batch
<int32_t, Arch
> first
,
623 xsimd::batch
<int32_t, Arch
> second
) {
624 return kernel::deinterleave(first
, second
, Arch
{});
627 template <class Arch
>
628 inline xsimd::batch
<int32_t, Arch
> madd(xsimd::batch
<int16_t, Arch
> x
,
629 xsimd::batch
<int16_t, Arch
> y
) {
630 return kernel::madd(x
, y
, Arch
{});
632 template <class Arch
>
633 inline xsimd::batch
<int16_t, Arch
> madd(xsimd::batch
<int8_t, Arch
> x
,
634 xsimd::batch
<int8_t, Arch
> y
) {
635 return kernel::madd(x
, y
, Arch
{});
637 template <class Arch
>
638 inline xsimd::batch
<int16_t, Arch
> madd(xsimd::batch
<uint8_t, Arch
> x
,
639 xsimd::batch
<int8_t, Arch
> y
) {
640 return kernel::madd(x
, y
, Arch
{});
642 template <class Arch
>
643 inline xsimd::batch
<int32_t, Arch
> maddw(xsimd::batch
<uint8_t, Arch
> x
,
644 xsimd::batch
<int8_t, Arch
> y
,
645 xsimd::batch
<int32_t, Arch
> z
647 return kernel::maddw(x
, y
, z
, Arch
{});
649 template <class Arch
>
650 inline xsimd::batch
<int32_t, Arch
> maddw(xsimd::batch
<uint8_t, Arch
> x
,
651 xsimd::batch
<int8_t, Arch
> y
653 return maddw(x
, y
, xsimd::batch
<int32_t, Arch
>((int32_t)0));
656 template <class Arch
>
657 inline auto PermuteSummer(xsimd::batch
<int32_t, Arch
> pack0123
,
658 xsimd::batch
<int32_t, Arch
> pack4567
)
659 -> decltype(kernel::PermuteSummer(pack0123
, pack4567
, Arch
{})) {
660 return kernel::PermuteSummer(pack0123
, pack4567
, Arch
{});
666 template <class Arch
>
667 inline xsimd::batch
<int32_t, Arch
> Pack0123(xsimd::batch
<int32_t, Arch
> sum0
,
668 xsimd::batch
<int32_t, Arch
> sum1
,
669 xsimd::batch
<int32_t, Arch
> sum2
,
670 xsimd::batch
<int32_t, Arch
> sum3
,
671 xsimd::kernel::requires_arch
<xsimd::generic
>) {
673 std::tie(sum0
, sum1
) = interleave(sum0
, sum1
, Arch
{});
674 auto pack01
= sum0
+ sum1
;
675 std::tie(sum2
, sum3
) = interleave(sum2
, sum3
, Arch
{});
676 auto pack23
= sum2
+ sum3
;
678 auto packed
= interleave(xsimd::bitwise_cast
<int64_t>(pack01
),
679 xsimd::bitwise_cast
<int64_t>(pack23
),
681 return xsimd::bitwise_cast
<int32_t>(std::get
<0>(packed
)) +
682 xsimd::bitwise_cast
<int32_t>(std::get
<1>(packed
));
686 template <class Arch
>
687 inline xsimd::batch
<int32_t, Arch
> Pack0123(xsimd::batch
<int32_t, Arch
> sum0
,
688 xsimd::batch
<int32_t, Arch
> sum1
,
689 xsimd::batch
<int32_t, Arch
> sum2
,
690 xsimd::batch
<int32_t, Arch
> sum3
) {
691 return kernel::Pack0123(sum0
, sum1
, sum2
, sum3
, Arch
{});
694 template <class Arch
>
695 static inline xsimd::batch
<int32_t, Arch
>
696 quantize(xsimd::batch
<float, Arch
> input
,
697 xsimd::batch
<float, Arch
> quant_mult
) {
698 return xsimd::nearbyint_as_int(input
* quant_mult
);
701 template <class Arch
>
702 inline xsimd::batch
<int32_t, Arch
>
703 QuantizerGrab(const float *input
, xsimd::batch
<float, Arch
> quant_mult_reg
) {
704 return quantize(xsimd::batch
<float, Arch
>::load_unaligned(input
),
709 inline __m512
Concat(const __m256 first
, const __m256 second
) {
710 // INTGEMM_AVX512DQ but that goes with INTGEMM_AVX512BW anyway.
711 return _mm512_insertf32x8(_mm512_castps256_ps512(first
), second
, 1);
714 // Like QuantizerGrab, but allows 32-byte halves (i.e. 8 columns) to be
715 // controlled independently.
716 /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set
717 * INTGEMM_AVX512BW */
718 inline __m512i
QuantizerGrabHalves(const float *input0
, const float *input1
,
719 const __m512 quant_mult_reg
) {
720 __m512 appended
= Concat(_mm256_loadu_ps(input0
), _mm256_loadu_ps(input1
));
721 appended
= _mm512_mul_ps(appended
, quant_mult_reg
);
722 return _mm512_cvtps_epi32(appended
);
725 template <class Arch
>
726 inline xsimd::batch
<int32_t, Arch
>
727 QuantizerGrabHalves(const float *input0
, const float *input1
,
728 xsimd::batch
<float, Arch
> quant_mult_reg
);
731 /* Read 8 floats at a time from input0, input1, input2, and input3. Quantize
732 * them to 8-bit by multiplying with quant_mult_reg then rounding. Concatenate
733 * the result into one register and return it.
735 class QuantizeTile8
{
736 template <class Arch
> struct Tiler
{
737 static constexpr uint32_t get(std::size_t i
, std::size_t n
) {
738 size_t factor
= xsimd::batch
<float, Arch
>::size
/ 4;
739 return (i
% factor
) * 4 + i
/ factor
;
744 template <class Arch
>
745 static inline xsimd::batch
<int8_t, Arch
>
746 Consecutive(xsimd::batch
<float, Arch
> quant_mult
, const float *input
) {
747 return Tile(quant_mult
, input
+ 0 * xsimd::batch
<float, Arch
>::size
,
748 input
+ 1 * xsimd::batch
<float, Arch
>::size
,
749 input
+ 2 * xsimd::batch
<float, Arch
>::size
,
750 input
+ 3 * xsimd::batch
<float, Arch
>::size
);
753 template <class Arch
>
754 static inline xsimd::batch
<uint8_t, Arch
>
755 ConsecutiveU(xsimd::batch
<float, Arch
> quant_mult
, const float *input
) {
756 return TileU(quant_mult
, input
+ 0 * xsimd::batch
<float, Arch
>::size
,
757 input
+ 1 * xsimd::batch
<float, Arch
>::size
,
758 input
+ 2 * xsimd::batch
<float, Arch
>::size
,
759 input
+ 3 * xsimd::batch
<float, Arch
>::size
);
762 template <class Arch
>
763 static inline xsimd::batch
<int8_t, Arch
>
764 ConsecutiveWithWrapping(xsimd::batch
<float, Arch
> quant_mult
,
765 const float *input
, size_t cols_left
, size_t cols
,
767 using batchf32
= xsimd::batch
<float, Arch
>;
768 const float *inputs
[4];
769 for (size_t i
= 0; i
< std::size(inputs
); ++i
) {
770 while (cols_left
< batchf32::size
) {
771 input
+= cols
* (row_step
- 1);
775 input
+= batchf32::size
;
776 cols_left
-= batchf32::size
;
778 return Tile(quant_mult
, inputs
[0], inputs
[1], inputs
[2], inputs
[3]);
781 template <class Arch
>
782 static inline xsimd::batch
<int8_t, Arch
>
783 ForReshape(xsimd::batch
<float, Arch
> quant_mult
, const float *input
,
785 using batchf32
= xsimd::batch
<float, Arch
>;
786 using batch8
= xsimd::batch
<int8_t, Arch
>;
787 using batch16
= xsimd::batch
<int16_t, Arch
>;
788 using batch32
= xsimd::batch
<int32_t, Arch
>;
790 // Put higher rows in the second half of the register. These will jumble
791 // around in the same way then conveniently land in the right place.
792 if constexpr (batchf32::size
== 16) {
793 const batch8
neg127(-127);
794 // In reverse order: grabbing the first 32-bit values from each 128-bit
795 // register, then the second 32-bit values, etc. Grab 4 registers at a
796 // time in 32-bit format.
798 QuantizerGrabHalves(input
+ 0 * cols
, input
+ 2 * cols
, quant_mult
);
800 QuantizerGrabHalves(input
+ 16 * cols
, input
+ 18 * cols
, quant_mult
);
802 QuantizerGrabHalves(input
+ 32 * cols
, input
+ 34 * cols
, quant_mult
);
804 QuantizerGrabHalves(input
+ 48 * cols
, input
+ 50 * cols
, quant_mult
);
806 // Pack 32-bit to 16-bit.
807 batch16 packed0
= deinterleave(g0
, g1
);
808 batch16 packed1
= deinterleave(g2
, g3
);
809 // Pack 16-bit to 8-bit.
810 batch8 packed
= deinterleave(packed0
, packed1
);
812 packed
= xsimd::max(packed
, neg127
);
814 return xsimd::bitwise_cast
<int8_t>(
815 xsimd::swizzle(xsimd::bitwise_cast
<int32_t>(packed
),
816 xsimd::make_batch_constant
<uint32_t, Arch
, Tiler
<Arch
>>()));
817 } else if constexpr (batchf32::size
== 8)
818 return Tile(quant_mult
, input
, input
+ 2 * cols
, input
+ 16 * cols
,
820 else if constexpr (batchf32::size
== 4)
822 return Tile(quant_mult
, input
, input
+ 4, input
+ 2 * cols
,
823 input
+ 2 * cols
+ 4);
828 template <class Arch
>
829 static inline xsimd::batch
<int8_t, Arch
>
830 Tile(xsimd::batch
<float, Arch
> quant_mult
, const float *input0
,
831 const float *input1
, const float *input2
, const float *input3
) {
832 using batch8
= xsimd::batch
<int8_t, Arch
>;
833 using batch16
= xsimd::batch
<int16_t, Arch
>;
834 using batch32
= xsimd::batch
<int32_t, Arch
>;
836 const batch8
neg127(-127);
837 // Grab 4 registers at a time in 32-bit format.
838 batch32 g0
= QuantizerGrab(input0
, quant_mult
);
839 batch32 g1
= QuantizerGrab(input1
, quant_mult
);
840 batch32 g2
= QuantizerGrab(input2
, quant_mult
);
841 batch32 g3
= QuantizerGrab(input3
, quant_mult
);
842 // Pack 32-bit to 16-bit.
843 batch16 packed0
= deinterleave(g0
, g1
);
844 batch16 packed1
= deinterleave(g2
, g3
);
845 // Pack 16-bit to 8-bit.
846 batch8 packed
= deinterleave(packed0
, packed1
);
848 packed
= xsimd::max(packed
, neg127
);
850 if constexpr (batch32::size
== 4)
852 // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14
853 // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7
854 // Technically this could be removed so long as the rows are bigger than 16
855 // and the values are only used for GEMM.
856 return xsimd::bitwise_cast
<int8_t>(
857 xsimd::swizzle(xsimd::bitwise_cast
<int32_t>(packed
),
858 xsimd::make_batch_constant
<uint32_t, Arch
, Tiler
<Arch
>>()));
862 // A version that produces uint8_ts
863 template <class Arch
>
864 static inline xsimd::batch
<uint8_t, Arch
>
865 TileU(xsimd::batch
<float, Arch
> quant_mult
, const float *input0
,
866 const float *input1
, const float *input2
, const float *input3
) {
867 using batch8
= xsimd::batch
<int8_t, Arch
>;
868 using batch16
= xsimd::batch
<int16_t, Arch
>;
869 using batch32
= xsimd::batch
<int32_t, Arch
>;
871 const batch8 neg127
= -127;
872 const batch8 pos127
= +127;
873 // Grab 4 registers at a time in 32-bit format.
874 batch32 g0
= QuantizerGrab(input0
, quant_mult
);
875 batch32 g1
= QuantizerGrab(input1
, quant_mult
);
876 batch32 g2
= QuantizerGrab(input2
, quant_mult
);
877 batch32 g3
= QuantizerGrab(input3
, quant_mult
);
878 // Pack 32-bit to 16-bit.
879 batch16 packed0
= deinterleave(g0
, g1
);
880 batch16 packed1
= deinterleave(g2
, g3
);
881 // Pack 16-bit to 8-bit.
882 batch8 packed
= deinterleave(packed0
, packed1
);
884 packed
= xsimd::max(packed
, neg127
); // Could be removed if we use +128
885 packed
= packed
+ pos127
;
886 if (batch32::size
== 4)
887 return xsimd::bitwise_cast
<uint8_t>(packed
);
888 // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14
889 // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7
890 // Technically this could be removed so long as the rows are bigger than 16
891 // and the values are only used for GEMM.
892 return xsimd::bitwise_cast
<uint8_t>(
893 xsimd::swizzle(xsimd::bitwise_cast
<int32_t>(packed
),
894 xsimd::make_batch_constant
<uint32_t, Arch
, Tiler
<Arch
>>()));
898 template <class Arch
>
899 inline void Transpose16InLane(
900 xsimd::batch
<int8_t, Arch
> &r0
, xsimd::batch
<int8_t, Arch
> &r1
,
901 xsimd::batch
<int8_t, Arch
> &r2
, xsimd::batch
<int8_t, Arch
> &r3
,
902 xsimd::batch
<int8_t, Arch
> &r4
, xsimd::batch
<int8_t, Arch
> &r5
,
903 xsimd::batch
<int8_t, Arch
> &r6
, xsimd::batch
<int8_t, Arch
> &r7
) {
904 /* r0: columns 0 1 2 3 4 5 6 7 from row 0
905 r1: columns 0 1 2 3 4 5 6 7 from row 1*/
906 auto r0_16
= xsimd::bitwise_cast
<int16_t>(r0
);
907 auto r1_16
= xsimd::bitwise_cast
<int16_t>(r1
);
908 auto r2_16
= xsimd::bitwise_cast
<int16_t>(r2
);
909 auto r3_16
= xsimd::bitwise_cast
<int16_t>(r3
);
910 auto r4_16
= xsimd::bitwise_cast
<int16_t>(r4
);
911 auto r5_16
= xsimd::bitwise_cast
<int16_t>(r5
);
912 auto r6_16
= xsimd::bitwise_cast
<int16_t>(r6
);
913 auto r7_16
= xsimd::bitwise_cast
<int16_t>(r7
);
915 std::tie(r0_16
, r1_16
) = interleave(r0_16
, r1_16
);
916 std::tie(r2_16
, r3_16
) = interleave(r2_16
, r3_16
);
917 std::tie(r4_16
, r5_16
) = interleave(r4_16
, r5_16
);
918 std::tie(r6_16
, r7_16
) = interleave(r6_16
, r7_16
);
919 /* r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1
920 r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1
921 r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3
922 r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3
923 r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5
924 r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5
925 r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7
926 r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7*/
927 auto r0_32
= xsimd::bitwise_cast
<int32_t>(r0_16
);
928 auto r2_32
= xsimd::bitwise_cast
<int32_t>(r2_16
);
929 auto r1_32
= xsimd::bitwise_cast
<int32_t>(r1_16
);
930 auto r3_32
= xsimd::bitwise_cast
<int32_t>(r3_16
);
931 auto r4_32
= xsimd::bitwise_cast
<int32_t>(r4_16
);
932 auto r6_32
= xsimd::bitwise_cast
<int32_t>(r6_16
);
933 auto r5_32
= xsimd::bitwise_cast
<int32_t>(r5_16
);
934 auto r7_32
= xsimd::bitwise_cast
<int32_t>(r7_16
);
936 std::tie(r0_32
, r2_32
) = interleave(r0_32
, r2_32
);
937 std::tie(r1_32
, r3_32
) = interleave(r1_32
, r3_32
);
938 std::tie(r4_32
, r6_32
) = interleave(r4_32
, r6_32
);
939 std::tie(r5_32
, r7_32
) = interleave(r5_32
, r7_32
);
940 /* r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3
941 r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3
942 r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3
943 r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3
944 r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7
945 r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7
946 r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7
947 r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7*/
949 auto r0_64
= xsimd::bitwise_cast
<int64_t>(r0_32
);
950 auto r2_64
= xsimd::bitwise_cast
<int64_t>(r2_32
);
951 auto r1_64
= xsimd::bitwise_cast
<int64_t>(r1_32
);
952 auto r3_64
= xsimd::bitwise_cast
<int64_t>(r3_32
);
953 auto r4_64
= xsimd::bitwise_cast
<int64_t>(r4_32
);
954 auto r6_64
= xsimd::bitwise_cast
<int64_t>(r6_32
);
955 auto r5_64
= xsimd::bitwise_cast
<int64_t>(r5_32
);
956 auto r7_64
= xsimd::bitwise_cast
<int64_t>(r7_32
);
958 std::tie(r0_64
, r4_64
) = interleave(r0_64
, r4_64
);
959 std::tie(r1_64
, r5_64
) = interleave(r1_64
, r5_64
);
960 std::tie(r2_64
, r6_64
) = interleave(r2_64
, r6_64
);
961 std::tie(r3_64
, r7_64
) = interleave(r3_64
, r7_64
);
963 r0
= xsimd::bitwise_cast
<int8_t>(r0_64
);
964 r1
= xsimd::bitwise_cast
<int8_t>(r1_64
);
965 r2
= xsimd::bitwise_cast
<int8_t>(r2_64
);
966 r3
= xsimd::bitwise_cast
<int8_t>(r3_64
);
967 r4
= xsimd::bitwise_cast
<int8_t>(r4_64
);
968 r5
= xsimd::bitwise_cast
<int8_t>(r5_64
);
969 r6
= xsimd::bitwise_cast
<int8_t>(r6_64
);
970 r7
= xsimd::bitwise_cast
<int8_t>(r7_64
);
971 /* r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7
972 r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7
973 r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7
974 r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7
975 r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7
976 r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7*/
977 /* Empirically gcc is able to remove these movs and just rename the outputs of
983 template <class Arch
, typename IntegerTy
>
984 void SelectColumnsOfB(const xsimd::batch
<int8_t, Arch
> *input
,
985 xsimd::batch
<int8_t, Arch
> *output
,
986 size_t rows_bytes
/* number of bytes in a row */,
987 const IntegerTy
*cols_begin
, const IntegerTy
*cols_end
) {
988 using batch8
= xsimd::batch
<int8_t, Arch
>;
989 /* Do columns for multiples of 8.*/
990 size_t register_rows
= rows_bytes
/ batch8::size
;
991 const batch8
*starts
[8];
992 for (; cols_begin
!= cols_end
; cols_begin
+= 8) {
993 for (size_t k
= 0; k
< 8; ++k
) {
995 input
+ (cols_begin
[k
] & 7) + (cols_begin
[k
] & ~7) * register_rows
;
997 for (size_t r
= 0; r
< register_rows
; ++r
) {
998 for (size_t k
= 0; k
< 8; ++k
) {
999 *(output
++) = *starts
[k
];
1008 namespace callbacks
{
1009 template <class Arch
>
1010 xsimd::batch
<float, Arch
> Unquantize::operator()(xsimd::batch
<int32_t, Arch
> total
, size_t, size_t,
1012 return xsimd::batch_cast
<float>(total
) * unquant_mult
;
1015 template <class Arch
>
1016 std::tuple
<xsimd::batch
<float, Arch
>, xsimd::batch
<float, Arch
>> Unquantize::operator()(
1017 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>> total
,
1018 size_t, size_t, size_t) {
1019 return std::make_tuple(
1020 xsimd::batch_cast
<float>(std::get
<0>(total
)) * unquant_mult
,
1021 xsimd::batch_cast
<float>(std::get
<1>(total
)) * unquant_mult
);
1024 template <class Arch
>
1025 xsimd::batch
<float, Arch
> AddBias::operator()(xsimd::batch
<float, Arch
> total
, size_t,
1026 size_t col_idx
, size_t) {
1027 return total
+ xsimd::batch
<float, Arch
>::load_aligned(bias_addr
+ col_idx
);
1030 template <class Arch
>
1031 std::tuple
<xsimd::batch
<float, Arch
>, xsimd::batch
<float, Arch
>>
1032 AddBias::operator()(
1033 std::tuple
<xsimd::batch
<float, Arch
>, xsimd::batch
<float, Arch
>> total
,
1034 size_t, size_t col_idx
, size_t) {
1035 return std::make_tuple(
1036 std::get
<0>(total
) + xsimd::batch
<float, Arch
>::load_aligned(bias_addr
+ col_idx
+ 0),
1037 std::get
<1>(total
) +
1038 xsimd::batch
<float, Arch
>::load_aligned(bias_addr
+ col_idx
+
1039 xsimd::batch
<float, Arch
>::size
));
1042 template <class Arch
>
1043 void Write::operator()(xsimd::batch
<float, Arch
> result
, size_t row_idx
,
1044 size_t col_idx
, size_t col_size
) {
1045 result
.store_aligned(output_addr
+ row_idx
* col_size
+ col_idx
);
1048 template <class Arch
>
1049 void Write::operator()(xsimd::batch
<int32_t, Arch
> result
, size_t row_idx
,
1050 size_t col_idx
, size_t col_size
) {
1051 xsimd::bitwise_cast
<float>(result
).store_aligned(
1052 output_addr
+ row_idx
* col_size
+ col_idx
);
1055 template <class Arch
>
1056 void Write::operator()(
1057 std::tuple
<xsimd::batch
<float, Arch
>, xsimd::batch
<float, Arch
>> result
,
1058 size_t row_idx
, size_t col_idx
, size_t col_size
) {
1059 std::get
<0>(result
).store_aligned(output_addr
+ row_idx
* col_size
+ col_idx
+
1061 std::get
<1>(result
).store_aligned(output_addr
+ row_idx
* col_size
+ col_idx
+
1062 xsimd::batch
<float, Arch
>::size
);
1065 template <class Arch
>
1066 void Write::operator()(
1067 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>> result
,
1068 size_t row_idx
, size_t col_idx
, size_t col_size
) {
1069 xsimd::bitwise_cast
<float>(std::get
<0>(result
))
1070 .store_aligned(output_addr
+ row_idx
* col_size
+ col_idx
+ 0);
1071 xsimd::bitwise_cast
<float>(std::get
<1>(result
))
1072 .store_aligned(output_addr
+ row_idx
* col_size
+ col_idx
+
1073 xsimd::batch
<int32_t, Arch
>::size
);
1077 void UnquantizeAndWrite::operator()(T
const &total
, size_t row_idx
,
1078 size_t col_idx
, size_t col_size
) {
1079 auto unquantized
= unquantize(total
, row_idx
, col_idx
, col_size
);
1080 write(unquantized
, row_idx
, col_idx
, col_size
);
1084 void UnquantizeAndAddBiasAndWrite::operator()(T
const &total
, size_t row_idx
,
1085 size_t col_idx
, size_t col_size
) {
1086 auto unquantized
= unquantize(total
, row_idx
, col_idx
, col_size
);
1087 auto bias_added
= add_bias(unquantized
, row_idx
, col_idx
, col_size
);
1088 write(bias_added
, row_idx
, col_idx
, col_size
);
1090 } // namespace callbacks
1092 template <class Arch
>
1093 void Engine
<Arch
>::QuantizeU(const float *input
, uint8_t *output
,
1094 float quant_mult
, size_t size
) {
1095 using batch8
= xsimd::batch
<int8_t, Arch
>;
1097 xsimd::batch
<float, Arch
> q(quant_mult
);
1098 const float *end
= input
+ size
;
1099 for (; input
!= end
; input
+= batch8::size
, output
+= batch8::size
) {
1100 auto tile
= QuantizeTile8::ConsecutiveU(q
, input
);
1101 tile
.store_aligned(output
);
1105 template <class Arch
>
1106 void Engine
<Arch
>::Quantize(const float *const input
, int8_t *const output
,
1107 float quant_mult
, size_t size
) {
1108 using batch8
= xsimd::batch
<int8_t, Arch
>;
1110 const std::size_t kBatch
= batch8::size
;
1111 const std::size_t fast_end
= size
& ~(kBatch
- 1);
1113 xsimd::batch
<float, Arch
> q(quant_mult
);
1114 for (std::size_t i
= 0; i
< fast_end
; i
+= kBatch
) {
1115 auto tile
= QuantizeTile8::Consecutive(q
, input
+ i
);
1116 tile
.store_aligned(output
+ i
);
1119 std::size_t overhang
= size
& (kBatch
- 1);
1122 /* Each does size(xsimd::batch<int8_t, Arch>) / 32 == kBatch / 4 floats at a
1123 * time. If we're allowed to read one of them, then we can read the whole
1126 const float *inputs
[4];
1128 for (i
= 0; i
< (overhang
+ (kBatch
/ 4) - 1) / (kBatch
/ 4); ++i
) {
1129 inputs
[i
] = &input
[fast_end
+ i
* (kBatch
/ 4)];
1131 /* These will be clipped off. */
1132 for (; i
< 4; ++i
) {
1133 inputs
[i
] = &input
[fast_end
];
1136 QuantizeTile8::Tile(q
, inputs
[0], inputs
[1], inputs
[2], inputs
[3]);
1137 std::memcpy(output
+ (size
& ~(kBatch
- 1)), &result
, overhang
);
1140 template <class Arch
>
1141 template <typename IntegerTy
>
1142 void Engine
<Arch
>::SelectColumnsB(const int8_t *input
, int8_t *output
,
1143 size_t rows
, const IntegerTy
*cols_begin
,
1144 const IntegerTy
*cols_end
) {
1145 using batch8
= xsimd::batch
<int8_t, Arch
>;
1146 SelectColumnsOfB(reinterpret_cast<const batch8
*>(input
),
1147 reinterpret_cast<batch8
*>(output
), rows
, cols_begin
,
1151 template <class Arch
>
1152 void Engine
<Arch
>::PrepareBTransposed(const float *input
, int8_t *output
,
1153 float quant_mult
, size_t cols
,
1155 using batch8
= xsimd::batch
<int8_t, Arch
>;
1156 const size_t RegisterElemsInt
= batch8::size
;
1157 const size_t kColStride
= 8;
1159 xsimd::batch
<float, Arch
> q(quant_mult
);
1160 auto *output_it
= reinterpret_cast<batch8
*>(output
);
1164 for (size_t ri
= 0; ri
< 8; ++ri
)
1165 *output_it
++ = QuantizeTile8::ConsecutiveWithWrapping(
1166 q
, input
+ (r
+ ri
) * cols
+ c
, cols
- c
, cols
, 8);
1167 c
+= RegisterElemsInt
;
1175 template <class Arch
>
1176 void Engine
<Arch
>::PrepareBQuantizedTransposed(const int8_t *input
,
1177 int8_t *output
, size_t cols
,
1179 using batch8
= xsimd::batch
<int8_t, Arch
>;
1180 const size_t RegisterElems
= batch8::size
;
1181 const size_t kColStride
= 8;
1183 auto *output_it
= reinterpret_cast<batch8
*>(output
);
1184 for (size_t r
= 0; r
< rows
; r
+= kColStride
)
1185 for (size_t c
= 0; c
< cols
; c
+= RegisterElems
)
1186 for (size_t ri
= 0; ri
< 8; ++ri
)
1188 *reinterpret_cast<const batch8
*>(input
+ (r
+ ri
) * cols
+ c
);
1191 template <class Arch
>
1192 void Engine
<Arch
>::PrepareB(const float *input
, int8_t *output_shadow
,
1193 float quant_mult
, size_t rows
, size_t cols
) {
1194 using batch8
= xsimd::batch
<int8_t, Arch
>;
1196 xsimd::batch
<float, Arch
> q(quant_mult
);
1197 /* Currently all multipliers have a stride of 8 columns.*/
1198 const size_t kColStride
= 8;
1199 auto *output
= reinterpret_cast<batch8
*>(output_shadow
);
1200 for (size_t c
= 0; c
< cols
; c
+= kColStride
) {
1201 for (size_t r
= 0; r
< rows
; r
+= sizeof(*output
), output
+= 8) {
1203 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 0) + c
, cols
);
1205 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 1) + c
, cols
);
1207 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 4) + c
, cols
);
1209 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 5) + c
, cols
);
1211 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 8) + c
, cols
);
1213 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 9) + c
, cols
);
1215 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 12) + c
, cols
);
1217 QuantizeTile8::ForReshape(q
, input
+ cols
* (r
+ 13) + c
, cols
);
1218 std::tie(output
[0], output
[1]) =
1219 interleave(xsimd::bitwise_cast
<int8_t>(output
[0]),
1220 xsimd::bitwise_cast
<int8_t>(output
[1]));
1221 std::tie(output
[2], output
[3]) =
1222 interleave(xsimd::bitwise_cast
<int8_t>(output
[2]),
1223 xsimd::bitwise_cast
<int8_t>(output
[3]));
1224 std::tie(output
[4], output
[5]) =
1225 interleave(xsimd::bitwise_cast
<int8_t>(output
[4]),
1226 xsimd::bitwise_cast
<int8_t>(output
[5]));
1227 std::tie(output
[6], output
[7]) =
1228 interleave(xsimd::bitwise_cast
<int8_t>(output
[6]),
1229 xsimd::bitwise_cast
<int8_t>(output
[7]));
1230 Transpose16InLane(output
[0], output
[1], output
[2], output
[3], output
[4],
1231 output
[5], output
[6], output
[7]);
1236 template <class Arch
>
1237 void Engine
<Arch
>::PrepareA(const float *input
, int8_t *output
,
1238 float quant_mult
, size_t rows
, size_t cols
) {
1239 Quantize(input
, output
, quant_mult
, rows
* cols
);
1242 template <class Arch
>
1243 void Engine
<Arch
>::Shift::PrepareA(const float *input
, uint8_t *output
,
1244 float quant_mult
, size_t rows
, size_t cols
) {
1245 QuantizeU(input
, output
, quant_mult
, rows
* cols
);
1248 template <class Arch
>
1249 template <class Callback
>
1250 void Engine
<Arch
>::Shift::Multiply(const uint8_t *A
, const int8_t *B
,
1251 size_t A_rows
, size_t width
, size_t B_cols
,
1252 Callback callback
) {
1254 using batch8
= xsimd::batch
<int8_t, Arch
>;
1255 using ubatch8
= xsimd::batch
<uint8_t, Arch
>;
1256 using batch32
= xsimd::batch
<int32_t, Arch
>;
1258 const size_t simd_width
= width
/ batch8::size
;
1259 for (size_t B0_colidx
= 0; B0_colidx
< B_cols
; B0_colidx
+= 8) {
1260 const auto *B0_col
=
1261 reinterpret_cast<const batch8
*>(B
) + simd_width
* B0_colidx
;
1262 /* Process one row of A at a time. Doesn't seem to be faster to do multiple
1263 * rows of A at once.*/
1264 for (size_t A_rowidx
= 0; A_rowidx
< A_rows
; ++A_rowidx
) {
1266 reinterpret_cast<const ubatch8
*>(A
+ A_rowidx
* width
);
1267 /* These will be packed 16-bit integers containing sums for each row of B
1268 multiplied by the row of A. Iterate over shared (inner) dimension.*/
1269 /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is
1272 ubatch8 a
= *(A_row
+ k
);
1273 batch32 isum0
= maddw(a
, *(B0_col
+ k
* 8));
1274 batch32 isum1
= maddw(a
, *(B0_col
+ k
* 8 + 1));
1275 batch32 isum2
= maddw(a
, *(B0_col
+ k
* 8 + 2));
1276 batch32 isum3
= maddw(a
, *(B0_col
+ k
* 8 + 3));
1277 batch32 isum4
= maddw(a
, *(B0_col
+ k
* 8 + 4));
1278 batch32 isum5
= maddw(a
, *(B0_col
+ k
* 8 + 5));
1279 batch32 isum6
= maddw(a
, *(B0_col
+ k
* 8 + 6));
1280 batch32 isum7
= maddw(a
, *(B0_col
+ k
* 8 + 7));
1281 for (k
= 1; k
< simd_width
; ++k
) {
1283 /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/
1284 /* Upcast to 32-bit and horizontally add.*/
1285 isum0
= maddw(a
, *(B0_col
+ k
* 8 + 0), isum0
);
1286 isum1
= maddw(a
, *(B0_col
+ k
* 8 + 1), isum1
);
1287 isum2
= maddw(a
, *(B0_col
+ k
* 8 + 2), isum2
);
1288 isum3
= maddw(a
, *(B0_col
+ k
* 8 + 3), isum3
);
1289 isum4
= maddw(a
, *(B0_col
+ k
* 8 + 4), isum4
);
1290 isum5
= maddw(a
, *(B0_col
+ k
* 8 + 5), isum5
);
1291 isum6
= maddw(a
, *(B0_col
+ k
* 8 + 6), isum6
);
1292 isum7
= maddw(a
, *(B0_col
+ k
* 8 + 7), isum7
);
1294 /* Reduce sums within 128-bit lanes.*/
1295 auto pack0123
= Pack0123(isum0
, isum1
, isum2
, isum3
);
1296 auto pack4567
= Pack0123(isum4
, isum5
, isum6
, isum7
);
1297 /*The specific implementation may need to reduce further.*/
1298 auto total
= PermuteSummer(pack0123
, pack4567
);
1299 callback(total
, A_rowidx
, B0_colidx
, B_cols
);
1304 template <class Arch
>
1305 template <class Callback
>
1306 void Engine
<Arch
>::Shift::PrepareBias(const int8_t *B
, size_t width
,
1307 size_t B_cols
, Callback C
) {
1308 using batch8
= xsimd::batch
<int8_t, Arch
>;
1309 const size_t simd_width
= width
/ batch8::size
;
1310 xsimd::batch
<uint8_t, Arch
> a(1);
1311 for (size_t j
= 0; j
< B_cols
; j
+= 8) {
1312 /*Process one row of A at a time. Doesn't seem to be faster to do multiple
1313 * rows of A at once.*/
1314 const int8_t *B_j
= B
+ j
* width
;
1316 /* Rather than initializing as zeros and adding, just initialize the
1318 /* These will be packed 16-bit integers containing sums for each column of
1319 * B multiplied by the row of A.*/
1320 /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is
1322 auto isum0
= maddw(a
, batch8::load_aligned(&B_j
[0 * batch8::size
]));
1323 auto isum1
= maddw(a
, batch8::load_aligned(&B_j
[1 * batch8::size
]));
1324 auto isum2
= maddw(a
, batch8::load_aligned(&B_j
[2 * batch8::size
]));
1325 auto isum3
= maddw(a
, batch8::load_aligned(&B_j
[3 * batch8::size
]));
1326 auto isum4
= maddw(a
, batch8::load_aligned(&B_j
[4 * batch8::size
]));
1327 auto isum5
= maddw(a
, batch8::load_aligned(&B_j
[5 * batch8::size
]));
1328 auto isum6
= maddw(a
, batch8::load_aligned(&B_j
[6 * batch8::size
]));
1329 auto isum7
= maddw(a
, batch8::load_aligned(&B_j
[7 * batch8::size
]));
1331 B_j
+= 8 * batch8::size
;
1333 for (size_t k
= 1; k
< simd_width
; ++k
, B_j
+= 8 * batch8::size
) {
1334 isum0
= maddw(a
, batch8::load_aligned(&B_j
[0 * batch8::size
]), isum0
);
1335 isum1
= maddw(a
, batch8::load_aligned(&B_j
[1 * batch8::size
]), isum1
);
1336 isum2
= maddw(a
, batch8::load_aligned(&B_j
[2 * batch8::size
]), isum2
);
1337 isum3
= maddw(a
, batch8::load_aligned(&B_j
[3 * batch8::size
]), isum3
);
1338 isum4
= maddw(a
, batch8::load_aligned(&B_j
[4 * batch8::size
]), isum4
);
1339 isum5
= maddw(a
, batch8::load_aligned(&B_j
[5 * batch8::size
]), isum5
);
1340 isum6
= maddw(a
, batch8::load_aligned(&B_j
[6 * batch8::size
]), isum6
);
1341 isum7
= maddw(a
, batch8::load_aligned(&B_j
[7 * batch8::size
]), isum7
);
1344 auto pack0123
= Pack0123(isum0
, isum1
, isum2
, isum3
);
1345 auto pack4567
= Pack0123(isum4
, isum5
, isum6
, isum7
);
1347 auto total
= PermuteSummer(pack0123
, pack4567
);
1348 C(total
, 0, j
, B_cols
);
1352 } // namespace gemmology