Skip to content

Commit cc92389

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Refactored benchmarking script for SparseNN training. (#3105)
Summary: Pull Request resolved: #3105 Refactored the training benchmarking by moving generative helper functinos into separate util file. The benchmarking script will later be updated to support general non-sparse models as welll. Reviewed By: aliafzal Differential Revision: D76833400 fbshipit-source-id: 9ca884fbe40676c02b8167ea2edfdb9a362b6f06
1 parent 6351273 commit cc92389

File tree

3 files changed

+617
-504
lines changed

3 files changed

+617
-504
lines changed
Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import copy
11+
from dataclasses import dataclass
12+
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
13+
14+
import torch
15+
import torch.distributed as dist
16+
from torch import nn, optim
17+
from torch.optim import Optimizer
18+
from torchrec.distributed import DistributedModelParallel
19+
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
20+
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
21+
from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR
22+
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
23+
from torchrec.distributed.planner.types import ParameterConstraints
24+
from torchrec.distributed.test_utils.test_input import ModelInput
25+
from torchrec.distributed.test_utils.test_model import (
26+
TestEBCSharder,
27+
TestOverArchLarge,
28+
TestSparseNN,
29+
)
30+
from torchrec.distributed.train_pipeline import (
31+
TrainPipelineBase,
32+
TrainPipelineFusedSparseDist,
33+
TrainPipelineSparseDist,
34+
)
35+
from torchrec.distributed.train_pipeline.train_pipelines import (
36+
PrefetchTrainPipelineSparseDist,
37+
TrainPipelineSemiSync,
38+
)
39+
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
40+
from torchrec.modules.embedding_configs import EmbeddingBagConfig
41+
42+
43+
@dataclass
44+
class ModelConfig:
45+
batch_size: int = 8192
46+
num_float_features: int = 10
47+
feature_pooling_avg: int = 10
48+
use_offsets: bool = False
49+
dev_str: str = ""
50+
long_kjt_indices: bool = True
51+
long_kjt_offsets: bool = True
52+
long_kjt_lengths: bool = True
53+
pin_memory: bool = True
54+
55+
def generate_model(
56+
self,
57+
tables: List[EmbeddingBagConfig],
58+
weighted_tables: List[EmbeddingBagConfig],
59+
dense_device: torch.device,
60+
) -> nn.Module:
61+
return TestSparseNN(
62+
tables=tables,
63+
weighted_tables=weighted_tables,
64+
dense_device=dense_device,
65+
sparse_device=torch.device("meta"),
66+
over_arch_clazz=TestOverArchLarge,
67+
)
68+
69+
70+
def generate_tables(
71+
num_unweighted_features: int,
72+
num_weighted_features: int,
73+
embedding_feature_dim: int,
74+
) -> Tuple[
75+
List[EmbeddingBagConfig],
76+
List[EmbeddingBagConfig],
77+
]:
78+
"""
79+
Generate embedding bag configurations for both unweighted and weighted features.
80+
81+
This function creates two lists of EmbeddingBagConfig objects:
82+
1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}"
83+
2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}"
84+
85+
For both types, the number of embeddings scales with the feature index,
86+
calculated as max(i + 1, 100) * 1000.
87+
88+
Args:
89+
num_unweighted_features (int): Number of unweighted features to generate.
90+
num_weighted_features (int): Number of weighted features to generate.
91+
embedding_feature_dim (int): Dimension of the embedding vectors.
92+
93+
Returns:
94+
Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing
95+
two lists - the first for unweighted embedding tables and the second for
96+
weighted embedding tables.
97+
"""
98+
tables = [
99+
EmbeddingBagConfig(
100+
num_embeddings=max(i + 1, 100) * 1000,
101+
embedding_dim=embedding_feature_dim,
102+
name="table_" + str(i),
103+
feature_names=["feature_" + str(i)],
104+
)
105+
for i in range(num_unweighted_features)
106+
]
107+
weighted_tables = [
108+
EmbeddingBagConfig(
109+
num_embeddings=max(i + 1, 100) * 1000,
110+
embedding_dim=embedding_feature_dim,
111+
name="weighted_table_" + str(i),
112+
feature_names=["weighted_feature_" + str(i)],
113+
)
114+
for i in range(num_weighted_features)
115+
]
116+
return tables, weighted_tables
117+
118+
119+
def generate_pipeline(
120+
pipeline_type: str,
121+
emb_lookup_stream: str,
122+
model: nn.Module,
123+
opt: torch.optim.Optimizer,
124+
device: torch.device,
125+
) -> Union[TrainPipelineBase, TrainPipelineSparseDist]:
126+
"""
127+
Generate a training pipeline instance based on the configuration.
128+
129+
This function creates and returns the appropriate training pipeline object
130+
based on the pipeline type specified. Different pipeline types are optimized
131+
for different training scenarios.
132+
133+
Args:
134+
pipeline_type (str): The type of training pipeline to use. Options include:
135+
- "base": Basic training pipeline
136+
- "sparse": Pipeline optimized for sparse operations
137+
- "fused": Pipeline with fused sparse distribution
138+
- "semi": Semi-synchronous training pipeline
139+
- "prefetch": Pipeline with prefetching for sparse distribution
140+
emb_lookup_stream (str): The stream to use for embedding lookups.
141+
Only used by certain pipeline types (e.g., "fused").
142+
model (nn.Module): The model to be trained.
143+
opt (torch.optim.Optimizer): The optimizer to use for training.
144+
device (torch.device): The device to run the training on.
145+
146+
Returns:
147+
Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the
148+
appropriate training pipeline class based on the configuration.
149+
150+
Raises:
151+
RuntimeError: If an unknown pipeline type is specified.
152+
"""
153+
154+
_pipeline_cls: Dict[
155+
str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]]
156+
] = {
157+
"base": TrainPipelineBase,
158+
"sparse": TrainPipelineSparseDist,
159+
"fused": TrainPipelineFusedSparseDist,
160+
"semi": TrainPipelineSemiSync,
161+
"prefetch": PrefetchTrainPipelineSparseDist,
162+
}
163+
164+
if pipeline_type == "semi":
165+
return TrainPipelineSemiSync(
166+
model=model, optimizer=opt, device=device, start_batch=0
167+
)
168+
elif pipeline_type == "fused":
169+
return TrainPipelineFusedSparseDist(
170+
model=model,
171+
optimizer=opt,
172+
device=device,
173+
emb_lookup_stream=emb_lookup_stream,
174+
)
175+
elif pipeline_type in _pipeline_cls:
176+
Pipeline = _pipeline_cls[pipeline_type]
177+
return Pipeline(model=model, optimizer=opt, device=device)
178+
else:
179+
raise RuntimeError(f"unknown pipeline option {pipeline_type}")
180+
181+
182+
def generate_planner(
183+
planner_type: str,
184+
topology: Topology,
185+
tables: Optional[List[EmbeddingBagConfig]],
186+
weighted_tables: Optional[List[EmbeddingBagConfig]],
187+
sharding_type: ShardingType,
188+
compute_kernel: EmbeddingComputeKernel,
189+
num_batches: int,
190+
batch_size: int,
191+
pooling_factors: Optional[List[float]],
192+
num_poolings: Optional[List[float]],
193+
) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]:
194+
"""
195+
Generate an embedding sharding planner based on the specified configuration.
196+
197+
Args:
198+
planner_type: Type of planner to use ("embedding" or "hetero")
199+
topology: Network topology for distributed training
200+
tables: List of unweighted embedding tables
201+
weighted_tables: List of weighted embedding tables
202+
sharding_type: Strategy for sharding embedding tables
203+
compute_kernel: Compute kernel to use for embedding tables
204+
num_batches: Number of batches to process
205+
batch_size: Size of each batch
206+
pooling_factors: Pooling factors for each feature of the table
207+
num_poolings: Number of poolings for each feature of the table
208+
209+
Returns:
210+
An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner
211+
212+
Raises:
213+
RuntimeError: If an unknown planner type is specified
214+
"""
215+
# Create parameter constraints for tables
216+
constraints = {}
217+
218+
if pooling_factors is None:
219+
pooling_factors = [POOLING_FACTOR] * num_batches
220+
221+
if num_poolings is None:
222+
num_poolings = [NUM_POOLINGS] * num_batches
223+
224+
batch_sizes = [batch_size] * num_batches
225+
226+
assert (
227+
len(pooling_factors) == num_batches and len(num_poolings) == num_batches
228+
), "The length of pooling_factors and num_poolings must match the number of batches."
229+
230+
if tables is not None:
231+
for table in tables:
232+
constraints[table.name] = ParameterConstraints(
233+
sharding_types=[sharding_type.value],
234+
compute_kernels=[compute_kernel.value],
235+
device_group="cuda",
236+
pooling_factors=pooling_factors,
237+
num_poolings=num_poolings,
238+
batch_sizes=batch_sizes,
239+
)
240+
241+
if weighted_tables is not None:
242+
for table in weighted_tables:
243+
constraints[table.name] = ParameterConstraints(
244+
sharding_types=[sharding_type.value],
245+
compute_kernels=[compute_kernel.value],
246+
device_group="cuda",
247+
pooling_factors=pooling_factors,
248+
num_poolings=num_poolings,
249+
batch_sizes=batch_sizes,
250+
is_weighted=True,
251+
)
252+
253+
if planner_type == "embedding":
254+
return EmbeddingShardingPlanner(
255+
topology=topology,
256+
constraints=constraints if constraints else None,
257+
)
258+
elif planner_type == "hetero":
259+
topology_groups = {"cuda": topology}
260+
return HeteroEmbeddingShardingPlanner(
261+
topology_groups=topology_groups,
262+
constraints=constraints if constraints else None,
263+
)
264+
else:
265+
raise RuntimeError(f"Unknown planner type: {planner_type}")
266+
267+
268+
def generate_sharded_model_and_optimizer(
269+
model: nn.Module,
270+
sharding_type: str,
271+
kernel_type: str,
272+
pg: dist.ProcessGroup,
273+
device: torch.device,
274+
fused_params: Optional[Dict[str, Any]] = None,
275+
planner: Optional[
276+
Union[
277+
EmbeddingShardingPlanner,
278+
HeteroEmbeddingShardingPlanner,
279+
]
280+
] = None,
281+
) -> Tuple[nn.Module, Optimizer]:
282+
# Ensure fused_params is always a dictionary
283+
fused_params_dict = {} if fused_params is None else fused_params
284+
285+
sharder = TestEBCSharder(
286+
sharding_type=sharding_type,
287+
kernel_type=kernel_type,
288+
fused_params=fused_params_dict,
289+
)
290+
sharders = [cast(ModuleSharder[nn.Module], sharder)]
291+
292+
# Use planner if provided
293+
plan = None
294+
if planner is not None:
295+
if pg is not None:
296+
plan = planner.collective_plan(model, sharders, pg)
297+
else:
298+
plan = planner.plan(model, sharders)
299+
300+
sharded_model = DistributedModelParallel(
301+
module=copy.deepcopy(model),
302+
env=ShardingEnv.from_process_group(pg),
303+
init_data_parallel=True,
304+
device=device,
305+
sharders=sharders,
306+
plan=plan,
307+
).to(device)
308+
optimizer = optim.SGD(
309+
[
310+
param
311+
for name, param in sharded_model.named_parameters()
312+
if "sparse" not in name
313+
],
314+
lr=0.1,
315+
)
316+
return sharded_model, optimizer
317+
318+
319+
def generate_data(
320+
tables: List[EmbeddingBagConfig],
321+
weighted_tables: List[EmbeddingBagConfig],
322+
model_config: ModelConfig,
323+
num_batches: int,
324+
) -> List[ModelInput]:
325+
"""
326+
Generate model input data for benchmarking.
327+
328+
Args:
329+
tables: List of unweighted embedding tables
330+
weighted_tables: List of weighted embedding tables
331+
model_config: Configuration for model generation
332+
num_batches: Number of batches to generate
333+
334+
Returns:
335+
A list of ModelInput objects representing the generated batches
336+
"""
337+
device = torch.device(model_config.dev_str) if model_config.dev_str else None
338+
339+
return [
340+
ModelInput.generate(
341+
batch_size=model_config.batch_size,
342+
tables=tables,
343+
weighted_tables=weighted_tables,
344+
num_float_features=model_config.num_float_features,
345+
pooling_avg=model_config.feature_pooling_avg,
346+
use_offsets=model_config.use_offsets,
347+
device=device,
348+
indices_dtype=(
349+
torch.int64 if model_config.long_kjt_indices else torch.int32
350+
),
351+
offsets_dtype=(
352+
torch.int64 if model_config.long_kjt_offsets else torch.int32
353+
),
354+
lengths_dtype=(
355+
torch.int64 if model_config.long_kjt_lengths else torch.int32
356+
),
357+
pin_memory=model_config.pin_memory,
358+
)
359+
for _ in range(num_batches)
360+
]

0 commit comments

Comments
 (0)