Skip to content

Commit e62add5

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Added basic planner to benchmark_train_sparsenn (#3069)
Summary: Pull Request resolved: #3069 Created an EmbeddingShardingPlanner in the runner function after generating the unsharded model and modified _generate_sharded_model_and_optimizer to accept and use this planner. This change enables optimized sharding of embedding tables based on the topology. Reviewed By: aliafzal Differential Revision: D76188112 fbshipit-source-id: 997e7ae2587118c4899f39350d07dd521ee54105
1 parent b93a0d7 commit e62add5

File tree

1 file changed

+26
-7
lines changed

1 file changed

+26
-7
lines changed

torchrec/distributed/benchmark/benchmark_train_sparsenn.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,16 @@
2323
from torch.optim import Optimizer
2424
from torchrec.distributed import DistributedModelParallel
2525
from torchrec.distributed.benchmark.benchmark_utils import benchmark_func, cmd_conf
26+
from torchrec.distributed.comm import get_local_size
2627
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
28+
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
2729

2830
from torchrec.distributed.test_utils.multi_process import (
2931
MultiProcessContext,
3032
run_multi_process_func,
3133
)
3234
from torchrec.distributed.test_utils.test_input import (
3335
ModelInput,
34-
TdModelInput,
3536
TestSparseNNInputConfig,
3637
)
3738
from torchrec.distributed.test_utils.test_model import (
@@ -278,23 +279,31 @@ def _generate_sharded_model_and_optimizer(
278279
pg: dist.ProcessGroup,
279280
device: torch.device,
280281
fused_params: Optional[Dict[str, Any]] = None,
282+
planner: Optional[EmbeddingShardingPlanner] = None,
281283
) -> Tuple[nn.Module, Optimizer]:
282284
sharder = TestEBCSharder(
283285
sharding_type=sharding_type,
284286
kernel_type=kernel_type,
285287
fused_params=fused_params,
286288
)
289+
290+
sharders = [cast(ModuleSharder[nn.Module], sharder)]
291+
292+
# Use planner if provided
293+
plan = None
294+
if planner is not None:
295+
if pg is not None:
296+
plan = planner.collective_plan(model, sharders, pg)
297+
else:
298+
plan = planner.plan(model, sharders)
299+
287300
sharded_model = DistributedModelParallel(
288301
module=copy.deepcopy(model),
289302
env=ShardingEnv.from_process_group(pg),
290303
init_data_parallel=True,
291304
device=device,
292-
sharders=[
293-
cast(
294-
ModuleSharder[nn.Module],
295-
sharder,
296-
)
297-
],
305+
sharders=sharders,
306+
plan=plan,
298307
).to(device)
299308
optimizer = optim.SGD(
300309
[
@@ -334,6 +343,15 @@ def runner(
334343
dense_device=ctx.device,
335344
)
336345

346+
# Create a planner for sharding
347+
planner = EmbeddingShardingPlanner(
348+
topology=Topology(
349+
local_world_size=get_local_size(world_size),
350+
world_size=world_size,
351+
compute_device=ctx.device.type,
352+
)
353+
)
354+
337355
sharded_model, optimizer = _generate_sharded_model_and_optimizer(
338356
model=unsharded_model,
339357
sharding_type=run_option.sharding_type.value,
@@ -345,6 +363,7 @@ def runner(
345363
"optimizer": EmbOptimType.EXACT_ADAGRAD,
346364
"learning_rate": 0.1,
347365
},
366+
planner=planner,
348367
)
349368
bench_inputs = _generate_data(
350369
tables=tables,

0 commit comments

Comments
 (0)