Skip to content

Commit 8ba65ba

Browse files
committed
add multi-gpu optimization
1 parent 17245be commit 8ba65ba

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

examples/generative/corrdiff/train.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,8 @@ def main(cfg: DictConfig) -> None:
279279
broadcast_buffers=True,
280280
output_device=dist.device,
281281
find_unused_parameters=True, # dist.find_unused_parameters,
282+
bucket_cap_mb = 35,
283+
gradient_as_bucket_view = True,
282284
)
283285
if cfg.wandb.watch_model and dist.rank == 0:
284286
wandb.watch(model)
@@ -369,7 +371,7 @@ def main(cfg: DictConfig) -> None:
369371

370372
# Instantiate the optimizer
371373
optimizer = torch.optim.Adam(
372-
params=model.parameters(), lr=cfg.training.hp.lr, betas=[0.9, 0.999], eps=1e-8
374+
params=model.parameters(), lr=cfg.training.hp.lr, betas=[0.9, 0.999], eps=1e-8, fused=True
373375
)
374376

375377
# Record the current time to measure the duration of subsequent operations.

0 commit comments

Comments
 (0)