Skip to content

Commit

Permalink
Add boundary check for SSE2/AVX2 brace match code.
Browse files Browse the repository at this point in the history
  • Loading branch information
zufuliu committed Nov 25, 2024
1 parent 21e16e6 commit a632725
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 50 deletions.
50 changes: 30 additions & 20 deletions scintilla/src/Document.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -2960,37 +2960,42 @@ Sci::Position Document::BraceMatch(Sci::Position position, Sci::Position /*maxRe
position = useStartPos ? startPos : NextPosition(position, direction);
const Sci::Position length = LengthNoExcept();
int depth = 1;
if (chBrace <= asciiBackwardSafeChar && IsValidIndex(position + 32*direction, length)) {
if (chBrace <= asciiBackwardSafeChar && IsValidIndex(position + 64*direction, length)) {
#if NP2_USE_AVX2
if (direction >= 0) {
const SplitView cbView = cb.AllView();
const __m256i mmBrace = mm256_set1_epi8(chBrace);
const __m256i mmSeek = mm256_set1_epi8(chSeek);
const Sci::Position maxPos = length - 2*sizeof(__m256i);
const Sci::Position segmentEndPos = std::min<Sci::Position>(maxPos, cbView.length1 - 1);
do {
const bool scanFirst = IsValidIndex(position, cbView.length1);
const Sci::Position segmentLength = scanFirst ? cbView.length1 : length;
const Sci::Position segmentLength = cbView.length1;
const bool scanFirst = IsValidIndex(position, segmentLength);
const Sci::Position endPos = scanFirst ? segmentEndPos : maxPos;
const char * const segment = scanFirst ? cbView.segment1 : cbView.segment2;
const __m256i *ptr = reinterpret_cast<const __m256i *>(segment + position);
Sci::Position index = position;
uint32_t mask = 0;
uint64_t mask = 0;
do {
const __m256i chunk1 = _mm256_loadu_si256(ptr);
const __m256i chunk2 = _mm256_loadu_si256(ptr + 1);
mask = mm256_movemask_epi8(_mm256_or_si256(_mm256_cmpeq_epi8(chunk1, mmBrace), _mm256_cmpeq_epi8(chunk1, mmSeek)));
mask |= static_cast<uint64_t>(mm256_movemask_epi8(_mm256_or_si256(_mm256_cmpeq_epi8(chunk2, mmBrace), _mm256_cmpeq_epi8(chunk2, mmSeek)))) << sizeof(__m256i);
if (mask != 0) {
index = position;
position += 2*sizeof(__m256i);
break;
}
ptr++;
position += sizeof(mmBrace);
} while (position < segmentLength);
position += sizeof(mmBrace);
if (position >= segmentLength && index < segmentLength) {
ptr += 2;
position += 2*sizeof(__m256i);
} while (position <= endPos);
if (position > segmentLength && index < segmentLength) {
position = segmentLength;
const uint32_t offset = static_cast<uint32_t>(position - index);
mask = bit_zero_high_u32(mask, offset);
mask = bit_zero_high_u64(mask, offset);
}
while (mask) {
const uint32_t trailing = np2::ctz(mask);
const uint64_t trailing = np2::ctz(mask);
index += trailing;
mask >>= trailing;
if (index > GetEndStyled() || StyleIndexAt(index) == styBrace) {
Expand All @@ -3003,33 +3008,38 @@ Sci::Position Document::BraceMatch(Sci::Position position, Sci::Position /*maxRe
index++;
mask >>= 1;
}
} while (position < length);
} while (position <= maxPos);
}
// end NP2_USE_AVX2
#elif NP2_USE_SSE2
if (direction >= 0) {
const SplitView cbView = cb.AllView();
const __m128i mmBrace = _mm_set1_epi8(chBrace);
const __m128i mmSeek = _mm_set1_epi8(chSeek);
const Sci::Position maxPos = length - 2*sizeof(__m128i);
const Sci::Position segmentEndPos = std::min<Sci::Position>(maxPos, cbView.length1 - 1);
do {
const bool scanFirst = IsValidIndex(position, cbView.length1);
const Sci::Position segmentLength = scanFirst ? cbView.length1 : length;
const Sci::Position segmentLength = cbView.length1;
const bool scanFirst = IsValidIndex(position, segmentLength);
const Sci::Position endPos = scanFirst ? segmentEndPos : maxPos;
const char * const segment = scanFirst ? cbView.segment1 : cbView.segment2;
const __m128i *ptr = reinterpret_cast<const __m128i *>(segment + position);
Sci::Position index = position;
uint32_t mask = 0;
do {
const __m128i chunk1 = _mm_loadu_si128(ptr);
const __m128i chunk2 = _mm_loadu_si128(ptr + 1);
mask = mm_movemask_epi8(_mm_or_si128(_mm_cmpeq_epi8(chunk1, mmBrace), _mm_cmpeq_epi8(chunk1, mmSeek)));
mask |= mm_movemask_epi8(_mm_or_si128(_mm_cmpeq_epi8(chunk2, mmBrace), _mm_cmpeq_epi8(chunk2, mmSeek))) << sizeof(__m128i);
if (mask != 0) {
index = position;
position += 2*sizeof(__m128i);
break;
}
ptr++;
position += sizeof(mmBrace);
} while (position < segmentLength);
position += sizeof(mmBrace);
if (position >= segmentLength && index < segmentLength) {
ptr += 2;
position += 2*sizeof(__m128i);
} while (position <= endPos);
if (position > segmentLength && index < segmentLength) {
position = segmentLength;
const uint32_t offset = static_cast<uint32_t>(position - index);
mask = bit_zero_high_u32(mask, offset);
Expand All @@ -3048,7 +3058,7 @@ Sci::Position Document::BraceMatch(Sci::Position position, Sci::Position /*maxRe
index++;
mask >>= 1;
}
} while (position < length);
} while (position <= maxPos);
}
// end NP2_USE_SSE2
#endif
Expand Down
76 changes: 46 additions & 30 deletions scintilla/test/BraceMatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
// cl /EHsc /std:c++20 /DNDEBUG /O2 /FAcs /GS- /GR- /Gv /W4 /arch:AVX2 BraceMatchTest.cpp
// clang-cl /EHsc /std:c++20 /DNDEBUG /O2 /FA /GS- /GR- /Gv /W4 -march=x86-64-v3 BraceMatchTest.cpp
// g++ -S -std=gnu++20 -DNDEBUG -O3 -fno-rtti -Wall -Wextra -march=x86-64-v3 BraceMatchTest.cpp
template <typename T>
constexpr T min(T x, T y) noexcept {
return (x < y) ? x : y;
}
constexpr bool IsValidIndex(size_t index, size_t length) noexcept {
return index < length;
}
Expand All @@ -31,38 +35,43 @@ struct SplitView {
};
constexpr char chBrace = '{';
constexpr char chSeek = '}';
constexpr uint32_t maxLength = 256;
constexpr uint32_t maxLength = 512;

void FindAllBraceForward(const SplitView &cbView, ptrdiff_t position, const ptrdiff_t length, uint32_t (&result)[maxLength]) noexcept {
unsigned j = 0;
#if NP2_USE_AVX2
const __m256i mmBrace = _mm256_set1_epi8(chBrace);
const __m256i mmSeek = _mm256_set1_epi8(chSeek);
while (position < length) {
const bool scanFirst = IsValidIndex(position, cbView.length1);
const ptrdiff_t segmentLength = scanFirst ? cbView.length1 : length;
const ptrdiff_t maxPos = length - 2*sizeof(__m256i);
const ptrdiff_t segmentLength = cbView.length1;
const ptrdiff_t segmentEndPos = min(maxPos, segmentLength - 1);
while (position <= maxPos) {
const bool scanFirst = IsValidIndex(position, segmentLength);
const ptrdiff_t endPos = scanFirst ? segmentEndPos : maxPos;
const char * const segment = scanFirst ? cbView.segment1 : cbView.segment2;
const __m256i *ptr = reinterpret_cast<const __m256i *>(segment + position);
ptrdiff_t index = position;
uint32_t mask = 0;
uint64_t mask = 0;
do {
const __m256i chunk1 = _mm256_loadu_si256(ptr);
mask = _mm256_movemask_epi8(_mm256_or_si256(_mm256_cmpeq_epi8(chunk1, mmBrace), _mm256_cmpeq_epi8(chunk1, mmSeek)));
const __m256i chunk2 = _mm256_loadu_si256(ptr + 1);
mask = mm256_movemask_epi8(_mm256_or_si256(_mm256_cmpeq_epi8(chunk1, mmBrace), _mm256_cmpeq_epi8(chunk1, mmSeek)));
mask |= static_cast<uint64_t>(mm256_movemask_epi8(_mm256_or_si256(_mm256_cmpeq_epi8(chunk2, mmBrace), _mm256_cmpeq_epi8(chunk2, mmSeek)))) << sizeof(__m256i);
if (mask != 0) {
index = position;
position += 2*sizeof(__m256i);
break;
}
ptr++;
position += sizeof(__m256i);
} while (position < segmentLength);
position += sizeof(__m256i);
if (position >= segmentLength && index < segmentLength) {
ptr += 2;
position += 2*sizeof(__m256i);
} while (position <= endPos);
if (position > segmentLength && index < segmentLength) {
position = segmentLength;
const uint32_t offset = static_cast<uint32_t>(position - index);
mask = bit_zero_high_u32(mask, offset);
mask = bit_zero_high_u64(mask, offset);
}
while (mask) {
const uint32_t trailing = np2::ctz(mask);
const uint64_t trailing = np2::ctz(mask);
index += trailing;
mask >>= trailing;
result[j++] = static_cast<uint32_t>(index + 1);
Expand All @@ -74,25 +83,30 @@ void FindAllBraceForward(const SplitView &cbView, ptrdiff_t position, const ptrd
#elif NP2_USE_SSE2
const __m128i mmBrace = _mm_set1_epi8(chBrace);
const __m128i mmSeek = _mm_set1_epi8(chSeek);
while (position < length) {
const bool scanFirst = IsValidIndex(position, cbView.length1);
const ptrdiff_t segmentLength = scanFirst ? cbView.length1 : length;
const ptrdiff_t maxPos = length - 2*sizeof(__m128i);
const ptrdiff_t segmentLength = cbView.length1;
const ptrdiff_t segmentEndPos = min(maxPos, segmentLength - 1);
while (position <= maxPos) {
const bool scanFirst = IsValidIndex(position, segmentLength);
const ptrdiff_t endPos = scanFirst ? segmentEndPos : maxPos;
const char * const segment = scanFirst ? cbView.segment1 : cbView.segment2;
const __m128i *ptr = reinterpret_cast<const __m128i *>(segment + position);
ptrdiff_t index = position;
uint32_t mask = 0;
do {
const __m128i chunk1 = _mm_loadu_si128(ptr);
mask = _mm_movemask_epi8(_mm_or_si128(_mm_cmpeq_epi8(chunk1, mmBrace), _mm_cmpeq_epi8(chunk1, mmSeek)));
const __m128i chunk2 = _mm_loadu_si128(ptr + 1);
mask = mm_movemask_epi8(_mm_or_si128(_mm_cmpeq_epi8(chunk1, mmBrace), _mm_cmpeq_epi8(chunk1, mmSeek)));
mask |= mm_movemask_epi8(_mm_or_si128(_mm_cmpeq_epi8(chunk2, mmBrace), _mm_cmpeq_epi8(chunk2, mmSeek))) << sizeof(__m128i);
if (mask != 0) {
index = position;
position += 2*sizeof(__m128i);
break;
}
ptr++;
position += sizeof(__m128i);
} while (position < segmentLength);
position += sizeof(__m128i);
if (position >= segmentLength && index < segmentLength) {
ptr += 2;
position += 2*sizeof(__m128i);
} while (position <= endPos);
if (position > segmentLength && index < segmentLength) {
position = segmentLength;
const uint32_t offset = static_cast<uint32_t>(position - index);
mask = bit_zero_high_u32(mask, offset);
Expand Down Expand Up @@ -184,10 +198,11 @@ int __cdecl main(int argc, char *argv[]) {
constexpr uint32_t gapLength = 0;
constexpr uint32_t position = 0;
constexpr uint32_t length = 0;
constexpr bool hasGap = gapPosition != 0 && gapLength != 0;
const SplitView cbView {
buffer + padding,
(gapPosition != 0 && gapLength != 0) ? gapPosition : length,
buffer + padding + gapLength,
hasGap ? gapPosition : length,
buffer + padding + (hasGap ? gapLength : 0),
length,
};
printf("doc: (%u, %u), gap: (%u, %u)\n", position, length, gapPosition, gapLength);
Expand All @@ -207,18 +222,19 @@ int __cdecl main(int argc, char *argv[]) {
}

const uint32_t value = rand();
const uint32_t gapPosition = value & 127;
const uint32_t gapLength = (value >> 4) & 127;
uint32_t position = (value >> 8) & (maxLength - 1);
const uint32_t length = maxLength - gapLength;
const uint32_t gapPosition = value & (maxLength/2 - 1);
const uint32_t gapLength = (value >> 16) & (maxLength/2 - 1);
uint32_t position = rand() & (maxLength - 1);
const bool hasGap = gapPosition != 0 && gapLength != 0;
const uint32_t length = maxLength - (hasGap ? gapLength : 0);
if (position >= length) {
position = length - 1;
}
memset(buffer + padding + gapPosition, chBrace, gapLength);
const SplitView cbView {
buffer + padding,
(gapPosition != 0 && gapLength != 0) ? gapPosition : length,
buffer + padding + gapLength,
hasGap ? gapPosition : length,
buffer + padding + (hasGap ? gapLength : 0),
length,
};

Expand Down

0 comments on commit a632725

Please sign in to comment.