Skip to content

Commit a7f0bee

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Added model config for supporting SparseNN, TowerSparseNN, TowerCollectionSparseNN models.
Summary: Refactored the training benchmarking by moving generative helper functinos into separate util file and added a model configuration that supports SparseNN, TowerSparseNN, TowerCollectionSparseNN models. Future commits will add support for DeepFM and DLRM models. Differential Revision: D76539867
1 parent 17cd308 commit a7f0bee

File tree

3 files changed

+649
-504
lines changed

3 files changed

+649
-504
lines changed
Lines changed: 391 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,391 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import copy
11+
from dataclasses import dataclass
12+
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
13+
14+
import torch
15+
import torch.distributed as dist
16+
from torch import nn, optim
17+
from torch.optim import Optimizer
18+
from torchrec.distributed import DistributedModelParallel
19+
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
20+
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
21+
from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR
22+
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
23+
from torchrec.distributed.planner.types import ParameterConstraints
24+
from torchrec.distributed.test_utils.test_input import ModelInput
25+
from torchrec.distributed.test_utils.test_model import (
26+
TestEBCSharder,
27+
TestOverArchLarge,
28+
TestSparseNN,
29+
TestTowerCollectionSparseNN,
30+
TestTowerSparseNN,
31+
)
32+
from torchrec.distributed.train_pipeline import (
33+
TrainPipelineBase,
34+
TrainPipelineFusedSparseDist,
35+
TrainPipelineSparseDist,
36+
)
37+
from torchrec.distributed.train_pipeline.train_pipelines import (
38+
PrefetchTrainPipelineSparseDist,
39+
TrainPipelineSemiSync,
40+
)
41+
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
42+
from torchrec.modules.embedding_configs import EmbeddingBagConfig
43+
44+
45+
@dataclass
46+
class ModelConfig:
47+
model_name: str = "test_sparsenn"
48+
49+
batch_size: int = 8192
50+
num_float_features: int = 10
51+
feature_pooling_avg: int = 10
52+
use_offsets: bool = False
53+
dev_str: str = ""
54+
long_kjt_indices: bool = True
55+
long_kjt_offsets: bool = True
56+
long_kjt_lengths: bool = True
57+
pin_memory: bool = True
58+
59+
def generate_model(
60+
self,
61+
tables: List[EmbeddingBagConfig],
62+
weighted_tables: List[EmbeddingBagConfig],
63+
dense_device: torch.device,
64+
) -> nn.Module:
65+
if self.model_name == "test_sparsenn":
66+
return TestSparseNN(
67+
tables=tables,
68+
weighted_tables=weighted_tables,
69+
dense_device=dense_device,
70+
sparse_device=torch.device("meta"),
71+
over_arch_clazz=TestOverArchLarge,
72+
)
73+
elif self.model_name == "test_tower_sparsenn":
74+
return TestTowerSparseNN(
75+
tables=tables,
76+
weighted_tables=weighted_tables,
77+
dense_device=dense_device,
78+
sparse_device=torch.device("meta"),
79+
num_float_features=self.num_float_features,
80+
)
81+
elif self.model_name == "test_tower_collection_sparsenn":
82+
return TestTowerCollectionSparseNN(
83+
tables=tables,
84+
weighted_tables=weighted_tables,
85+
dense_device=dense_device,
86+
sparse_device=torch.device("meta"),
87+
num_float_features=self.num_float_features,
88+
)
89+
else:
90+
raise RuntimeError(f"Unknown model name: {self.model_name}")
91+
92+
93+
def generate_tables(
94+
num_unweighted_features: int,
95+
num_weighted_features: int,
96+
embedding_feature_dim: int,
97+
) -> Tuple[
98+
List[EmbeddingBagConfig],
99+
List[EmbeddingBagConfig],
100+
]:
101+
"""
102+
Generate embedding bag configurations for both unweighted and weighted features.
103+
104+
This function creates two lists of EmbeddingBagConfig objects:
105+
1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}"
106+
2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}"
107+
108+
For both types, the number of embeddings scales with the feature index,
109+
calculated as max(i + 1, 100) * 1000.
110+
111+
Args:
112+
num_unweighted_features (int): Number of unweighted features to generate.
113+
num_weighted_features (int): Number of weighted features to generate.
114+
embedding_feature_dim (int): Dimension of the embedding vectors.
115+
116+
Returns:
117+
Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing
118+
two lists - the first for unweighted embedding tables and the second for
119+
weighted embedding tables.
120+
"""
121+
tables = [
122+
EmbeddingBagConfig(
123+
num_embeddings=max(i + 1, 100) * 1000,
124+
embedding_dim=embedding_feature_dim,
125+
name="table_" + str(i),
126+
feature_names=["feature_" + str(i)],
127+
)
128+
for i in range(num_unweighted_features)
129+
]
130+
weighted_tables = [
131+
EmbeddingBagConfig(
132+
num_embeddings=max(i + 1, 100) * 1000,
133+
embedding_dim=embedding_feature_dim,
134+
name="weighted_table_" + str(i),
135+
feature_names=["weighted_feature_" + str(i)],
136+
)
137+
for i in range(num_weighted_features)
138+
]
139+
return tables, weighted_tables
140+
141+
142+
def generate_pipeline(
143+
pipeline_type: str,
144+
emb_lookup_stream: str,
145+
model: nn.Module,
146+
opt: torch.optim.Optimizer,
147+
device: torch.device,
148+
) -> Union[TrainPipelineBase, TrainPipelineSparseDist]:
149+
"""
150+
Generate a training pipeline instance based on the configuration.
151+
152+
This function creates and returns the appropriate training pipeline object
153+
based on the pipeline type specified. Different pipeline types are optimized
154+
for different training scenarios.
155+
156+
Args:
157+
pipeline_type (str): The type of training pipeline to use. Options include:
158+
- "base": Basic training pipeline
159+
- "sparse": Pipeline optimized for sparse operations
160+
- "fused": Pipeline with fused sparse distribution
161+
- "semi": Semi-synchronous training pipeline
162+
- "prefetch": Pipeline with prefetching for sparse distribution
163+
emb_lookup_stream (str): The stream to use for embedding lookups.
164+
Only used by certain pipeline types (e.g., "fused").
165+
model (nn.Module): The model to be trained.
166+
opt (torch.optim.Optimizer): The optimizer to use for training.
167+
device (torch.device): The device to run the training on.
168+
169+
Returns:
170+
Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the
171+
appropriate training pipeline class based on the configuration.
172+
173+
Raises:
174+
RuntimeError: If an unknown pipeline type is specified.
175+
"""
176+
177+
_pipeline_cls: Dict[
178+
str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]]
179+
] = {
180+
"base": TrainPipelineBase,
181+
"sparse": TrainPipelineSparseDist,
182+
"fused": TrainPipelineFusedSparseDist,
183+
"semi": TrainPipelineSemiSync,
184+
"prefetch": PrefetchTrainPipelineSparseDist,
185+
}
186+
187+
if pipeline_type == "semi":
188+
return TrainPipelineSemiSync(
189+
model=model, optimizer=opt, device=device, start_batch=0
190+
)
191+
elif pipeline_type == "fused":
192+
return TrainPipelineFusedSparseDist(
193+
model=model,
194+
optimizer=opt,
195+
device=device,
196+
emb_lookup_stream=emb_lookup_stream,
197+
)
198+
elif pipeline_type in _pipeline_cls:
199+
Pipeline = _pipeline_cls[pipeline_type]
200+
return Pipeline(model=model, optimizer=opt, device=device)
201+
else:
202+
raise RuntimeError(f"unknown pipeline option {pipeline_type}")
203+
204+
205+
def generate_planner(
206+
planner_type: str,
207+
topology: Topology,
208+
tables: Optional[List[EmbeddingBagConfig]],
209+
weighted_tables: Optional[List[EmbeddingBagConfig]],
210+
sharding_type: ShardingType,
211+
compute_kernel: EmbeddingComputeKernel,
212+
num_batches: int,
213+
batch_size: int,
214+
pooling_factors: Optional[List[float]],
215+
num_poolings: Optional[List[float]],
216+
) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]:
217+
"""
218+
Generate an embedding sharding planner based on the specified configuration.
219+
220+
Args:
221+
planner_type: Type of planner to use ("embedding" or "hetero")
222+
topology: Network topology for distributed training
223+
tables: List of unweighted embedding tables
224+
weighted_tables: List of weighted embedding tables
225+
sharding_type: Strategy for sharding embedding tables
226+
compute_kernel: Compute kernel to use for embedding tables
227+
num_batches: Number of batches to process
228+
batch_size: Size of each batch
229+
pooling_factors: Pooling factors for each feature of the table
230+
num_poolings: Number of poolings for each feature of the table
231+
232+
Returns:
233+
An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner
234+
235+
Raises:
236+
RuntimeError: If an unknown planner type is specified
237+
"""
238+
# Create parameter constraints for tables
239+
constraints = {}
240+
241+
if pooling_factors is None:
242+
pooling_factors = [POOLING_FACTOR] * num_batches
243+
244+
if num_poolings is None:
245+
num_poolings = [NUM_POOLINGS] * num_batches
246+
247+
batch_sizes = [batch_size] * num_batches
248+
249+
assert (
250+
len(pooling_factors) == num_batches and len(num_poolings) == num_batches
251+
), "The length of pooling_factors and num_poolings must match the number of batches."
252+
253+
if tables is not None:
254+
for table in tables:
255+
constraints[table.name] = ParameterConstraints(
256+
sharding_types=[sharding_type.value],
257+
compute_kernels=[compute_kernel.value],
258+
device_group="cuda",
259+
pooling_factors=pooling_factors,
260+
num_poolings=num_poolings,
261+
batch_sizes=batch_sizes,
262+
)
263+
264+
if weighted_tables is not None:
265+
for table in weighted_tables:
266+
constraints[table.name] = ParameterConstraints(
267+
sharding_types=[sharding_type.value],
268+
compute_kernels=[compute_kernel.value],
269+
device_group="cuda",
270+
pooling_factors=pooling_factors,
271+
num_poolings=num_poolings,
272+
batch_sizes=batch_sizes,
273+
is_weighted=True,
274+
)
275+
276+
if planner_type == "embedding":
277+
return EmbeddingShardingPlanner(
278+
topology=topology,
279+
constraints=constraints if constraints else None,
280+
)
281+
elif planner_type == "hetero":
282+
topology_groups = {"cuda": topology}
283+
return HeteroEmbeddingShardingPlanner(
284+
topology_groups=topology_groups,
285+
constraints=constraints if constraints else None,
286+
)
287+
else:
288+
raise RuntimeError(f"Unknown planner type: {planner_type}")
289+
290+
291+
def generate_sharded_model_and_optimizer(
292+
model: nn.Module,
293+
sharding_type: str,
294+
kernel_type: str,
295+
pg: dist.ProcessGroup,
296+
device: torch.device,
297+
fused_params: Optional[Dict[str, Any]] = None,
298+
planner: Optional[
299+
Union[
300+
EmbeddingShardingPlanner,
301+
HeteroEmbeddingShardingPlanner,
302+
]
303+
] = None,
304+
) -> Tuple[nn.Module, Optimizer]:
305+
# Ensure fused_params is always a dictionary
306+
fused_params_dict = {} if fused_params is None else fused_params
307+
308+
sharder = TestEBCSharder(
309+
sharding_type=sharding_type,
310+
kernel_type=kernel_type,
311+
fused_params=fused_params_dict,
312+
)
313+
sharders = [cast(ModuleSharder[nn.Module], sharder)]
314+
315+
# Use planner if provided
316+
plan = None
317+
if planner is not None:
318+
if pg is not None:
319+
plan = planner.collective_plan(model, sharders, pg)
320+
else:
321+
plan = planner.plan(model, sharders)
322+
323+
sharded_model = DistributedModelParallel(
324+
module=copy.deepcopy(model),
325+
env=ShardingEnv.from_process_group(pg),
326+
init_data_parallel=True,
327+
device=device,
328+
sharders=sharders,
329+
plan=plan,
330+
).to(device)
331+
optimizer = optim.SGD(
332+
[
333+
param
334+
for name, param in sharded_model.named_parameters()
335+
if "sparse" not in name
336+
],
337+
lr=0.1,
338+
)
339+
return sharded_model, optimizer
340+
341+
342+
def generate_data(
343+
model_class_name: str,
344+
tables: List[EmbeddingBagConfig],
345+
weighted_tables: List[EmbeddingBagConfig],
346+
model_config: ModelConfig,
347+
num_batches: int,
348+
) -> List[ModelInput]:
349+
"""
350+
Generate model input data for benchmarking.
351+
352+
Args:
353+
tables: List of unweighted embedding tables
354+
weighted_tables: List of weighted embedding tables
355+
model_config: Configuration for model generation
356+
num_batches: Number of batches to generate
357+
358+
Returns:
359+
A list of ModelInput objects representing the generated batches
360+
"""
361+
device = torch.device(model_config.dev_str) if model_config.dev_str else None
362+
363+
if (
364+
model_class_name == "TestSparseNN"
365+
or model_class_name == "TestTowerSparseNN"
366+
or model_class_name == "TestTowerCollectionSparseNN"
367+
):
368+
return [
369+
ModelInput.generate(
370+
batch_size=model_config.batch_size,
371+
tables=tables,
372+
weighted_tables=weighted_tables,
373+
num_float_features=model_config.num_float_features,
374+
pooling_avg=model_config.feature_pooling_avg,
375+
use_offsets=model_config.use_offsets,
376+
device=device,
377+
indices_dtype=(
378+
torch.int64 if model_config.long_kjt_indices else torch.int32
379+
),
380+
offsets_dtype=(
381+
torch.int64 if model_config.long_kjt_offsets else torch.int32
382+
),
383+
lengths_dtype=(
384+
torch.int64 if model_config.long_kjt_lengths else torch.int32
385+
),
386+
pin_memory=model_config.pin_memory,
387+
)
388+
for _ in range(num_batches)
389+
]
390+
else:
391+
raise RuntimeError(f"Unknown model name: {model_config.model_name}")

0 commit comments

Comments
 (0)