Skip to content

Commit 17245be

Browse files
LostnEkkojialusui1102
authored andcommitted
Lint and format code properly
Signed-off-by: Neal Pan <[email protected]>
1 parent 79cfc7b commit 17245be

File tree

12 files changed

+594
-343
lines changed

12 files changed

+594
-343
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2424
- Updated CorrDiff training code to support multiple patch iterations to amortize regression cost and usage of `torch.compile`
2525
- Refactored `physicsnemo/models/diffusion/layers.py` to optimize data type casting workflow, avoiding unnecessary casting under autocast mode
2626
- Refactored Conv2d to enable fusion of conv2d with bias addition
27-
- Refactored GroupNorm, UNetBlock, SongUNet, SongUNetPosEmbd to support usage of Apex GroupNorm, fusion of activation with GroupNorm, and AMP workflow.
27+
- Refactored GroupNorm, UNetBlock, SongUNet, SongUNetPosEmbd to support usage of
28+
Apex GroupNorm, fusion of activation with GroupNorm, and AMP workflow.
2829
- Updated SongUNetPosEmbd to avoid unnecessary HtoD Memcpy of `pos_embd`
2930
- Updated `from_checkpoint` to accommodate usage of Apex GroupNorm
3031
- Refactored CorrDiff NVTX annotation workflow to be configurable

examples/generative/corrdiff/train.py

+131-57
Large diffs are not rendered by default.

physicsnemo/metrics/diffusion/loss.py

+28-18
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,15 @@ def __init__(
10571057
self.hr_mean_conditioning = hr_mean_conditioning
10581058
self.y_mean = None
10591059

1060-
def __call__(self, net, img_clean, img_lr, patch_num_per_iter=-1, labels=None, augment_pipe=None):
1060+
def __call__(
1061+
self,
1062+
net,
1063+
img_clean,
1064+
img_lr,
1065+
patch_num_per_iter=-1,
1066+
labels=None,
1067+
augment_pipe=None,
1068+
):
10611069
"""
10621070
Calculate and return the loss for denoising score matching.
10631071
@@ -1085,27 +1093,29 @@ def __call__(self, net, img_clean, img_lr, patch_num_per_iter=-1, labels=None, a
10851093
A tensor representing the loss calculated based on the network's
10861094
predictions.
10871095
"""
1088-
1089-
self.patch_num = patch_num_per_iter if patch_num_per_iter != -1 else self.patch_num
1090-
1096+
1097+
self.patch_num = (
1098+
patch_num_per_iter if patch_num_per_iter != -1 else self.patch_num
1099+
)
1100+
10911101
rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device)
10921102
sigma = (rnd_normal * self.P_std + self.P_mean).exp()
10931103
weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
1094-
1104+
10951105
# augment for conditional generaiton
10961106
img_tot = torch.cat((img_clean, img_lr), dim=1)
10971107
y_tot, augment_labels = (
10981108
augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None)
10991109
)
11001110
y = y_tot[:, : img_clean.shape[1], :, :]
11011111
y_lr = y_tot[:, img_clean.shape[1] :, :, :]
1102-
1112+
11031113
y_lr_res = y_lr.to(memory_format=torch.channels_last)
11041114

11051115
# global index
11061116
b = y.shape[0]
1107-
Nx = torch.arange(self.img_shape_x,device=img_clean.device).int()
1108-
Ny = torch.arange(self.img_shape_y,device=img_clean.device).int()
1117+
Nx = torch.arange(self.img_shape_x, device=img_clean.device).int()
1118+
Ny = torch.arange(self.img_shape_y, device=img_clean.device).int()
11091119
grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)[
11101120
None,
11111121
].expand(b, -1, -1, -1)
@@ -1122,25 +1132,24 @@ def __call__(self, net, img_clean, img_lr, patch_num_per_iter=-1, labels=None, a
11221132
self.y_mean = y_mean
11231133
# else:
11241134
# y_mean = self.y_mean
1125-
1135+
11261136
y = y - self.y_mean
11271137

11281138
if self.hr_mean_conditioning:
11291139
y_lr = torch.cat((self.y_mean, y_lr), dim=1)
1130-
1131-
1140+
11321141
global_index = None
11331142
# patchified training
11341143
# conditioning: cat(y_mean, y_lr, input_interp, pos_embd), 4+12+100+4
1135-
1144+
11361145
if (
11371146
self.img_shape_x != self.patch_shape_x
11381147
or self.img_shape_y != self.patch_shape_y
11391148
):
1140-
1149+
11411150
c_in = y_lr.shape[1]
11421151
c_out = y.shape[1]
1143-
1152+
11441153
rnd_normal = torch.randn(
11451154
[img_clean.shape[0] * self.patch_num, 1, 1, 1], device=img_clean.device
11461155
)
@@ -1179,7 +1188,7 @@ def __call__(self, net, img_clean, img_lr, patch_num_per_iter=-1, labels=None, a
11791188
dtype=torch.int,
11801189
device=img_clean.device,
11811190
)
1182-
1191+
11831192
for i in range(self.patch_num):
11841193
rnd_x = random.randint(0, self.img_shape_x - self.patch_shape_x)
11851194
rnd_y = random.randint(0, self.img_shape_y - self.patch_shape_y)
@@ -1207,10 +1216,10 @@ def __call__(self, net, img_clean, img_lr, patch_num_per_iter=-1, labels=None, a
12071216
),
12081217
1,
12091218
)
1210-
1219+
12111220
y = y_new
12121221
y_lr = y_lr_new
1213-
1222+
12141223
latent = y + torch.randn_like(y) * sigma
12151224
D_yn = net(
12161225
latent,
@@ -1220,11 +1229,12 @@ def __call__(self, net, img_clean, img_lr, patch_num_per_iter=-1, labels=None, a
12201229
global_index=global_index,
12211230
augment_labels=augment_labels,
12221231
)
1223-
1232+
12241233
loss = weight * ((D_yn - y) ** 2)
12251234

12261235
return loss
12271236

1237+
12281238
class VELoss_dfsr:
12291239
"""
12301240
Loss function for dfsr model, modified from class VELoss.

physicsnemo/models/diffusion/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
Linear,
2424
PositionalEmbedding,
2525
UNetBlock,
26-
2726
)
2827
from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd
2928
from .dhariwal_unet import DhariwalUNet

0 commit comments

Comments
 (0)