diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 968d7aa15..2919f4684 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -246,6 +246,12 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: ssd_tbe_params["kvzch_eviction_trigger_mode"] = fused_params.get( "kvzch_eviction_trigger_mode" ) + if "enable_optimizer_offloading" in fused_params: + ssd_tbe_params["enable_optimizer_offloading"] = fused_params.get( + "enable_optimizer_offloading" + ) + else: + ssd_tbe_params["enable_optimizer_offloading"] = False ssd_tbe_params["table_names"] = [table.name for table in config.embedding_tables]