Backed out 2 changesets (bug 903746) for causing non-unified build bustages on nsIPri...
[gecko.git] / third_party / gemmology / gemmology.h
blobeb5ebed3b4598f157b15128266dddbb44fd9ece8
1 #ifndef GEMMOLOGY_H
2 #define GEMMOLOGY_H
4 #include "gemmology_fwd.h"
6 #include <cstdint>
7 #include <cstring>
8 #include <tuple>
10 #include <xsimd/xsimd.hpp>
12 namespace gemmology {
14 namespace {
17 // Arch specific implementation of various elementary operations
20 namespace kernel {
22 #ifdef __AVX512BW__
23 template <class Arch>
24 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
25 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
26 xsimd::kernel::requires_arch<xsimd::avx512bw>) {
27 return {_mm512_unpacklo_epi8(first, second),
28 _mm512_unpackhi_epi8(first, second)};
31 template <class Arch>
32 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
33 interleave(xsimd::batch<int16_t, Arch> first,
34 xsimd::batch<int16_t, Arch> second,
35 xsimd::kernel::requires_arch<xsimd::avx512bw>) {
36 return {_mm512_unpacklo_epi16(first, second),
37 _mm512_unpackhi_epi16(first, second)};
40 template <class Arch>
41 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
42 interleave(xsimd::batch<int32_t, Arch> first,
43 xsimd::batch<int32_t, Arch> second,
44 xsimd::kernel::requires_arch<xsimd::avx512bw>) {
45 return {_mm512_unpacklo_epi32(first, second),
46 _mm512_unpackhi_epi32(first, second)};
49 template <class Arch>
50 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
51 interleave(xsimd::batch<int64_t, Arch> first,
52 xsimd::batch<int64_t, Arch> second,
53 xsimd::kernel::requires_arch<xsimd::avx512bw>) {
54 return {_mm512_unpacklo_epi64(first, second),
55 _mm512_unpackhi_epi64(first, second)};
58 template <class Arch>
59 xsimd::batch<int8_t, Arch>
60 deinterleave(xsimd::batch<int16_t, Arch> first,
61 xsimd::batch<int16_t, Arch> second,
62 xsimd::kernel::requires_arch<xsimd::avx512bw>) {
63 return _mm512_packs_epi16(first, second);
66 template <class Arch>
67 xsimd::batch<int16_t, Arch>
68 deinterleave(xsimd::batch<int32_t, Arch> first,
69 xsimd::batch<int32_t, Arch> second,
70 xsimd::kernel::requires_arch<xsimd::avx512bw>) {
71 return _mm512_packs_epi32(first, second);
74 template <class Arch>
75 inline xsimd::batch<int32_t, Arch>
76 madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
77 xsimd::kernel::requires_arch<xsimd::avx512bw>) {
78 return _mm512_madd_epi16(x, y);
81 template <class Arch>
82 inline xsimd::batch<int16_t, Arch>
83 madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
84 xsimd::kernel::requires_arch<xsimd::avx512bw>) {
85 return _mm512_maddubs_epi16(x, y);
88 template <class Arch>
89 inline xsimd::batch<int16_t, Arch>
90 madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
91 xsimd::kernel::requires_arch<xsimd::avx512bw>) {
92 return _mm512_madd_epi16(x, y);
95 template <class Arch>
96 inline xsimd::batch<int32_t, xsimd::avx2>
97 PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
98 xsimd::batch<int32_t, Arch> pack4567,
99 xsimd::kernel::requires_arch<xsimd::avx512bw>) {
100 // Form [0th 128-bit register of pack0123, 0st 128-bit register of pack4567,
101 // 2nd 128-bit register of pack0123, 2nd 128-bit register of pack4567]
102 __m512i mix0 =
103 _mm512_mask_permutex_epi64(pack0123, 0xcc, pack4567, (0 << 4) | (1 << 6));
104 // Form [1st 128-bit register of pack0123, 1st 128-bit register of pack4567,
105 // 3rd 128-bit register of pack0123, 3rd 128-bit register of pack4567]
106 __m512i mix1 =
107 _mm512_mask_permutex_epi64(pack4567, 0x33, pack0123, 2 | (3 << 2));
108 __m512i added = _mm512_add_epi32(mix0, mix1);
109 // Now we have 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7.
110 // Fold register over itself.
111 return _mm256_add_epi32(_mm512_castsi512_si256(added),
112 _mm512_extracti64x4_epi64(added, 1));
114 #endif
116 #ifdef __AVX2__
117 template <class Arch>
118 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
119 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
120 xsimd::kernel::requires_arch<xsimd::avx2>) {
121 return {_mm256_unpacklo_epi8(first, second),
122 _mm256_unpackhi_epi8(first, second)};
125 template <class Arch>
126 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
127 interleave(xsimd::batch<int16_t, Arch> first,
128 xsimd::batch<int16_t, Arch> second,
129 xsimd::kernel::requires_arch<xsimd::avx2>) {
130 return {_mm256_unpacklo_epi16(first, second),
131 _mm256_unpackhi_epi16(first, second)};
134 template <class Arch>
135 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
136 interleave(xsimd::batch<int32_t, Arch> first,
137 xsimd::batch<int32_t, Arch> second,
138 xsimd::kernel::requires_arch<xsimd::avx2>) {
139 return {_mm256_unpacklo_epi32(first, second),
140 _mm256_unpackhi_epi32(first, second)};
143 template <class Arch>
144 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
145 interleave(xsimd::batch<int64_t, Arch> first,
146 xsimd::batch<int64_t, Arch> second,
147 xsimd::kernel::requires_arch<xsimd::avx2>) {
148 return {_mm256_unpacklo_epi64(first, second),
149 _mm256_unpackhi_epi64(first, second)};
152 template <class Arch>
153 xsimd::batch<int8_t, Arch>
154 deinterleave(xsimd::batch<int16_t, Arch> first,
155 xsimd::batch<int16_t, Arch> second,
156 xsimd::kernel::requires_arch<xsimd::avx2>) {
157 return _mm256_packs_epi16(first, second);
160 template <class Arch>
161 xsimd::batch<int16_t, Arch>
162 deinterleave(xsimd::batch<int32_t, Arch> first,
163 xsimd::batch<int32_t, Arch> second,
164 xsimd::kernel::requires_arch<xsimd::avx2>) {
165 return _mm256_packs_epi32(first, second);
168 template <class Arch>
169 inline xsimd::batch<int32_t, Arch>
170 madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
171 xsimd::kernel::requires_arch<xsimd::avx2>) {
172 return _mm256_madd_epi16(x, y);
175 template <class Arch>
176 inline xsimd::batch<int16_t, Arch>
177 madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
178 xsimd::kernel::requires_arch<xsimd::avx2>) {
179 return _mm256_maddubs_epi16(x, y);
182 template <class Arch>
183 inline xsimd::batch<int16_t, Arch>
184 madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
185 xsimd::kernel::requires_arch<xsimd::avx2>) {
186 return _mm256_maddubs_epi16(xsimd::abs(x), _mm256_sign_epi8(y, x));
189 template <class Arch>
190 inline xsimd::batch<int32_t, Arch>
191 PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
192 xsimd::batch<int32_t, Arch> pack4567,
193 xsimd::kernel::requires_arch<xsimd::avx2>) {
194 // This instruction generates 1s 2s 3s 4s 5f 6f 7f 8f
195 __m256i rev = _mm256_permute2f128_si256(pack0123, pack4567, 0x21);
196 // This instruction generates 1f 2f 3f 4f 5s 6s 7s 8s
197 __m256i blended = _mm256_blend_epi32(pack0123, pack4567, 0xf0);
198 return _mm256_add_epi32(rev, blended);
201 template <class Arch>
202 inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
203 xsimd::batch<int32_t, Arch> sum1,
204 xsimd::batch<int32_t, Arch> sum2,
205 xsimd::batch<int32_t, Arch> sum3,
206 xsimd::kernel::requires_arch<xsimd::avx2>) {
207 auto pack01 = _mm256_hadd_epi32(sum0, sum1);
208 auto pack23 = _mm256_hadd_epi32(sum2, sum3);
209 return _mm256_hadd_epi32(pack01, pack23);
212 #ifdef __AVXVNNI__
214 template <class Arch>
215 inline xsimd::batch<int32_t, Arch>
216 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
217 xsimd::batch<int32_t, Arch> z,
218 xsimd::kernel::requires_arch<xsimd::avxvnni>) {
219 return _mm256_dpbusd_avx_epi32(z, x, y);
221 #endif
223 #ifdef __AVX512VNNI__
225 template <class Arch>
226 inline xsimd::batch<int32_t, Arch>
227 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
228 xsimd::batch<int32_t, Arch> z,
229 xsimd::kernel::requires_arch<xsimd::avx512vnni<xsimd::avx512bw>>) {
230 return _mm512_dpbusd_epi32(z, x, y);
233 template <class Arch>
234 inline xsimd::batch<int32_t, Arch>
235 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
236 xsimd::batch<int32_t, Arch> z,
237 xsimd::kernel::requires_arch<xsimd::avx512vnni<xsimd::avx512vbmi>>) {
238 return _mm512_dpbusd_epi32(z, x, y);
240 #endif
242 #endif
244 #ifdef __SSSE3__
246 template <class Arch>
247 inline xsimd::batch<int16_t, Arch>
248 madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
249 xsimd::kernel::requires_arch<xsimd::ssse3>) {
250 return _mm_maddubs_epi16(x, y);
253 template <class Arch>
254 inline xsimd::batch<int16_t, Arch>
255 madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
256 xsimd::kernel::requires_arch<xsimd::ssse3>) {
257 return _mm_maddubs_epi16(xsimd::abs(x), _mm_sign_epi8(y, x));
260 template <class Arch>
261 inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
262 xsimd::batch<int32_t, Arch> sum1,
263 xsimd::batch<int32_t, Arch> sum2,
264 xsimd::batch<int32_t, Arch> sum3,
265 xsimd::kernel::requires_arch<xsimd::ssse3>) {
266 auto pack01 = _mm_hadd_epi32(sum0, sum1);
267 auto pack23 = _mm_hadd_epi32(sum2, sum3);
268 return _mm_hadd_epi32(pack01, pack23);
270 #endif
272 #ifdef __SSE2__
273 template <class Arch>
274 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
275 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
276 xsimd::kernel::requires_arch<xsimd::sse2>) {
277 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
280 template <class Arch>
281 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
282 interleave(xsimd::batch<int16_t, Arch> first,
283 xsimd::batch<int16_t, Arch> second,
284 xsimd::kernel::requires_arch<xsimd::sse2>) {
285 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
288 template <class Arch>
289 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
290 interleave(xsimd::batch<int32_t, Arch> first,
291 xsimd::batch<int32_t, Arch> second,
292 xsimd::kernel::requires_arch<xsimd::sse2>) {
293 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
296 template <class Arch>
297 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
298 interleave(xsimd::batch<int64_t, Arch> first,
299 xsimd::batch<int64_t, Arch> second,
300 xsimd::kernel::requires_arch<xsimd::sse2>) {
301 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
304 template <class Arch>
305 xsimd::batch<int8_t, Arch>
306 deinterleave(xsimd::batch<int16_t, Arch> first,
307 xsimd::batch<int16_t, Arch> second,
308 xsimd::kernel::requires_arch<xsimd::sse2>) {
309 return _mm_packs_epi16(first, second);
312 template <class Arch>
313 xsimd::batch<int16_t, Arch>
314 deinterleave(xsimd::batch<int32_t, Arch> first,
315 xsimd::batch<int32_t, Arch> second,
316 xsimd::kernel::requires_arch<xsimd::sse2>) {
317 return _mm_packs_epi32(first, second);
320 template <class Arch>
321 inline xsimd::batch<int32_t, Arch>
322 madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
323 xsimd::kernel::requires_arch<xsimd::sse2>) {
324 return _mm_madd_epi16(x, y);
327 template <class Arch>
328 inline xsimd::batch<int16_t, Arch>
329 madd(xsimd::batch<uint8_t, Arch> a, xsimd::batch<int8_t, Arch> b,
330 xsimd::kernel::requires_arch<xsimd::sse2>) {
331 // Adapted from
332 // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2
333 // a = 0x00 0x01 0xFE 0x04 ...
334 // b = 0x00 0x02 0x80 0x84 ...
336 // To extend signed 8-bit value, MSB has to be set to 0xFF
337 __m128i sign_mask_b = _mm_cmplt_epi8(b, _mm_setzero_si128());
339 // sign_mask_b = 0x00 0x00 0xFF 0xFF ...
341 // Unpack positives with 0x00, negatives with 0xFF
342 __m128i a_epi16_l = _mm_unpacklo_epi8(a, _mm_setzero_si128());
343 __m128i a_epi16_h = _mm_unpackhi_epi8(a, _mm_setzero_si128());
344 __m128i b_epi16_l = _mm_unpacklo_epi8(b, sign_mask_b);
345 __m128i b_epi16_h = _mm_unpackhi_epi8(b, sign_mask_b);
347 // Here - valid 16-bit signed integers corresponding to the 8-bit input
348 // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ...
350 // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts
351 __m128i madd_epi32_l = _mm_madd_epi16(a_epi16_l, b_epi16_l);
352 __m128i madd_epi32_h = _mm_madd_epi16(a_epi16_h, b_epi16_h);
354 // Now go back from 32-bit values to 16-bit values & signed saturate
355 return _mm_packs_epi32(madd_epi32_l, madd_epi32_h);
358 template <class Arch>
359 inline xsimd::batch<int16_t, Arch>
360 madd(xsimd::batch<int8_t, Arch> a, xsimd::batch<int8_t, Arch> b,
361 xsimd::kernel::requires_arch<xsimd::sse2>) {
362 // adapted
363 // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2
364 // a = 0x00 0x01 0xFE 0x04 ...
365 // b = 0x00 0x02 0x80 0x84 ...
367 // To extend signed 8-bit value, MSB has to be set to 0xFF
368 __m128i sign_mask_a = _mm_cmplt_epi8(a, _mm_setzero_si128());
369 __m128i sign_mask_b = _mm_cmplt_epi8(b, _mm_setzero_si128());
371 // sign_mask_a = 0x00 0x00 0xFF 0x00 ...
372 // sign_mask_b = 0x00 0x00 0xFF 0xFF ...
374 // Unpack positives with 0x00, negatives with 0xFF
375 __m128i a_epi16_l = _mm_unpacklo_epi8(a, sign_mask_a);
376 __m128i a_epi16_h = _mm_unpackhi_epi8(a, sign_mask_a);
377 __m128i b_epi16_l = _mm_unpacklo_epi8(b, sign_mask_b);
378 __m128i b_epi16_h = _mm_unpackhi_epi8(b, sign_mask_b);
380 // Here - valid 16-bit signed integers corresponding to the 8-bit input
381 // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ...
383 // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts
384 __m128i madd_epi32_l = _mm_madd_epi16(a_epi16_l, b_epi16_l);
385 __m128i madd_epi32_h = _mm_madd_epi16(a_epi16_h, b_epi16_h);
387 // Now go back from 32-bit values to 16-bit values & signed saturate
388 return _mm_packs_epi32(madd_epi32_l, madd_epi32_h);
391 template <class Arch>
392 inline std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
393 PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
394 xsimd::batch<int32_t, Arch> pack4567,
395 xsimd::kernel::requires_arch<xsimd::sse2>) {
396 return {pack0123, pack4567};
399 #endif
401 #if __ARM_ARCH >= 7
402 template <class Arch>
403 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
404 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
405 xsimd::kernel::requires_arch<xsimd::neon>) {
406 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
409 template <class Arch>
410 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
411 interleave(xsimd::batch<int16_t, Arch> first,
412 xsimd::batch<int16_t, Arch> second,
413 xsimd::kernel::requires_arch<xsimd::neon>) {
414 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
417 template <class Arch>
418 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
419 interleave(xsimd::batch<int32_t, Arch> first,
420 xsimd::batch<int32_t, Arch> second,
421 xsimd::kernel::requires_arch<xsimd::neon>) {
422 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
425 template <class Arch>
426 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
427 interleave(xsimd::batch<int64_t, Arch> first,
428 xsimd::batch<int64_t, Arch> second,
429 xsimd::kernel::requires_arch<xsimd::neon>) {
430 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
433 template <class Arch>
434 xsimd::batch<int8_t, Arch>
435 deinterleave(xsimd::batch<int16_t, Arch> first,
436 xsimd::batch<int16_t, Arch> second,
437 xsimd::kernel::requires_arch<xsimd::neon>) {
439 return vcombine_s8(vqmovn_s16(first), vqmovn_s16(second));
442 template <class Arch>
443 xsimd::batch<int16_t, Arch>
444 deinterleave(xsimd::batch<int32_t, Arch> first,
445 xsimd::batch<int32_t, Arch> second,
446 xsimd::kernel::requires_arch<xsimd::neon>) {
447 return vcombine_s16(vqmovn_s32(first), vqmovn_s32(second));
450 template <class Arch>
451 inline xsimd::batch<int32_t, Arch>
452 madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
453 xsimd::kernel::requires_arch<xsimd::neon>) {
455 int32x4_t low = vmull_s16(vget_low_s16(x), vget_low_s16(y));
456 int32x4_t high = vmull_s16(vget_high_s16(x), vget_high_s16(y));
458 int32x2_t low_sum = vpadd_s32(vget_low_s32(low), vget_high_s32(low));
459 int32x2_t high_sum = vpadd_s32(vget_low_s32(high), vget_high_s32(high));
461 return vcombine_s32(low_sum, high_sum);
464 template <class Arch>
465 inline xsimd::batch<int16_t, Arch>
466 madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
467 xsimd::kernel::requires_arch<xsimd::neon>) {
469 // This would be much simpler if x86 would choose to zero extend OR sign
470 // extend, not both. This could probably be optimized better.
472 // Zero extend x
473 int16x8_t x_odd =
474 vreinterpretq_s16_u16(vshrq_n_u16(vreinterpretq_u16_u8(x), 8));
475 int16x8_t x_even = vreinterpretq_s16_u16(
476 vbicq_u16(vreinterpretq_u16_u8(x), vdupq_n_u16(0xff00)));
478 // Sign extend by shifting left then shifting right.
479 int16x8_t y_even = vshrq_n_s16(vshlq_n_s16(vreinterpretq_s16_s8(y), 8), 8);
480 int16x8_t y_odd = vshrq_n_s16(vreinterpretq_s16_s8(y), 8);
482 // multiply
483 int16x8_t prod1 = vmulq_s16(x_even, y_even);
484 int16x8_t prod2 = vmulq_s16(x_odd, y_odd);
486 // saturated add
487 return vqaddq_s16(prod1, prod2);
490 template <class Arch>
491 inline xsimd::batch<int16_t, Arch>
492 madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
493 xsimd::kernel::requires_arch<xsimd::neon>) {
494 int16x8_t low = vmull_s8(vget_low_s8(x), vget_low_s8(y));
495 int16x8_t high = vmull_s8(vget_high_s8(x), vget_high_s8(y));
497 int16x4_t low_sum = vpadd_s16(vget_low_s16(low), vget_high_s16(low));
498 int16x4_t high_sum = vpadd_s16(vget_low_s16(high), vget_high_s16(high));
500 return vcombine_s16(low_sum, high_sum);
503 template <class Arch>
504 inline std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
505 PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
506 xsimd::batch<int32_t, Arch> pack4567,
507 xsimd::kernel::requires_arch<xsimd::neon>) {
508 return {pack0123, pack4567};
510 #endif
512 #ifdef __aarch64__
513 template <class Arch>
514 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
515 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
516 xsimd::kernel::requires_arch<xsimd::neon64>) {
517 return {vzip1q_s8(first, second), vzip2q_s8(first, second)};
520 template <class Arch>
521 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
522 interleave(xsimd::batch<int16_t, Arch> first,
523 xsimd::batch<int16_t, Arch> second,
524 xsimd::kernel::requires_arch<xsimd::neon64>) {
525 return {vzip1q_s16(first, second), vzip2q_s16(first, second)};
528 template <class Arch>
529 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
530 interleave(xsimd::batch<int32_t, Arch> first,
531 xsimd::batch<int32_t, Arch> second,
532 xsimd::kernel::requires_arch<xsimd::neon64>) {
533 return {vzip1q_s32(first, second), vzip2q_s32(first, second)};
536 template <class Arch>
537 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
538 interleave(xsimd::batch<int64_t, Arch> first,
539 xsimd::batch<int64_t, Arch> second,
540 xsimd::kernel::requires_arch<xsimd::neon64>) {
541 return {vzip1q_s64(first, second), vzip2q_s64(first, second)};
544 template <class Arch>
545 xsimd::batch<int8_t, Arch>
546 deinterleave(xsimd::batch<int16_t, Arch> first,
547 xsimd::batch<int16_t, Arch> second,
548 xsimd::kernel::requires_arch<xsimd::neon64>) {
550 return vqmovn_high_s16(vqmovn_s16(first), second);
553 template <class Arch>
554 xsimd::batch<int16_t, Arch>
555 deinterleave(xsimd::batch<int32_t, Arch> first,
556 xsimd::batch<int32_t, Arch> second,
557 xsimd::kernel::requires_arch<xsimd::neon64>) {
558 return vqmovn_high_s32(vqmovn_s32(first), second);
561 #ifdef __ARM_FEATURE_MATMUL_INT8
562 template <class Arch>
563 inline xsimd::batch<int32_t, Arch>
564 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
565 xsimd::batch<int32_t, Arch> z,
566 xsimd::kernel::requires_arch<xsimd::i8mm<xsimd::neon64>>) {
567 return vusdotq_s32(z, x, y);
569 #endif
571 template <class Arch>
572 inline xsimd::batch<int32_t, Arch>
573 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
574 xsimd::batch<int32_t, Arch> z,
575 xsimd::kernel::requires_arch<xsimd::neon64>) {
576 int16x8_t tl = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(x))),
577 vmovl_s8(vget_low_s8(y)));
578 int16x8_t th = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x))),
579 vmovl_s8(vget_high_s8(y)));
580 return vpadalq_s16(vpadalq_s16(z, tl), th);
583 template <class Arch>
584 inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
585 xsimd::batch<int32_t, Arch> sum1,
586 xsimd::batch<int32_t, Arch> sum2,
587 xsimd::batch<int32_t, Arch> sum3,
588 xsimd::kernel::requires_arch<xsimd::neon64>) {
589 auto pack01 = vpaddq_s32(sum0, sum1);
590 auto pack23 = vpaddq_s32(sum2, sum3);
591 return vpaddq_s32(pack01, pack23);
594 #endif
596 template <class Arch>
597 inline xsimd::batch<int32_t, Arch>
598 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
599 xsimd::batch<int32_t, Arch> z,
600 xsimd::kernel::requires_arch<xsimd::generic>) {
601 return z + madd(xsimd::batch<int16_t, Arch>(1), madd(x, y, Arch{}), Arch{});
604 } // namespace kernel
607 // Generic dispatcher for interleave, deinterleave madd and PermuteSummer
610 template <class T, class Arch>
611 std::tuple<xsimd::batch<T, Arch>, xsimd::batch<T, Arch>>
612 interleave(xsimd::batch<T, Arch> first, xsimd::batch<T, Arch> second) {
613 return kernel::interleave(first, second, Arch{});
616 template <class Arch>
617 xsimd::batch<int8_t, Arch> deinterleave(xsimd::batch<int16_t, Arch> first,
618 xsimd::batch<int16_t, Arch> second) {
619 return kernel::deinterleave(first, second, Arch{});
621 template <class Arch>
622 xsimd::batch<int16_t, Arch> deinterleave(xsimd::batch<int32_t, Arch> first,
623 xsimd::batch<int32_t, Arch> second) {
624 return kernel::deinterleave(first, second, Arch{});
627 template <class Arch>
628 inline xsimd::batch<int32_t, Arch> madd(xsimd::batch<int16_t, Arch> x,
629 xsimd::batch<int16_t, Arch> y) {
630 return kernel::madd(x, y, Arch{});
632 template <class Arch>
633 inline xsimd::batch<int16_t, Arch> madd(xsimd::batch<int8_t, Arch> x,
634 xsimd::batch<int8_t, Arch> y) {
635 return kernel::madd(x, y, Arch{});
637 template <class Arch>
638 inline xsimd::batch<int16_t, Arch> madd(xsimd::batch<uint8_t, Arch> x,
639 xsimd::batch<int8_t, Arch> y) {
640 return kernel::madd(x, y, Arch{});
642 template <class Arch>
643 inline xsimd::batch<int32_t, Arch> maddw(xsimd::batch<uint8_t, Arch> x,
644 xsimd::batch<int8_t, Arch> y,
645 xsimd::batch<int32_t, Arch> z
647 return kernel::maddw(x, y, z, Arch{});
649 template <class Arch>
650 inline xsimd::batch<int32_t, Arch> maddw(xsimd::batch<uint8_t, Arch> x,
651 xsimd::batch<int8_t, Arch> y
653 return maddw(x, y, xsimd::batch<int32_t, Arch>((int32_t)0));
656 template <class Arch>
657 inline auto PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
658 xsimd::batch<int32_t, Arch> pack4567)
659 -> decltype(kernel::PermuteSummer(pack0123, pack4567, Arch{})) {
660 return kernel::PermuteSummer(pack0123, pack4567, Arch{});
664 namespace kernel {
666 template <class Arch>
667 inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
668 xsimd::batch<int32_t, Arch> sum1,
669 xsimd::batch<int32_t, Arch> sum2,
670 xsimd::batch<int32_t, Arch> sum3,
671 xsimd::kernel::requires_arch<xsimd::generic>) {
673 std::tie(sum0, sum1) = interleave(sum0, sum1, Arch{});
674 auto pack01 = sum0 + sum1;
675 std::tie(sum2, sum3) = interleave(sum2, sum3, Arch{});
676 auto pack23 = sum2 + sum3;
678 auto packed = interleave(xsimd::bitwise_cast<int64_t>(pack01),
679 xsimd::bitwise_cast<int64_t>(pack23),
680 Arch{});
681 return xsimd::bitwise_cast<int32_t>(std::get<0>(packed)) +
682 xsimd::bitwise_cast<int32_t>(std::get<1>(packed));
686 template <class Arch>
687 inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
688 xsimd::batch<int32_t, Arch> sum1,
689 xsimd::batch<int32_t, Arch> sum2,
690 xsimd::batch<int32_t, Arch> sum3) {
691 return kernel::Pack0123(sum0, sum1, sum2, sum3, Arch{});
694 template <class Arch>
695 static inline xsimd::batch<int32_t, Arch>
696 quantize(xsimd::batch<float, Arch> input,
697 xsimd::batch<float, Arch> quant_mult) {
698 return xsimd::nearbyint_as_int(input * quant_mult);
701 template <class Arch>
702 inline xsimd::batch<int32_t, Arch>
703 QuantizerGrab(const float *input, xsimd::batch<float, Arch> quant_mult_reg) {
704 return quantize(xsimd::batch<float, Arch>::load_unaligned(input),
705 quant_mult_reg);
708 #ifdef __AVX512BW__
709 inline __m512 Concat(const __m256 first, const __m256 second) {
710 // INTGEMM_AVX512DQ but that goes with INTGEMM_AVX512BW anyway.
711 return _mm512_insertf32x8(_mm512_castps256_ps512(first), second, 1);
714 // Like QuantizerGrab, but allows 32-byte halves (i.e. 8 columns) to be
715 // controlled independently.
716 /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set
717 * INTGEMM_AVX512BW */
718 inline __m512i QuantizerGrabHalves(const float *input0, const float *input1,
719 const __m512 quant_mult_reg) {
720 __m512 appended = Concat(_mm256_loadu_ps(input0), _mm256_loadu_ps(input1));
721 appended = _mm512_mul_ps(appended, quant_mult_reg);
722 return _mm512_cvtps_epi32(appended);
724 #else
725 template <class Arch>
726 inline xsimd::batch<int32_t, Arch>
727 QuantizerGrabHalves(const float *input0, const float *input1,
728 xsimd::batch<float, Arch> quant_mult_reg);
729 #endif
731 /* Read 8 floats at a time from input0, input1, input2, and input3. Quantize
732 * them to 8-bit by multiplying with quant_mult_reg then rounding. Concatenate
733 * the result into one register and return it.
735 class QuantizeTile8 {
736 template <class Arch> struct Tiler {
737 static constexpr uint32_t get(std::size_t i, std::size_t n) {
738 size_t factor = xsimd::batch<float, Arch>::size / 4;
739 return (i % factor) * 4 + i / factor;
743 public:
744 template <class Arch>
745 static inline xsimd::batch<int8_t, Arch>
746 Consecutive(xsimd::batch<float, Arch> quant_mult, const float *input) {
747 return Tile(quant_mult, input + 0 * xsimd::batch<float, Arch>::size,
748 input + 1 * xsimd::batch<float, Arch>::size,
749 input + 2 * xsimd::batch<float, Arch>::size,
750 input + 3 * xsimd::batch<float, Arch>::size);
753 template <class Arch>
754 static inline xsimd::batch<uint8_t, Arch>
755 ConsecutiveU(xsimd::batch<float, Arch> quant_mult, const float *input) {
756 return TileU(quant_mult, input + 0 * xsimd::batch<float, Arch>::size,
757 input + 1 * xsimd::batch<float, Arch>::size,
758 input + 2 * xsimd::batch<float, Arch>::size,
759 input + 3 * xsimd::batch<float, Arch>::size);
762 template <class Arch>
763 static inline xsimd::batch<int8_t, Arch>
764 ConsecutiveWithWrapping(xsimd::batch<float, Arch> quant_mult,
765 const float *input, size_t cols_left, size_t cols,
766 size_t row_step) {
767 using batchf32 = xsimd::batch<float, Arch>;
768 const float *inputs[4];
769 for (size_t i = 0; i < std::size(inputs); ++i) {
770 while (cols_left < batchf32::size) {
771 input += cols * (row_step - 1);
772 cols_left += cols;
774 inputs[i] = input;
775 input += batchf32::size;
776 cols_left -= batchf32::size;
778 return Tile(quant_mult, inputs[0], inputs[1], inputs[2], inputs[3]);
781 template <class Arch>
782 static inline xsimd::batch<int8_t, Arch>
783 ForReshape(xsimd::batch<float, Arch> quant_mult, const float *input,
784 size_t cols) {
785 using batchf32 = xsimd::batch<float, Arch>;
786 using batch8 = xsimd::batch<int8_t, Arch>;
787 using batch16 = xsimd::batch<int16_t, Arch>;
788 using batch32 = xsimd::batch<int32_t, Arch>;
789 using ubatch32 = xsimd::batch<uint32_t, Arch>;
791 // Put higher rows in the second half of the register. These will jumble
792 // around in the same way then conveniently land in the right place.
793 if constexpr (batchf32::size == 16) {
794 const batch8 neg127(-127);
795 // In reverse order: grabbing the first 32-bit values from each 128-bit
796 // register, then the second 32-bit values, etc. Grab 4 registers at a
797 // time in 32-bit format.
798 batch32 g0 =
799 QuantizerGrabHalves(input + 0 * cols, input + 2 * cols, quant_mult);
800 batch32 g1 =
801 QuantizerGrabHalves(input + 16 * cols, input + 18 * cols, quant_mult);
802 batch32 g2 =
803 QuantizerGrabHalves(input + 32 * cols, input + 34 * cols, quant_mult);
804 batch32 g3 =
805 QuantizerGrabHalves(input + 48 * cols, input + 50 * cols, quant_mult);
807 // Pack 32-bit to 16-bit.
808 batch16 packed0 = deinterleave(g0, g1);
809 batch16 packed1 = deinterleave(g2, g3);
810 // Pack 16-bit to 8-bit.
811 batch8 packed = deinterleave(packed0, packed1);
812 // Ban -128.
813 packed = xsimd::max(packed, neg127);
815 return xsimd::bitwise_cast<int8_t>(
816 xsimd::swizzle(xsimd::bitwise_cast<int32_t>(packed),
817 xsimd::make_batch_constant<ubatch32, Tiler<Arch>>()));
818 } else if constexpr (batchf32::size == 8)
819 return Tile(quant_mult, input, input + 2 * cols, input + 16 * cols,
820 input + 18 * cols);
821 else if constexpr (batchf32::size == 4)
822 // Skip a row.
823 return Tile(quant_mult, input, input + 4, input + 2 * cols,
824 input + 2 * cols + 4);
825 else
826 return {};
829 template <class Arch>
830 static inline xsimd::batch<int8_t, Arch>
831 Tile(xsimd::batch<float, Arch> quant_mult, const float *input0,
832 const float *input1, const float *input2, const float *input3) {
833 using batch8 = xsimd::batch<int8_t, Arch>;
834 using batch16 = xsimd::batch<int16_t, Arch>;
835 using batch32 = xsimd::batch<int32_t, Arch>;
836 using ubatch32 = xsimd::batch<uint32_t, Arch>;
838 const batch8 neg127(-127);
839 // Grab 4 registers at a time in 32-bit format.
840 batch32 g0 = QuantizerGrab(input0, quant_mult);
841 batch32 g1 = QuantizerGrab(input1, quant_mult);
842 batch32 g2 = QuantizerGrab(input2, quant_mult);
843 batch32 g3 = QuantizerGrab(input3, quant_mult);
844 // Pack 32-bit to 16-bit.
845 batch16 packed0 = deinterleave(g0, g1);
846 batch16 packed1 = deinterleave(g2, g3);
847 // Pack 16-bit to 8-bit.
848 batch8 packed = deinterleave(packed0, packed1);
849 // Ban -128.
850 packed = xsimd::max(packed, neg127);
852 if constexpr (batch32::size == 4)
853 return packed;
854 // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14
855 // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7
856 // Technically this could be removed so long as the rows are bigger than 16
857 // and the values are only used for GEMM.
858 return xsimd::bitwise_cast<int8_t>(
859 xsimd::swizzle(xsimd::bitwise_cast<int32_t>(packed),
860 xsimd::make_batch_constant<ubatch32, Tiler<Arch>>()));
863 private:
864 // A version that produces uint8_ts
865 template <class Arch>
866 static inline xsimd::batch<uint8_t, Arch>
867 TileU(xsimd::batch<float, Arch> quant_mult, const float *input0,
868 const float *input1, const float *input2, const float *input3) {
869 using batch8 = xsimd::batch<int8_t, Arch>;
870 using batch16 = xsimd::batch<int16_t, Arch>;
871 using batch32 = xsimd::batch<int32_t, Arch>;
872 using ubatch32 = xsimd::batch<uint32_t, Arch>;
874 const batch8 neg127 = -127;
875 const batch8 pos127 = +127;
876 // Grab 4 registers at a time in 32-bit format.
877 batch32 g0 = QuantizerGrab(input0, quant_mult);
878 batch32 g1 = QuantizerGrab(input1, quant_mult);
879 batch32 g2 = QuantizerGrab(input2, quant_mult);
880 batch32 g3 = QuantizerGrab(input3, quant_mult);
881 // Pack 32-bit to 16-bit.
882 batch16 packed0 = deinterleave(g0, g1);
883 batch16 packed1 = deinterleave(g2, g3);
884 // Pack 16-bit to 8-bit.
885 batch8 packed = deinterleave(packed0, packed1);
886 // Ban -128.
887 packed = xsimd::max(packed, neg127); // Could be removed if we use +128
888 packed = packed + pos127;
889 if (batch32::size == 4)
890 return xsimd::bitwise_cast<uint8_t>(packed);
891 // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14
892 // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7
893 // Technically this could be removed so long as the rows are bigger than 16
894 // and the values are only used for GEMM.
895 return xsimd::bitwise_cast<uint8_t>(
896 xsimd::swizzle(xsimd::bitwise_cast<int32_t>(packed),
897 xsimd::make_batch_constant<ubatch32, Tiler<Arch>>()));
901 template <class Arch>
902 inline void Transpose16InLane(
903 xsimd::batch<int8_t, Arch> &r0, xsimd::batch<int8_t, Arch> &r1,
904 xsimd::batch<int8_t, Arch> &r2, xsimd::batch<int8_t, Arch> &r3,
905 xsimd::batch<int8_t, Arch> &r4, xsimd::batch<int8_t, Arch> &r5,
906 xsimd::batch<int8_t, Arch> &r6, xsimd::batch<int8_t, Arch> &r7) {
907 /* r0: columns 0 1 2 3 4 5 6 7 from row 0
908 r1: columns 0 1 2 3 4 5 6 7 from row 1*/
909 auto r0_16 = xsimd::bitwise_cast<int16_t>(r0);
910 auto r1_16 = xsimd::bitwise_cast<int16_t>(r1);
911 auto r2_16 = xsimd::bitwise_cast<int16_t>(r2);
912 auto r3_16 = xsimd::bitwise_cast<int16_t>(r3);
913 auto r4_16 = xsimd::bitwise_cast<int16_t>(r4);
914 auto r5_16 = xsimd::bitwise_cast<int16_t>(r5);
915 auto r6_16 = xsimd::bitwise_cast<int16_t>(r6);
916 auto r7_16 = xsimd::bitwise_cast<int16_t>(r7);
918 std::tie(r0_16, r1_16) = interleave(r0_16, r1_16);
919 std::tie(r2_16, r3_16) = interleave(r2_16, r3_16);
920 std::tie(r4_16, r5_16) = interleave(r4_16, r5_16);
921 std::tie(r6_16, r7_16) = interleave(r6_16, r7_16);
922 /* r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1
923 r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1
924 r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3
925 r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3
926 r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5
927 r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5
928 r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7
929 r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7*/
930 auto r0_32 = xsimd::bitwise_cast<int32_t>(r0_16);
931 auto r2_32 = xsimd::bitwise_cast<int32_t>(r2_16);
932 auto r1_32 = xsimd::bitwise_cast<int32_t>(r1_16);
933 auto r3_32 = xsimd::bitwise_cast<int32_t>(r3_16);
934 auto r4_32 = xsimd::bitwise_cast<int32_t>(r4_16);
935 auto r6_32 = xsimd::bitwise_cast<int32_t>(r6_16);
936 auto r5_32 = xsimd::bitwise_cast<int32_t>(r5_16);
937 auto r7_32 = xsimd::bitwise_cast<int32_t>(r7_16);
939 std::tie(r0_32, r2_32) = interleave(r0_32, r2_32);
940 std::tie(r1_32, r3_32) = interleave(r1_32, r3_32);
941 std::tie(r4_32, r6_32) = interleave(r4_32, r6_32);
942 std::tie(r5_32, r7_32) = interleave(r5_32, r7_32);
943 /* r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3
944 r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3
945 r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3
946 r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3
947 r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7
948 r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7
949 r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7
950 r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7*/
952 auto r0_64 = xsimd::bitwise_cast<int64_t>(r0_32);
953 auto r2_64 = xsimd::bitwise_cast<int64_t>(r2_32);
954 auto r1_64 = xsimd::bitwise_cast<int64_t>(r1_32);
955 auto r3_64 = xsimd::bitwise_cast<int64_t>(r3_32);
956 auto r4_64 = xsimd::bitwise_cast<int64_t>(r4_32);
957 auto r6_64 = xsimd::bitwise_cast<int64_t>(r6_32);
958 auto r5_64 = xsimd::bitwise_cast<int64_t>(r5_32);
959 auto r7_64 = xsimd::bitwise_cast<int64_t>(r7_32);
961 std::tie(r0_64, r4_64) = interleave(r0_64, r4_64);
962 std::tie(r1_64, r5_64) = interleave(r1_64, r5_64);
963 std::tie(r2_64, r6_64) = interleave(r2_64, r6_64);
964 std::tie(r3_64, r7_64) = interleave(r3_64, r7_64);
966 r0 = xsimd::bitwise_cast<int8_t>(r0_64);
967 r1 = xsimd::bitwise_cast<int8_t>(r1_64);
968 r2 = xsimd::bitwise_cast<int8_t>(r2_64);
969 r3 = xsimd::bitwise_cast<int8_t>(r3_64);
970 r4 = xsimd::bitwise_cast<int8_t>(r4_64);
971 r5 = xsimd::bitwise_cast<int8_t>(r5_64);
972 r6 = xsimd::bitwise_cast<int8_t>(r6_64);
973 r7 = xsimd::bitwise_cast<int8_t>(r7_64);
974 /* r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7
975 r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7
976 r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7
977 r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7
978 r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7
979 r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7*/
980 /* Empirically gcc is able to remove these movs and just rename the outputs of
981 * Interleave64. */
982 std::swap(r1, r4);
983 std::swap(r3, r6);
986 template <class Arch, typename IntegerTy>
987 void SelectColumnsOfB(const xsimd::batch<int8_t, Arch> *input,
988 xsimd::batch<int8_t, Arch> *output,
989 size_t rows_bytes /* number of bytes in a row */,
990 const IntegerTy *cols_begin, const IntegerTy *cols_end) {
991 using batch8 = xsimd::batch<int8_t, Arch>;
992 /* Do columns for multiples of 8.*/
993 size_t register_rows = rows_bytes / batch8::size;
994 const batch8 *starts[8];
995 for (; cols_begin != cols_end; cols_begin += 8) {
996 for (size_t k = 0; k < 8; ++k) {
997 starts[k] =
998 input + (cols_begin[k] & 7) + (cols_begin[k] & ~7) * register_rows;
1000 for (size_t r = 0; r < register_rows; ++r) {
1001 for (size_t k = 0; k < 8; ++k) {
1002 *(output++) = *starts[k];
1003 starts[k] += 8;
1009 } // namespace
1011 namespace callbacks {
1012 template <class Arch>
1013 xsimd::batch<float, Arch> Unquantize::operator()(xsimd::batch<int32_t, Arch> total, size_t, size_t,
1014 size_t) {
1015 return xsimd::batch_cast<float>(total) * unquant_mult;
1018 template <class Arch>
1019 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> Unquantize::operator()(
1020 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> total,
1021 size_t, size_t, size_t) {
1022 return std::make_tuple(
1023 xsimd::batch_cast<float>(std::get<0>(total)) * unquant_mult,
1024 xsimd::batch_cast<float>(std::get<1>(total)) * unquant_mult);
1027 template <class Arch>
1028 xsimd::batch<float, Arch> AddBias::operator()(xsimd::batch<float, Arch> total, size_t,
1029 size_t col_idx, size_t) {
1030 return total + xsimd::batch<float, Arch>::load_aligned(bias_addr + col_idx);
1033 template <class Arch>
1034 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>>
1035 AddBias::operator()(
1036 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> total,
1037 size_t, size_t col_idx, size_t) {
1038 return std::make_tuple(
1039 std::get<0>(total) + xsimd::batch<float, Arch>::load_aligned(bias_addr + col_idx + 0),
1040 std::get<1>(total) +
1041 xsimd::batch<float, Arch>::load_aligned(bias_addr + col_idx +
1042 xsimd::batch<float, Arch>::size));
1045 template <class Arch>
1046 void Write::operator()(xsimd::batch<float, Arch> result, size_t row_idx,
1047 size_t col_idx, size_t col_size) {
1048 result.store_aligned(output_addr + row_idx * col_size + col_idx);
1051 template <class Arch>
1052 void Write::operator()(xsimd::batch<int32_t, Arch> result, size_t row_idx,
1053 size_t col_idx, size_t col_size) {
1054 xsimd::bitwise_cast<float>(result).store_aligned(
1055 output_addr + row_idx * col_size + col_idx);
1058 template <class Arch>
1059 void Write::operator()(
1060 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> result,
1061 size_t row_idx, size_t col_idx, size_t col_size) {
1062 std::get<0>(result).store_aligned(output_addr + row_idx * col_size + col_idx +
1064 std::get<1>(result).store_aligned(output_addr + row_idx * col_size + col_idx +
1065 xsimd::batch<float, Arch>::size);
1068 template <class Arch>
1069 void Write::operator()(
1070 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> result,
1071 size_t row_idx, size_t col_idx, size_t col_size) {
1072 xsimd::bitwise_cast<float>(std::get<0>(result))
1073 .store_aligned(output_addr + row_idx * col_size + col_idx + 0);
1074 xsimd::bitwise_cast<float>(std::get<1>(result))
1075 .store_aligned(output_addr + row_idx * col_size + col_idx +
1076 xsimd::batch<int32_t, Arch>::size);
1079 template <class T>
1080 void UnquantizeAndWrite::operator()(T const &total, size_t row_idx,
1081 size_t col_idx, size_t col_size) {
1082 auto unquantized = unquantize(total, row_idx, col_idx, col_size);
1083 write(unquantized, row_idx, col_idx, col_size);
1086 template <class T>
1087 void UnquantizeAndAddBiasAndWrite::operator()(T const &total, size_t row_idx,
1088 size_t col_idx, size_t col_size) {
1089 auto unquantized = unquantize(total, row_idx, col_idx, col_size);
1090 auto bias_added = add_bias(unquantized, row_idx, col_idx, col_size);
1091 write(bias_added, row_idx, col_idx, col_size);
1093 } // namespace callbacks
1095 template <class Arch>
1096 void Engine<Arch>::QuantizeU(const float *input, uint8_t *output,
1097 float quant_mult, size_t size) {
1098 using batch8 = xsimd::batch<int8_t, Arch>;
1100 xsimd::batch<float, Arch> q(quant_mult);
1101 const float *end = input + size;
1102 for (; input != end; input += batch8::size, output += batch8::size) {
1103 auto tile = QuantizeTile8::ConsecutiveU(q, input);
1104 tile.store_aligned(output);
1108 template <class Arch>
1109 void Engine<Arch>::Quantize(const float *const input, int8_t *const output,
1110 float quant_mult, size_t size) {
1111 using batch8 = xsimd::batch<int8_t, Arch>;
1113 const std::size_t kBatch = batch8::size;
1114 const std::size_t fast_end = size & ~(kBatch - 1);
1116 xsimd::batch<float, Arch> q(quant_mult);
1117 for (std::size_t i = 0; i < fast_end; i += kBatch) {
1118 auto tile = QuantizeTile8::Consecutive(q, input + i);
1119 tile.store_aligned(output + i);
1122 std::size_t overhang = size & (kBatch - 1);
1123 if (!overhang)
1124 return;
1125 /* Each does size(xsimd::batch<int8_t, Arch>) / 32 == kBatch / 4 floats at a
1126 * time. If we're allowed to read one of them, then we can read the whole
1127 * register.
1129 const float *inputs[4];
1130 std::size_t i;
1131 for (i = 0; i < (overhang + (kBatch / 4) - 1) / (kBatch / 4); ++i) {
1132 inputs[i] = &input[fast_end + i * (kBatch / 4)];
1134 /* These will be clipped off. */
1135 for (; i < 4; ++i) {
1136 inputs[i] = &input[fast_end];
1138 auto result =
1139 QuantizeTile8::Tile(q, inputs[0], inputs[1], inputs[2], inputs[3]);
1140 std::memcpy(output + (size & ~(kBatch - 1)), &result, overhang);
1143 template <class Arch>
1144 template <typename IntegerTy>
1145 void Engine<Arch>::SelectColumnsB(const int8_t *input, int8_t *output,
1146 size_t rows, const IntegerTy *cols_begin,
1147 const IntegerTy *cols_end) {
1148 using batch8 = xsimd::batch<int8_t, Arch>;
1149 SelectColumnsOfB(reinterpret_cast<const batch8 *>(input),
1150 reinterpret_cast<batch8 *>(output), rows, cols_begin,
1151 cols_end);
1154 template <class Arch>
1155 void Engine<Arch>::PrepareBTransposed(const float *input, int8_t *output,
1156 float quant_mult, size_t cols,
1157 size_t rows) {
1158 using batch8 = xsimd::batch<int8_t, Arch>;
1159 const size_t RegisterElemsInt = batch8::size;
1160 const size_t kColStride = 8;
1162 xsimd::batch<float, Arch> q(quant_mult);
1163 auto *output_it = reinterpret_cast<batch8 *>(output);
1164 size_t r = 0;
1165 size_t c = 0;
1166 while (r < rows) {
1167 for (size_t ri = 0; ri < 8; ++ri)
1168 *output_it++ = QuantizeTile8::ConsecutiveWithWrapping(
1169 q, input + (r + ri) * cols + c, cols - c, cols, 8);
1170 c += RegisterElemsInt;
1171 while (c >= cols) {
1172 r += kColStride;
1173 c -= cols;
1178 template <class Arch>
1179 void Engine<Arch>::PrepareBQuantizedTransposed(const int8_t *input,
1180 int8_t *output, size_t cols,
1181 size_t rows) {
1182 using batch8 = xsimd::batch<int8_t, Arch>;
1183 const size_t RegisterElems = batch8::size;
1184 const size_t kColStride = 8;
1186 auto *output_it = reinterpret_cast<batch8 *>(output);
1187 for (size_t r = 0; r < rows; r += kColStride)
1188 for (size_t c = 0; c < cols; c += RegisterElems)
1189 for (size_t ri = 0; ri < 8; ++ri)
1190 *output_it++ =
1191 *reinterpret_cast<const batch8 *>(input + (r + ri) * cols + c);
1194 template <class Arch>
1195 void Engine<Arch>::PrepareB(const float *input, int8_t *output_shadow,
1196 float quant_mult, size_t rows, size_t cols) {
1197 using batch8 = xsimd::batch<int8_t, Arch>;
1199 xsimd::batch<float, Arch> q(quant_mult);
1200 /* Currently all multipliers have a stride of 8 columns.*/
1201 const size_t kColStride = 8;
1202 auto *output = reinterpret_cast<batch8 *>(output_shadow);
1203 for (size_t c = 0; c < cols; c += kColStride) {
1204 for (size_t r = 0; r < rows; r += sizeof(*output), output += 8) {
1205 output[0] =
1206 QuantizeTile8::ForReshape(q, input + cols * (r + 0) + c, cols);
1207 output[1] =
1208 QuantizeTile8::ForReshape(q, input + cols * (r + 1) + c, cols);
1209 output[2] =
1210 QuantizeTile8::ForReshape(q, input + cols * (r + 4) + c, cols);
1211 output[3] =
1212 QuantizeTile8::ForReshape(q, input + cols * (r + 5) + c, cols);
1213 output[4] =
1214 QuantizeTile8::ForReshape(q, input + cols * (r + 8) + c, cols);
1215 output[5] =
1216 QuantizeTile8::ForReshape(q, input + cols * (r + 9) + c, cols);
1217 output[6] =
1218 QuantizeTile8::ForReshape(q, input + cols * (r + 12) + c, cols);
1219 output[7] =
1220 QuantizeTile8::ForReshape(q, input + cols * (r + 13) + c, cols);
1221 std::tie(output[0], output[1]) =
1222 interleave(xsimd::bitwise_cast<int8_t>(output[0]),
1223 xsimd::bitwise_cast<int8_t>(output[1]));
1224 std::tie(output[2], output[3]) =
1225 interleave(xsimd::bitwise_cast<int8_t>(output[2]),
1226 xsimd::bitwise_cast<int8_t>(output[3]));
1227 std::tie(output[4], output[5]) =
1228 interleave(xsimd::bitwise_cast<int8_t>(output[4]),
1229 xsimd::bitwise_cast<int8_t>(output[5]));
1230 std::tie(output[6], output[7]) =
1231 interleave(xsimd::bitwise_cast<int8_t>(output[6]),
1232 xsimd::bitwise_cast<int8_t>(output[7]));
1233 Transpose16InLane(output[0], output[1], output[2], output[3], output[4],
1234 output[5], output[6], output[7]);
1239 template <class Arch>
1240 void Engine<Arch>::PrepareA(const float *input, int8_t *output,
1241 float quant_mult, size_t rows, size_t cols) {
1242 Quantize(input, output, quant_mult, rows * cols);
1245 template <class Arch>
1246 void Engine<Arch>::Shift::PrepareA(const float *input, uint8_t *output,
1247 float quant_mult, size_t rows, size_t cols) {
1248 QuantizeU(input, output, quant_mult, rows * cols);
1251 template <class Arch>
1252 template <class Callback>
1253 void Engine<Arch>::Shift::Multiply(const uint8_t *A, const int8_t *B,
1254 size_t A_rows, size_t width, size_t B_cols,
1255 Callback callback) {
1257 using batch8 = xsimd::batch<int8_t, Arch>;
1258 using ubatch8 = xsimd::batch<uint8_t, Arch>;
1259 using batch32 = xsimd::batch<int32_t, Arch>;
1261 const size_t simd_width = width / batch8::size;
1262 for (size_t B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
1263 const auto *B0_col =
1264 reinterpret_cast<const batch8 *>(B) + simd_width * B0_colidx;
1265 /* Process one row of A at a time. Doesn't seem to be faster to do multiple
1266 * rows of A at once.*/
1267 for (size_t A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
1268 const auto *A_row =
1269 reinterpret_cast<const ubatch8 *>(A + A_rowidx * width);
1270 /* These will be packed 16-bit integers containing sums for each row of B
1271 multiplied by the row of A. Iterate over shared (inner) dimension.*/
1272 /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is
1273 * declared here.*/
1274 size_t k = 0;
1275 ubatch8 a = *(A_row + k);
1276 batch32 isum0 = maddw(a, *(B0_col + k * 8));
1277 batch32 isum1 = maddw(a, *(B0_col + k * 8 + 1));
1278 batch32 isum2 = maddw(a, *(B0_col + k * 8 + 2));
1279 batch32 isum3 = maddw(a, *(B0_col + k * 8 + 3));
1280 batch32 isum4 = maddw(a, *(B0_col + k * 8 + 4));
1281 batch32 isum5 = maddw(a, *(B0_col + k * 8 + 5));
1282 batch32 isum6 = maddw(a, *(B0_col + k * 8 + 6));
1283 batch32 isum7 = maddw(a, *(B0_col + k * 8 + 7));
1284 for (k = 1; k < simd_width; ++k) {
1285 a = *(A_row + k);
1286 /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/
1287 /* Upcast to 32-bit and horizontally add.*/
1288 isum0 = maddw(a, *(B0_col + k * 8 + 0), isum0);
1289 isum1 = maddw(a, *(B0_col + k * 8 + 1), isum1);
1290 isum2 = maddw(a, *(B0_col + k * 8 + 2), isum2);
1291 isum3 = maddw(a, *(B0_col + k * 8 + 3), isum3);
1292 isum4 = maddw(a, *(B0_col + k * 8 + 4), isum4);
1293 isum5 = maddw(a, *(B0_col + k * 8 + 5), isum5);
1294 isum6 = maddw(a, *(B0_col + k * 8 + 6), isum6);
1295 isum7 = maddw(a, *(B0_col + k * 8 + 7), isum7);
1297 /* Reduce sums within 128-bit lanes.*/
1298 auto pack0123 = Pack0123(isum0, isum1, isum2, isum3);
1299 auto pack4567 = Pack0123(isum4, isum5, isum6, isum7);
1300 /*The specific implementation may need to reduce further.*/
1301 auto total = PermuteSummer(pack0123, pack4567);
1302 callback(total, A_rowidx, B0_colidx, B_cols);
1307 template <class Arch>
1308 template <class Callback>
1309 void Engine<Arch>::Shift::PrepareBias(const int8_t *B, size_t width,
1310 size_t B_cols, Callback C) {
1311 using batch8 = xsimd::batch<int8_t, Arch>;
1312 const size_t simd_width = width / batch8::size;
1313 xsimd::batch<uint8_t, Arch> a(1);
1314 for (size_t j = 0; j < B_cols; j += 8) {
1315 /*Process one row of A at a time. Doesn't seem to be faster to do multiple
1316 * rows of A at once.*/
1317 const int8_t *B_j = B + j * width;
1319 /* Rather than initializing as zeros and adding, just initialize the
1320 * first.*/
1321 /* These will be packed 16-bit integers containing sums for each column of
1322 * B multiplied by the row of A.*/
1323 /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is
1324 * declared here.*/
1325 auto isum0 = maddw(a, batch8::load_aligned(&B_j[0 * batch8::size]));
1326 auto isum1 = maddw(a, batch8::load_aligned(&B_j[1 * batch8::size]));
1327 auto isum2 = maddw(a, batch8::load_aligned(&B_j[2 * batch8::size]));
1328 auto isum3 = maddw(a, batch8::load_aligned(&B_j[3 * batch8::size]));
1329 auto isum4 = maddw(a, batch8::load_aligned(&B_j[4 * batch8::size]));
1330 auto isum5 = maddw(a, batch8::load_aligned(&B_j[5 * batch8::size]));
1331 auto isum6 = maddw(a, batch8::load_aligned(&B_j[6 * batch8::size]));
1332 auto isum7 = maddw(a, batch8::load_aligned(&B_j[7 * batch8::size]));
1334 B_j += 8 * batch8::size;
1336 for (size_t k = 1; k < simd_width; ++k, B_j += 8 * batch8::size) {
1337 isum0 = maddw(a, batch8::load_aligned(&B_j[0 * batch8::size]), isum0);
1338 isum1 = maddw(a, batch8::load_aligned(&B_j[1 * batch8::size]), isum1);
1339 isum2 = maddw(a, batch8::load_aligned(&B_j[2 * batch8::size]), isum2);
1340 isum3 = maddw(a, batch8::load_aligned(&B_j[3 * batch8::size]), isum3);
1341 isum4 = maddw(a, batch8::load_aligned(&B_j[4 * batch8::size]), isum4);
1342 isum5 = maddw(a, batch8::load_aligned(&B_j[5 * batch8::size]), isum5);
1343 isum6 = maddw(a, batch8::load_aligned(&B_j[6 * batch8::size]), isum6);
1344 isum7 = maddw(a, batch8::load_aligned(&B_j[7 * batch8::size]), isum7);
1347 auto pack0123 = Pack0123(isum0, isum1, isum2, isum3);
1348 auto pack4567 = Pack0123(isum4, isum5, isum6, isum7);
1350 auto total = PermuteSummer(pack0123, pack4567);
1351 C(total, 0, j, B_cols);
1355 } // namespace gemmology
1357 #endif