Skip to content

Commit 26f6058

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Added Optimizer configuration that supports optimizer type, learning rate, momentum, and weight decay configurations. (#3094)
Summary: Pull Request resolved: #3094 This commit introduces enhancements to the optimizer configuration in TorchRec. It now supports specifying the optimizer type, learning rate, momentum, and weight decay. These changes provide more flexibility and control over the training process, allowing users to fine-tune their models with different optimization strategies and hyperparameters. Differential Revision: D76559261
1 parent 7d6306a commit 26f6058

File tree

2 files changed

+74
-16
lines changed

2 files changed

+74
-16
lines changed

torchrec/distributed/benchmark/benchmark_pipeline_utils.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -294,21 +294,23 @@ def generate_sharded_model_and_optimizer(
294294
kernel_type: str,
295295
pg: dist.ProcessGroup,
296296
device: torch.device,
297-
fused_params: Optional[Dict[str, Any]] = None,
297+
fused_params: Dict[str, Any],
298+
dense_optimizer: str = "SGD",
299+
dense_lr: float = 0.1,
300+
dense_momentum: Optional[float] = None,
301+
dense_weight_decay: Optional[float] = None,
298302
planner: Optional[
299303
Union[
300304
EmbeddingShardingPlanner,
301305
HeteroEmbeddingShardingPlanner,
302306
]
303307
] = None,
304308
) -> Tuple[nn.Module, Optimizer]:
305-
# Ensure fused_params is always a dictionary
306-
fused_params_dict = {} if fused_params is None else fused_params
307309

308310
sharder = TestEBCSharder(
309311
sharding_type=sharding_type,
310312
kernel_type=kernel_type,
311-
fused_params=fused_params_dict,
313+
fused_params=fused_params,
312314
)
313315
sharders = [cast(ModuleSharder[nn.Module], sharder)]
314316

@@ -328,14 +330,40 @@ def generate_sharded_model_and_optimizer(
328330
sharders=sharders,
329331
plan=plan,
330332
).to(device)
331-
optimizer = optim.SGD(
332-
[
333-
param
334-
for name, param in sharded_model.named_parameters()
335-
if "sparse" not in name
336-
],
337-
lr=0.1,
338-
)
333+
334+
# Get dense parameters
335+
dense_params = [
336+
param
337+
for name, param in sharded_model.named_parameters()
338+
if "sparse" not in name
339+
]
340+
341+
# Create optimizer based on the specified type
342+
optimizer_classes = {
343+
"sgd": optim.SGD,
344+
"adam": optim.Adam,
345+
"adagrad": optim.Adagrad,
346+
"rmsprop": optim.RMSprop,
347+
"adadelta": optim.Adadelta,
348+
"adamw": optim.AdamW,
349+
"adamax": optim.Adamax,
350+
"nadam": optim.NAdam,
351+
"asgd": optim.ASGD,
352+
"lbfgs": optim.LBFGS,
353+
}
354+
optimizer_class = optimizer_classes.get(dense_optimizer.lower(), optim.SGD)
355+
356+
# Create optimizer with momentum and/or weight_decay if provided
357+
optimizer_kwargs = {"lr": dense_lr}
358+
359+
if dense_momentum is not None:
360+
optimizer_kwargs["momentum"] = dense_momentum
361+
362+
if dense_weight_decay is not None:
363+
optimizer_kwargs["weight_decay"] = dense_weight_decay
364+
365+
optimizer = optimizer_class(dense_params, **optimizer_kwargs) # pyre-ignore[6]
366+
339367
return sharded_model, optimizer
340368

341369

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ class RunOptions:
6565
pooling_factors (Optional[List[float]]): Pooling factors for each feature of the table.
6666
This is the average number of values each sample has for the feature.
6767
num_poolings (Optional[List[float]]): Number of poolings for each feature of the table.
68+
dense_optimizer (str): Optimizer to use for dense parameters.
69+
Default is "SGD".
70+
dense_lr (float): Learning rate for dense parameters.
71+
Default is 0.1.
72+
sparse_optimizer (str): Optimizer to use for sparse parameters.
73+
Default is "EXACT_ADAGRAD".
74+
sparse_lr (float): Learning rate for sparse parameters.
75+
Default is 0.1.
6876
"""
6977

7078
world_size: int = 2
@@ -76,6 +84,14 @@ class RunOptions:
7684
planner_type: str = "embedding"
7785
pooling_factors: Optional[List[float]] = None
7886
num_poolings: Optional[List[float]] = None
87+
dense_optimizer: str = "SGD"
88+
dense_lr: float = 0.1
89+
dense_momentum: Optional[float] = None
90+
dense_weight_decay: Optional[float] = None
91+
sparse_optimizer: str = "EXACT_ADAGRAD"
92+
sparse_lr: float = 0.1
93+
sparse_momentum: Optional[float] = None
94+
sparse_weight_decay: Optional[float] = None
7995

8096

8197
@dataclass
@@ -204,17 +220,31 @@ def runner(
204220
num_batches=run_option.num_batches,
205221
)
206222

223+
# Prepare fused_params for sparse optimizer
224+
fused_params = {
225+
"optimizer": getattr(EmbOptimType, run_option.sparse_optimizer.upper()),
226+
"learning_rate": run_option.sparse_lr,
227+
}
228+
229+
# Add momentum and weight_decay to fused_params if provided
230+
if run_option.sparse_momentum is not None:
231+
fused_params["momentum"] = run_option.sparse_momentum
232+
233+
if run_option.sparse_weight_decay is not None:
234+
fused_params["weight_decay"] = run_option.sparse_weight_decay
235+
207236
sharded_model, optimizer = generate_sharded_model_and_optimizer(
208237
model=unsharded_model,
209238
sharding_type=run_option.sharding_type.value,
210239
kernel_type=run_option.compute_kernel.value,
211240
# pyre-ignore
212241
pg=ctx.pg,
213242
device=ctx.device,
214-
fused_params={
215-
"optimizer": EmbOptimType.EXACT_ADAGRAD,
216-
"learning_rate": 0.1,
217-
},
243+
fused_params=fused_params,
244+
dense_optimizer=run_option.dense_optimizer,
245+
dense_lr=run_option.dense_lr,
246+
dense_momentum=run_option.dense_momentum,
247+
dense_weight_decay=run_option.dense_weight_decay,
218248
planner=planner,
219249
)
220250
pipeline = generate_pipeline(

0 commit comments

Comments
 (0)