|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +# pyre-strict |
| 9 | + |
| 10 | +import copy |
| 11 | +from dataclasses import dataclass |
| 12 | +from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union |
| 13 | + |
| 14 | +import torch |
| 15 | +import torch.distributed as dist |
| 16 | +from torch import nn, optim |
| 17 | +from torch.optim import Optimizer |
| 18 | +from torchrec.distributed import DistributedModelParallel |
| 19 | +from torchrec.distributed.embedding_types import EmbeddingComputeKernel |
| 20 | +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology |
| 21 | +from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR |
| 22 | +from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner |
| 23 | +from torchrec.distributed.planner.types import ParameterConstraints |
| 24 | +from torchrec.distributed.test_utils.test_input import ModelInput |
| 25 | +from torchrec.distributed.test_utils.test_model import ( |
| 26 | + TestEBCSharder, |
| 27 | + TestOverArchLarge, |
| 28 | + TestSparseNN, |
| 29 | + TestTowerCollectionSparseNN, |
| 30 | + TestTowerSparseNN, |
| 31 | +) |
| 32 | +from torchrec.distributed.train_pipeline import ( |
| 33 | + TrainPipelineBase, |
| 34 | + TrainPipelineFusedSparseDist, |
| 35 | + TrainPipelineSparseDist, |
| 36 | +) |
| 37 | +from torchrec.distributed.train_pipeline.train_pipelines import ( |
| 38 | + PrefetchTrainPipelineSparseDist, |
| 39 | + TrainPipelineSemiSync, |
| 40 | +) |
| 41 | +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType |
| 42 | +from torchrec.modules.embedding_configs import EmbeddingBagConfig |
| 43 | + |
| 44 | + |
| 45 | +@dataclass |
| 46 | +class ModelConfig: |
| 47 | + model_name: str = "test_sparsenn" |
| 48 | + |
| 49 | + batch_size: int = 8192 |
| 50 | + num_float_features: int = 10 |
| 51 | + feature_pooling_avg: int = 10 |
| 52 | + use_offsets: bool = False |
| 53 | + dev_str: str = "" |
| 54 | + long_kjt_indices: bool = True |
| 55 | + long_kjt_offsets: bool = True |
| 56 | + long_kjt_lengths: bool = True |
| 57 | + pin_memory: bool = True |
| 58 | + |
| 59 | + def generate_model( |
| 60 | + self, |
| 61 | + tables: List[EmbeddingBagConfig], |
| 62 | + weighted_tables: List[EmbeddingBagConfig], |
| 63 | + dense_device: torch.device, |
| 64 | + ) -> nn.Module: |
| 65 | + if self.model_name == "test_sparsenn": |
| 66 | + return TestSparseNN( |
| 67 | + tables=tables, |
| 68 | + weighted_tables=weighted_tables, |
| 69 | + dense_device=dense_device, |
| 70 | + sparse_device=torch.device("meta"), |
| 71 | + over_arch_clazz=TestOverArchLarge, |
| 72 | + ) |
| 73 | + elif self.model_name == "test_tower_sparsenn": |
| 74 | + return TestTowerSparseNN( |
| 75 | + tables=tables, |
| 76 | + weighted_tables=weighted_tables, |
| 77 | + dense_device=dense_device, |
| 78 | + sparse_device=torch.device("meta"), |
| 79 | + num_float_features=self.num_float_features, |
| 80 | + ) |
| 81 | + elif self.model_name == "test_tower_collection_sparsenn": |
| 82 | + return TestTowerCollectionSparseNN( |
| 83 | + tables=tables, |
| 84 | + weighted_tables=weighted_tables, |
| 85 | + dense_device=dense_device, |
| 86 | + sparse_device=torch.device("meta"), |
| 87 | + num_float_features=self.num_float_features, |
| 88 | + ) |
| 89 | + else: |
| 90 | + raise RuntimeError(f"Unknown model name: {self.model_name}") |
| 91 | + |
| 92 | + |
| 93 | +def generate_tables( |
| 94 | + num_unweighted_features: int, |
| 95 | + num_weighted_features: int, |
| 96 | + embedding_feature_dim: int, |
| 97 | +) -> Tuple[ |
| 98 | + List[EmbeddingBagConfig], |
| 99 | + List[EmbeddingBagConfig], |
| 100 | +]: |
| 101 | + """ |
| 102 | + Generate embedding bag configurations for both unweighted and weighted features. |
| 103 | +
|
| 104 | + This function creates two lists of EmbeddingBagConfig objects: |
| 105 | + 1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}" |
| 106 | + 2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}" |
| 107 | +
|
| 108 | + For both types, the number of embeddings scales with the feature index, |
| 109 | + calculated as max(i + 1, 100) * 1000. |
| 110 | +
|
| 111 | + Args: |
| 112 | + num_unweighted_features (int): Number of unweighted features to generate. |
| 113 | + num_weighted_features (int): Number of weighted features to generate. |
| 114 | + embedding_feature_dim (int): Dimension of the embedding vectors. |
| 115 | +
|
| 116 | + Returns: |
| 117 | + Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing |
| 118 | + two lists - the first for unweighted embedding tables and the second for |
| 119 | + weighted embedding tables. |
| 120 | + """ |
| 121 | + tables = [ |
| 122 | + EmbeddingBagConfig( |
| 123 | + num_embeddings=max(i + 1, 100) * 1000, |
| 124 | + embedding_dim=embedding_feature_dim, |
| 125 | + name="table_" + str(i), |
| 126 | + feature_names=["feature_" + str(i)], |
| 127 | + ) |
| 128 | + for i in range(num_unweighted_features) |
| 129 | + ] |
| 130 | + weighted_tables = [ |
| 131 | + EmbeddingBagConfig( |
| 132 | + num_embeddings=max(i + 1, 100) * 1000, |
| 133 | + embedding_dim=embedding_feature_dim, |
| 134 | + name="weighted_table_" + str(i), |
| 135 | + feature_names=["weighted_feature_" + str(i)], |
| 136 | + ) |
| 137 | + for i in range(num_weighted_features) |
| 138 | + ] |
| 139 | + return tables, weighted_tables |
| 140 | + |
| 141 | + |
| 142 | +def generate_pipeline( |
| 143 | + pipeline_type: str, |
| 144 | + emb_lookup_stream: str, |
| 145 | + model: nn.Module, |
| 146 | + opt: torch.optim.Optimizer, |
| 147 | + device: torch.device, |
| 148 | +) -> Union[TrainPipelineBase, TrainPipelineSparseDist]: |
| 149 | + """ |
| 150 | + Generate a training pipeline instance based on the configuration. |
| 151 | +
|
| 152 | + This function creates and returns the appropriate training pipeline object |
| 153 | + based on the pipeline type specified. Different pipeline types are optimized |
| 154 | + for different training scenarios. |
| 155 | +
|
| 156 | + Args: |
| 157 | + pipeline_type (str): The type of training pipeline to use. Options include: |
| 158 | + - "base": Basic training pipeline |
| 159 | + - "sparse": Pipeline optimized for sparse operations |
| 160 | + - "fused": Pipeline with fused sparse distribution |
| 161 | + - "semi": Semi-synchronous training pipeline |
| 162 | + - "prefetch": Pipeline with prefetching for sparse distribution |
| 163 | + emb_lookup_stream (str): The stream to use for embedding lookups. |
| 164 | + Only used by certain pipeline types (e.g., "fused"). |
| 165 | + model (nn.Module): The model to be trained. |
| 166 | + opt (torch.optim.Optimizer): The optimizer to use for training. |
| 167 | + device (torch.device): The device to run the training on. |
| 168 | +
|
| 169 | + Returns: |
| 170 | + Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the |
| 171 | + appropriate training pipeline class based on the configuration. |
| 172 | +
|
| 173 | + Raises: |
| 174 | + RuntimeError: If an unknown pipeline type is specified. |
| 175 | + """ |
| 176 | + |
| 177 | + _pipeline_cls: Dict[ |
| 178 | + str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]] |
| 179 | + ] = { |
| 180 | + "base": TrainPipelineBase, |
| 181 | + "sparse": TrainPipelineSparseDist, |
| 182 | + "fused": TrainPipelineFusedSparseDist, |
| 183 | + "semi": TrainPipelineSemiSync, |
| 184 | + "prefetch": PrefetchTrainPipelineSparseDist, |
| 185 | + } |
| 186 | + |
| 187 | + if pipeline_type == "semi": |
| 188 | + return TrainPipelineSemiSync( |
| 189 | + model=model, optimizer=opt, device=device, start_batch=0 |
| 190 | + ) |
| 191 | + elif pipeline_type == "fused": |
| 192 | + return TrainPipelineFusedSparseDist( |
| 193 | + model=model, |
| 194 | + optimizer=opt, |
| 195 | + device=device, |
| 196 | + emb_lookup_stream=emb_lookup_stream, |
| 197 | + ) |
| 198 | + elif pipeline_type in _pipeline_cls: |
| 199 | + Pipeline = _pipeline_cls[pipeline_type] |
| 200 | + return Pipeline(model=model, optimizer=opt, device=device) |
| 201 | + else: |
| 202 | + raise RuntimeError(f"unknown pipeline option {pipeline_type}") |
| 203 | + |
| 204 | + |
| 205 | +def generate_planner( |
| 206 | + planner_type: str, |
| 207 | + topology: Topology, |
| 208 | + tables: Optional[List[EmbeddingBagConfig]], |
| 209 | + weighted_tables: Optional[List[EmbeddingBagConfig]], |
| 210 | + sharding_type: ShardingType, |
| 211 | + compute_kernel: EmbeddingComputeKernel, |
| 212 | + num_batches: int, |
| 213 | + batch_size: int, |
| 214 | + pooling_factors: Optional[List[float]], |
| 215 | + num_poolings: Optional[List[float]], |
| 216 | +) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: |
| 217 | + """ |
| 218 | + Generate an embedding sharding planner based on the specified configuration. |
| 219 | +
|
| 220 | + Args: |
| 221 | + planner_type: Type of planner to use ("embedding" or "hetero") |
| 222 | + topology: Network topology for distributed training |
| 223 | + tables: List of unweighted embedding tables |
| 224 | + weighted_tables: List of weighted embedding tables |
| 225 | + sharding_type: Strategy for sharding embedding tables |
| 226 | + compute_kernel: Compute kernel to use for embedding tables |
| 227 | + num_batches: Number of batches to process |
| 228 | + batch_size: Size of each batch |
| 229 | + pooling_factors: Pooling factors for each feature of the table |
| 230 | + num_poolings: Number of poolings for each feature of the table |
| 231 | +
|
| 232 | + Returns: |
| 233 | + An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner |
| 234 | +
|
| 235 | + Raises: |
| 236 | + RuntimeError: If an unknown planner type is specified |
| 237 | + """ |
| 238 | + # Create parameter constraints for tables |
| 239 | + constraints = {} |
| 240 | + |
| 241 | + if pooling_factors is None: |
| 242 | + pooling_factors = [POOLING_FACTOR] * num_batches |
| 243 | + |
| 244 | + if num_poolings is None: |
| 245 | + num_poolings = [NUM_POOLINGS] * num_batches |
| 246 | + |
| 247 | + batch_sizes = [batch_size] * num_batches |
| 248 | + |
| 249 | + assert ( |
| 250 | + len(pooling_factors) == num_batches and len(num_poolings) == num_batches |
| 251 | + ), "The length of pooling_factors and num_poolings must match the number of batches." |
| 252 | + |
| 253 | + if tables is not None: |
| 254 | + for table in tables: |
| 255 | + constraints[table.name] = ParameterConstraints( |
| 256 | + sharding_types=[sharding_type.value], |
| 257 | + compute_kernels=[compute_kernel.value], |
| 258 | + device_group="cuda", |
| 259 | + pooling_factors=pooling_factors, |
| 260 | + num_poolings=num_poolings, |
| 261 | + batch_sizes=batch_sizes, |
| 262 | + ) |
| 263 | + |
| 264 | + if weighted_tables is not None: |
| 265 | + for table in weighted_tables: |
| 266 | + constraints[table.name] = ParameterConstraints( |
| 267 | + sharding_types=[sharding_type.value], |
| 268 | + compute_kernels=[compute_kernel.value], |
| 269 | + device_group="cuda", |
| 270 | + pooling_factors=pooling_factors, |
| 271 | + num_poolings=num_poolings, |
| 272 | + batch_sizes=batch_sizes, |
| 273 | + is_weighted=True, |
| 274 | + ) |
| 275 | + |
| 276 | + if planner_type == "embedding": |
| 277 | + return EmbeddingShardingPlanner( |
| 278 | + topology=topology, |
| 279 | + constraints=constraints if constraints else None, |
| 280 | + ) |
| 281 | + elif planner_type == "hetero": |
| 282 | + topology_groups = {"cuda": topology} |
| 283 | + return HeteroEmbeddingShardingPlanner( |
| 284 | + topology_groups=topology_groups, |
| 285 | + constraints=constraints if constraints else None, |
| 286 | + ) |
| 287 | + else: |
| 288 | + raise RuntimeError(f"Unknown planner type: {planner_type}") |
| 289 | + |
| 290 | + |
| 291 | +def generate_sharded_model_and_optimizer( |
| 292 | + model: nn.Module, |
| 293 | + sharding_type: str, |
| 294 | + kernel_type: str, |
| 295 | + pg: dist.ProcessGroup, |
| 296 | + device: torch.device, |
| 297 | + fused_params: Optional[Dict[str, Any]] = None, |
| 298 | + planner: Optional[ |
| 299 | + Union[ |
| 300 | + EmbeddingShardingPlanner, |
| 301 | + HeteroEmbeddingShardingPlanner, |
| 302 | + ] |
| 303 | + ] = None, |
| 304 | +) -> Tuple[nn.Module, Optimizer]: |
| 305 | + # Ensure fused_params is always a dictionary |
| 306 | + fused_params_dict = {} if fused_params is None else fused_params |
| 307 | + |
| 308 | + sharder = TestEBCSharder( |
| 309 | + sharding_type=sharding_type, |
| 310 | + kernel_type=kernel_type, |
| 311 | + fused_params=fused_params_dict, |
| 312 | + ) |
| 313 | + sharders = [cast(ModuleSharder[nn.Module], sharder)] |
| 314 | + |
| 315 | + # Use planner if provided |
| 316 | + plan = None |
| 317 | + if planner is not None: |
| 318 | + if pg is not None: |
| 319 | + plan = planner.collective_plan(model, sharders, pg) |
| 320 | + else: |
| 321 | + plan = planner.plan(model, sharders) |
| 322 | + |
| 323 | + sharded_model = DistributedModelParallel( |
| 324 | + module=copy.deepcopy(model), |
| 325 | + env=ShardingEnv.from_process_group(pg), |
| 326 | + init_data_parallel=True, |
| 327 | + device=device, |
| 328 | + sharders=sharders, |
| 329 | + plan=plan, |
| 330 | + ).to(device) |
| 331 | + optimizer = optim.SGD( |
| 332 | + [ |
| 333 | + param |
| 334 | + for name, param in sharded_model.named_parameters() |
| 335 | + if "sparse" not in name |
| 336 | + ], |
| 337 | + lr=0.1, |
| 338 | + ) |
| 339 | + return sharded_model, optimizer |
| 340 | + |
| 341 | + |
| 342 | +def generate_data( |
| 343 | + model_class_name: str, |
| 344 | + tables: List[EmbeddingBagConfig], |
| 345 | + weighted_tables: List[EmbeddingBagConfig], |
| 346 | + model_config: ModelConfig, |
| 347 | + num_batches: int, |
| 348 | +) -> List[ModelInput]: |
| 349 | + """ |
| 350 | + Generate model input data for benchmarking. |
| 351 | +
|
| 352 | + Args: |
| 353 | + tables: List of unweighted embedding tables |
| 354 | + weighted_tables: List of weighted embedding tables |
| 355 | + model_config: Configuration for model generation |
| 356 | + num_batches: Number of batches to generate |
| 357 | +
|
| 358 | + Returns: |
| 359 | + A list of ModelInput objects representing the generated batches |
| 360 | + """ |
| 361 | + device = torch.device(model_config.dev_str) if model_config.dev_str else None |
| 362 | + |
| 363 | + if ( |
| 364 | + model_class_name == "TestSparseNN" |
| 365 | + or model_class_name == "TestTowerSparseNN" |
| 366 | + or model_class_name == "TestTowerCollectionSparseNN" |
| 367 | + ): |
| 368 | + return [ |
| 369 | + ModelInput.generate( |
| 370 | + batch_size=model_config.batch_size, |
| 371 | + tables=tables, |
| 372 | + weighted_tables=weighted_tables, |
| 373 | + num_float_features=model_config.num_float_features, |
| 374 | + pooling_avg=model_config.feature_pooling_avg, |
| 375 | + use_offsets=model_config.use_offsets, |
| 376 | + device=device, |
| 377 | + indices_dtype=( |
| 378 | + torch.int64 if model_config.long_kjt_indices else torch.int32 |
| 379 | + ), |
| 380 | + offsets_dtype=( |
| 381 | + torch.int64 if model_config.long_kjt_offsets else torch.int32 |
| 382 | + ), |
| 383 | + lengths_dtype=( |
| 384 | + torch.int64 if model_config.long_kjt_lengths else torch.int32 |
| 385 | + ), |
| 386 | + pin_memory=model_config.pin_memory, |
| 387 | + ) |
| 388 | + for _ in range(num_batches) |
| 389 | + ] |
| 390 | + else: |
| 391 | + raise RuntimeError(f"Unknown model name: {model_config.model_name}") |
0 commit comments