Skip to content

Commit 76a0826

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Add documentation to configuration classes and skip envs with not enough GPUs (#3068)
Summary: Pull Request resolved: #3068 1) Add comprehensive docstrings to RunOptions, EmbeddingTablesConfig, and PipelineConfig. 2) Replace direct return with assert statement Reviewed By: TroyGarden, aliafzal Differential Revision: D76160331 fbshipit-source-id: 107c83d8d56bce3f3d327cab9e4571dfd5f21356
1 parent 6b692a6 commit 76a0826

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

torchrec/distributed/benchmark/benchmark_train_sparsenn.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,25 @@
5555

5656
@dataclass
5757
class RunOptions:
58+
"""
59+
Configuration options for running sparse neural network benchmarks.
60+
61+
This class defines the parameters that control how the benchmark is executed,
62+
including distributed training settings, batch configuration, and profiling options.
63+
64+
Args:
65+
world_size (int): Number of processes/GPUs to use for distributed training.
66+
Default is 2.
67+
num_batches (int): Number of batches to process during the benchmark.
68+
Default is 10.
69+
sharding_type (ShardingType): Strategy for sharding embedding tables across devices.
70+
Default is ShardingType.TABLE_WISE (entire tables are placed on single devices).
71+
input_type (str): Type of input format to use for the model.
72+
Default is "kjt" (KeyedJaggedTensor).
73+
profile (str): Directory to save profiling results. If empty, profiling is disabled.
74+
Default is "" (disabled).
75+
"""
76+
5877
world_size: int = 2
5978
num_batches: int = 10
6079
sharding_type: ShardingType = ShardingType.TABLE_WISE
@@ -64,6 +83,22 @@ class RunOptions:
6483

6584
@dataclass
6685
class EmbeddingTablesConfig:
86+
"""
87+
Configuration for embedding tables used in sparse neural network benchmarks.
88+
89+
This class defines the parameters for generating embedding tables with both weighted
90+
and unweighted features. It provides a method to generate the actual embedding bag
91+
configurations that can be used to create embedding tables.
92+
93+
Args:
94+
num_unweighted_features (int): Number of unweighted features to generate.
95+
Default is 100.
96+
num_weighted_features (int): Number of weighted features to generate.
97+
Default is 100.
98+
embedding_feature_dim (int): Dimension of the embedding vectors.
99+
Default is 512.
100+
"""
101+
67102
num_unweighted_features: int = 100
68103
num_weighted_features: int = 100
69104
embedding_feature_dim: int = 512
@@ -74,6 +109,21 @@ def generate_tables(
74109
List[EmbeddingBagConfig],
75110
List[EmbeddingBagConfig],
76111
]:
112+
"""
113+
Generate embedding bag configurations for both unweighted and weighted features.
114+
115+
This method creates two lists of EmbeddingBagConfig objects:
116+
1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}"
117+
2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}"
118+
119+
For both types, the number of embeddings scales with the feature index,
120+
calculated as max(i + 1, 100) * 1000.
121+
122+
Returns:
123+
Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing
124+
two lists - the first for unweighted embedding tables and the second for
125+
weighted embedding tables.
126+
"""
77127
tables = [
78128
EmbeddingBagConfig(
79129
num_embeddings=max(i + 1, 100) * 1000,
@@ -97,12 +147,50 @@ def generate_tables(
97147

98148
@dataclass
99149
class PipelineConfig:
150+
"""
151+
Configuration for training pipelines used in sparse neural network benchmarks.
152+
153+
This class defines the parameters for configuring the training pipeline and provides
154+
a method to generate the appropriate pipeline instance based on the configuration.
155+
156+
Args:
157+
pipeline (str): The type of training pipeline to use. Options include:
158+
- "base": Basic training pipeline
159+
- "sparse": Pipeline optimized for sparse operations
160+
- "fused": Pipeline with fused sparse distribution
161+
- "semi": Semi-synchronous training pipeline
162+
- "prefetch": Pipeline with prefetching for sparse distribution
163+
Default is "base".
164+
emb_lookup_stream (str): The stream to use for embedding lookups.
165+
Only used by certain pipeline types (e.g., "fused").
166+
Default is "data_dist".
167+
"""
168+
100169
pipeline: str = "base"
101170
emb_lookup_stream: str = "data_dist"
102171

103172
def generate_pipeline(
104173
self, model: nn.Module, opt: torch.optim.Optimizer, device: torch.device
105174
) -> Union[TrainPipelineBase, TrainPipelineSparseDist]:
175+
"""
176+
Generate a training pipeline instance based on the configuration.
177+
178+
This method creates and returns the appropriate training pipeline object
179+
based on the pipeline type specified in the configuration. Different
180+
pipeline types are optimized for different training scenarios.
181+
182+
Args:
183+
model (nn.Module): The model to be trained.
184+
opt (torch.optim.Optimizer): The optimizer to use for training.
185+
device (torch.device): The device to run the training on.
186+
187+
Returns:
188+
Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the
189+
appropriate training pipeline class based on the configuration.
190+
191+
Raises:
192+
RuntimeError: If an unknown pipeline type is specified.
193+
"""
106194
_pipeline_cls: Dict[
107195
str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]]
108196
] = {
@@ -228,6 +316,10 @@ def runner(
228316
input_config: TestSparseNNInputConfig,
229317
pipeline_config: PipelineConfig,
230318
) -> None:
319+
# Ensure GPUs are available and we have enough of them
320+
assert (
321+
torch.cuda.is_available() and torch.cuda.device_count() >= world_size
322+
), "CUDA not available or insufficient GPUs for the requested world_size"
231323

232324
torch.autograd.set_detect_anomaly(True)
233325
with MultiProcessContext(

0 commit comments

Comments
 (0)