Skip to content

Commit 17cd308

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Enhance ParameterCosntraint configuration in the becnhmarking script
Summary: Updated the `ParameterConstraints` in the TorchRec benchmarking script to include pooling factors, number of poolings, and batch sizes. This enhancement allows for more detailed configuration of embedding tables, improving the flexibility and precision of sharding strategies in distributed training scenarios. Differential Revision: D76440004
1 parent 3b6b537 commit 17cd308

File tree

1 file changed

+36
-3
lines changed

1 file changed

+36
-3
lines changed

torchrec/distributed/benchmark/benchmark_train_sparsenn.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import copy
1313

14-
from dataclasses import dataclass
14+
from dataclasses import dataclass, field
1515
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
1616

1717
import click
@@ -26,6 +26,7 @@
2626
from torchrec.distributed.comm import get_local_size
2727
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
2828
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
29+
from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR
2930
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
3031
from torchrec.distributed.planner.types import ParameterConstraints
3132

@@ -80,6 +81,9 @@ class RunOptions:
8081
planner_type (str): Type of sharding planner to use. Options are:
8182
- "embedding": EmbeddingShardingPlanner (default)
8283
- "hetero": HeteroEmbeddingShardingPlanner
84+
pooling_factors (Optional[List[float]]): Pooling factors for each feature of the table.
85+
This is the average number of values each sample has for the feature.
86+
num_poolings (Optional[List[float]]): Number of poolings for each feature of the table.
8387
"""
8488

8589
world_size: int = 2
@@ -89,6 +93,8 @@ class RunOptions:
8993
input_type: str = "kjt"
9094
profile: str = ""
9195
planner_type: str = "embedding"
96+
pooling_factors: Optional[List[float]] = None
97+
num_poolings: Optional[List[float]] = None
9298

9399

94100
@dataclass
@@ -111,7 +117,7 @@ class EmbeddingTablesConfig:
111117

112118
num_unweighted_features: int = 100
113119
num_weighted_features: int = 100
114-
embedding_feature_dim: int = 512
120+
embedding_feature_dim: int = 128
115121

116122
def generate_tables(
117123
self,
@@ -286,17 +292,36 @@ def _generate_planner(
286292
tables: Optional[List[EmbeddingBagConfig]],
287293
weighted_tables: Optional[List[EmbeddingBagConfig]],
288294
sharding_type: ShardingType,
289-
compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED,
295+
compute_kernel: EmbeddingComputeKernel,
296+
num_batches: int,
297+
batch_size: int,
298+
pooling_factors: Optional[List[float]],
299+
num_poolings: Optional[List[float]],
290300
) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]:
291301
# Create parameter constraints for tables
292302
constraints = {}
293303

304+
if pooling_factors is None:
305+
pooling_factors = [POOLING_FACTOR] * num_batches
306+
307+
if num_poolings is None:
308+
num_poolings = [NUM_POOLINGS] * num_batches
309+
310+
batch_sizes = [batch_size] * num_batches
311+
312+
assert (
313+
len(pooling_factors) == num_batches and len(num_poolings) == num_batches
314+
), "The length of pooling_factors and num_poolings must match the number of batches."
315+
294316
if tables is not None:
295317
for table in tables:
296318
constraints[table.name] = ParameterConstraints(
297319
sharding_types=[sharding_type.value],
298320
compute_kernels=[compute_kernel.value],
299321
device_group="cuda",
322+
pooling_factors=pooling_factors,
323+
num_poolings=num_poolings,
324+
batch_sizes=batch_sizes,
300325
)
301326

302327
if weighted_tables is not None:
@@ -305,6 +330,10 @@ def _generate_planner(
305330
sharding_types=[sharding_type.value],
306331
compute_kernels=[compute_kernel.value],
307332
device_group="cuda",
333+
pooling_factors=pooling_factors,
334+
num_poolings=num_poolings,
335+
batch_sizes=batch_sizes,
336+
is_weighted=True,
308337
)
309338

310339
if planner_type == "embedding":
@@ -413,6 +442,10 @@ def runner(
413442
weighted_tables=weighted_tables,
414443
sharding_type=run_option.sharding_type,
415444
compute_kernel=run_option.compute_kernel,
445+
num_batches=run_option.num_batches,
446+
batch_size=input_config.batch_size,
447+
pooling_factors=run_option.pooling_factors,
448+
num_poolings=run_option.num_poolings,
416449
)
417450

418451
sharded_model, optimizer = _generate_sharded_model_and_optimizer(

0 commit comments

Comments
 (0)