Skip to content

Commit 79cfc7b

Browse files
committed
revised from_checkpoint, update tests and CHANGELOG
Signed-off-by: jialusui1102 <[email protected]>
1 parent 0799545 commit 79cfc7b

File tree

6 files changed

+38
-25
lines changed

6 files changed

+38
-25
lines changed

CHANGELOG.md

+8
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414
- General purpose patching API for patch-based diffusion
1515
- New positional embedding selection strategy for CorrDiff SongUNet models
1616
- Added Multi-Storage Client to allow checkpointing to/from Object Storage
17+
- Added `ResidualLoss_Opt` for patch amortized CorrDiff training
1718

1819
### Changed
1920

2021
- Simplified CorrDiff config files, updated default values
2122
- Refactored CorrDiff losses and samplers to use the patching API
2223
- Support for non-square images and patches in patch-based diffusion
24+
- Updated CorrDiff training code to support multiple patch iterations to amortize regression cost and usage of `torch.compile`
25+
- Refactored `physicsnemo/models/diffusion/layers.py` to optimize data type casting workflow, avoiding unnecessary casting under autocast mode
26+
- Refactored Conv2d to enable fusion of conv2d with bias addition
27+
- Refactored GroupNorm, UNetBlock, SongUNet, SongUNetPosEmbd to support usage of Apex GroupNorm, fusion of activation with GroupNorm, and AMP workflow.
28+
- Updated SongUNetPosEmbd to avoid unnecessary HtoD Memcpy of `pos_embd`
29+
- Updated `from_checkpoint` to accommodate usage of Apex GroupNorm
30+
- Refactored CorrDiff NVTX annotation workflow to be configurable
2331

2432
### Deprecated
2533

physicsnemo/models/diffusion/layers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import nvtx
3232
import contextlib
3333
import torch.cuda.amp as amp
34-
34+
import pdb
3535

3636
class Linear(torch.nn.Module):
3737
"""
@@ -353,7 +353,7 @@ def forward(self, x):
353353
bias = self.bias.to(x.dtype)
354354
if self.use_apex_gn:
355355
x = self.gn(x)
356-
elif self.training: #check
356+
elif self.training:
357357
# Use default torch implementation of GroupNorm for training
358358
# This does not support channels last memory format
359359
x = torch.nn.functional.group_norm(

physicsnemo/models/module.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,8 @@ def from_checkpoint(cls, file_name: str, model_args: Optional[Dict] = None) -> "
375375
# Load model arguments and instantiate the model
376376
with open(local_path.joinpath("args.json"), "r") as f:
377377
args = json.load(f)
378-
378+
apex_in_ckp = "use_apex_gn" in args["__args__"].keys()
379+
379380
# Merge model_args (adding new keys and updating existing ones)
380381
if model_args is not None:
381382
args["__args__"].update(model_args)
@@ -384,7 +385,8 @@ def from_checkpoint(cls, file_name: str, model_args: Optional[Dict] = None) -> "
384385
model_dict = torch.load(
385386
local_path.joinpath("model.pt"), map_location=model.device
386387
)
387-
if "use_apex_gn" in args["__args__"].keys() and args["__args__"]["use_apex_gn"]:
388+
#TODO: for corrdiff model architecture specifically
389+
if not apex_in_ckp and "use_apex_gn" in args["__args__"].keys() and args["__args__"]["use_apex_gn"]:
388390
filtered_state_dict = {}
389391
for key, value in model_dict.items():
390392
filtered_state_dict[key] = value # Keep the original key
@@ -399,7 +401,6 @@ def from_checkpoint(cls, file_name: str, model_args: Optional[Dict] = None) -> "
399401
model.load_state_dict(filtered_state_dict,strict=False)
400402
else:
401403
model.load_state_dict(model_dict,strict=False)
402-
403404
return model
404405

405406
@staticmethod

test/models/common/checkpoints.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def validate_checkpoint(
3535
in_args: Tuple[Tensor] = (),
3636
rtol: float = 1e-5,
3737
atol: float = 1e-5,
38+
enable_autocast: bool = False,
3839
) -> bool:
3940
"""Check network's checkpoint safely saves and loads the state of the model
4041
@@ -54,6 +55,8 @@ def validate_checkpoint(
5455
Relative tolerance of error allowed, by default 1e-5
5556
atol : float, optional
5657
Absolute tolerance of error allowed, by default 1e-5
58+
enable_autocast: bool, optional
59+
Whether to enable autocast in model forward
5760
5861
Returns
5962
-------
@@ -72,8 +75,9 @@ def validate_checkpoint(
7275
pass
7376

7477
# Now test forward passes
75-
output_1 = model_1.forward(*in_args)
76-
output_2 = model_2.forward(*in_args)
78+
with torch.autocast("cuda", enabled=enable_autocast):
79+
output_1 = model_1.forward(*in_args)
80+
output_2 = model_2.forward(*in_args)
7781

7882
# Model outputs should initially be different
7983
assert not compare_output(
@@ -85,12 +89,15 @@ def validate_checkpoint(
8589
model_2.load("checkpoint.mdlus")
8690

8791
# Forward with loaded checkpoint
88-
output_2 = model_2.forward(*in_args)
92+
with torch.autocast("cuda", enabled=enable_autocast):
93+
output_2 = model_2.forward(*in_args)
94+
8995
loaded_checkpoint = compare_output(output_1, output_2, rtol, atol)
9096

9197
# Restore checkpoint with from_checkpoint, checks initialization of model directly from checkpoint
9298
model_2 = physicsnemo.Module.from_checkpoint("checkpoint.mdlus").to(model_1.device)
93-
output_2 = model_2.forward(*in_args)
99+
with torch.autocast("cuda", enabled=enable_autocast):
100+
output_2 = model_2.forward(*in_args)
94101
restored_checkpoint = compare_output(output_1, output_2, rtol, atol)
95102

96103
# Delete checkpoint file (it should exist!)

test/models/diffusion/test_song_unet_agn_amp.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,9 @@ def test_song_unet_checkpoint(device):
217217
noise_labels = torch.randn([1]).to(device)
218218
class_labels = torch.randint(0, 1, (1, 1)).to(device)
219219
input_image = torch.ones([1, 2, 16, 16]).to(device)
220-
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
221-
assert common.validate_checkpoint(
222-
model_1, model_2, (*[input_image, noise_labels, class_labels],)
223-
)
220+
assert common.validate_checkpoint(
221+
model_1, model_2, (*[input_image, noise_labels, class_labels],),enable_autocast=True
222+
)
224223

225224

226225
@common.check_ort_version()
@@ -243,11 +242,10 @@ def test_son_unet_deploy(device):
243242
class_labels = torch.randint(0, 1, (1, 1)).to(device)
244243
input_image = torch.ones([1, 2, 16, 16]).to(device)
245244

246-
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
247-
assert common.validate_onnx_export(
248-
model, (*[input_image, noise_labels, class_labels],)
249-
)
250-
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
251-
assert common.validate_onnx_runtime(
252-
model, (*[input_image, noise_labels, class_labels],)
253-
)
245+
assert common.validate_onnx_export(
246+
model, (*[input_image, noise_labels, class_labels],)
247+
)
248+
249+
assert common.validate_onnx_runtime(
250+
model, (*[input_image, noise_labels, class_labels],)
251+
)

test/models/diffusion/test_song_unet_pos_embd_agn_amp.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,9 @@ def test_song_unet_checkpoint(device):
233233
noise_labels = torch.randn([1]).to(device)
234234
class_labels = torch.randint(0, 1, (1, 1)).to(device)
235235
input_image = torch.ones([1, 2, 16, 16]).to(device)
236-
with torch.autocast("cuda", enabled=True):
237-
assert common.validate_checkpoint(
238-
model_1, model_2, (*[input_image, noise_labels, class_labels],), rtol=1e-5, atol=1e-5,
239-
)
236+
assert common.validate_checkpoint(
237+
model_1, model_2, (*[input_image, noise_labels, class_labels],),enable_autocast=True
238+
)
240239

241240

242241
@common.check_ort_version()

0 commit comments

Comments
 (0)