Skip to content

Commit 7af0733

Browse files
committed
core: implement Pattern<&[T]> for &[T]; get rid of SlicePattern
Implement Haystack<&[T]> and corresponding Pattern<&[T]> for &[T]. That is, provide implementation for searching for subslices in slices. This replaces SlicePattern type. To make use of this new implementations, provide a few new methods on [T] type modelling them after types on str. Specifically, introduce {starts,ends}_with_pattern, find, rfind, [T]::{split,rsplit}_once and trim{,_start,_end}_matches. Note that due to existing starts_with and ends_with methods, the _pattern suffix had to be used. This is unfortunate but the type of starts_with’s argument cannot be changed without affecting type inference and thus breaking the API. This change doesn’t implement functions returning iterators such as split_pattern or matches which in str type are built on top of the Pattern API. Issue: rust-lang#49802 Issue: rust-lang#56345
1 parent 0ff2270 commit 7af0733

File tree

4 files changed

+1294
-351
lines changed

4 files changed

+1294
-351
lines changed

library/core/src/slice/cmp.rs

+265-13
Original file line numberDiff line numberDiff line change
@@ -227,34 +227,286 @@ impl_marker_for!(BytewiseEquality,
227227
u8 i8 u16 i16 u32 i32 u64 i64 u128 i128 usize isize char bool);
228228

229229
pub(super) trait SliceContains: Sized {
230-
fn slice_contains(&self, x: &[Self]) -> bool;
230+
fn slice_contains_element(hs: &[Self], needle: &Self) -> bool;
231+
fn slice_contains_slice(hs: &[Self], needle: &[Self]) -> bool;
231232
}
232233

233234
impl<T> SliceContains for T
234235
where
235236
T: PartialEq,
236237
{
237-
default fn slice_contains(&self, x: &[Self]) -> bool {
238-
x.iter().any(|y| *y == *self)
238+
default fn slice_contains_element(hs: &[Self], needle: &Self) -> bool {
239+
hs.iter().any(|element| *element == *needle)
240+
}
241+
242+
default fn slice_contains_slice(hs: &[Self], needle: &[Self]) -> bool {
243+
default_slice_contains_slice(hs, needle)
239244
}
240245
}
241246

242247
impl SliceContains for u8 {
243248
#[inline]
244-
fn slice_contains(&self, x: &[Self]) -> bool {
245-
memchr::memchr(*self, x).is_some()
249+
fn slice_contains_element(hs: &[Self], needle: &Self) -> bool {
250+
memchr::memchr(*needle, hs).is_some()
251+
}
252+
253+
#[inline]
254+
fn slice_contains_slice(hs: &[Self], needle: &[Self]) -> bool {
255+
if needle.len() <= 32 {
256+
if let Some(result) = simd_contains(hs, needle) {
257+
return result;
258+
}
259+
}
260+
default_slice_contains_slice(hs, needle)
246261
}
247262
}
248263

264+
unsafe fn bytes_of<T>(slice: &[T]) -> &[u8] {
265+
// SAFETY: caller promises that `T` and `u8` have the same memory layout,
266+
// thus casting `x.as_ptr()` as `*const u8` is safe. The `x.as_ptr()` comes
267+
// from a reference and is thus guaranteed to be valid for reads for the
268+
// length of the slice `x.len()`, which cannot be larger than
269+
// `isize::MAX`. The returned slice is never mutated.
270+
unsafe { from_raw_parts(slice.as_ptr() as *const u8, slice.len()) }
271+
}
272+
249273
impl SliceContains for i8 {
250274
#[inline]
251-
fn slice_contains(&self, x: &[Self]) -> bool {
252-
let byte = *self as u8;
253-
// SAFETY: `i8` and `u8` have the same memory layout, thus casting `x.as_ptr()`
254-
// as `*const u8` is safe. The `x.as_ptr()` comes from a reference and is thus guaranteed
255-
// to be valid for reads for the length of the slice `x.len()`, which cannot be larger
256-
// than `isize::MAX`. The returned slice is never mutated.
257-
let bytes: &[u8] = unsafe { from_raw_parts(x.as_ptr() as *const u8, x.len()) };
258-
memchr::memchr(byte, bytes).is_some()
275+
fn slice_contains_element(hs: &[Self], needle: &Self) -> bool {
276+
// SAFETY: i8 and u8 have the same memory layout
277+
u8::slice_contains_element(unsafe { bytes_of(hs) }, &(*needle as u8))
278+
}
279+
280+
#[inline]
281+
fn slice_contains_slice(hs: &[Self], needle: &[Self]) -> bool {
282+
// SAFETY: i8 and u8 have the same memory layout
283+
unsafe { u8::slice_contains_slice(bytes_of(hs), bytes_of(needle)) }
284+
}
285+
}
286+
287+
impl SliceContains for bool {
288+
#[inline]
289+
fn slice_contains_element(hs: &[Self], needle: &Self) -> bool {
290+
// SAFETY: bool and u8 have the same memory layout and all valid bool
291+
// bit patterns are valid u8 bit patterns.
292+
u8::slice_contains_element(unsafe { bytes_of(hs) }, &(*needle as u8))
293+
}
294+
295+
#[inline]
296+
fn slice_contains_slice(hs: &[Self], needle: &[Self]) -> bool {
297+
// SAFETY: bool and u8 have the same memory layout and all valid bool
298+
// bit patterns are valid u8 bit patterns.
299+
unsafe { u8::slice_contains_slice(bytes_of(hs), bytes_of(needle)) }
300+
}
301+
}
302+
303+
fn default_slice_contains_slice<T: PartialEq>(hs: &[T], needle: &[T]) -> bool {
304+
super::pattern::NaiveSearcherState::new(hs.len())
305+
.next_match(hs, needle)
306+
.is_some()
307+
}
308+
309+
310+
/// SIMD search for short needles based on
311+
/// Wojciech Muła's "SIMD-friendly algorithms for substring searching"[0]
312+
///
313+
/// It skips ahead by the vector width on each iteration (rather than the needle length as two-way
314+
/// does) by probing the first and last byte of the needle for the whole vector width
315+
/// and only doing full needle comparisons when the vectorized probe indicated potential matches.
316+
///
317+
/// Since the x86_64 baseline only offers SSE2 we only use u8x16 here.
318+
/// If we ever ship std with for x86-64-v3 or adapt this for other platforms then wider vectors
319+
/// should be evaluated.
320+
///
321+
/// For haystacks smaller than vector-size + needle length it falls back to
322+
/// a naive O(n*m) search so this implementation should not be called on larger needles.
323+
///
324+
/// [0]: http://0x80.pl/articles/simd-strfind.html#sse-avx2
325+
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
326+
#[inline]
327+
fn simd_contains(haystack: &[u8], needle: &[u8]) -> Option<bool> {
328+
debug_assert!(needle.len() > 1);
329+
330+
use crate::ops::BitAnd;
331+
use crate::simd::mask8x16 as Mask;
332+
use crate::simd::u8x16 as Block;
333+
use crate::simd::{SimdPartialEq, ToBitMask};
334+
335+
let first_probe = needle[0];
336+
let last_byte_offset = needle.len() - 1;
337+
338+
// the offset used for the 2nd vector
339+
let second_probe_offset = if needle.len() == 2 {
340+
// never bail out on len=2 needles because the probes will fully cover them and have
341+
// no degenerate cases.
342+
1
343+
} else {
344+
// try a few bytes in case first and last byte of the needle are the same
345+
let Some(second_probe_offset) = (needle.len().saturating_sub(4)..needle.len()).rfind(|&idx| needle[idx] != first_probe) else {
346+
// fall back to other search methods if we can't find any different bytes
347+
// since we could otherwise hit some degenerate cases
348+
return None;
349+
};
350+
second_probe_offset
351+
};
352+
353+
// do a naive search if the haystack is too small to fit
354+
if haystack.len() < Block::LANES + last_byte_offset {
355+
return Some(haystack.windows(needle.len()).any(|c| c == needle));
356+
}
357+
358+
let first_probe: Block = Block::splat(first_probe);
359+
let second_probe: Block = Block::splat(needle[second_probe_offset]);
360+
// first byte are already checked by the outer loop. to verify a match only the
361+
// remainder has to be compared.
362+
let trimmed_needle = &needle[1..];
363+
364+
// this #[cold] is load-bearing, benchmark before removing it...
365+
let check_mask = #[cold]
366+
|idx, mask: u16, skip: bool| -> bool {
367+
if skip {
368+
return false;
369+
}
370+
371+
// and so is this. optimizations are weird.
372+
let mut mask = mask;
373+
374+
while mask != 0 {
375+
let trailing = mask.trailing_zeros();
376+
let offset = idx + trailing as usize + 1;
377+
// SAFETY: mask is between 0 and 15 trailing zeroes, we skip one additional byte that was already compared
378+
// and then take trimmed_needle.len() bytes. This is within the bounds defined by the outer loop
379+
unsafe {
380+
let sub = haystack.get_unchecked(offset..).get_unchecked(..trimmed_needle.len());
381+
if small_slice_eq(sub, trimmed_needle) {
382+
return true;
383+
}
384+
}
385+
mask &= !(1 << trailing);
386+
}
387+
return false;
388+
};
389+
390+
let test_chunk = |idx| -> u16 {
391+
// SAFETY: this requires at least LANES bytes being readable at idx
392+
// that is ensured by the loop ranges (see comments below)
393+
let a: Block = unsafe { haystack.as_ptr().add(idx).cast::<Block>().read_unaligned() };
394+
// SAFETY: this requires LANES + block_offset bytes being readable at idx
395+
let b: Block = unsafe {
396+
haystack.as_ptr().add(idx).add(second_probe_offset).cast::<Block>().read_unaligned()
397+
};
398+
let eq_first: Mask = a.simd_eq(first_probe);
399+
let eq_last: Mask = b.simd_eq(second_probe);
400+
let both = eq_first.bitand(eq_last);
401+
let mask = both.to_bitmask();
402+
403+
return mask;
404+
};
405+
406+
let mut i = 0;
407+
let mut result = false;
408+
// The loop condition must ensure that there's enough headroom to read LANE bytes,
409+
// and not only at the current index but also at the index shifted by block_offset
410+
const UNROLL: usize = 4;
411+
while i + last_byte_offset + UNROLL * Block::LANES < haystack.len() && !result {
412+
let mut masks = [0u16; UNROLL];
413+
for j in 0..UNROLL {
414+
masks[j] = test_chunk(i + j * Block::LANES);
415+
}
416+
for j in 0..UNROLL {
417+
let mask = masks[j];
418+
if mask != 0 {
419+
result |= check_mask(i + j * Block::LANES, mask, result);
420+
}
421+
}
422+
i += UNROLL * Block::LANES;
423+
}
424+
while i + last_byte_offset + Block::LANES < haystack.len() && !result {
425+
let mask = test_chunk(i);
426+
if mask != 0 {
427+
result |= check_mask(i, mask, result);
428+
}
429+
i += Block::LANES;
430+
}
431+
432+
// Process the tail that didn't fit into LANES-sized steps.
433+
// This simply repeats the same procedure but as right-aligned chunk instead
434+
// of a left-aligned one. The last byte must be exactly flush with the string end so
435+
// we don't miss a single byte or read out of bounds.
436+
let i = haystack.len() - last_byte_offset - Block::LANES;
437+
let mask = test_chunk(i);
438+
if mask != 0 {
439+
result |= check_mask(i, mask, result);
440+
}
441+
442+
Some(result)
443+
}
444+
445+
/// Compares short slices for equality.
446+
///
447+
/// It avoids a call to libc's memcmp which is faster on long slices
448+
/// due to SIMD optimizations but it incurs a function call overhead.
449+
///
450+
/// # Safety
451+
///
452+
/// Both slices must have the same length.
453+
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] // only called on x86
454+
#[inline]
455+
unsafe fn small_slice_eq(x: &[u8], y: &[u8]) -> bool {
456+
debug_assert_eq!(x.len(), y.len());
457+
// This function is adapted from
458+
// https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32
459+
460+
// If we don't have enough bytes to do 4-byte at a time loads, then
461+
// fall back to the naive slow version.
462+
//
463+
// Potential alternative: We could do a copy_nonoverlapping combined with a mask instead
464+
// of a loop. Benchmark it.
465+
if x.len() < 4 {
466+
for (&b1, &b2) in x.iter().zip(y) {
467+
if b1 != b2 {
468+
return false;
469+
}
470+
}
471+
return true;
472+
}
473+
// When we have 4 or more bytes to compare, then proceed in chunks of 4 at
474+
// a time using unaligned loads.
475+
//
476+
// Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is
477+
// that this particular version of memcmp is likely to be called with tiny
478+
// needles. That means that if we do 8 byte loads, then a higher proportion
479+
// of memcmp calls will use the slower variant above. With that said, this
480+
// is a hypothesis and is only loosely supported by benchmarks. There's
481+
// likely some improvement that could be made here. The main thing here
482+
// though is to optimize for latency, not throughput.
483+
484+
// SAFETY: Via the conditional above, we know that both `px` and `py`
485+
// have the same length, so `px < pxend` implies that `py < pyend`.
486+
// Thus, derefencing both `px` and `py` in the loop below is safe.
487+
//
488+
// Moreover, we set `pxend` and `pyend` to be 4 bytes before the actual
489+
// end of `px` and `py`. Thus, the final dereference outside of the
490+
// loop is guaranteed to be valid. (The final comparison will overlap with
491+
// the last comparison done in the loop for lengths that aren't multiples
492+
// of four.)
493+
//
494+
// Finally, we needn't worry about alignment here, since we do unaligned
495+
// loads.
496+
unsafe {
497+
let (mut px, mut py) = (x.as_ptr(), y.as_ptr());
498+
let (pxend, pyend) = (px.add(x.len() - 4), py.add(y.len() - 4));
499+
while px < pxend {
500+
let vx = (px as *const u32).read_unaligned();
501+
let vy = (py as *const u32).read_unaligned();
502+
if vx != vy {
503+
return false;
504+
}
505+
px = px.add(4);
506+
py = py.add(4);
507+
}
508+
let vx = (pxend as *const u32).read_unaligned();
509+
let vy = (pyend as *const u32).read_unaligned();
510+
vx == vy
259511
}
260512
}

0 commit comments

Comments
 (0)