Skip to content

Commit 6ae6f05

Browse files
committed
rebase changes and update tests and configs
Signed-off-by: jialusui1102 <[email protected]>
1 parent 8ba65ba commit 6ae6f05

File tree

9 files changed

+195
-373
lines changed

9 files changed

+195
-373
lines changed

examples/generative/corrdiff/conf/base/model/patched_diffusion.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
name: diffusion
17+
name: patched_diffusion
1818
# Model type.
1919
hr_mean_conditioning: True
2020
# Recommended to use high-res conditioning for diffusion.

examples/generative/corrdiff/conf/base/model_size/normal.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ model_args:
2323
# Per-resolution multipliers for the number of channels.
2424
channel_mult: [1, 2, 2, 2, 2]
2525
# Resolutions at which self-attention layers are applied.
26-
attention_levels: [28]
26+
attn_resolutions: [28]

examples/generative/corrdiff/conf/base/training/corrdiff_patched_diffusion_opt.yaml

-98
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
hydra:
18+
job:
19+
chdir: true
20+
name: patched_diffusion_opt
21+
run:
22+
dir: ./output/${hydra:job.name}
23+
searchpath:
24+
- pkg://conf/base # Do not modify
25+
26+
# Base parameters for dataset, model, training, and validation
27+
defaults:
28+
29+
- dataset: hrrr_corrdiff_synthetic
30+
# The dataset type for training.
31+
# Accepted values:
32+
# `gefs_hrrr`: full GEFS-HRRR dataset for continental US.
33+
# `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments.
34+
# `cwb`: full CWB dataset for Taiwan.
35+
# `custom`: user-defined dataset. Parameters need to be specified below.
36+
37+
- model: patched_diffusion
38+
# The model type.
39+
# Accepted values:
40+
# `regression`: a regression UNet for deterministic predictions
41+
# `lt_aware_ce_regression`: similar to `regression` but with lead time
42+
# conditioning
43+
# `diffusion`: a diffusion UNet for residual predictions
44+
# `patched_diffusion`: a more memory-efficient diffusion model
45+
# `lt_aware_patched_diffusion`: similar to `patched_diffusion` but
46+
# with lead time conditioning
47+
48+
- model_size: normal
49+
# The model size configuration.
50+
# Accepted values:
51+
# `normal`: normal model size
52+
# `mini`: smaller model size for fast experiments
53+
54+
- training: ${model}
55+
# The base training parameters. Determined by the model type.
56+
57+
58+
# Dataset parameters. Used for `custom` dataset type.
59+
# Modify or add below parameters that should be passed as argument to the
60+
# user-defined dataset class.
61+
dataset:
62+
data_path: ./data
63+
# Path to .nc data file
64+
stats_path: ./data/stats.json
65+
# Path to json stats file
66+
67+
# Training parameters
68+
training:
69+
hp:
70+
training_duration: 200000000
71+
# Training duration based on the number of processed samples
72+
total_batch_size: 512
73+
# Total batch size
74+
batch_size_per_gpu: 4
75+
76+
patch_shape_x: 448
77+
patch_shape_y: 448
78+
# Patch size. Patch training is used if these dimensions differ from
79+
# img_shape_x and img_shape_y.
80+
patch_num: 16
81+
# Number of patches from a single sample. Total number of patches is
82+
# patch_num * batch_size_global.
83+
max_patch_per_gpu: 9
84+
# Maximum number of pataches a gpu can hold
85+
86+
lr: 0.0002
87+
# Learning rate
88+
grad_clip_threshold: 1e6
89+
lr_decay: 0.7
90+
lr_rampup: 1000000
91+
92+
# Performance
93+
perf:
94+
fp_optimizations: amp-bf16
95+
# Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"]
96+
# "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16}
97+
dataloader_workers: 4
98+
# DataLoader worker processes
99+
songunet_checkpoint_level: 0 # 0 means no checkpointing
100+
# Gradient checkpointing level, value is number of layers to checkpoint
101+
# optimization_mode: True
102+
use_apex_gn: True
103+
torch_compile: True
104+
profile_mode: False
105+
106+
io:
107+
regression_checkpoint_path: /lustre/fsw/portfolios/coreai/users/asui/video-corrdiff-checkpoints/training-state-regression-000513.mdlus
108+
# Path to load the regression checkpoint
109+
110+
# Where to load the regression checkpoint
111+
print_progress_freq: 1000
112+
# How often to print progress
113+
save_checkpoint_freq: 500000
114+
# How often to save the checkpoints, measured in number of processed samples
115+
validation_freq: 5000
116+
# how often to record the validation loss, measured in number of processed samples
117+
validation_steps: 10
118+
# how many loss evaluations are used to compute the validation loss per checkpoint
119+
120+
# Parameters for wandb logging
121+
wandb:
122+
mode: offline
123+
# Configure whether to use wandb: "offline", "online", "disabled"
124+
results_dir: "./wandb"
125+
# Directory to store wandb results
126+
watch_model: false
127+
# If true, wandb will track model parameters and gradients

examples/generative/corrdiff/train.py

+14-23
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ def main(cfg: DictConfig) -> None:
162162
prob_channels = dataset.get_prob_channel_index()
163163
else:
164164
prob_channels = None
165-
166165
# Parse the patch shape
167166
if (
168167
cfg.model.name == "patched_diffusion"
@@ -348,11 +347,6 @@ def main(cfg: DictConfig) -> None:
348347
if cfg.model.name == "patched_diffusion" and len(patch_nums_iter)>1:
349348
loss_fn = ResidualLoss_Opt(
350349
regression_net=regression_net,
351-
img_shape_x=img_shape[1],
352-
img_shape_y=img_shape[0],
353-
patch_shape_x=patch_shape[1],
354-
patch_shape_y=patch_shape[0],
355-
patch_num=patch_num,
356350
hr_mean_conditioning=cfg.model.hr_mean_conditioning,
357351
)
358352
elif cfg.model.name in (
@@ -415,11 +409,11 @@ def main(cfg: DictConfig) -> None:
415409
tick_start_nimg = cur_nimg
416410
tick_start_time = time.time()
417411

418-
if cur_nimg - start_nimg == 4 * cfg.training.hp.total_batch_size:
412+
if cur_nimg - start_nimg == 14 * cfg.training.hp.total_batch_size:
419413
logger0.info(f"Starting Profiler at {cur_nimg}")
420414
torch.cuda.profiler.start()
421415

422-
if cur_nimg - start_nimg == 6 * cfg.training.hp.total_batch_size:
416+
if cur_nimg - start_nimg == 16 * cfg.training.hp.total_batch_size:
423417
logger0.info(f"Stoping Profiler at {cur_nimg}")
424418
torch.cuda.profiler.stop()
425419

@@ -432,7 +426,7 @@ def main(cfg: DictConfig) -> None:
432426
f"accumulation round {n_i}", color="Magenta"
433427
):
434428
with nvtx.annotate(f"loading data", color="green"):
435-
img_clean, img_lr, labels, *lead_time_label = next(
429+
img_clean, img_lr, *lead_time_label = next(
436430
dataset_iterator
437431
)
438432
if use_apex_gn:
@@ -446,7 +440,6 @@ def main(cfg: DictConfig) -> None:
446440
dtype=input_dtype,
447441
non_blocking=True,
448442
).to(memory_format=torch.channels_last)
449-
labels = labels.to(dist.device, non_blocking=True)
450443
else:
451444
img_clean = (
452445
img_clean.to(dist.device)
@@ -458,15 +451,13 @@ def main(cfg: DictConfig) -> None:
458451
.to(input_dtype)
459452
.contiguous()
460453
)
461-
labels = labels.to(dist.device).contiguous()
462454
loss_fn_kwargs = {
463455
"net": model,
464456
"img_clean": img_clean,
465457
"img_lr": img_lr,
466-
"labels": labels,
467458
"augment_pipe": None,
468459
}
469-
460+
470461
if lead_time_label:
471462
lead_time_label = (
472463
lead_time_label[0].to(dist.device).contiguous()
@@ -570,7 +561,7 @@ def main(cfg: DictConfig) -> None:
570561
):
571562
with torch.no_grad():
572563
for _ in range(cfg.training.io.validation_steps):
573-
img_clean_valid, img_lr_valid, labels_valid = next(
564+
img_clean_valid, img_lr_valid, *lead_time_label_valid = next(
574565
validation_dataset_iterator
575566
)
576567

@@ -585,9 +576,6 @@ def main(cfg: DictConfig) -> None:
585576
dtype=input_dtype,
586577
non_blocking=True,
587578
).to(memory_format=torch.channels_last)
588-
labels_valid = labels_valid.to(
589-
dist.device, non_blocking=True
590-
)
591579

592580
else:
593581
img_clean_valid = (
@@ -600,17 +588,20 @@ def main(cfg: DictConfig) -> None:
600588
.to(input_dtype)
601589
.contiguous()
602590
)
603-
labels_valid = labels_valid.to(
604-
dist.device
605-
).contiguous()
606591

607-
loss_fn_valid_kwargs = {
592+
loss_valid_kwargs = {
608593
"net": model,
609594
"img_clean": img_clean_valid,
610595
"img_lr": img_lr_valid,
611-
"labels": labels_valid,
612596
"augment_pipe": None,
613597
}
598+
if lead_time_label_valid:
599+
lead_time_label_valid = (
600+
lead_time_label_valid[0].to(dist.device).contiguous()
601+
)
602+
loss_valid_kwargs.update(
603+
{"lead_time_label": lead_time_label_valid}
604+
)
614605
if isinstance(loss_fn, ResidualLoss_Opt):
615606
loss_fn.y_mean = None
616607

@@ -621,7 +612,7 @@ def main(cfg: DictConfig) -> None:
621612
loss_fn_kwargs.update({"patching": patching})
622613
# pdb.set_trace()
623614
with torch.autocast("cuda", dtype=amp_dtype, enabled=enable_amp):
624-
loss_valid = loss_fn(**loss_fn_valid_kwargs)
615+
loss_valid = loss_fn(**loss_valid_kwargs)
625616

626617
loss_valid = (
627618
(loss_valid.sum() / batch_size_per_gpu)

0 commit comments

Comments
 (0)