From 156bb82f491829df444ac0c9be6401303221cf60 Mon Sep 17 00:00:00 2001 From: Isuru Janith Ranawaka Date: Tue, 15 Jul 2025 08:56:11 -0700 Subject: [PATCH] Modify sharding_single_rank_test_single_process method to test E2E deterministic results when random_seed is set Summary: refactoring and calling `sharding_single_rank_test_single_process` twice - storing the predictions / module sharding and comparing to ensure are the same Differential Revision: D78348074 --- .../distributed/test_utils/test_sharding.py | 555 +++++++++++------- 1 file changed, 331 insertions(+), 224 deletions(-) diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index b78159f37..085e97feb 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -111,6 +111,288 @@ def _gather_all_tensors( return all_local_tensors +def _verify_and_gather_predictions_for_single_run( + pg: dist.ProcessGroup, + device: torch.device, + rank: int, + world_size: int, + model_class: TestSparseNNBase, + embedding_groups: Dict[str, List[str]], + tables: List[EmbeddingTableConfig], + sharders: List[ModuleSharder[nn.Module]], + optim: EmbOptimType, + weighted_tables: Optional[List[EmbeddingTableConfig]] = None, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + local_size: Optional[int] = None, + qcomms_config: Optional[QCommsConfig] = None, + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ] = None, + variable_batch_size: bool = False, # variable batch per rank + batch_size: int = 4, + feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, + variable_batch_per_feature: bool = False, # VBE + global_constant_batch: bool = False, + world_size_2D: Optional[int] = None, # 2D parallel + node_group_size: Optional[int] = None, # 2D parallel + use_inter_host_allreduce: bool = False, # 2D parallel + input_type: str = "kjt", # "kjt" or "td" + allow_zero_batch_size: bool = False, + custom_all_reduce: bool = False, # 2D parallel + submodule_configs: Optional[List[DMPCollectionConfig]] = None, + use_offsets: bool = False, + indices_dtype: torch.dtype = torch.int64, + offsets_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, + random_seed: Optional[int] = None, +) -> Optional[torch.Tensor]: + batch_size = random.randint(0, batch_size) if allow_zero_batch_size else batch_size + # Generate model & inputs. + (global_model, inputs) = gen_model_and_input( + model_class=model_class, + tables=tables, + # pyre-ignore [6] + generate=( + cast( + VariableBatchModelInputCallable, + ModelInput.generate_variable_batch_input, + ) + if variable_batch_per_feature + else ModelInput.generate + ), + weighted_tables=weighted_tables, + embedding_groups=embedding_groups, + world_size=world_size, + num_float_features=16, + variable_batch_size=variable_batch_size, + batch_size=batch_size, + feature_processor_modules=feature_processor_modules, + global_constant_batch=global_constant_batch, + input_type=input_type, + use_offsets=use_offsets, + indices_dtype=indices_dtype, + offsets_dtype=offsets_dtype, + lengths_dtype=lengths_dtype, + random_seed=random_seed, + ) + + global_model = global_model.to(device) + global_input = inputs[0][0].to(device) + local_input = inputs[0][1][rank].to(device) + + # Shard model. + local_model = model_class( + tables=cast(List[BaseEmbeddingConfig], tables), + weighted_tables=cast(List[BaseEmbeddingConfig], weighted_tables), + embedding_groups=embedding_groups, + dense_device=device, + sparse_device=torch.device("meta"), + num_float_features=16, + feature_processor_modules=feature_processor_modules, + ) + + global_model_named_params_as_dict = dict(global_model.named_parameters()) + local_model_named_params_as_dict = dict(local_model.named_parameters()) + # Registers a hook to update parameters in the backward pass, when gradients are computed. + if apply_optimizer_in_backward_config is not None: + for apply_optim_name, ( + optimizer_type, + optimizer_kwargs, + ) in apply_optimizer_in_backward_config.items(): + for name, param in global_model_named_params_as_dict.items(): + if apply_optim_name not in name: + continue + assert name in local_model_named_params_as_dict + local_param = local_model_named_params_as_dict[name] + apply_optimizer_in_backward( + optimizer_type, + [param], + optimizer_kwargs, + ) + apply_optimizer_in_backward( + optimizer_type, [local_param], optimizer_kwargs + ) + + # For 2D parallelism, we use single group world size and local world size + planner = EmbeddingShardingPlanner( + topology=Topology( + world_size=world_size_2D if world_size_2D else world_size, + compute_device=device.type, + local_world_size=node_group_size if node_group_size else local_size, + ), + constraints=constraints, + ) + plan: ShardingPlan = planner.collective_plan(local_model, sharders, pg) + + if submodule_configs is not None: + # Dynamic 2D parallel, create a new plan for each submodule + for config in submodule_configs: + planner = EmbeddingShardingPlanner( + topology=Topology( + world_size=config.sharding_group_size, + compute_device=device.type, + local_world_size=config.node_group_size, + ), + constraints=constraints, + ) + sharder = TestEmbeddingCollectionSharder( + sharding_type=ShardingType.ROW_WISE.value, + kernel_type=EmbeddingComputeKernel.FUSED.value, + qcomms_config=qcomms_config, + ) + sharders.append(sharder) # pyre-ignore[6] + config.plan = planner.collective_plan( + local_model, [sharder], pg # pyre-ignore[6] + ) + + """ + Simulating multiple nodes on a single node. However, metadata information and + tensor placement must still be consistent. Here we overwrite this to do so. + + NOTE: + inter/intra process groups should still behave as expected. + + TODO: may need to add some checks that only does this if we're running on a + single GPU (which should be most cases). + """ + for group in plan.plan: + for _, parameter_sharding in cast( + EmbeddingModuleShardingPlan, plan.plan[group] + ).items(): + if ( + parameter_sharding.sharding_type + in { + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ShardingType.GRID_SHARD.value, + } + and device.type != "cpu" + ): + sharding_spec = parameter_sharding.sharding_spec + if sharding_spec is not None: + # pyre-ignore + for shard in sharding_spec.shards: + placement = shard.placement + rank: Optional[int] = placement.rank() + assert rank is not None + shard.placement = torch.distributed._remote_device( + f"rank:{rank}/cuda:{rank}" + ) + + hook_called: bool = False + if world_size_2D is not None: + all_reduce_func = None + if custom_all_reduce: + all_reduce_pg: dist.ProcessGroup = create_device_mesh_for_2D( + use_inter_host_allreduce, + world_size=world_size, + local_size=world_size_2D, + ).get_group(mesh_dim="replicate") + + def _custom_hook(input: List[torch.Tensor]) -> None: + nonlocal hook_called + opts = dist.AllreduceCoalescedOptions() + opts.reduceOp = dist.ReduceOp.AVG + handle = all_reduce_pg.allreduce_coalesced(input, opts=opts) + handle.wait() + hook_called = True + + all_reduce_func = _custom_hook + + local_model = DMPCollection( + module=local_model, + sharding_group_size=world_size_2D, + world_size=world_size, + global_pg=pg, + node_group_size=node_group_size, + plan=plan, + sharders=sharders, + device=device, + use_inter_host_allreduce=use_inter_host_allreduce, + custom_all_reduce=all_reduce_func, + submodule_configs=submodule_configs, + ) + else: + local_model = DistributedModelParallel( + local_model, + env=ShardingEnv.from_process_group(pg), + plan=plan, + sharders=sharders, + device=device, + ) + + dense_optim = KeyedOptimizerWrapper( + dict(in_backward_optimizer_filter(local_model.named_parameters())), + lambda params: torch.optim.SGD(params, lr=0.1), + ) + local_opt = CombinedOptimizer([local_model.fused_optimizer, dense_optim]) + + # Load model state from the global model. + copy_state_dict( + local_model.state_dict(), + global_model.state_dict(), + exclude_predfix="sparse.pooled_embedding_arch.embedding_modules._itp_iter", + ) + alter_global_ebc_dtype(global_model) + + # Run a single training step of the sharded model. + local_pred = gen_full_pred_after_one_step( + local_model, + local_opt, + local_input, + ) + + if world_size_2D is not None and custom_all_reduce: + assert hook_called, "custom all reduce hook was not called" + + # TODO: support non-sharded forward with zero batch size KJT + if not allow_zero_batch_size: + all_local_pred = _gather_all_tensors(local_pred, world_size, pg) + + # Run second training step of the unsharded model. + assert optim == EmbOptimType.EXACT_SGD + global_opt = torch.optim.SGD(global_model.parameters(), lr=0.1) + + global_pred = gen_full_pred_after_one_step( + global_model, global_opt, global_input + ) + + # Compare predictions of sharded vs unsharded models. + if qcomms_config is not None: + # With quantized comms, we can relax constraints a bit + rtol = 0.003 + if CommType.FP8 in [ + qcomms_config.forward_precision, + qcomms_config.backward_precision, + ]: + rtol = 0.05 + atol = global_pred.max().item() * rtol + torch.testing.assert_close( + global_pred, torch.cat(all_local_pred), rtol=rtol, atol=atol + ) + elif ( + weighted_tables is not None + and weighted_tables[0].data_type == DataType.FP16 + ): + # we relax this accuracy test because when the embedding table weights is FP16, + # the sharded EBC would upscale the precision to FP32 for the returned embedding + # KJT.weights (FP32) + sharded_EBC (FP16) ==> embeddings (FP32) + # the test uses the unsharded EBC for reference to compare the results, but the unsharded EBC + # uses EmbeddingBags can only handle same precision, i.e., + # KJT.weights (FP32) + unsharded_EBC (FP32) ==> embeddings (FP32) + # therefore, the discrepancy leads to a relaxed tol level. + torch.testing.assert_close( + global_pred, + torch.cat(all_local_pred), + atol=1e-4, # relaxed atol due to FP16 in weights + rtol=1e-4, # relaxed rtol due to FP16 in weights + ) + else: + torch.testing.assert_close(global_pred, torch.cat(all_local_pred)) + + return global_pred + + def create_test_sharder( sharder_type: str, sharding_type: str, @@ -761,249 +1043,74 @@ def sharding_single_rank_test_single_process( lengths_dtype: torch.dtype = torch.int64, random_seed: Optional[int] = None, ) -> None: - batch_size = random.randint(0, batch_size) if allow_zero_batch_size else batch_size - # Generate model & inputs. - (global_model, inputs) = gen_model_and_input( + first_run = _verify_and_gather_predictions_for_single_run( + pg=pg, + device=device, + rank=rank, + world_size=world_size, model_class=model_class, + embedding_groups=embedding_groups, tables=tables, - # pyre-ignore [6] - generate=( - cast( - VariableBatchModelInputCallable, - ModelInput.generate_variable_batch_input, - ) - if variable_batch_per_feature - else ModelInput.generate - ), + sharders=sharders, + optim=optim, weighted_tables=weighted_tables, - embedding_groups=embedding_groups, - world_size=world_size, - num_float_features=16, + constraints=constraints, + local_size=local_size, + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_size=variable_batch_size, batch_size=batch_size, feature_processor_modules=feature_processor_modules, + variable_batch_per_feature=variable_batch_per_feature, global_constant_batch=global_constant_batch, + world_size_2D=world_size_2D, + node_group_size=node_group_size, + use_inter_host_allreduce=use_inter_host_allreduce, input_type=input_type, + allow_zero_batch_size=allow_zero_batch_size, + custom_all_reduce=custom_all_reduce, + submodule_configs=submodule_configs, use_offsets=use_offsets, indices_dtype=indices_dtype, offsets_dtype=offsets_dtype, lengths_dtype=lengths_dtype, random_seed=random_seed, ) - - global_model = global_model.to(device) - global_input = inputs[0][0].to(device) - local_input = inputs[0][1][rank].to(device) - - # Shard model. - local_model = model_class( - tables=cast(List[BaseEmbeddingConfig], tables), - weighted_tables=cast(List[BaseEmbeddingConfig], weighted_tables), - embedding_groups=embedding_groups, - dense_device=device, - sparse_device=torch.device("meta"), - num_float_features=16, - feature_processor_modules=feature_processor_modules, - ) - - global_model_named_params_as_dict = dict(global_model.named_parameters()) - local_model_named_params_as_dict = dict(local_model.named_parameters()) - # Registers a hook to update parameters in the backward pass, when gradients are computed. - if apply_optimizer_in_backward_config is not None: - for apply_optim_name, ( - optimizer_type, - optimizer_kwargs, - ) in apply_optimizer_in_backward_config.items(): - for name, param in global_model_named_params_as_dict.items(): - if apply_optim_name not in name: - continue - assert name in local_model_named_params_as_dict - local_param = local_model_named_params_as_dict[name] - apply_optimizer_in_backward( - optimizer_type, - [param], - optimizer_kwargs, - ) - apply_optimizer_in_backward( - optimizer_type, [local_param], optimizer_kwargs - ) - - # For 2D parallelism, we use single group world size and local world size - planner = EmbeddingShardingPlanner( - topology=Topology( - world_size=world_size_2D if world_size_2D else world_size, - compute_device=device.type, - local_world_size=node_group_size if node_group_size else local_size, - ), - constraints=constraints, - ) - plan: ShardingPlan = planner.collective_plan(local_model, sharders, pg) - - if submodule_configs is not None: - # Dynamic 2D parallel, create a new plan for each submodule - for config in submodule_configs: - planner = EmbeddingShardingPlanner( - topology=Topology( - world_size=config.sharding_group_size, - compute_device=device.type, - local_world_size=config.node_group_size, - ), - constraints=constraints, - ) - sharder = TestEmbeddingCollectionSharder( - sharding_type=ShardingType.ROW_WISE.value, - kernel_type=EmbeddingComputeKernel.FUSED.value, - qcomms_config=qcomms_config, - ) - sharders.append(sharder) # pyre-ignore[6] - config.plan = planner.collective_plan( - local_model, [sharder], pg # pyre-ignore[6] - ) - - """ - Simulating multiple nodes on a single node. However, metadata information and - tensor placement must still be consistent. Here we overwrite this to do so. - - NOTE: - inter/intra process groups should still behave as expected. - - TODO: may need to add some checks that only does this if we're running on a - single GPU (which should be most cases). - """ - for group in plan.plan: - for _, parameter_sharding in cast( - EmbeddingModuleShardingPlan, plan.plan[group] - ).items(): - if ( - parameter_sharding.sharding_type - in { - ShardingType.TABLE_ROW_WISE.value, - ShardingType.TABLE_COLUMN_WISE.value, - ShardingType.GRID_SHARD.value, - } - and device.type != "cpu" - ): - sharding_spec = parameter_sharding.sharding_spec - if sharding_spec is not None: - # pyre-ignore - for shard in sharding_spec.shards: - placement = shard.placement - rank: Optional[int] = placement.rank() - assert rank is not None - shard.placement = torch.distributed._remote_device( - f"rank:{rank}/cuda:{rank}" - ) - - hook_called: bool = False - if world_size_2D is not None: - all_reduce_func = None - if custom_all_reduce: - all_reduce_pg: dist.ProcessGroup = create_device_mesh_for_2D( - use_inter_host_allreduce, - world_size=world_size, - local_size=world_size_2D, - ).get_group(mesh_dim="replicate") - - def _custom_hook(input: List[torch.Tensor]) -> None: - nonlocal hook_called - opts = dist.AllreduceCoalescedOptions() - opts.reduceOp = dist.ReduceOp.AVG - handle = all_reduce_pg.allreduce_coalesced(input, opts=opts) - handle.wait() - hook_called = True - - all_reduce_func = _custom_hook - - local_model = DMPCollection( - module=local_model, - sharding_group_size=world_size_2D, + if random_seed is not None: + second_run = _verify_and_gather_predictions_for_single_run( + pg=pg, + device=device, + rank=rank, world_size=world_size, - global_pg=pg, - node_group_size=node_group_size, - plan=plan, + model_class=model_class, + embedding_groups=embedding_groups, + tables=tables, sharders=sharders, - device=device, + optim=optim, + weighted_tables=weighted_tables, + constraints=constraints, + local_size=local_size, + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + batch_size=batch_size, + feature_processor_modules=feature_processor_modules, + variable_batch_per_feature=variable_batch_per_feature, + global_constant_batch=global_constant_batch, + world_size_2D=world_size_2D, + node_group_size=node_group_size, use_inter_host_allreduce=use_inter_host_allreduce, - custom_all_reduce=all_reduce_func, + input_type=input_type, + allow_zero_batch_size=allow_zero_batch_size, + custom_all_reduce=custom_all_reduce, submodule_configs=submodule_configs, + use_offsets=use_offsets, + indices_dtype=indices_dtype, + offsets_dtype=offsets_dtype, + lengths_dtype=lengths_dtype, + random_seed=random_seed, ) - else: - local_model = DistributedModelParallel( - local_model, - env=ShardingEnv.from_process_group(pg), - plan=plan, - sharders=sharders, - device=device, - ) - - dense_optim = KeyedOptimizerWrapper( - dict(in_backward_optimizer_filter(local_model.named_parameters())), - lambda params: torch.optim.SGD(params, lr=0.1), - ) - local_opt = CombinedOptimizer([local_model.fused_optimizer, dense_optim]) - - # Load model state from the global model. - copy_state_dict( - local_model.state_dict(), - global_model.state_dict(), - exclude_predfix="sparse.pooled_embedding_arch.embedding_modules._itp_iter", - ) - alter_global_ebc_dtype(global_model) - - # Run a single training step of the sharded model. - local_pred = gen_full_pred_after_one_step( - local_model, - local_opt, - local_input, - ) - - if world_size_2D is not None and custom_all_reduce: - assert hook_called, "custom all reduce hook was not called" - - # TODO: support non-sharded forward with zero batch size KJT - if not allow_zero_batch_size: - all_local_pred = _gather_all_tensors(local_pred, world_size, pg) - - # Run second training step of the unsharded model. - assert optim == EmbOptimType.EXACT_SGD - global_opt = torch.optim.SGD(global_model.parameters(), lr=0.1) - - global_pred = gen_full_pred_after_one_step( - global_model, global_opt, global_input - ) - - # Compare predictions of sharded vs unsharded models. - if qcomms_config is not None: - # With quantized comms, we can relax constraints a bit - rtol = 0.003 - if CommType.FP8 in [ - qcomms_config.forward_precision, - qcomms_config.backward_precision, - ]: - rtol = 0.05 - atol = global_pred.max().item() * rtol - torch.testing.assert_close( - global_pred, torch.cat(all_local_pred), rtol=rtol, atol=atol - ) - elif ( - weighted_tables is not None - and weighted_tables[0].data_type == DataType.FP16 - ): - # we relax this accuracy test because when the embedding table weights is FP16, - # the sharded EBC would upscale the precision to FP32 for the returned embedding - # KJT.weights (FP32) + sharded_EBC (FP16) ==> embeddings (FP32) - # the test uses the unsharded EBC for reference to compare the results, but the unsharded EBC - # uses EmbeddingBags can only handle same precision, i.e., - # KJT.weights (FP32) + unsharded_EBC (FP32) ==> embeddings (FP32) - # therefore, the discrepancy leads to a relaxed tol level. - torch.testing.assert_close( - global_pred, - torch.cat(all_local_pred), - atol=1e-4, # relaxed atol due to FP16 in weights - rtol=1e-4, # relaxed rtol due to FP16 in weights - ) - else: - torch.testing.assert_close(global_pred, torch.cat(all_local_pred)) + torch.testing.assert_close(first_run, second_run) def sharding_single_rank_test(