@@ -8,7 +8,10 @@ use core::iter::{Chain, FromIterator, FusedIterator};
8
8
use core:: mem;
9
9
use core:: ops:: { BitAnd , BitOr , BitXor , Sub } ;
10
10
11
- use super :: map:: { self , ConsumeAllOnDrop , DefaultHashBuilder , DrainFilterInner , HashMap , Keys } ;
11
+ use super :: map:: {
12
+ self , make_hash, make_insert_hash, ConsumeAllOnDrop , DefaultHashBuilder , DrainFilterInner ,
13
+ HashMap , Keys , RawEntryMut ,
14
+ } ;
12
15
use crate :: raw:: { Allocator , Global } ;
13
16
14
17
// Future Optimization (FIXME!)
@@ -953,6 +956,12 @@ where
953
956
/// Inserts a value computed from `f` into the set if the given `value` is
954
957
/// not present, then returns a reference to the value in the set.
955
958
///
959
+ /// # Panics
960
+ ///
961
+ /// Panics if the value from the function and the provided lookup value
962
+ /// are not equivalent or have different hashes. See [`Equivalent`]
963
+ /// and [`Hash`] for more information.
964
+ ///
956
965
/// # Examples
957
966
///
958
967
/// ```
@@ -967,20 +976,40 @@ where
967
976
/// assert_eq!(value, pet);
968
977
/// }
969
978
/// assert_eq!(set.len(), 4); // a new "fish" was inserted
979
+ /// assert!(set.contains("fish"));
970
980
/// ```
971
981
#[ cfg_attr( feature = "inline-more" , inline) ]
972
982
pub fn get_or_insert_with < Q : ?Sized , F > ( & mut self , value : & Q , f : F ) -> & T
973
983
where
974
984
Q : Hash + Equivalent < T > ,
975
985
F : FnOnce ( & Q ) -> T ,
976
986
{
987
+ #[ cold]
988
+ #[ inline( never) ]
989
+ fn assert_failed ( ) {
990
+ panic ! (
991
+ "the value from the function and the lookup value \
992
+ must be equivalent and have the same hash"
993
+ ) ;
994
+ }
995
+
977
996
// Although the raw entry gives us `&mut T`, we only return `&T` to be consistent with
978
997
// `get`. Key mutation is "raw" because you're not supposed to affect `Eq` or `Hash`.
979
- self . map
980
- . raw_entry_mut ( )
981
- . from_key ( value)
982
- . or_insert_with ( || ( f ( value) , ( ) ) )
983
- . 0
998
+ let hash = make_hash :: < Q , S > ( & self . map . hash_builder , value) ;
999
+ let raw_entry_builder = self . map . raw_entry_mut ( ) ;
1000
+ match raw_entry_builder. from_key_hashed_nocheck ( hash, value) {
1001
+ RawEntryMut :: Occupied ( entry) => entry. into_key ( ) ,
1002
+ RawEntryMut :: Vacant ( entry) => {
1003
+ let insert_value = f ( value) ;
1004
+ let insert_value_hash = make_insert_hash :: < T , S > ( entry. hasher ( ) , & insert_value) ;
1005
+ if !( hash == insert_value_hash && value. equivalent ( & insert_value) ) {
1006
+ assert_failed ( ) ;
1007
+ }
1008
+ entry
1009
+ . insert_hashed_nocheck ( insert_value_hash, insert_value, ( ) )
1010
+ . 0
1011
+ }
1012
+ }
984
1013
}
985
1014
986
1015
/// Gets the given value's corresponding entry in the set for in-place manipulation.
@@ -2429,7 +2458,7 @@ fn assert_covariance() {
2429
2458
#[ cfg( test) ]
2430
2459
mod test_set {
2431
2460
use super :: super :: map:: DefaultHashBuilder ;
2432
- use super :: HashSet ;
2461
+ use super :: { make_hash , Equivalent , HashSet } ;
2433
2462
use std:: vec:: Vec ;
2434
2463
2435
2464
#[ test]
@@ -2886,4 +2915,88 @@ mod test_set {
2886
2915
set. insert ( i) ;
2887
2916
}
2888
2917
}
2918
+
2919
+ #[ test]
2920
+ fn duplicate_insert ( ) {
2921
+ let mut set = HashSet :: new ( ) ;
2922
+ set. insert ( 1 ) ;
2923
+ set. get_or_insert_with ( & 1 , |_| 1 ) ;
2924
+ set. get_or_insert_with ( & 1 , |_| 1 ) ;
2925
+ assert ! ( [ 1 ] . iter( ) . eq( set. iter( ) ) ) ;
2926
+ }
2927
+
2928
+ #[ test]
2929
+ #[ allow( clippy:: derived_hash_with_manual_eq) ]
2930
+ #[ should_panic]
2931
+ fn some_invalid_hash ( ) {
2932
+ use core:: hash:: { Hash , Hasher } ;
2933
+ #[ derive( Eq , PartialEq ) ]
2934
+ struct Invalid {
2935
+ count : u32 ,
2936
+ }
2937
+
2938
+ struct InvalidRef {
2939
+ count : u32 ,
2940
+ }
2941
+ impl Equivalent < Invalid > for InvalidRef {
2942
+ fn equivalent ( & self , key : & Invalid ) -> bool {
2943
+ self . count == key. count
2944
+ }
2945
+ }
2946
+ impl Hash for Invalid {
2947
+ fn hash < H : Hasher > ( & self , state : & mut H ) {
2948
+ self . count . hash ( state) ;
2949
+ }
2950
+ }
2951
+ impl Hash for InvalidRef {
2952
+ fn hash < H : Hasher > ( & self , state : & mut H ) {
2953
+ let double = self . count * 2 ;
2954
+ double. hash ( state) ;
2955
+ }
2956
+ }
2957
+ let mut set: HashSet < Invalid > = HashSet :: new ( ) ;
2958
+ let key = InvalidRef { count : 1 } ;
2959
+ let value = Invalid { count : 1 } ;
2960
+ if key. equivalent ( & value) {
2961
+ set. get_or_insert_with ( & key, |_| value) ;
2962
+ }
2963
+ }
2964
+
2965
+ #[ test]
2966
+ #[ allow( clippy:: derived_hash_with_manual_eq) ]
2967
+ #[ should_panic]
2968
+ fn some_invalid_equivalent ( ) {
2969
+ use core:: hash:: { Hash , Hasher } ;
2970
+ #[ derive( Eq , PartialEq ) ]
2971
+ struct Invalid {
2972
+ count : u32 ,
2973
+ other : u32 ,
2974
+ }
2975
+
2976
+ struct InvalidRef {
2977
+ count : u32 ,
2978
+ other : u32 ,
2979
+ }
2980
+ impl Equivalent < Invalid > for InvalidRef {
2981
+ fn equivalent ( & self , key : & Invalid ) -> bool {
2982
+ self . count == key. count && self . other == key. other
2983
+ }
2984
+ }
2985
+ impl Hash for Invalid {
2986
+ fn hash < H : Hasher > ( & self , state : & mut H ) {
2987
+ self . count . hash ( state) ;
2988
+ }
2989
+ }
2990
+ impl Hash for InvalidRef {
2991
+ fn hash < H : Hasher > ( & self , state : & mut H ) {
2992
+ self . count . hash ( state) ;
2993
+ }
2994
+ }
2995
+ let mut set: HashSet < Invalid > = HashSet :: new ( ) ;
2996
+ let key = InvalidRef { count : 1 , other : 1 } ;
2997
+ let value = Invalid { count : 1 , other : 2 } ;
2998
+ if make_hash ( set. hasher ( ) , & key) == make_hash ( set. hasher ( ) , & value) {
2999
+ set. get_or_insert_with ( & key, |_| value) ;
3000
+ }
3001
+ }
2889
3002
}
0 commit comments