@@ -76,7 +76,7 @@ class IterData:
76
76
77
77
@dataclass
78
78
class KVZCHCachedData :
79
- cached_optimizer_state_per_table : List [torch .Tensor ]
79
+ cached_optimizer_states_per_table : List [List [ torch .Tensor ] ]
80
80
cached_weight_tensor_per_table : List [torch .Tensor ]
81
81
cached_id_tensor_per_table : List [torch .Tensor ]
82
82
cached_bucket_splits : List [torch .Tensor ]
@@ -2393,18 +2393,10 @@ def split_optimizer_states(
2393
2393
# init for checkpointing loading
2394
2394
assert (
2395
2395
self ._cached_kvzch_data is not None
2396
- and self ._cached_kvzch_data .cached_optimizer_state_per_table
2396
+ and self ._cached_kvzch_data .cached_optimizer_states_per_table
2397
2397
), "optimizer state is not initialized for load checkpointing"
2398
2398
2399
- # NOTE: This is a temporary hack to have split_optimizer_states return a
2400
- # List[List[Tensor]] instead of List[Tensor] to match the behavior of
2401
- # _split_optimizer_states_non_kv_zch. This should be removed after
2402
- # proper support for multiple optimizers is added for the
2403
- # enable_optimizer_offloading=True case.
2404
- return [
2405
- [opt ]
2406
- for opt in self ._cached_kvzch_data .cached_optimizer_state_per_table
2407
- ]
2399
+ return self ._cached_kvzch_data .cached_optimizer_states_per_table
2408
2400
2409
2401
logging .info (
2410
2402
f"split_optimizer_states for KV ZCH: { no_snapshot = } , { should_flush = } "
@@ -2559,25 +2551,15 @@ def get_optimizer_state(
2559
2551
should_flush : bool = False ,
2560
2552
) -> List [Dict [str , torch .Tensor ]]:
2561
2553
"""
2562
- Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD
2563
- so only momentum1 state is returned.
2554
+ Returns a list of dictionaries of optimizer states split by table.
2564
2555
"""
2565
- states_list = self .split_optimizer_states (
2556
+ states_list : List [ List [ Tensor ]] = self .split_optimizer_states (
2566
2557
sorted_id_tensor = sorted_id_tensor ,
2567
2558
no_snapshot = no_snapshot ,
2568
2559
should_flush = should_flush ,
2569
2560
)
2570
-
2571
- if self .optimizer == OptimType .EXACT_ROWWISE_ADAGRAD :
2572
- keys = ["momentum1" ]
2573
- elif self .optimizer == OptimType .PARTIAL_ROWWISE_ADAM :
2574
- keys = ["momentum1" , "momentum2" ]
2575
- else :
2576
- raise NotImplementedError (
2577
- f"Getting optimizer states is not supported for { self .optimizer } "
2578
- )
2579
-
2580
- return [dict (zip (keys , states )) for states in states_list ]
2561
+ state_names = self .optimizer .state_names ()
2562
+ return [dict (zip (state_names , states )) for states in states_list ]
2581
2563
2582
2564
@torch .jit .export
2583
2565
def debug_split_embedding_weights (self ) -> List [torch .Tensor ]:
@@ -2829,6 +2811,94 @@ def split_embedding_weights(
2829
2811
2830
2812
return (pmt_splits , bucket_sorted_id_splits , active_id_cnt_per_bucket_split )
2831
2813
2814
+ @torch .jit .ignore
2815
+ def _apply_state_dict_w_offloading (self ) -> None :
2816
+ # Row count per table
2817
+ (rows , _ ) = zip (* self .embedding_specs )
2818
+ # Cumulative row counts per table for rowwise states
2819
+ row_count_cumsum : List [int ] = [0 ] + list (itertools .accumulate (rows ))
2820
+
2821
+ for t , _ in enumerate (self .embedding_specs ):
2822
+ # pyre-ignore [16]
2823
+ bucket_id_start , _ = self .kv_zch_params .bucket_offsets [t ]
2824
+ # pyre-ignore [16]
2825
+ bucket_size = self .kv_zch_params .bucket_sizes [t ]
2826
+ row_offset = row_count_cumsum [t ] - bucket_id_start * bucket_size
2827
+
2828
+ # pyre-ignore [16]
2829
+ weight_state = self ._cached_kvzch_data .cached_weight_tensor_per_table [t ]
2830
+ # pyre-ignore [16]
2831
+ opt_states = self ._cached_kvzch_data .cached_optimizer_states_per_table [t ]
2832
+
2833
+ self .streaming_write_weight_and_id_per_table (
2834
+ weight_state ,
2835
+ opt_states ,
2836
+ # pyre-ignore [16]
2837
+ self ._cached_kvzch_data .cached_id_tensor_per_table [t ],
2838
+ row_offset ,
2839
+ )
2840
+ self ._cached_kvzch_data .cached_weight_tensor_per_table [t ] = None
2841
+ self ._cached_kvzch_data .cached_optimizer_states_per_table [t ] = None
2842
+
2843
+ @torch .jit .ignore
2844
+ def _apply_state_dict_no_offloading (self ) -> None :
2845
+ # Row count per table
2846
+ (rows , _ ) = zip (* self .embedding_specs )
2847
+ # Cumulative row counts per table for rowwise states
2848
+ row_count_cumsum : List [int ] = [0 ] + list (itertools .accumulate (rows ))
2849
+
2850
+ def copy_optimizer_state_ (dst : Tensor , src : Tensor , indices : Tensor ) -> None :
2851
+ device = dst .device
2852
+ dst .index_put_ (
2853
+ indices = (
2854
+ # indices is expected to be a tuple of Tensors, not Tensor
2855
+ indices .to (device ).view (- 1 ),
2856
+ ),
2857
+ values = src .to (device ),
2858
+ )
2859
+
2860
+ for t , _ in enumerate (rows ):
2861
+ # pyre-ignore [16]
2862
+ bucket_id_start , _ = self .kv_zch_params .bucket_offsets [t ]
2863
+ # pyre-ignore [16]
2864
+ bucket_size = self .kv_zch_params .bucket_sizes [t ]
2865
+ row_offset = row_count_cumsum [t ] - bucket_id_start * bucket_size
2866
+
2867
+ # pyre-ignore [16]
2868
+ weights = self ._cached_kvzch_data .cached_weight_tensor_per_table [t ]
2869
+ # pyre-ignore [16]
2870
+ ids = self ._cached_kvzch_data .cached_id_tensor_per_table [t ]
2871
+ local_ids = ids + row_offset
2872
+
2873
+ logging .info (
2874
+ f"applying sd for table { t } without optimizer offloading, local_ids is { local_ids } "
2875
+ )
2876
+ # pyre-ignore [16]
2877
+ opt_states = self ._cached_kvzch_data .cached_optimizer_states_per_table [t ]
2878
+
2879
+ # Set up the plan for copying optimizer states over
2880
+ if self .optimizer == OptimType .EXACT_ROWWISE_ADAGRAD :
2881
+ mapping = [(opt_states [0 ], self .momentum1_dev )]
2882
+ elif self .optimizer == OptimType .PARTIAL_ROWWISE_ADAM :
2883
+ mapping = [
2884
+ (opt_states [0 ], self .momentum1_dev ),
2885
+ (opt_states [1 ], self .momentum2_dev ),
2886
+ ]
2887
+ else :
2888
+ mapping = []
2889
+
2890
+ # Execute the plan and copy the optimizer states over
2891
+ # pyre-ignore [6]
2892
+ [copy_optimizer_state_ (dst , src , local_ids ) for (src , dst ) in mapping ]
2893
+
2894
+ self .ssd_db .set_cuda (
2895
+ local_ids .view (- 1 ),
2896
+ weights ,
2897
+ torch .as_tensor (local_ids .size (0 )),
2898
+ 1 ,
2899
+ False ,
2900
+ )
2901
+
2832
2902
@torch .jit .ignore
2833
2903
def apply_state_dict (self ) -> None :
2834
2904
if self .backend_return_whole_row :
@@ -2844,7 +2914,7 @@ def apply_state_dict(self) -> None:
2844
2914
assert self .kv_zch_params is not None , "apply_state_dict supports KV ZCH only"
2845
2915
assert (
2846
2916
self ._cached_kvzch_data is not None
2847
- and self ._cached_kvzch_data .cached_optimizer_state_per_table is not None
2917
+ and self ._cached_kvzch_data .cached_optimizer_states_per_table is not None
2848
2918
), "optimizer state is not initialized for load checkpointing"
2849
2919
assert (
2850
2920
self ._cached_kvzch_data .cached_weight_tensor_per_table is not None
@@ -2855,51 +2925,11 @@ def apply_state_dict(self) -> None:
2855
2925
# optimizer state, round to the nearest 4
2856
2926
# optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
2857
2927
# apply weight and optimizer state per table
2858
- table_offset = 0
2859
- for i , (emb_height , _ ) in enumerate (self .embedding_specs ):
2860
- # pyre-ignore [16]
2861
- bucket_id_start , _ = self .kv_zch_params .bucket_offsets [i ]
2862
- # pyre-ignore [16]
2863
- bucket_size = self .kv_zch_params .bucket_sizes [i ]
2864
- row_offset = table_offset - bucket_id_start * bucket_size
2928
+ if self .enable_optimizer_offloading :
2929
+ self ._apply_state_dict_w_offloading ()
2930
+ else :
2931
+ self ._apply_state_dict_no_offloading ()
2865
2932
2866
- if self .enable_optimizer_offloading :
2867
- # pyre-ignore [16]
2868
- weight_state = self ._cached_kvzch_data .cached_weight_tensor_per_table [i ]
2869
- # pyre-ignore [16]
2870
- opt_state = self ._cached_kvzch_data .cached_optimizer_state_per_table [i ]
2871
- self .streaming_write_weight_and_id_per_table (
2872
- weight_state ,
2873
- [opt_state ],
2874
- # pyre-ignore [16]
2875
- self ._cached_kvzch_data .cached_id_tensor_per_table [i ],
2876
- row_offset ,
2877
- )
2878
- self ._cached_kvzch_data .cached_weight_tensor_per_table [i ] = None
2879
- self ._cached_kvzch_data .cached_optimizer_state_per_table [i ] = None
2880
- else :
2881
- weight = self ._cached_kvzch_data .cached_weight_tensor_per_table [i ]
2882
- id = self ._cached_kvzch_data .cached_id_tensor_per_table [i ]
2883
- local_id = id + row_offset
2884
- logging .info (
2885
- f"applying sd for table { i } without optimizer offloading, local_id is { local_id } "
2886
- )
2887
- opt_state = self ._cached_kvzch_data .cached_optimizer_state_per_table [i ]
2888
- t_device = self .momentum1_dev .device
2889
- self .momentum1_dev .index_put_ (
2890
- indices = (
2891
- local_id .to (t_device ).view (- 1 ),
2892
- ), # expects tuple of tensors
2893
- values = opt_state .to (t_device ),
2894
- )
2895
- self .ssd_db .set_cuda (
2896
- local_id .view (- 1 ),
2897
- weight ,
2898
- torch .as_tensor (local_id .size (0 )),
2899
- 1 ,
2900
- False ,
2901
- )
2902
- table_offset += emb_height
2903
2933
self .clear_cache ()
2904
2934
2905
2935
@torch .jit .ignore
@@ -3001,22 +3031,33 @@ def enable_load_state_dict_mode(self) -> None:
3001
3031
3002
3032
dtype = self .weights_precision .as_dtype ()
3003
3033
self ._cached_kvzch_data = KVZCHCachedData ([], [], [], [])
3004
- for i , (_ , emb_dim ) in enumerate (self .embedding_specs ):
3005
- # for checkpointing loading, we need to store the weight and id tensor temporarily in memory
3034
+
3035
+ for i , _ in enumerate (self .embedding_specs ):
3036
+ # For checkpointing loading, we need to store the weight and id
3037
+ # tensor temporarily in memory. First check that the local_weight_counts
3038
+ # are properly set before even initializing the optimizer states
3006
3039
assert (
3007
3040
self .local_weight_counts [i ] > 0
3008
3041
), f"local_weight_counts for table { i } is not set"
3042
+
3043
+ # pyre-ignore [16]
3044
+ self ._cached_kvzch_data .cached_optimizer_states_per_table = (
3045
+ self .optimizer .empty_states (
3046
+ self .embedding_specs ,
3047
+ self .optimizer_state_dtypes ,
3048
+ self .local_weight_counts ,
3049
+ )
3050
+ )
3051
+
3052
+ for i , (_ , emb_dim ) in enumerate (self .embedding_specs ):
3009
3053
# pyre-ignore [16]
3010
3054
bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [i ]
3011
3055
rows = self .local_weight_counts [i ]
3012
3056
weight_state = torch .empty (rows , emb_dim , dtype = dtype , device = "cpu" )
3013
- opt_state = torch .empty (rows , dtype = torch .float32 , device = "cpu" )
3014
3057
# pyre-ignore [16]
3015
3058
self ._cached_kvzch_data .cached_weight_tensor_per_table .append (weight_state )
3016
- # pyre-ignore [16]
3017
- self ._cached_kvzch_data .cached_optimizer_state_per_table .append (opt_state )
3018
3059
logging .info (
3019
- f"for checkpoint loading, table { i } , weight_state shape is { weight_state .shape } , opt_state shape is { opt_state . shape } "
3060
+ f"for checkpoint loading, table { i } , weight_state shape is { weight_state .shape } "
3020
3061
)
3021
3062
id_tensor = torch .zeros ((rows , 1 ), dtype = torch .int64 , device = "cpu" )
3022
3063
# pyre-ignore [16]
0 commit comments