@@ -227,34 +227,286 @@ impl_marker_for!(BytewiseEquality,
227
227
u8 i8 u16 i16 u32 i32 u64 i64 u128 i128 usize isize char bool ) ;
228
228
229
229
pub ( super ) trait SliceContains : Sized {
230
- fn slice_contains ( & self , x : & [ Self ] ) -> bool ;
230
+ fn slice_contains_element ( hs : & [ Self ] , needle : & Self ) -> bool ;
231
+ fn slice_contains_slice ( hs : & [ Self ] , needle : & [ Self ] ) -> bool ;
231
232
}
232
233
233
234
impl < T > SliceContains for T
234
235
where
235
236
T : PartialEq ,
236
237
{
237
- default fn slice_contains ( & self , x : & [ Self ] ) -> bool {
238
- x. iter ( ) . any ( |y| * y == * self )
238
+ default fn slice_contains_element ( hs : & [ Self ] , needle : & Self ) -> bool {
239
+ hs. iter ( ) . any ( |element| * element == * needle)
240
+ }
241
+
242
+ default fn slice_contains_slice ( hs : & [ Self ] , needle : & [ Self ] ) -> bool {
243
+ default_slice_contains_slice ( hs, needle)
239
244
}
240
245
}
241
246
242
247
impl SliceContains for u8 {
243
248
#[ inline]
244
- fn slice_contains ( & self , x : & [ Self ] ) -> bool {
245
- memchr:: memchr ( * self , x) . is_some ( )
249
+ fn slice_contains_element ( hs : & [ Self ] , needle : & Self ) -> bool {
250
+ memchr:: memchr ( * needle, hs) . is_some ( )
251
+ }
252
+
253
+ #[ inline]
254
+ fn slice_contains_slice ( hs : & [ Self ] , needle : & [ Self ] ) -> bool {
255
+ if needle. len ( ) <= 32 {
256
+ if let Some ( result) = simd_contains ( hs, needle) {
257
+ return result;
258
+ }
259
+ }
260
+ default_slice_contains_slice ( hs, needle)
246
261
}
247
262
}
248
263
264
+ unsafe fn bytes_of < T > ( slice : & [ T ] ) -> & [ u8 ] {
265
+ // SAFETY: caller promises that `T` and `u8` have the same memory layout,
266
+ // thus casting `x.as_ptr()` as `*const u8` is safe. The `x.as_ptr()` comes
267
+ // from a reference and is thus guaranteed to be valid for reads for the
268
+ // length of the slice `x.len()`, which cannot be larger than
269
+ // `isize::MAX`. The returned slice is never mutated.
270
+ unsafe { from_raw_parts ( slice. as_ptr ( ) as * const u8 , slice. len ( ) ) }
271
+ }
272
+
249
273
impl SliceContains for i8 {
250
274
#[ inline]
251
- fn slice_contains ( & self , x : & [ Self ] ) -> bool {
252
- let byte = * self as u8 ;
253
- // SAFETY: `i8` and `u8` have the same memory layout, thus casting `x.as_ptr()`
254
- // as `*const u8` is safe. The `x.as_ptr()` comes from a reference and is thus guaranteed
255
- // to be valid for reads for the length of the slice `x.len()`, which cannot be larger
256
- // than `isize::MAX`. The returned slice is never mutated.
257
- let bytes: & [ u8 ] = unsafe { from_raw_parts ( x. as_ptr ( ) as * const u8 , x. len ( ) ) } ;
258
- memchr:: memchr ( byte, bytes) . is_some ( )
275
+ fn slice_contains_element ( hs : & [ Self ] , needle : & Self ) -> bool {
276
+ // SAFETY: i8 and u8 have the same memory layout
277
+ u8:: slice_contains_element ( unsafe { bytes_of ( hs) } , & ( * needle as u8 ) )
278
+ }
279
+
280
+ #[ inline]
281
+ fn slice_contains_slice ( hs : & [ Self ] , needle : & [ Self ] ) -> bool {
282
+ // SAFETY: i8 and u8 have the same memory layout
283
+ unsafe { u8:: slice_contains_slice ( bytes_of ( hs) , bytes_of ( needle) ) }
284
+ }
285
+ }
286
+
287
+ impl SliceContains for bool {
288
+ #[ inline]
289
+ fn slice_contains_element ( hs : & [ Self ] , needle : & Self ) -> bool {
290
+ // SAFETY: bool and u8 have the same memory layout and all valid bool
291
+ // bit patterns are valid u8 bit patterns.
292
+ u8:: slice_contains_element ( unsafe { bytes_of ( hs) } , & ( * needle as u8 ) )
293
+ }
294
+
295
+ #[ inline]
296
+ fn slice_contains_slice ( hs : & [ Self ] , needle : & [ Self ] ) -> bool {
297
+ // SAFETY: bool and u8 have the same memory layout and all valid bool
298
+ // bit patterns are valid u8 bit patterns.
299
+ unsafe { u8:: slice_contains_slice ( bytes_of ( hs) , bytes_of ( needle) ) }
300
+ }
301
+ }
302
+
303
+ fn default_slice_contains_slice < T : PartialEq > ( hs : & [ T ] , needle : & [ T ] ) -> bool {
304
+ super :: pattern:: NaiveSearcherState :: new ( hs. len ( ) )
305
+ . next_match ( hs, needle)
306
+ . is_some ( )
307
+ }
308
+
309
+
310
+ /// SIMD search for short needles based on
311
+ /// Wojciech Muła's "SIMD-friendly algorithms for substring searching"[0]
312
+ ///
313
+ /// It skips ahead by the vector width on each iteration (rather than the needle length as two-way
314
+ /// does) by probing the first and last byte of the needle for the whole vector width
315
+ /// and only doing full needle comparisons when the vectorized probe indicated potential matches.
316
+ ///
317
+ /// Since the x86_64 baseline only offers SSE2 we only use u8x16 here.
318
+ /// If we ever ship std with for x86-64-v3 or adapt this for other platforms then wider vectors
319
+ /// should be evaluated.
320
+ ///
321
+ /// For haystacks smaller than vector-size + needle length it falls back to
322
+ /// a naive O(n*m) search so this implementation should not be called on larger needles.
323
+ ///
324
+ /// [0]: http://0x80.pl/articles/simd-strfind.html#sse-avx2
325
+ #[ cfg( all( target_arch = "x86_64" , target_feature = "sse2" ) ) ]
326
+ #[ inline]
327
+ fn simd_contains ( haystack : & [ u8 ] , needle : & [ u8 ] ) -> Option < bool > {
328
+ debug_assert ! ( needle. len( ) > 1 ) ;
329
+
330
+ use crate :: ops:: BitAnd ;
331
+ use crate :: simd:: mask8x16 as Mask ;
332
+ use crate :: simd:: u8x16 as Block ;
333
+ use crate :: simd:: { SimdPartialEq , ToBitMask } ;
334
+
335
+ let first_probe = needle[ 0 ] ;
336
+ let last_byte_offset = needle. len ( ) - 1 ;
337
+
338
+ // the offset used for the 2nd vector
339
+ let second_probe_offset = if needle. len ( ) == 2 {
340
+ // never bail out on len=2 needles because the probes will fully cover them and have
341
+ // no degenerate cases.
342
+ 1
343
+ } else {
344
+ // try a few bytes in case first and last byte of the needle are the same
345
+ let Some ( second_probe_offset) = ( needle. len ( ) . saturating_sub ( 4 ) ..needle. len ( ) ) . rfind ( |& idx| needle[ idx] != first_probe) else {
346
+ // fall back to other search methods if we can't find any different bytes
347
+ // since we could otherwise hit some degenerate cases
348
+ return None ;
349
+ } ;
350
+ second_probe_offset
351
+ } ;
352
+
353
+ // do a naive search if the haystack is too small to fit
354
+ if haystack. len ( ) < Block :: LANES + last_byte_offset {
355
+ return Some ( haystack. windows ( needle. len ( ) ) . any ( |c| c == needle) ) ;
356
+ }
357
+
358
+ let first_probe: Block = Block :: splat ( first_probe) ;
359
+ let second_probe: Block = Block :: splat ( needle[ second_probe_offset] ) ;
360
+ // first byte are already checked by the outer loop. to verify a match only the
361
+ // remainder has to be compared.
362
+ let trimmed_needle = & needle[ 1 ..] ;
363
+
364
+ // this #[cold] is load-bearing, benchmark before removing it...
365
+ let check_mask = #[ cold]
366
+ |idx, mask: u16, skip: bool| -> bool {
367
+ if skip {
368
+ return false ;
369
+ }
370
+
371
+ // and so is this. optimizations are weird.
372
+ let mut mask = mask;
373
+
374
+ while mask != 0 {
375
+ let trailing = mask. trailing_zeros ( ) ;
376
+ let offset = idx + trailing as usize + 1 ;
377
+ // SAFETY: mask is between 0 and 15 trailing zeroes, we skip one additional byte that was already compared
378
+ // and then take trimmed_needle.len() bytes. This is within the bounds defined by the outer loop
379
+ unsafe {
380
+ let sub = haystack. get_unchecked ( offset..) . get_unchecked ( ..trimmed_needle. len ( ) ) ;
381
+ if small_slice_eq ( sub, trimmed_needle) {
382
+ return true ;
383
+ }
384
+ }
385
+ mask &= !( 1 << trailing) ;
386
+ }
387
+ return false ;
388
+ } ;
389
+
390
+ let test_chunk = |idx| -> u16 {
391
+ // SAFETY: this requires at least LANES bytes being readable at idx
392
+ // that is ensured by the loop ranges (see comments below)
393
+ let a: Block = unsafe { haystack. as_ptr ( ) . add ( idx) . cast :: < Block > ( ) . read_unaligned ( ) } ;
394
+ // SAFETY: this requires LANES + block_offset bytes being readable at idx
395
+ let b: Block = unsafe {
396
+ haystack. as_ptr ( ) . add ( idx) . add ( second_probe_offset) . cast :: < Block > ( ) . read_unaligned ( )
397
+ } ;
398
+ let eq_first: Mask = a. simd_eq ( first_probe) ;
399
+ let eq_last: Mask = b. simd_eq ( second_probe) ;
400
+ let both = eq_first. bitand ( eq_last) ;
401
+ let mask = both. to_bitmask ( ) ;
402
+
403
+ return mask;
404
+ } ;
405
+
406
+ let mut i = 0 ;
407
+ let mut result = false ;
408
+ // The loop condition must ensure that there's enough headroom to read LANE bytes,
409
+ // and not only at the current index but also at the index shifted by block_offset
410
+ const UNROLL : usize = 4 ;
411
+ while i + last_byte_offset + UNROLL * Block :: LANES < haystack. len ( ) && !result {
412
+ let mut masks = [ 0u16 ; UNROLL ] ;
413
+ for j in 0 ..UNROLL {
414
+ masks[ j] = test_chunk ( i + j * Block :: LANES ) ;
415
+ }
416
+ for j in 0 ..UNROLL {
417
+ let mask = masks[ j] ;
418
+ if mask != 0 {
419
+ result |= check_mask ( i + j * Block :: LANES , mask, result) ;
420
+ }
421
+ }
422
+ i += UNROLL * Block :: LANES ;
423
+ }
424
+ while i + last_byte_offset + Block :: LANES < haystack. len ( ) && !result {
425
+ let mask = test_chunk ( i) ;
426
+ if mask != 0 {
427
+ result |= check_mask ( i, mask, result) ;
428
+ }
429
+ i += Block :: LANES ;
430
+ }
431
+
432
+ // Process the tail that didn't fit into LANES-sized steps.
433
+ // This simply repeats the same procedure but as right-aligned chunk instead
434
+ // of a left-aligned one. The last byte must be exactly flush with the string end so
435
+ // we don't miss a single byte or read out of bounds.
436
+ let i = haystack. len ( ) - last_byte_offset - Block :: LANES ;
437
+ let mask = test_chunk ( i) ;
438
+ if mask != 0 {
439
+ result |= check_mask ( i, mask, result) ;
440
+ }
441
+
442
+ Some ( result)
443
+ }
444
+
445
+ /// Compares short slices for equality.
446
+ ///
447
+ /// It avoids a call to libc's memcmp which is faster on long slices
448
+ /// due to SIMD optimizations but it incurs a function call overhead.
449
+ ///
450
+ /// # Safety
451
+ ///
452
+ /// Both slices must have the same length.
453
+ #[ cfg( all( target_arch = "x86_64" , target_feature = "sse2" ) ) ] // only called on x86
454
+ #[ inline]
455
+ unsafe fn small_slice_eq ( x : & [ u8 ] , y : & [ u8 ] ) -> bool {
456
+ debug_assert_eq ! ( x. len( ) , y. len( ) ) ;
457
+ // This function is adapted from
458
+ // https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32
459
+
460
+ // If we don't have enough bytes to do 4-byte at a time loads, then
461
+ // fall back to the naive slow version.
462
+ //
463
+ // Potential alternative: We could do a copy_nonoverlapping combined with a mask instead
464
+ // of a loop. Benchmark it.
465
+ if x. len ( ) < 4 {
466
+ for ( & b1, & b2) in x. iter ( ) . zip ( y) {
467
+ if b1 != b2 {
468
+ return false ;
469
+ }
470
+ }
471
+ return true ;
472
+ }
473
+ // When we have 4 or more bytes to compare, then proceed in chunks of 4 at
474
+ // a time using unaligned loads.
475
+ //
476
+ // Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is
477
+ // that this particular version of memcmp is likely to be called with tiny
478
+ // needles. That means that if we do 8 byte loads, then a higher proportion
479
+ // of memcmp calls will use the slower variant above. With that said, this
480
+ // is a hypothesis and is only loosely supported by benchmarks. There's
481
+ // likely some improvement that could be made here. The main thing here
482
+ // though is to optimize for latency, not throughput.
483
+
484
+ // SAFETY: Via the conditional above, we know that both `px` and `py`
485
+ // have the same length, so `px < pxend` implies that `py < pyend`.
486
+ // Thus, derefencing both `px` and `py` in the loop below is safe.
487
+ //
488
+ // Moreover, we set `pxend` and `pyend` to be 4 bytes before the actual
489
+ // end of `px` and `py`. Thus, the final dereference outside of the
490
+ // loop is guaranteed to be valid. (The final comparison will overlap with
491
+ // the last comparison done in the loop for lengths that aren't multiples
492
+ // of four.)
493
+ //
494
+ // Finally, we needn't worry about alignment here, since we do unaligned
495
+ // loads.
496
+ unsafe {
497
+ let ( mut px, mut py) = ( x. as_ptr ( ) , y. as_ptr ( ) ) ;
498
+ let ( pxend, pyend) = ( px. add ( x. len ( ) - 4 ) , py. add ( y. len ( ) - 4 ) ) ;
499
+ while px < pxend {
500
+ let vx = ( px as * const u32 ) . read_unaligned ( ) ;
501
+ let vy = ( py as * const u32 ) . read_unaligned ( ) ;
502
+ if vx != vy {
503
+ return false ;
504
+ }
505
+ px = px. add ( 4 ) ;
506
+ py = py. add ( 4 ) ;
507
+ }
508
+ let vx = ( pxend as * const u32 ) . read_unaligned ( ) ;
509
+ let vy = ( pyend as * const u32 ) . read_unaligned ( ) ;
510
+ vx == vy
259
511
}
260
512
}
0 commit comments