Bug 1883521 [wpt PR 44917] - Stop modifying author provided selection colors, a=testonly
[gecko.git] / mozglue / misc / SIMD.cpp
blob3893de57b32cd2d07c9afd376aedfd789e105995
1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
7 #include "mozilla/SIMD.h"
9 #include <cstring>
10 #include <stdint.h>
11 #include <type_traits>
13 #include "mozilla/EndianUtils.h"
14 #include "mozilla/SSE.h"
16 #ifdef MOZILLA_PRESUME_SSE2
18 # include <immintrin.h>
20 #endif
22 namespace mozilla {
24 template <typename TValue>
25 const TValue* FindInBufferNaive(const TValue* ptr, TValue value,
26 size_t length) {
27 const TValue* end = ptr + length;
28 while (ptr < end) {
29 if (*ptr == value) {
30 return ptr;
32 ptr++;
34 return nullptr;
37 #ifdef MOZILLA_PRESUME_SSE2
39 const __m128i* Cast128(uintptr_t ptr) {
40 return reinterpret_cast<const __m128i*>(ptr);
43 template <typename T>
44 T GetAs(uintptr_t ptr) {
45 return *reinterpret_cast<const T*>(ptr);
48 // Akin to ceil/floor, AlignDown/AlignUp will return the original pointer if it
49 // is already aligned.
50 uintptr_t AlignDown16(uintptr_t ptr) { return ptr & ~0xf; }
52 uintptr_t AlignUp16(uintptr_t ptr) { return AlignDown16(ptr + 0xf); }
54 template <typename TValue>
55 __m128i CmpEq128(__m128i a, __m128i b) {
56 static_assert(sizeof(TValue) == 1 || sizeof(TValue) == 2);
57 if (sizeof(TValue) == 1) {
58 return _mm_cmpeq_epi8(a, b);
60 return _mm_cmpeq_epi16(a, b);
63 # ifdef __GNUC__
65 // Earlier versions of GCC are missing the _mm_loadu_si32 instruction. This
66 // workaround from Peter Cordes (https://stackoverflow.com/a/72837992) compiles
67 // down to the same instructions. We could just replace _mm_loadu_si32
68 __m128i Load32BitsIntoXMM(uintptr_t ptr) {
69 int tmp;
70 memcpy(&tmp, reinterpret_cast<const void*>(ptr),
71 sizeof(tmp)); // unaligned aliasing-safe load
72 return _mm_cvtsi32_si128(tmp); // efficient on GCC/clang/MSVC
75 # else
77 __m128i Load32BitsIntoXMM(uintptr_t ptr) {
78 return _mm_loadu_si32(Cast128(ptr));
81 # endif
83 const char* Check4x4Chars(__m128i needle, uintptr_t a, uintptr_t b, uintptr_t c,
84 uintptr_t d) {
85 __m128i haystackA = Load32BitsIntoXMM(a);
86 __m128i cmpA = CmpEq128<char>(needle, haystackA);
87 __m128i haystackB = Load32BitsIntoXMM(b);
88 __m128i cmpB = CmpEq128<char>(needle, haystackB);
89 __m128i haystackC = Load32BitsIntoXMM(c);
90 __m128i cmpC = CmpEq128<char>(needle, haystackC);
91 __m128i haystackD = Load32BitsIntoXMM(d);
92 __m128i cmpD = CmpEq128<char>(needle, haystackD);
93 __m128i or_ab = _mm_or_si128(cmpA, cmpB);
94 __m128i or_cd = _mm_or_si128(cmpC, cmpD);
95 __m128i or_abcd = _mm_or_si128(or_ab, or_cd);
96 int orMask = _mm_movemask_epi8(or_abcd);
97 if (orMask & 0xf) {
98 int cmpMask;
99 cmpMask = _mm_movemask_epi8(cmpA);
100 if (cmpMask & 0xf) {
101 return reinterpret_cast<const char*>(a + __builtin_ctz(cmpMask));
103 cmpMask = _mm_movemask_epi8(cmpB);
104 if (cmpMask & 0xf) {
105 return reinterpret_cast<const char*>(b + __builtin_ctz(cmpMask));
107 cmpMask = _mm_movemask_epi8(cmpC);
108 if (cmpMask & 0xf) {
109 return reinterpret_cast<const char*>(c + __builtin_ctz(cmpMask));
111 cmpMask = _mm_movemask_epi8(cmpD);
112 if (cmpMask & 0xf) {
113 return reinterpret_cast<const char*>(d + __builtin_ctz(cmpMask));
117 return nullptr;
120 template <typename TValue>
121 const TValue* Check4x16Bytes(__m128i needle, uintptr_t a, uintptr_t b,
122 uintptr_t c, uintptr_t d) {
123 __m128i haystackA = _mm_loadu_si128(Cast128(a));
124 __m128i cmpA = CmpEq128<TValue>(needle, haystackA);
125 __m128i haystackB = _mm_loadu_si128(Cast128(b));
126 __m128i cmpB = CmpEq128<TValue>(needle, haystackB);
127 __m128i haystackC = _mm_loadu_si128(Cast128(c));
128 __m128i cmpC = CmpEq128<TValue>(needle, haystackC);
129 __m128i haystackD = _mm_loadu_si128(Cast128(d));
130 __m128i cmpD = CmpEq128<TValue>(needle, haystackD);
131 __m128i or_ab = _mm_or_si128(cmpA, cmpB);
132 __m128i or_cd = _mm_or_si128(cmpC, cmpD);
133 __m128i or_abcd = _mm_or_si128(or_ab, or_cd);
134 int orMask = _mm_movemask_epi8(or_abcd);
135 if (orMask) {
136 int cmpMask;
137 cmpMask = _mm_movemask_epi8(cmpA);
138 if (cmpMask) {
139 return reinterpret_cast<const TValue*>(a + __builtin_ctz(cmpMask));
141 cmpMask = _mm_movemask_epi8(cmpB);
142 if (cmpMask) {
143 return reinterpret_cast<const TValue*>(b + __builtin_ctz(cmpMask));
145 cmpMask = _mm_movemask_epi8(cmpC);
146 if (cmpMask) {
147 return reinterpret_cast<const TValue*>(c + __builtin_ctz(cmpMask));
149 cmpMask = _mm_movemask_epi8(cmpD);
150 if (cmpMask) {
151 return reinterpret_cast<const TValue*>(d + __builtin_ctz(cmpMask));
155 return nullptr;
158 enum class HaystackOverlap {
159 Overlapping,
160 Sequential,
163 // Check two 16-byte chunks for the two-byte sequence loaded into needle1
164 // followed by needle1. `carryOut` is an optional pointer which we will
165 // populate based on whether the last character of b matches needle1. This
166 // should be provided on subsequent calls via `carryIn` so we can detect cases
167 // where the last byte of b's 16-byte chunk is needle1 and the first byte of
168 // the next a's 16-byte chunk is needle2. `overlap` and whether
169 // `carryIn`/`carryOut` are NULL should be knowable at compile time to avoid
170 // branching.
171 template <typename TValue>
172 const TValue* Check2x2x16Bytes(__m128i needle1, __m128i needle2, uintptr_t a,
173 uintptr_t b, __m128i* carryIn, __m128i* carryOut,
174 HaystackOverlap overlap) {
175 const int shiftRightAmount = 16 - sizeof(TValue);
176 const int shiftLeftAmount = sizeof(TValue);
177 __m128i haystackA = _mm_loadu_si128(Cast128(a));
178 __m128i cmpA1 = CmpEq128<TValue>(needle1, haystackA);
179 __m128i cmpA2 = CmpEq128<TValue>(needle2, haystackA);
180 __m128i cmpA;
181 if (carryIn) {
182 cmpA = _mm_and_si128(
183 _mm_or_si128(_mm_bslli_si128(cmpA1, shiftLeftAmount), *carryIn), cmpA2);
184 } else {
185 cmpA = _mm_and_si128(_mm_bslli_si128(cmpA1, shiftLeftAmount), cmpA2);
187 __m128i haystackB = _mm_loadu_si128(Cast128(b));
188 __m128i cmpB1 = CmpEq128<TValue>(needle1, haystackB);
189 __m128i cmpB2 = CmpEq128<TValue>(needle2, haystackB);
190 __m128i cmpB;
191 if (overlap == HaystackOverlap::Overlapping) {
192 cmpB = _mm_and_si128(_mm_bslli_si128(cmpB1, shiftLeftAmount), cmpB2);
193 } else {
194 MOZ_ASSERT(overlap == HaystackOverlap::Sequential);
195 __m128i carryAB = _mm_bsrli_si128(cmpA1, shiftRightAmount);
196 cmpB = _mm_and_si128(
197 _mm_or_si128(_mm_bslli_si128(cmpB1, shiftLeftAmount), carryAB), cmpB2);
199 __m128i or_ab = _mm_or_si128(cmpA, cmpB);
200 int orMask = _mm_movemask_epi8(or_ab);
201 if (orMask) {
202 int cmpMask;
203 cmpMask = _mm_movemask_epi8(cmpA);
204 if (cmpMask) {
205 return reinterpret_cast<const TValue*>(a + __builtin_ctz(cmpMask) -
206 shiftLeftAmount);
208 cmpMask = _mm_movemask_epi8(cmpB);
209 if (cmpMask) {
210 return reinterpret_cast<const TValue*>(b + __builtin_ctz(cmpMask) -
211 shiftLeftAmount);
215 if (carryOut) {
216 _mm_store_si128(carryOut, _mm_bsrli_si128(cmpB1, shiftRightAmount));
219 return nullptr;
222 template <typename TValue>
223 const TValue* FindInBuffer(const TValue* ptr, TValue value, size_t length) {
224 static_assert(sizeof(TValue) == 1 || sizeof(TValue) == 2);
225 static_assert(std::is_unsigned<TValue>::value);
226 uint64_t splat64;
227 if (sizeof(TValue) == 1) {
228 splat64 = 0x0101010101010101llu;
229 } else {
230 splat64 = 0x0001000100010001llu;
233 // Load our needle into a 16-byte register
234 uint64_t u64_value = static_cast<uint64_t>(value) * splat64;
235 int64_t i64_value = *reinterpret_cast<int64_t*>(&u64_value);
236 __m128i needle = _mm_set_epi64x(i64_value, i64_value);
238 size_t numBytes = length * sizeof(TValue);
239 uintptr_t cur = reinterpret_cast<uintptr_t>(ptr);
240 uintptr_t end = cur + numBytes;
242 if ((sizeof(TValue) > 1 && numBytes < 16) || numBytes < 4) {
243 while (cur < end) {
244 if (GetAs<TValue>(cur) == value) {
245 return reinterpret_cast<const TValue*>(cur);
247 cur += sizeof(TValue);
249 return nullptr;
252 if (numBytes < 16) {
253 // NOTE: here and below, we have some bit fiddling which could look a
254 // little weird. The important thing to note though is it's just a trick
255 // for getting the number 4 if numBytes is greater than or equal to 8,
256 // and 0 otherwise. This lets us fully cover the range without any
257 // branching for the case where numBytes is in [4,8), and [8,16). We get
258 // four ranges from this - if numbytes > 8, we get:
259 // [0,4), [4,8], [end - 8), [end - 4)
260 // and if numbytes < 8, we get
261 // [0,4), [0,4), [end - 4), [end - 4)
262 uintptr_t a = cur;
263 uintptr_t b = cur + ((numBytes & 8) >> 1);
264 uintptr_t c = end - 4 - ((numBytes & 8) >> 1);
265 uintptr_t d = end - 4;
266 const char* charResult = Check4x4Chars(needle, a, b, c, d);
267 // Note: we ensure above that sizeof(TValue) == 1 here, so this is
268 // either char to char or char to something like a uint8_t.
269 return reinterpret_cast<const TValue*>(charResult);
272 if (numBytes < 64) {
273 // NOTE: see the above explanation of the similar chunk of code, but in
274 // this case, replace 8 with 32 and 4 with 16.
275 uintptr_t a = cur;
276 uintptr_t b = cur + ((numBytes & 32) >> 1);
277 uintptr_t c = end - 16 - ((numBytes & 32) >> 1);
278 uintptr_t d = end - 16;
279 return Check4x16Bytes<TValue>(needle, a, b, c, d);
282 // Get the initial unaligned load out of the way. This will overlap with the
283 // aligned stuff below, but the overlapped part should effectively be free
284 // (relative to a mispredict from doing a byte-by-byte loop).
285 __m128i haystack = _mm_loadu_si128(Cast128(cur));
286 __m128i cmp = CmpEq128<TValue>(needle, haystack);
287 int cmpMask = _mm_movemask_epi8(cmp);
288 if (cmpMask) {
289 return reinterpret_cast<const TValue*>(cur + __builtin_ctz(cmpMask));
292 // Now we're working with aligned memory. Hooray! \o/
293 cur = AlignUp16(cur);
295 // The address of the final 48-63 bytes. We overlap this with what we check in
296 // our hot loop below to avoid branching. Again, the overlap should be
297 // negligible compared with a branch mispredict.
298 uintptr_t tailStartPtr = AlignDown16(end - 48);
299 uintptr_t tailEndPtr = end - 16;
301 while (cur < tailStartPtr) {
302 uintptr_t a = cur;
303 uintptr_t b = cur + 16;
304 uintptr_t c = cur + 32;
305 uintptr_t d = cur + 48;
306 const TValue* result = Check4x16Bytes<TValue>(needle, a, b, c, d);
307 if (result) {
308 return result;
310 cur += 64;
313 uintptr_t a = tailStartPtr;
314 uintptr_t b = tailStartPtr + 16;
315 uintptr_t c = tailStartPtr + 32;
316 uintptr_t d = tailEndPtr;
317 return Check4x16Bytes<TValue>(needle, a, b, c, d);
320 template <typename TValue>
321 const TValue* TwoElementLoop(uintptr_t start, uintptr_t end, TValue v1,
322 TValue v2) {
323 static_assert(sizeof(TValue) == 1 || sizeof(TValue) == 2);
325 const TValue* cur = reinterpret_cast<const TValue*>(start);
326 const TValue* preEnd = reinterpret_cast<const TValue*>(end - sizeof(TValue));
328 uint32_t expected = static_cast<uint32_t>(v1) |
329 (static_cast<uint32_t>(v2) << (sizeof(TValue) * 8));
330 while (cur < preEnd) {
331 // NOTE: this should only ever be called on little endian architectures.
332 static_assert(MOZ_LITTLE_ENDIAN());
333 // We or cur[0] and cur[1] together explicitly and compare to expected,
334 // in order to avoid UB from just loading them as a uint16_t/uint32_t.
335 // However, it will compile down the same code after optimizations on
336 // little endian systems which support unaligned loads. Comparing them
337 // value-by-value, however, will not, and seems to perform worse in local
338 // microbenchmarking. Even after bitwise or'ing the comparison values
339 // together to avoid the short circuit, the compiler doesn't seem to get
340 // the hint and creates two branches, the first of which might be
341 // frequently mispredicted.
342 uint32_t actual = static_cast<uint32_t>(cur[0]) |
343 (static_cast<uint32_t>(cur[1]) << (sizeof(TValue) * 8));
344 if (actual == expected) {
345 return cur;
347 cur++;
349 return nullptr;
352 template <typename TValue>
353 const TValue* FindTwoInBuffer(const TValue* ptr, TValue v1, TValue v2,
354 size_t length) {
355 static_assert(sizeof(TValue) == 1 || sizeof(TValue) == 2);
356 static_assert(std::is_unsigned<TValue>::value);
357 uint64_t splat64;
358 if (sizeof(TValue) == 1) {
359 splat64 = 0x0101010101010101llu;
360 } else {
361 splat64 = 0x0001000100010001llu;
364 // Load our needle into a 16-byte register
365 uint64_t u64_v1 = static_cast<uint64_t>(v1) * splat64;
366 int64_t i64_v1 = *reinterpret_cast<int64_t*>(&u64_v1);
367 __m128i needle1 = _mm_set_epi64x(i64_v1, i64_v1);
368 uint64_t u64_v2 = static_cast<uint64_t>(v2) * splat64;
369 int64_t i64_v2 = *reinterpret_cast<int64_t*>(&u64_v2);
370 __m128i needle2 = _mm_set_epi64x(i64_v2, i64_v2);
372 size_t numBytes = length * sizeof(TValue);
373 uintptr_t cur = reinterpret_cast<uintptr_t>(ptr);
374 uintptr_t end = cur + numBytes;
376 if (numBytes < 16) {
377 return TwoElementLoop<TValue>(cur, end, v1, v2);
380 if (numBytes < 32) {
381 uintptr_t a = cur;
382 uintptr_t b = end - 16;
383 return Check2x2x16Bytes<TValue>(needle1, needle2, a, b, nullptr, nullptr,
384 HaystackOverlap::Overlapping);
387 // Get the initial unaligned load out of the way. This will likely overlap
388 // with the aligned stuff below, but the overlapped part should effectively
389 // be free.
390 __m128i haystack = _mm_loadu_si128(Cast128(cur));
391 __m128i cmp1 = CmpEq128<TValue>(needle1, haystack);
392 __m128i cmp2 = CmpEq128<TValue>(needle2, haystack);
393 int cmpMask1 = _mm_movemask_epi8(cmp1);
394 int cmpMask2 = _mm_movemask_epi8(cmp2);
395 int cmpMask = (cmpMask1 << sizeof(TValue)) & cmpMask2;
396 if (cmpMask) {
397 return reinterpret_cast<const TValue*>(cur + __builtin_ctz(cmpMask) -
398 sizeof(TValue));
401 // Now we're working with aligned memory. Hooray! \o/
402 cur = AlignUp16(cur);
404 // The address of the final 48-63 bytes. We overlap this with what we check in
405 // our hot loop below to avoid branching. Again, the overlap should be
406 // negligible compared with a branch mispredict.
407 uintptr_t tailEndPtr = end - 16;
408 uintptr_t tailStartPtr = AlignDown16(tailEndPtr);
410 __m128i cmpMaskCarry = _mm_set1_epi32(0);
411 while (cur < tailStartPtr) {
412 uintptr_t a = cur;
413 uintptr_t b = cur + 16;
414 const TValue* result =
415 Check2x2x16Bytes<TValue>(needle1, needle2, a, b, &cmpMaskCarry,
416 &cmpMaskCarry, HaystackOverlap::Sequential);
417 if (result) {
418 return result;
420 cur += 32;
423 uint32_t carry = (cur == tailStartPtr) ? 0xffffffff : 0;
424 __m128i wideCarry = Load32BitsIntoXMM(reinterpret_cast<uintptr_t>(&carry));
425 cmpMaskCarry = _mm_and_si128(cmpMaskCarry, wideCarry);
426 uintptr_t a = tailStartPtr;
427 uintptr_t b = tailEndPtr;
428 return Check2x2x16Bytes<TValue>(needle1, needle2, a, b, &cmpMaskCarry,
429 nullptr, HaystackOverlap::Overlapping);
432 const char* SIMD::memchr8SSE2(const char* ptr, char value, size_t length) {
433 // Signed chars are just really annoying to do bit logic with. Convert to
434 // unsigned at the outermost scope so we don't have to worry about it.
435 const unsigned char* uptr = reinterpret_cast<const unsigned char*>(ptr);
436 unsigned char uvalue = static_cast<unsigned char>(value);
437 const unsigned char* uresult =
438 FindInBuffer<unsigned char>(uptr, uvalue, length);
439 return reinterpret_cast<const char*>(uresult);
442 // So, this is a bit awkward. It generally simplifies things if we can just
443 // assume all the AVX2 code is 64-bit, so we have this preprocessor guard
444 // in SIMD_avx2 over all of its actual code, and it also defines versions
445 // of its endpoints that just assert false if the guard is not satisfied.
446 // A 32 bit processor could implement the AVX2 instruction set though, which
447 // would result in it passing the supports_avx2() check and landing in an
448 // assertion failure. Accordingly, we just don't allow that to happen. We
449 // are not particularly concerned about ensuring that newer 32 bit processors
450 // get access to the AVX2 functions exposed here.
451 # if defined(MOZILLA_MAY_SUPPORT_AVX2) && defined(__x86_64__)
453 bool SupportsAVX2() { return supports_avx2(); }
455 # else
457 bool SupportsAVX2() { return false; }
459 # endif
461 const char* SIMD::memchr8(const char* ptr, char value, size_t length) {
462 if (SupportsAVX2()) {
463 return memchr8AVX2(ptr, value, length);
465 return memchr8SSE2(ptr, value, length);
468 const char16_t* SIMD::memchr16SSE2(const char16_t* ptr, char16_t value,
469 size_t length) {
470 return FindInBuffer<char16_t>(ptr, value, length);
473 const char16_t* SIMD::memchr16(const char16_t* ptr, char16_t value,
474 size_t length) {
475 if (SupportsAVX2()) {
476 return memchr16AVX2(ptr, value, length);
478 return memchr16SSE2(ptr, value, length);
481 const uint64_t* SIMD::memchr64(const uint64_t* ptr, uint64_t value,
482 size_t length) {
483 if (SupportsAVX2()) {
484 return memchr64AVX2(ptr, value, length);
486 return FindInBufferNaive<uint64_t>(ptr, value, length);
489 const char* SIMD::memchr2x8(const char* ptr, char v1, char v2, size_t length) {
490 // Signed chars are just really annoying to do bit logic with. Convert to
491 // unsigned at the outermost scope so we don't have to worry about it.
492 const unsigned char* uptr = reinterpret_cast<const unsigned char*>(ptr);
493 unsigned char uv1 = static_cast<unsigned char>(v1);
494 unsigned char uv2 = static_cast<unsigned char>(v2);
495 const unsigned char* uresult =
496 FindTwoInBuffer<unsigned char>(uptr, uv1, uv2, length);
497 return reinterpret_cast<const char*>(uresult);
500 const char16_t* SIMD::memchr2x16(const char16_t* ptr, char16_t v1, char16_t v2,
501 size_t length) {
502 return FindTwoInBuffer<char16_t>(ptr, v1, v2, length);
505 #else
507 const char* SIMD::memchr8(const char* ptr, char value, size_t length) {
508 const void* result = ::memchr(reinterpret_cast<const void*>(ptr),
509 static_cast<int>(value), length);
510 return reinterpret_cast<const char*>(result);
513 const char* SIMD::memchr8SSE2(const char* ptr, char value, size_t length) {
514 return memchr8(ptr, value, length);
517 const char16_t* SIMD::memchr16(const char16_t* ptr, char16_t value,
518 size_t length) {
519 return FindInBufferNaive<char16_t>(ptr, value, length);
522 const char16_t* SIMD::memchr16SSE2(const char16_t* ptr, char16_t value,
523 size_t length) {
524 return memchr16(ptr, value, length);
527 const uint64_t* SIMD::memchr64(const uint64_t* ptr, uint64_t value,
528 size_t length) {
529 return FindInBufferNaive<uint64_t>(ptr, value, length);
532 const char* SIMD::memchr2x8(const char* ptr, char v1, char v2, size_t length) {
533 const char* end = ptr + length - 1;
534 while (ptr < end) {
535 ptr = memchr8(ptr, v1, end - ptr);
536 if (!ptr) {
537 return nullptr;
539 if (ptr[1] == v2) {
540 return ptr;
542 ptr++;
544 return nullptr;
547 const char16_t* SIMD::memchr2x16(const char16_t* ptr, char16_t v1, char16_t v2,
548 size_t length) {
549 const char16_t* end = ptr + length - 1;
550 while (ptr < end) {
551 ptr = memchr16(ptr, v1, end - ptr);
552 if (!ptr) {
553 return nullptr;
555 if (ptr[1] == v2) {
556 return ptr;
558 ptr++;
560 return nullptr;
563 #endif
565 } // namespace mozilla