Skip to content

Commit 6517556

Browse files
committed
WIP: Add memchr8 based on candidate phase of Teddy
1 parent 1230fc5 commit 6517556

File tree

9 files changed

+575
-24
lines changed

9 files changed

+575
-24
lines changed

src/arch/all/memchr.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,64 @@ impl<'a, 'h> DoubleEndedIterator for ThreeIter<'a, 'h> {
867867
}
868868
}
869869

870+
/// TODO
871+
#[derive(Clone, Copy, Debug)]
872+
pub struct Eight<'a> {
873+
needles: &'a [u8],
874+
}
875+
876+
impl Eight<'_> {
877+
/// TODO
878+
#[inline]
879+
pub fn new(needles: &[u8]) -> Eight<'_> {
880+
assert!(needles.len() <= 8);
881+
882+
Eight { needles }
883+
}
884+
885+
/// TODO
886+
#[inline]
887+
pub unsafe fn find_raw(
888+
&self,
889+
start: *const u8,
890+
end: *const u8,
891+
) -> Option<*const u8> {
892+
if self.needles.is_empty() || start >= end {
893+
return None;
894+
}
895+
generic::fwd_byte_by_byte(start, end, |b| self.needles.contains(&b))
896+
}
897+
898+
/// TODO
899+
pub fn iter<'a, 'h>(&'a self, haystack: &'h [u8]) -> EightIter<'a, 'h> {
900+
EightIter { searcher: self, it: generic::Iter::new(haystack) }
901+
}
902+
}
903+
904+
/// TODO
905+
#[derive(Clone, Debug)]
906+
pub struct EightIter<'a, 'h> {
907+
searcher: &'a Eight<'a>,
908+
it: generic::Iter<'h>,
909+
}
910+
911+
impl<'a, 'h> Iterator for EightIter<'a, 'h> {
912+
type Item = usize;
913+
914+
#[inline]
915+
fn next(&mut self) -> Option<usize> {
916+
// SAFETY: We rely on the generic iterator to provide valid start
917+
// and end pointers, but we guarantee that any pointer returned by
918+
// 'find_raw' falls within the bounds of the start and end pointer.
919+
unsafe { self.it.next(|s, e| self.searcher.find_raw(s, e)) }
920+
}
921+
922+
#[inline]
923+
fn size_hint(&self) -> (usize, Option<usize>) {
924+
self.it.size_hint()
925+
}
926+
}
927+
870928
/// Return `true` if `x` contains any zero byte.
871929
///
872930
/// That is, this routine treats `x` as a register of 8-bit lanes and returns
@@ -971,6 +1029,15 @@ mod tests {
9711029
)
9721030
}
9731031

1032+
#[test]
1033+
fn forward_eight() {
1034+
crate::tests::memchr::Runner::new(8).forward_iter(
1035+
|haystack, needles| {
1036+
Some(Eight::new(needles).iter(haystack).collect())
1037+
},
1038+
)
1039+
}
1040+
9741041
// This was found by quickcheck in the course of refactoring this crate
9751042
// after memchr 2.5.0.
9761043
#[test]

src/arch/generic/memchr.rs

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,158 @@ impl<V: Vector> Three<V> {
980980
}
981981
}
982982

983+
#[derive(Clone, Copy, Debug)]
984+
pub(crate) struct Eight<'a, V> {
985+
needles: &'a [u8],
986+
lo: V,
987+
hi: V,
988+
}
989+
990+
impl<V: Vector> Eight<'_, V> {
991+
const LOOP_SIZE: usize = 2 * V::BYTES;
992+
993+
#[inline(always)]
994+
pub(crate) unsafe fn new(needles: &[u8]) -> Eight<'_, V> {
995+
assert!(needles.len() <= 8);
996+
997+
debug_assert!(V::BYTES <= 32);
998+
999+
let mut lo = [0; 32];
1000+
let mut hi = [0; 32];
1001+
1002+
for (idx, byte) in needles.iter().enumerate() {
1003+
let lo_nibble = byte & 0xF;
1004+
lo[lo_nibble as usize] |= 1 << idx;
1005+
lo[16 + lo_nibble as usize] |= 1 << idx;
1006+
1007+
let hi_nibble = byte >> 4;
1008+
hi[hi_nibble as usize] |= 1 << idx;
1009+
hi[16 + hi_nibble as usize] |= 1 << idx;
1010+
}
1011+
1012+
let lo = V::load_unaligned(lo.as_ptr());
1013+
let hi = V::load_unaligned(hi.as_ptr());
1014+
1015+
Eight { needles, lo, hi }
1016+
}
1017+
1018+
#[inline(always)]
1019+
pub(crate) fn needles(&self) -> &[u8] {
1020+
self.needles
1021+
}
1022+
1023+
#[inline(always)]
1024+
pub(crate) unsafe fn find_raw(
1025+
&self,
1026+
start: *const u8,
1027+
end: *const u8,
1028+
) -> Option<*const u8> {
1029+
debug_assert!(V::BYTES <= 32, "vector cannot be bigger than 32 bytes");
1030+
1031+
let topos = V::Mask::first_offset;
1032+
let len = end.distance(start);
1033+
debug_assert!(
1034+
len >= V::BYTES,
1035+
"haystack has length {}, but must be at least {}",
1036+
len,
1037+
V::BYTES
1038+
);
1039+
1040+
// Search a possibly unaligned chunk at `start`. This covers any part
1041+
// of the haystack prior to where aligned loads can start.
1042+
if let Some(cur) = self.search_chunk(start, topos) {
1043+
return Some(cur);
1044+
}
1045+
// Set `cur` to the first V-aligned pointer greater than `start`.
1046+
let mut cur = start.add(V::BYTES - (start.as_usize() & V::ALIGN));
1047+
debug_assert!(cur > start && end.sub(V::BYTES) >= start);
1048+
if len >= Self::LOOP_SIZE {
1049+
while cur <= end.sub(Self::LOOP_SIZE) {
1050+
debug_assert_eq!(0, cur.as_usize() % V::BYTES);
1051+
1052+
let chunk_a = V::load_aligned(cur);
1053+
let chunk_b = V::load_aligned(cur.add(V::BYTES));
1054+
1055+
let lo_chunk_a = chunk_a.and(V::splat(0xF));
1056+
let lo_chunk_b = chunk_b.and(V::splat(0xF));
1057+
let lo_hits_a = self.lo.shuffle_bytes(lo_chunk_a);
1058+
let lo_hits_b = self.lo.shuffle_bytes(lo_chunk_b);
1059+
1060+
let hi_chunk_a = chunk_a.shift_8bit_lane_right::<4>();
1061+
let hi_chunk_b = chunk_b.shift_8bit_lane_right::<4>();
1062+
let hi_hits_a = self.hi.shuffle_bytes(hi_chunk_a);
1063+
let hi_hits_b = self.hi.shuffle_bytes(hi_chunk_b);
1064+
1065+
let hits_a = lo_hits_a.and(hi_hits_a);
1066+
let hits_b = lo_hits_b.and(hi_hits_b);
1067+
1068+
let mask_a = hits_a.cmpeq(V::splat(0)).inverted_movemask();
1069+
let mask_b = hits_b.cmpeq(V::splat(0)).inverted_movemask();
1070+
1071+
if mask_a.has_non_zero() {
1072+
let offset = topos(mask_a);
1073+
1074+
return Some(cur.add(offset));
1075+
} else if mask_b.has_non_zero() {
1076+
let offset = topos(mask_b);
1077+
1078+
return Some(cur.add(V::BYTES).add(offset));
1079+
}
1080+
1081+
cur = cur.add(Self::LOOP_SIZE);
1082+
}
1083+
}
1084+
// Handle any leftovers after the aligned loop above. We use unaligned
1085+
// loads here, but I believe we are guaranteed that they are aligned
1086+
// since `cur` is aligned.
1087+
while cur <= end.sub(V::BYTES) {
1088+
debug_assert!(end.distance(cur) >= V::BYTES);
1089+
if let Some(cur) = self.search_chunk(cur, topos) {
1090+
return Some(cur);
1091+
}
1092+
cur = cur.add(V::BYTES);
1093+
}
1094+
// Finally handle any remaining bytes less than the size of V. In this
1095+
// case, our pointer may indeed be unaligned and the load may overlap
1096+
// with the previous one. But that's okay since we know the previous
1097+
// load didn't lead to a match (otherwise we wouldn't be here).
1098+
if cur < end {
1099+
debug_assert!(end.distance(cur) < V::BYTES);
1100+
cur = cur.sub(V::BYTES - end.distance(cur));
1101+
debug_assert_eq!(end.distance(cur), V::BYTES);
1102+
return self.search_chunk(cur, topos);
1103+
}
1104+
None
1105+
}
1106+
1107+
#[inline(always)]
1108+
unsafe fn search_chunk(
1109+
&self,
1110+
cur: *const u8,
1111+
mask_to_offset: impl Fn(V::Mask) -> usize,
1112+
) -> Option<*const u8> {
1113+
let chunk = V::load_unaligned(cur);
1114+
1115+
let lo_chunk = chunk.and(V::splat(0xF));
1116+
let lo_hits = self.lo.shuffle_bytes(lo_chunk);
1117+
1118+
let hi_chunk = chunk.shift_8bit_lane_right::<4>();
1119+
let hi_hits = self.hi.shuffle_bytes(hi_chunk);
1120+
1121+
let hits = lo_hits.and(hi_hits);
1122+
1123+
let mask = hits.cmpeq(V::splat(0)).inverted_movemask();
1124+
1125+
if mask.has_non_zero() {
1126+
let offset = mask_to_offset(mask);
1127+
1128+
return Some(cur.add(offset));
1129+
}
1130+
1131+
None
1132+
}
1133+
}
1134+
9831135
/// An iterator over all occurrences of a set of bytes in a haystack.
9841136
///
9851137
/// This iterator implements the routines necessary to provide a

src/arch/x86_64/avx2/memchr.rs

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,133 @@ impl<'a, 'h> DoubleEndedIterator for ThreeIter<'a, 'h> {
12731273

12741274
impl<'a, 'h> core::iter::FusedIterator for ThreeIter<'a, 'h> {}
12751275

1276+
/// TODO
1277+
#[derive(Clone, Copy, Debug)]
1278+
pub struct Eight<'a> {
1279+
sse2: generic::Eight<'a, __m128i>,
1280+
avx2: generic::Eight<'a, __m256i>,
1281+
}
1282+
1283+
impl Eight<'_> {
1284+
/// TODO
1285+
#[inline]
1286+
pub fn new(needles: &[u8]) -> Option<Eight<'_>> {
1287+
if Eight::is_available() {
1288+
// SAFETY: we check that avx2 is available above.
1289+
unsafe { Some(Eight::new_unchecked(needles)) }
1290+
} else {
1291+
None
1292+
}
1293+
}
1294+
1295+
/// TODO
1296+
#[target_feature(enable = "avx2")]
1297+
#[inline]
1298+
pub unsafe fn new_unchecked(needles: &[u8]) -> Eight<'_> {
1299+
Eight {
1300+
sse2: generic::Eight::new(needles),
1301+
avx2: generic::Eight::new(needles),
1302+
}
1303+
}
1304+
1305+
/// TODO
1306+
#[inline]
1307+
pub fn is_available() -> bool {
1308+
#[cfg(target_feature = "avx2")]
1309+
{
1310+
true
1311+
}
1312+
#[cfg(not(target_feature = "avx2"))]
1313+
{
1314+
#[cfg(feature = "std")]
1315+
{
1316+
std::is_x86_feature_detected!("avx2")
1317+
}
1318+
#[cfg(not(feature = "std"))]
1319+
{
1320+
false
1321+
}
1322+
}
1323+
}
1324+
1325+
/// TODO
1326+
#[inline]
1327+
pub unsafe fn find_raw(
1328+
&self,
1329+
start: *const u8,
1330+
end: *const u8,
1331+
) -> Option<*const u8> {
1332+
if self.avx2.needles().is_empty() || start >= end {
1333+
return None;
1334+
}
1335+
let len = end.distance(start);
1336+
if len < __m256i::BYTES {
1337+
return if len < __m128i::BYTES {
1338+
// SAFETY: We require the caller to pass valid start/end
1339+
// pointers.
1340+
return generic::fwd_byte_by_byte(start, end, |b| {
1341+
self.sse2.needles().contains(&b)
1342+
});
1343+
} else {
1344+
// SAFETY: We require the caller to pass valid start/end
1345+
// pointers.
1346+
self.find_raw_sse2(start, end)
1347+
};
1348+
}
1349+
self.find_raw_avx2(start, end)
1350+
}
1351+
1352+
#[target_feature(enable = "ssse3")]
1353+
#[inline]
1354+
unsafe fn find_raw_sse2(
1355+
&self,
1356+
start: *const u8,
1357+
end: *const u8,
1358+
) -> Option<*const u8> {
1359+
self.sse2.find_raw(start, end)
1360+
}
1361+
1362+
#[target_feature(enable = "avx2")]
1363+
#[inline]
1364+
unsafe fn find_raw_avx2(
1365+
&self,
1366+
start: *const u8,
1367+
end: *const u8,
1368+
) -> Option<*const u8> {
1369+
self.avx2.find_raw(start, end)
1370+
}
1371+
1372+
/// TODO
1373+
#[inline]
1374+
pub fn iter<'a, 'h>(&'a self, haystack: &'h [u8]) -> EightIter<'a, 'h> {
1375+
EightIter { searcher: self, it: generic::Iter::new(haystack) }
1376+
}
1377+
}
1378+
1379+
/// TODO
1380+
#[derive(Clone, Debug)]
1381+
pub struct EightIter<'a, 'h> {
1382+
searcher: &'a Eight<'a>,
1383+
it: generic::Iter<'h>,
1384+
}
1385+
1386+
impl<'a, 'h> Iterator for EightIter<'a, 'h> {
1387+
type Item = usize;
1388+
1389+
#[inline]
1390+
fn next(&mut self) -> Option<usize> {
1391+
// SAFETY: We rely on the generic iterator to provide valid start
1392+
// and end pointers, but we guarantee that any pointer returned by
1393+
// 'find_raw' falls within the bounds of the start and end pointer.
1394+
unsafe { self.it.next(|s, e| self.searcher.find_raw(s, e)) }
1395+
}
1396+
1397+
#[inline]
1398+
fn size_hint(&self) -> (usize, Option<usize>) {
1399+
self.it.size_hint()
1400+
}
1401+
}
1402+
12761403
#[cfg(test)]
12771404
mod tests {
12781405
use super::*;
@@ -1349,4 +1476,13 @@ mod tests {
13491476
},
13501477
)
13511478
}
1479+
1480+
#[test]
1481+
fn forward_eight() {
1482+
crate::tests::memchr::Runner::new(8).forward_iter(
1483+
|haystack, needles| {
1484+
Some(Eight::new(needles)?.iter(haystack).collect())
1485+
},
1486+
)
1487+
}
13521488
}

0 commit comments

Comments
 (0)