@@ -2007,3 +2007,230 @@ def test_kv_opt_state_w_offloading(
20072007 atol = tolerance ,
20082008 rtol = tolerance ,
20092009 )
2010+
2011+ @given (
2012+ ** default_st ,
2013+ num_buckets = st .integers (min_value = 10 , max_value = 15 ),
2014+ )
2015+ @settings (verbosity = Verbosity .verbose , max_examples = MAX_EXAMPLES , deadline = None )
2016+ def test_kv_opt_state_w_offloading (
2017+ self ,
2018+ T : int ,
2019+ D : int ,
2020+ B : int ,
2021+ log_E : int ,
2022+ L : int ,
2023+ weighted : bool ,
2024+ cache_set_scale : float ,
2025+ pooling_mode : PoolingMode ,
2026+ weights_precision : SparseType ,
2027+ output_dtype : SparseType ,
2028+ share_table : bool ,
2029+ trigger_bounds_check : bool ,
2030+ mixed_B : bool ,
2031+ num_buckets : int ,
2032+ ) -> None :
2033+ # Constants
2034+ lr = 0.5
2035+ eps = 0.2
2036+ ssd_shards = 2
2037+
2038+ trigger_bounds_check = False # don't stimulate boundary check cases
2039+ assume (not weighted or pooling_mode == PoolingMode .SUM )
2040+ assume (not mixed_B or pooling_mode != PoolingMode .NONE )
2041+
2042+ # TODO: check split_optimizer_states when optimizer offloading is ready
2043+ # Generate embedding modules and inputs
2044+ (
2045+ emb ,
2046+ emb_ref ,
2047+ Es ,
2048+ _ ,
2049+ bucket_offsets ,
2050+ bucket_sizes ,
2051+ ) = self .generate_kvzch_tbes (
2052+ T ,
2053+ D ,
2054+ B ,
2055+ log_E ,
2056+ L ,
2057+ weighted ,
2058+ lr = lr ,
2059+ eps = eps ,
2060+ ssd_shards = ssd_shards ,
2061+ cache_set_scale = cache_set_scale ,
2062+ pooling_mode = pooling_mode ,
2063+ weights_precision = weights_precision ,
2064+ output_dtype = output_dtype ,
2065+ share_table = share_table ,
2066+ num_buckets = num_buckets ,
2067+ enable_optimizer_offloading = False ,
2068+ )
2069+
2070+ # Generate inputs
2071+ (
2072+ indices_list ,
2073+ per_sample_weights_list ,
2074+ indices ,
2075+ offsets ,
2076+ per_sample_weights ,
2077+ batch_size_per_feature_per_rank ,
2078+ ) = self .generate_inputs_ (
2079+ B ,
2080+ L ,
2081+ Es ,
2082+ emb .feature_table_map ,
2083+ weights_precision = weights_precision ,
2084+ trigger_bounds_check = trigger_bounds_check ,
2085+ mixed_B = mixed_B ,
2086+ bucket_offsets = bucket_offsets ,
2087+ bucket_sizes = bucket_sizes ,
2088+ is_kv_tbes = True ,
2089+ )
2090+
2091+ # Execute forward
2092+ output_ref_list , output = self .execute_ssd_forward_ (
2093+ emb ,
2094+ emb_ref ,
2095+ indices_list ,
2096+ per_sample_weights_list ,
2097+ indices ,
2098+ offsets ,
2099+ per_sample_weights ,
2100+ B ,
2101+ L ,
2102+ weighted ,
2103+ batch_size_per_feature_per_rank = batch_size_per_feature_per_rank ,
2104+ )
2105+
2106+ # Generate output gradient
2107+ output_grad_list = [torch .randn_like (out ) for out in output_ref_list ]
2108+
2109+ # Execute torch EmbeddingBag backward
2110+ [out .backward (grad ) for (out , grad ) in zip (output_ref_list , output_grad_list )]
2111+ if batch_size_per_feature_per_rank is not None :
2112+ grad_test = self .concat_ref_tensors_vbe (
2113+ output_grad_list , batch_size_per_feature_per_rank
2114+ )
2115+ else :
2116+ grad_test = self .concat_ref_tensors (
2117+ output_grad_list ,
2118+ pooling_mode != PoolingMode .NONE , # do_pooling
2119+ B ,
2120+ D * 4 ,
2121+ )
2122+
2123+ # Execute TBE SSD backward
2124+ output .backward (grad_test )
2125+
2126+ tolerance = (
2127+ 1.0e-4
2128+ if weights_precision == SparseType .FP32 and output_dtype == SparseType .FP32
2129+ else 1.0e-2
2130+ )
2131+
2132+ emb .flush ()
2133+
2134+ # Compare emb state dict with expected values from nn.EmbeddingBag
2135+ emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list = (
2136+ emb .split_embedding_weights (no_snapshot = False , should_flush = True )
2137+ )
2138+ split_optimizer_states = emb .split_optimizer_states (bucket_asc_ids_list )
2139+ table_input_id_range = []
2140+ for t , row in enumerate (Es ):
2141+ bucket_id_start = bucket_offsets [t ][0 ]
2142+ bucket_id_end = bucket_offsets [t ][1 ]
2143+ bucket_size = bucket_sizes [t ]
2144+ table_input_id_range .append (
2145+ (
2146+ min (bucket_id_start * bucket_size , row ),
2147+ min (bucket_id_end * bucket_size , row ),
2148+ )
2149+ )
2150+ # since we use ref_emb in dense format, the rows start from id 0
2151+ self .assertEqual (table_input_id_range [- 1 ][0 ], 0 )
2152+
2153+ # Compare optimizer states
2154+ for f , t in self .get_physical_table_arg_indices_ (emb .feature_table_map ):
2155+ # pyre-fixme[16]: Optional type has no attribute `float`.
2156+ ref_emb = emb_ref [f ].weight .grad .float ().to_dense ().pow (2 ).cpu ()
2157+ ref_optimizer_state = ref_emb .mean (dim = 1 )[
2158+ table_input_id_range [t ][0 ] : min (
2159+ table_input_id_range [t ][1 ], emb_ref [f ].weight .size (0 )
2160+ )
2161+ ]
2162+ ref_kv_opt = ref_optimizer_state [bucket_asc_ids_list [t ]].view (- 1 )
2163+ torch .testing .assert_close (
2164+ split_optimizer_states [t ].float (),
2165+ ref_kv_opt ,
2166+ atol = tolerance ,
2167+ rtol = tolerance ,
2168+ )
2169+
2170+ for feature_index , table_index in self .get_physical_table_arg_indices_ (
2171+ emb .feature_table_map
2172+ ):
2173+ """
2174+ validate bucket_asc_ids_list and num_active_id_per_bucket_list
2175+ """
2176+ bucket_asc_id = bucket_asc_ids_list [table_index ]
2177+ num_active_id_per_bucket = num_active_id_per_bucket_list [table_index ]
2178+
2179+ bucket_id_start = bucket_offsets [table_index ][0 ]
2180+ bucket_id_offsets = torch .ops .fbgemm .asynchronous_complete_cumsum (
2181+ num_active_id_per_bucket .view (- 1 )
2182+ )
2183+ for bucket_idx , id_count in enumerate (num_active_id_per_bucket ):
2184+ bucket_id = bucket_idx + bucket_id_start
2185+ active_id_cnt = 0
2186+ for idx in range (
2187+ bucket_id_offsets [bucket_idx ],
2188+ bucket_id_offsets [bucket_idx + 1 ],
2189+ ):
2190+ # for chunk-based hashing
2191+ self .assertEqual (
2192+ bucket_id , bucket_asc_id [idx ] // bucket_sizes [table_index ]
2193+ )
2194+ active_id_cnt += 1
2195+ self .assertEqual (active_id_cnt , id_count )
2196+
2197+ """
2198+ validate embeddings
2199+ """
2200+ num_ids = len (bucket_asc_ids_list [table_index ])
2201+ emb_r_w = emb_ref [feature_index ].weight [
2202+ bucket_asc_ids_list [table_index ].view (- 1 )
2203+ ]
2204+ emb_r_w_g = (
2205+ emb_ref [feature_index ]
2206+ .weight .grad .float ()
2207+ .to_dense ()[bucket_asc_ids_list [table_index ].view (- 1 )]
2208+ )
2209+ self .assertLess (table_index , len (emb_state_dict_list ))
2210+ assert len (split_optimizer_states [table_index ]) == num_ids
2211+ opt = split_optimizer_states [table_index ]
2212+ new_ref_weight = torch .addcdiv (
2213+ emb_r_w .float (),
2214+ value = - lr ,
2215+ tensor1 = emb_r_w_g ,
2216+ tensor2 = opt .float ()
2217+ .sqrt_ ()
2218+ .add_ (eps )
2219+ .view (
2220+ num_ids ,
2221+ 1 ,
2222+ )
2223+ .cuda (),
2224+ ).cpu ()
2225+
2226+ emb_w = (
2227+ emb_state_dict_list [table_index ]
2228+ .narrow (0 , 0 , bucket_asc_ids_list [table_index ].size (0 ))
2229+ .float ()
2230+ )
2231+ torch .testing .assert_close (
2232+ emb_w ,
2233+ new_ref_weight ,
2234+ atol = tolerance ,
2235+ rtol = tolerance ,
2236+ )
0 commit comments