diff --git a/CHANGELOG.md b/CHANGELOG.md index e758767a3f..119a6fb7d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ERA5 download example updated to use current file format convention and restricts global statistics computation to the training set - Support for training custom StormCast models and various other improvements for StormCast +- Updated CorrDiff training code to support multiple patch iterations to amortize + regression cost and usage of `torch.compile` +- Refactored `physicsnemo/models/diffusion/layers.py` to optimize data type + casting workflow, avoiding unnecessary casting under autocast mode +- Refactored Conv2d to enable fusion of conv2d with bias addition +- Refactored GroupNorm, UNetBlock, SongUNet, SongUNetPosEmbd to support usage of + Apex GroupNorm, fusion of activation with GroupNorm, and AMP workflow. +- Updated SongUNetPosEmbd to avoid unnecessary HtoD Memcpy of `pos_embd` +- Updated `from_checkpoint` to accommodate conversion between Apex optimized ckp + and non-optimized ckp +- Refactored CorrDiff NVTX annotation workflow to be configurable +- Refactored `ResidualLoss` to support patch-accumlating training for + amortizing regression costs ### Deprecated diff --git a/examples/generative/corrdiff/conf/base/model/patched_diffusion.yaml b/examples/generative/corrdiff/conf/base/model/patched_diffusion.yaml index bf88bbb649..5b52de3be6 100644 --- a/examples/generative/corrdiff/conf/base/model/patched_diffusion.yaml +++ b/examples/generative/corrdiff/conf/base/model/patched_diffusion.yaml @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: diffusion +name: patched_diffusion # Model type. hr_mean_conditioning: True # Recommended to use high-res conditioning for diffusion. diff --git a/examples/generative/corrdiff/conf/base/model_size/normal.yaml b/examples/generative/corrdiff/conf/base/model_size/normal.yaml index dd3450a33d..b81fe15348 100644 --- a/examples/generative/corrdiff/conf/base/model_size/normal.yaml +++ b/examples/generative/corrdiff/conf/base/model_size/normal.yaml @@ -23,4 +23,4 @@ model_args: # Per-resolution multipliers for the number of channels. channel_mult: [1, 2, 2, 2, 2] # Resolutions at which self-attention layers are applied. - attention_levels: [28] \ No newline at end of file + attn_resolutions: [28] \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/config_training_hrrr_patched_diffusion_opt.yaml b/examples/generative/corrdiff/conf/config_training_hrrr_patched_diffusion_opt.yaml new file mode 100644 index 0000000000..775032860a --- /dev/null +++ b/examples/generative/corrdiff/conf/config_training_hrrr_patched_diffusion_opt.yaml @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: true + name: patched_diffusion_opt + run: + dir: ./output/${hydra:job.name} + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, training, and validation +defaults: + + - dataset: hrrr_corrdiff_synthetic + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - model: patched_diffusion + # The model type. + # Accepted values: + # `regression`: a regression UNet for deterministic predictions + # `lt_aware_ce_regression`: similar to `regression` but with lead time + # conditioning + # `diffusion`: a diffusion UNet for residual predictions + # `patched_diffusion`: a more memory-efficient diffusion model + # `lt_aware_patched_diffusion`: similar to `patched_diffusion` but + # with lead time conditioning + + - model_size: normal + # The model size configuration. + # Accepted values: + # `normal`: normal model size + # `mini`: smaller model size for fast experiments + + - training: ${model} + # The base training parameters. Determined by the model type. + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + data_path: ./data + # Path to .nc data file + stats_path: ./data/stats.json + # Path to json stats file + +# Training parameters +training: + hp: + training_duration: 200000000 + # Training duration based on the number of processed samples + total_batch_size: 512 + # Total batch size + batch_size_per_gpu: 4 + + patch_shape_x: 448 + patch_shape_y: 448 + # Patch size. Patch training is used if these dimensions differ from + # img_shape_x and img_shape_y. + patch_num: 16 + # Number of patches from a single sample. Total number of patches is + # patch_num * batch_size_global. + max_patch_per_gpu: 9 + # Maximum number of pataches a gpu can hold + + lr: 0.0002 + # Learning rate + grad_clip_threshold: 1e6 + lr_decay: 0.7 + lr_rampup: 1000000 + + # Performance + perf: + fp_optimizations: amp-bf16 + # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] + # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} + dataloader_workers: 4 + # DataLoader worker processes + songunet_checkpoint_level: 0 # 0 means no checkpointing + # Gradient checkpointing level, value is number of layers to checkpoint + # optimization_mode: True + use_apex_gn: True + torch_compile: True + profile_mode: False + + io: + regression_checkpoint_path: /lustre/fsw/portfolios/coreai/users/asui/video-corrdiff-checkpoints/training-state-regression-000513.mdlus + # Path to load the regression checkpoint + + # Where to load the regression checkpoint + print_progress_freq: 1000 + # How often to print progress + save_checkpoint_freq: 500000 + # How often to save the checkpoints, measured in number of processed samples + validation_freq: 5000 + # how often to record the validation loss, measured in number of processed samples + validation_steps: 10 + # how many loss evaluations are used to compute the validation loss per checkpoint + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" + results_dir: "./wandb" + # Directory to store wandb results + watch_model: false + # If true, wandb will track model parameters and gradients diff --git a/examples/generative/corrdiff/train.py b/examples/generative/corrdiff/train.py index bce79e89b3..9142aacf78 100644 --- a/examples/generative/corrdiff/train.py +++ b/examples/generative/corrdiff/train.py @@ -47,6 +47,16 @@ handle_and_clip_gradients, is_time_for_periodic_task, ) +import nvtx +import contextlib +import pdb + + +torch._dynamo.reset() +# Increase the cache size limit +torch._dynamo.config.cache_size_limit = 264 # Set to a higher value +torch._dynamo.config.verbose = True # Enable verbose logging +torch._dynamo.config.suppress_errors = False # Forces the error to show all details def checkpoint_list(path, suffix=".mdlus"): @@ -153,7 +163,6 @@ def main(cfg: DictConfig) -> None: prob_channels = dataset.get_prob_channel_index() else: prob_channels = None - # Parse the patch shape if ( cfg.model.name == "patched_diffusion" @@ -192,6 +201,31 @@ def main(cfg: DictConfig) -> None: model_args["prob_channels"] = prob_channels if hasattr(cfg.model, "model_args"): # override defaults from config file model_args.update(OmegaConf.to_container(cfg.model.model_args)) + + optimization_mode = False + # optimization mode: + # if hasattr(cfg.training.perf, "torch_compile") and cfg.training.perf.torch_compile: + # model_args.update({"use_apex_gn":True,"fused_conv_bias":True,"model_type":"SongUNetPosOptEmbd" }) + # optimization_mode = True + use_torch_compile = False + use_apex_gn = False + profile_mode = False + + if hasattr(cfg.training.perf, "torch_compile"): + use_torch_compile = cfg.training.perf.torch_compile + if hasattr(cfg.training.perf, "use_apex_gn"): + use_apex_gn = cfg.training.perf.use_apex_gn + if hasattr(cfg.training.perf, "profile_mode"): + profile_mode = cfg.training.perf.profile_mode + + model_args.update( + { + "use_apex_gn": use_apex_gn, + "profile_mode": profile_mode, + "amp_mode": enable_amp, + } + ) + if cfg.model.name == "regression": model = UNet( img_in_channels=img_in_channels + model_args["N_grid_channels"], @@ -225,6 +259,8 @@ def main(cfg: DictConfig) -> None: raise ValueError(f"Invalid model: {cfg.model.name}") model.train().requires_grad_(True).to(dist.device) + if use_apex_gn: + model.to(memory_format=torch.channels_last) # Check if regression model is used with patching if ( @@ -242,7 +278,9 @@ def main(cfg: DictConfig) -> None: device_ids=[dist.local_rank], broadcast_buffers=True, output_device=dist.device, - find_unused_parameters=dist.find_unused_parameters, + find_unused_parameters=True, # dist.find_unused_parameters, + bucket_cap_mb=35, + gradient_as_bucket_view=True, ) if cfg.wandb.watch_model and dist.rank == 0: wandb.watch(model) @@ -259,11 +297,64 @@ def main(cfg: DictConfig) -> None: raise FileNotFoundError( f"Expected this regression checkpoint but not found: {regression_checkpoint_path}" ) - regression_net = Module.from_checkpoint(regression_checkpoint_path) + reg_model_args = { + "use_apex_gn": use_apex_gn, + "profile_mode": profile_mode, + "amp_mode": enable_amp, + } + regression_net = Module.from_checkpoint( + regression_checkpoint_path, reg_model_args + ) regression_net.eval().requires_grad_(False).to(dist.device) + if use_apex_gn: + regression_net.to(memory_format=torch.channels_last) logger0.success("Loaded the pre-trained regression model") + if use_torch_compile: + model = torch.compile(model) + regression_net = torch.compile(regression_net) + + # Compute the number of required gradient accumulation rounds + # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size + batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds( + cfg.training.hp.total_batch_size, + cfg.training.hp.batch_size_per_gpu, + dist.world_size, + ) + batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu + logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") + + patch_num = getattr(cfg.training.hp, "patch_num", 1) + max_patch_per_gpu = getattr(cfg.training.hp, "max_patch_per_gpu", 1) + + # calculate patch per iter + if hasattr(cfg.training.hp, "max_patch_per_gpu") and max_patch_per_gpu > 1: + max_patch_num_per_iter = min( + patch_num, (max_patch_per_gpu // batch_size_per_gpu) + ) # Ensure at least 1 patch per iter + patch_iterations = ( + patch_num + max_patch_num_per_iter - 1 + ) // max_patch_num_per_iter + patch_nums_iter = [ + min(max_patch_num_per_iter, patch_num - i * max_patch_num_per_iter) + for i in range(patch_iterations) + ] + print( + f"max_patch_num_per_iter is {max_patch_num_per_iter}, patch_iterations is {patch_iterations}, patch_nums_iter is {patch_nums_iter}" + ) + else: + patch_nums_iter = [patch_num] + + use_patch_grad_acc = False + if len(patch_nums_iter) > 1: + use_patch_grad_acc = True + # Instantiate the loss function + # if cfg.model.name == "patched_diffusion" and len(patch_nums_iter)>1: + # loss_fn = ResidualLoss_Opt( + # regression_net=regression_net, + # hr_mean_conditioning=cfg.model.hr_mean_conditioning, + # ) if cfg.model.name in ( "diffusion", "patched_diffusion", @@ -280,22 +371,16 @@ def main(cfg: DictConfig) -> None: # Instantiate the optimizer optimizer = torch.optim.Adam( - params=model.parameters(), lr=cfg.training.hp.lr, betas=[0.9, 0.999], eps=1e-8 + params=model.parameters(), + lr=cfg.training.hp.lr, + betas=[0.9, 0.999], + eps=1e-8, + fused=True, ) # Record the current time to measure the duration of subsequent operations. start_time = time.time() - # Compute the number of required gradient accumulation rounds - # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size - batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds( - cfg.training.hp.total_batch_size, - cfg.training.hp.batch_size_per_gpu, - dist.world_size, - ) - batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu - logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") - ## Resume training from previous checkpoints if exists if dist.world_size > 1: torch.distributed.barrier() @@ -319,207 +404,313 @@ def main(cfg: DictConfig) -> None: # init variables to monitor running mean of average loss since last periodic average_loss_running_mean = 0 n_average_loss_running_mean = 1 + start_nimg = cur_nimg + input_dtype = torch.float32 + if enable_amp: + input_dtype = torch.float32 + elif fp16: + input_dtype = torch.float16 + + # enable profiler: + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + + while not done: + tick_start_nimg = cur_nimg + tick_start_time = time.time() + + if cur_nimg - start_nimg == 24 * cfg.training.hp.total_batch_size: + logger0.info(f"Starting Profiler at {cur_nimg}") + torch.cuda.profiler.start() + + if cur_nimg - start_nimg == 25 * cfg.training.hp.total_batch_size: + logger0.info(f"Stoping Profiler at {cur_nimg}") + torch.cuda.profiler.stop() + + with nvtx.annotate("Training iteration", color="green"): + # Compute & accumulate gradients + optimizer.zero_grad(set_to_none=True) + loss_accum = 0 + for n_i in range(num_accumulation_rounds): + with nvtx.annotate( + f"accumulation round {n_i}", color="Magenta" + ): + with nvtx.annotate(f"loading data", color="green"): + img_clean, img_lr, *lead_time_label = next( + dataset_iterator + ) + if use_apex_gn: + img_clean = img_clean.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr = img_lr.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + else: + img_clean = ( + img_clean.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr = ( + img_lr.to(dist.device) + .to(input_dtype) + .contiguous() + ) + loss_fn_kwargs = { + "net": model, + "img_clean": img_clean, + "img_lr": img_lr, + "augment_pipe": None, + "use_patch_grad_acc": use_patch_grad_acc, + } - while not done: - tick_start_nimg = cur_nimg - tick_start_time = time.time() - # Compute & accumulate gradients - optimizer.zero_grad(set_to_none=True) - loss_accum = 0 - for _ in range(num_accumulation_rounds): - img_clean, img_lr, *lead_time_label = next(dataset_iterator) - img_clean = img_clean.to(dist.device).to(torch.float32).contiguous() - img_lr = img_lr.to(dist.device).to(torch.float32).contiguous() - loss_fn_kwargs = { - "net": model, - "img_clean": img_clean, - "img_lr": img_lr, - "augment_pipe": None, - } - # Sample new random patches for this iteration and add patching to - # loss arguments - if patching is not None: - patching.reset_patch_indices() - loss_fn_kwargs.update({"patching": patching}) - if lead_time_label: - lead_time_label = lead_time_label[0].to(dist.device).contiguous() - loss_fn_kwargs.update({"lead_time_label": lead_time_label}) - else: - lead_time_label = None - with torch.autocast("cuda", dtype=amp_dtype, enabled=enable_amp): - loss = loss_fn(**loss_fn_kwargs) - loss = loss.sum() / batch_size_per_gpu - loss_accum += loss / num_accumulation_rounds - loss.backward() - - loss_sum = torch.tensor([loss_accum], device=dist.device) - if dist.world_size > 1: - torch.distributed.barrier() - torch.distributed.all_reduce(loss_sum, op=torch.distributed.ReduceOp.SUM) - average_loss = (loss_sum / dist.world_size).cpu().item() - - # update running mean of average loss since last periodic task - average_loss_running_mean += ( - average_loss - average_loss_running_mean - ) / n_average_loss_running_mean - n_average_loss_running_mean += 1 - - if dist.rank == 0: - writer.add_scalar("training_loss", average_loss, cur_nimg) - writer.add_scalar( - "training_loss_running_mean", average_loss_running_mean, cur_nimg - ) - wandb.log( - { - "training_loss": average_loss, - "training_loss_running_mean": average_loss_running_mean, - } - ) + if lead_time_label: + lead_time_label = ( + lead_time_label[0].to(dist.device).contiguous() + ) + loss_fn_kwargs.update( + {"lead_time_label": lead_time_label} + ) + else: + lead_time_label = None + if use_patch_grad_acc: + loss_fn.y_mean = None + + for patch_num_per_iter in patch_nums_iter: + if patching is not None: + patching.set_patch_sum(patch_num_per_iter) + loss_fn_kwargs.update({"patching": patching}) + # pdb.set_trace() + with nvtx.annotate(f"loss forward", color="green"): + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + loss = loss_fn(**loss_fn_kwargs) + + loss = loss.sum() / batch_size_per_gpu + loss_accum += loss / num_accumulation_rounds + with nvtx.annotate(f"loss backward", color="yellow"): + loss.backward() + + with nvtx.annotate(f"loss aggregate", color="green"): + loss_sum = torch.tensor([loss_accum], device=dist.device) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + loss_sum, op=torch.distributed.ReduceOp.SUM + ) + average_loss = (loss_sum / dist.world_size).cpu().item() - ptt = is_time_for_periodic_task( - cur_nimg, - cfg.training.io.print_progress_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ) - if ptt: - # reset running mean of average loss - average_loss_running_mean = 0 - n_average_loss_running_mean = 1 - - # Update weights. - lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate - for g in optimizer.param_groups: - if lr_rampup > 0: - g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1) - if cur_nimg >= lr_rampup: - g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // 5e6) - current_lr = g["lr"] - if dist.rank == 0: - writer.add_scalar("learning_rate", current_lr, cur_nimg) - handle_and_clip_gradients( - model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold - ) - optimizer.step() - - cur_nimg += cfg.training.hp.total_batch_size - done = cur_nimg >= cfg.training.hp.training_duration - - # Validation - if validation_dataset_iterator is not None: - valid_loss_accum = 0 - if is_time_for_periodic_task( - cur_nimg, - cfg.training.io.validation_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - ): - with torch.no_grad(): - for _ in range(cfg.training.io.validation_steps): - img_clean_valid, img_lr_valid, *lead_time_label_valid = next( - validation_dataset_iterator - ) + # update running mean of average loss since last periodic task + average_loss_running_mean += ( + average_loss - average_loss_running_mean + ) / n_average_loss_running_mean + n_average_loss_running_mean += 1 - img_clean_valid = ( - img_clean_valid.to(dist.device) - .to(torch.float32) - .contiguous() - ) - img_lr_valid = ( - img_lr_valid.to(dist.device).to(torch.float32).contiguous() - ) - loss_valid_kwargs = { - "net": model, - "img_clean": img_clean_valid, - "img_lr": img_lr_valid, - "augment_pipe": None, - } - if lead_time_label_valid: - lead_time_label_valid = ( - lead_time_label_valid[0].to(dist.device).contiguous() - ) - loss_valid_kwargs.update( - {"lead_time_label": lead_time_label_valid} - ) - loss_valid = loss_fn(**loss_valid_kwargs) - loss_valid = ( - (loss_valid.sum() / batch_size_per_gpu).cpu().item() - ) - valid_loss_accum += ( - loss_valid / cfg.training.io.validation_steps - ) - valid_loss_sum = torch.tensor( - [valid_loss_accum], device=dist.device - ) - if dist.world_size > 1: - torch.distributed.barrier() - torch.distributed.all_reduce( - valid_loss_sum, op=torch.distributed.ReduceOp.SUM - ) - average_valid_loss = valid_loss_sum / dist.world_size if dist.rank == 0: + writer.add_scalar("training_loss", average_loss, cur_nimg) writer.add_scalar( - "validation_loss", average_valid_loss, cur_nimg - ) - wandb.log( - { - "validation_loss": average_valid_loss, - } + "training_loss_running_mean", + average_loss_running_mean, + cur_nimg, ) - if is_time_for_periodic_task( - cur_nimg, - cfg.training.io.print_progress_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ): - # Print stats if we crossed the printing threshold with this batch - tick_end_time = time.time() - fields = [] - fields += [f"samples {cur_nimg:<9.1f}"] - fields += [f"training_loss {average_loss:<7.2f}"] - fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] - fields += [f"learning_rate {current_lr:<7.8f}"] - fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] - fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] - fields += [ - f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" - ] - fields += [ - f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" - ] - fields += [ - f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" - ] - fields += [ - f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" - ] - logger0.info(" ".join(fields)) - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - - # Save checkpoints - if dist.world_size > 1: - torch.distributed.barrier() - if is_time_for_periodic_task( - cur_nimg, - cfg.training.io.save_checkpoint_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ): - save_checkpoint( - path=checkpoint_dir, - models=model, - optimizer=optimizer, - epoch=cur_nimg, - ) + ptt = is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ) + if ptt: + # reset running mean of average loss + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + + # Update weights. + with nvtx.annotate(f"update weights", color="blue"): + lr_rampup = ( + cfg.training.hp.lr_rampup + ) # ramp up the learning rate + for g in optimizer.param_groups: + if lr_rampup > 0: + g["lr"] = cfg.training.hp.lr * min( + cur_nimg / lr_rampup, 1 + ) + if cur_nimg >= lr_rampup: + g["lr"] *= cfg.training.hp.lr_decay ** ( + (cur_nimg - lr_rampup) // 5e6 + ) + current_lr = g["lr"] + if dist.rank == 0: + writer.add_scalar("learning_rate", current_lr, cur_nimg) + handle_and_clip_gradients( + model, + grad_clip_threshold=cfg.training.hp.grad_clip_threshold, + ) + with nvtx.annotate("optimizer step", color="blue"): + optimizer.step() + + cur_nimg += cfg.training.hp.total_batch_size + done = cur_nimg >= cfg.training.hp.training_duration + + with nvtx.annotate("validation", color="red"): + # Validation + if validation_dataset_iterator is not None: + valid_loss_accum = 0 + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.validation_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + ): + with torch.no_grad(): + for _ in range(cfg.training.io.validation_steps): + ( + img_clean_valid, + img_lr_valid, + *lead_time_label_valid, + ) = next(validation_dataset_iterator) + + if use_apex_gn: + img_clean_valid = img_clean_valid.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr_valid = img_lr_valid.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + + else: + img_clean_valid = ( + img_clean_valid.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr_valid = ( + img_lr_valid.to(dist.device) + .to(input_dtype) + .contiguous() + ) + + loss_valid_kwargs = { + "net": model, + "img_clean": img_clean_valid, + "img_lr": img_lr_valid, + "augment_pipe": None, + "use_patch_grad_acc": use_patch_grad_acc, + } + if lead_time_label_valid: + lead_time_label_valid = ( + lead_time_label_valid[0] + .to(dist.device) + .contiguous() + ) + loss_valid_kwargs.update( + {"lead_time_label": lead_time_label_valid} + ) + if use_patch_grad_acc: + loss_fn.y_mean = None + + for patch_num_per_iter in patch_nums_iter: + if patching is not None: + patching.set_patch_sum(patch_num_per_iter) + loss_fn_kwargs.update( + {"patching": patching} + ) + # pdb.set_trace() + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + loss_valid = loss_fn(**loss_valid_kwargs) + + loss_valid = ( + (loss_valid.sum() / batch_size_per_gpu) + .cpu() + .item() + ) + valid_loss_accum += ( + loss_valid + / cfg.training.io.validation_steps + ) + valid_loss_sum = torch.tensor( + [valid_loss_accum], device=dist.device + ) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + valid_loss_sum, + op=torch.distributed.ReduceOp.SUM, + ) + average_valid_loss = valid_loss_sum / dist.world_size + if dist.rank == 0: + writer.add_scalar( + "validation_loss", average_valid_loss, cur_nimg + ) + + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # Print stats if we crossed the printing threshold with this batch + tick_end_time = time.time() + fields = [] + fields += [f"samples {cur_nimg:<9.1f}"] + fields += [f"training_loss {average_loss:<7.2f}"] + fields += [ + f"training_loss_running_mean {average_loss_running_mean:<7.2f}" + ] + fields += [f"learning_rate {current_lr:<7.8f}"] + fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] + fields += [ + f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}" + ] + fields += [ + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" + ] + fields += [ + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + ] + logger0.info(" ".join(fields)) + torch.cuda.reset_peak_memory_stats() + + # Save checkpoints + if dist.world_size > 1: + torch.distributed.barrier() + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.save_checkpoint_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + save_checkpoint( + path=checkpoint_dir, + models=model, + optimizer=optimizer, + epoch=cur_nimg, + ) # Retain only the recent n checkpoints, if desired if cfg.training.io.save_n_recent_checkpoints > 0: diff --git a/physicsnemo/metrics/diffusion/loss.py b/physicsnemo/metrics/diffusion/loss.py index 6d51e8bb8c..25a9eb49b2 100644 --- a/physicsnemo/metrics/diffusion/loss.py +++ b/physicsnemo/metrics/diffusion/loss.py @@ -518,6 +518,7 @@ def __init__( self.P_std = P_std self.sigma_data = sigma_data self.hr_mean_conditioning = hr_mean_conditioning + self.y_mean = None def __call__( self, @@ -529,6 +530,7 @@ def __call__( augment_pipe: Optional[ Callable[[Tensor], Tuple[Tensor, Optional[Tensor]]] ] = None, + use_patch_grad_acc: bool = False, ) -> Tensor: """ Calculate and return the loss for denoising score matching. @@ -611,6 +613,9 @@ def __call__( Tuple[torch.Tensor, Optional[torch.Tensor]]: - Augmented images of shape (B, C_hr+C_lr, H, W) - Optional augmentation labels + use_patch_grad_acc: bool, optional + A boolean flag indicating whether to enable multi-iterations of patching accumulations + for amortizing regression cost. Default False. Returns ------- @@ -656,28 +661,52 @@ def __call__( y_lr_res = y_lr batch_size = y.shape[0] - # form residual - if lead_time_label is not None: - y_mean = self.regression_net( - torch.zeros_like(y, device=img_clean.device), - y_lr_res, - lead_time_label=lead_time_label, - augment_labels=augment_labels, - ) + # if using multi-iterations of patching, switch to optimized version + if use_patch_grad_acc: + # form residual + if self.y_mean is None: + if lead_time_label is not None: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + augment_labels=augment_labels, + ) + self.y_mean = y_mean + + # if on full domain: else: - y_mean = self.regression_net( - torch.zeros_like(y, device=img_clean.device), - y_lr_res, - augment_labels=augment_labels, - ) + # form residual + if lead_time_label is not None: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + augment_labels=augment_labels, + ) + + self.y_mean = y_mean - y = y - y_mean + y = y - self.y_mean if self.hr_mean_conditioning: - y_lr = torch.cat((y_mean, y_lr), dim=1).contiguous() + y_lr = torch.cat((self.y_mean, y_lr), dim=1) # patchified training # conditioning: cat(y_mean, y_lr, input_interp, pos_embd), 4+12+100+4 + # removed patch_embedding_selector due to compilation issue with dynamo. if patching: # Patched residual # (batch_size * patch_num, c_out, patch_shape_y, patch_shape_x) @@ -686,17 +715,8 @@ def __call__( # (batch_size * patch_num, 2*c_in, patch_shape_y, patch_shape_x) y_lr_patched = patching.apply(input=y_lr, additional_input=img_lr) - # Function to select the correct positional embedding for each - # patch - def patch_embedding_selector(emb): - # emb: (N_pe, image_shape_y, image_shape_x) - # return: (batch_size * patch_num, N_pe, patch_shape_y, patch_shape_x) - return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) - y = y_patched y_lr = y_lr_patched - else: - patch_embedding_selector = None # Noise rnd_normal = torch.randn([y.shape[0], 1, 1, 1], device=img_clean.device) @@ -711,7 +731,10 @@ def patch_embedding_selector(emb): latent, y_lr, sigma, - embedding_selector=patch_embedding_selector, + embedding_selector=None, + global_index=patching.global_index(batch_size, img_clean.device) + if patching is not None + else None, lead_time_label=lead_time_label, augment_labels=augment_labels, ) @@ -720,7 +743,10 @@ def patch_embedding_selector(emb): latent, y_lr, sigma, - embedding_selector=patch_embedding_selector, + embedding_selector=None, + global_index=patching.global_index(batch_size, img_clean.device) + if patching is not None + else None, augment_labels=augment_labels, ) loss = weight * ((D_yn - y) ** 2) diff --git a/physicsnemo/models/diffusion/layers.py b/physicsnemo/models/diffusion/layers.py index 1fb3b171e9..e91028a7cf 100644 --- a/physicsnemo/models/diffusion/layers.py +++ b/physicsnemo/models/diffusion/layers.py @@ -19,15 +19,29 @@ Diffusion-Based Generative Models". """ +import contextlib +import importlib from typing import Any, Dict, List import numpy as np +import nvtx import torch +import torch.cuda.amp as amp from einops import rearrange -from torch.nn.functional import silu +from torch.nn.functional import elu, gelu, leaky_relu, relu, sigmoid, silu, tanh from physicsnemo.models.diffusion import weight_init +# Import apex GroupNorm if installed only +_is_apex_available = False +if torch.cuda.is_available(): + try: + apex_gn_module = importlib.import_module("apex.contrib.group_norm") + ApexGroupNorm = getattr(apex_gn_module, "GroupNorm") + _is_apex_available = True + except ImportError: + pass + class Linear(torch.nn.Module): """ @@ -56,6 +70,8 @@ class Linear(torch.nn.Module): A scaling factor to multiply with the initialized weights. By default 1. init_bias : float, optional A scaling factor to multiply with the initialized biases. By default 0. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( @@ -66,10 +82,12 @@ def __init__( init_mode: str = "kaiming_normal", init_weight: int = 1, init_bias: int = 0, + amp_mode: bool = False, ): super().__init__() self.in_features = in_features self.out_features = out_features + self.amp_mode = amp_mode init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) self.weight = torch.nn.Parameter( weight_init([out_features, in_features], **init_kwargs) * init_weight @@ -81,9 +99,16 @@ def __init__( ) def forward(self, x): - x = x @ self.weight.to(x.dtype).t() + weight, bias = self.weight, self.bias + # pdb.set_trace() + if not self.amp_mode: + if self.weight is not None and self.weight.dtype != x.dtype: + weight = self.weight.to(x.dtype) + if self.bias is not None and self.bias.dtype != x.dtype: + bias = self.bias.to(x.dtype) + x = x @ weight.t() if self.bias is not None: - x = x.add_(self.bias.to(x.dtype)) + x = x.add_(bias) return x @@ -128,6 +153,10 @@ class Conv2d(torch.nn.Module): A scaling factor to multiply with the initialized weights. By default 1.0. init_bias : float, optional A scaling factor to multiply with the initialized biases. By default 0.0. + fused_conv_bias: bool, optional + A boolean flag indicating whether bias will be passed as a parameter of conv2d. By default False. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( @@ -143,9 +172,16 @@ def __init__( init_mode: str = "kaiming_normal", init_weight: float = 1.0, init_bias: float = 0.0, + fused_conv_bias: bool = False, + amp_mode: bool = False, ): if up and down: raise ValueError("Both 'up' and 'down' cannot be true at the same time.") + if not kernel and fused_conv_bias: + print( + "Warning: Kernel is required when fused_conv_bias is enabled. Setting fused_conv_bias to False." + ) + fused_conv_bias = False super().__init__() self.in_channels = in_channels @@ -153,6 +189,8 @@ def __init__( self.up = up self.down = down self.fused_resample = fused_resample + self.fused_conv_bias = fused_conv_bias + self.amp_mode = amp_mode init_kwargs = dict( mode=init_mode, fan_in=in_channels * kernel * kernel, @@ -176,13 +214,21 @@ def __init__( self.register_buffer("resample_filter", f if up or down else None) def forward(self, x): - w = self.weight.to(x.dtype) if self.weight is not None else None - b = self.bias.to(x.dtype) if self.bias is not None else None - f = ( - self.resample_filter.to(x.dtype) - if self.resample_filter is not None - else None - ) + weight, bias, resample_filter = self.weight, self.bias, self.resample_filter + if not self.amp_mode: + if self.weight is not None and self.weight.dtype != x.dtype: + weight = self.weight.to(x.dtype) + if self.bias is not None and self.bias.dtype != x.dtype: + bias = self.bias.to(x.dtype) + if ( + self.resample_filter is not None + and self.resample_filter.dtype != x.dtype + ): + resample_filter = self.resample_filter.to(x.dtype) + + w = weight if weight is not None else None + b = bias if bias is not None else None + f = resample_filter if resample_filter is not None else None w_pad = w.shape[-1] // 2 if w is not None else 0 f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 @@ -194,15 +240,29 @@ def forward(self, x): stride=2, padding=max(f_pad - w_pad, 0), ) - x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) + if self.fused_conv_bias: + x = torch.nn.functional.conv2d( + x, w, padding=max(w_pad - f_pad, 0), bias=b + ) + else: + x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) elif self.fused_resample and self.down and w is not None: x = torch.nn.functional.conv2d(x, w, padding=w_pad + f_pad) - x = torch.nn.functional.conv2d( - x, - f.tile([self.out_channels, 1, 1, 1]), - groups=self.out_channels, - stride=2, - ) + if self.fused_conv_bias: + x = torch.nn.functional.conv2d( + x, + f.tile([self.out_channels, 1, 1, 1]), + groups=self.out_channels, + stride=2, + bias=b, + ) + else: + x = torch.nn.functional.conv2d( + x, + f.tile([self.out_channels, 1, 1, 1]), + groups=self.out_channels, + stride=2, + ) else: if self.up: x = torch.nn.functional.conv_transpose2d( @@ -220,9 +280,12 @@ def forward(self, x): stride=2, padding=f_pad, ) - if w is not None: - x = torch.nn.functional.conv2d(x, w, padding=w_pad) - if b is not None: + if w is not None: # ask in corrdiff channel whether w will ever be none + if self.fused_conv_bias: + x = torch.nn.functional.conv2d(x, w, padding=w_pad, bias=b) + else: + x = torch.nn.functional.conv2d(x, w, padding=w_pad) + if b is not None and not self.fused_conv_bias: x = x.add_(b.reshape(1, -1, 1, 1)) return x @@ -249,7 +312,15 @@ class GroupNorm(torch.nn.Module): eps : float, optional A small number added to the variance to prevent division by zero, by default 1e-5. - + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + fused_act : bool, optional + Whether to fuse the activation function with GroupNorm. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. Notes ----- If `num_channels` is not divisible by `num_groups`, the actual number of groups @@ -262,28 +333,71 @@ def __init__( num_groups: int = 32, min_channels_per_group: int = 4, eps: float = 1e-5, + use_apex_gn: bool = False, + fused_act: bool = False, + act: str = None, + amp_mode: bool = False, ): + if fused_act and act is None: + raise ValueError("'act' must be specified when 'fused_act' is set to True.") + super().__init__() self.num_groups = min(num_groups, num_channels // min_channels_per_group) self.eps = eps self.weight = torch.nn.Parameter(torch.ones(num_channels)) self.bias = torch.nn.Parameter(torch.zeros(num_channels)) + if use_apex_gn and not _is_apex_available: + raise ValueError("'apex' is not installed, set `use_apex_gn=False`") + self.use_apex_gn = use_apex_gn + self.fused_act = fused_act + self.act = act.lower() if act else act + self.act_fn = None + self.amp_mode = amp_mode + if self.use_apex_gn: + if self.act: + self.gn = ApexGroupNorm( + num_groups=self.num_groups, + num_channels=num_channels, + eps=self.eps, + affine=True, + act=self.act, + ) + + else: + self.gn = ApexGroupNorm( + num_groups=self.num_groups, + num_channels=num_channels, + eps=self.eps, + affine=True, + ) + if self.fused_act: + self.act_fn = self.get_activation_function() def forward(self, x): - if self.training: + weight, bias = self.weight, self.bias + if not self.amp_mode: + if not self.use_apex_gn: + if weight.dtype != x.dtype: + weight = self.weight.to(x.dtype) + if bias.dtype != x.dtype: + bias = self.bias.to(x.dtype) + if self.use_apex_gn: + x = self.gn(x) + elif self.training: # Use default torch implementation of GroupNorm for training # This does not support channels last memory format x = torch.nn.functional.group_norm( x, num_groups=self.num_groups, - weight=self.weight.to(x.dtype), - bias=self.bias.to(x.dtype), + weight=weight, + bias=bias, eps=self.eps, ) + if self.fused_act: + x = self.act_fn(x) else: # Use custom GroupNorm implementation that supports channels last # memory layout for inference - dtype = x.dtype x = x.float() x = rearrange(x, "b (g c) h w -> b g c h w", g=self.num_groups) @@ -293,13 +407,34 @@ def forward(self, x): x = (x - mean) * (var + self.eps).rsqrt() x = rearrange(x, "b g c h w -> b (g c) h w") - weight = rearrange(self.weight, "c -> 1 c 1 1") - bias = rearrange(self.bias, "c -> 1 c 1 1") + weight = rearrange(weight, "c -> 1 c 1 1") + bias = rearrange(bias, "c -> 1 c 1 1") x = x * weight + bias - x = x.type(dtype) + if self.fused_act: + x = self.act_fn(x) return x + def get_activation_function(self): + """ + Get activation function given string input + """ + + activation_map = { + "silu": silu, + "relu": relu, + "leaky_relu": leaky_relu, + "sigmoid": sigmoid, + "tanh": tanh, + "gelu": gelu, + "elu": elu, + } + + act_fn = activation_map.get(self.act, None) + if act_fn is None: + raise ValueError(f"Unknown activation function: {self.act}") + return act_fn + class AttentionOp(torch.autograd.Function): """ @@ -331,6 +466,7 @@ def backward(ctx, dw): dim=2, input_dtype=torch.float32, ) + dq = torch.einsum("nck,nqk->ncq", k.to(torch.float32), db).to( q.dtype ) / np.sqrt(k.shape[1]) @@ -383,6 +519,17 @@ class UNetBlock(torch.nn.Module): init_attn : dict, optional Initialization parameters specific to attention mechanism layers. Defaults to 'init' if not provided. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + fused_conv_bias: bool, optional + A boolean flag indicating whether bias will be passed as a parameter of conv2d. By default False. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( @@ -404,6 +551,11 @@ def __init__( init: Dict[str, Any] = dict(), init_zero: Dict[str, Any] = dict(init_weight=0), init_attn: Any = None, + use_apex_gn: bool = False, + act: str = "silu", + fused_conv_bias: bool = False, + profile_mode: bool = False, + amp_mode: bool = False, ): super().__init__() @@ -420,8 +572,16 @@ def __init__( self.dropout = dropout self.skip_scale = skip_scale self.adaptive_scale = adaptive_scale - - self.norm0 = GroupNorm(num_channels=in_channels, eps=eps) + self.profile_mode = profile_mode + self.amp_mode = amp_mode + self.norm0 = GroupNorm( + num_channels=in_channels, + eps=eps, + use_apex_gn=use_apex_gn, + fused_act=True, + act=act, + amp_mode=amp_mode, + ) self.conv0 = Conv2d( in_channels=in_channels, out_channels=out_channels, @@ -429,21 +589,45 @@ def __init__( up=up, down=down, resample_filter=resample_filter, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, **init, ) self.affine = Linear( in_features=emb_channels, out_features=out_channels * (2 if adaptive_scale else 1), + amp_mode=amp_mode, **init, ) - self.norm1 = GroupNorm(num_channels=out_channels, eps=eps) + if self.adaptive_scale: + self.norm1 = GroupNorm( + num_channels=out_channels, + eps=eps, + use_apex_gn=use_apex_gn, + amp_mode=amp_mode, + ) + else: + self.norm1 = GroupNorm( + num_channels=out_channels, + eps=eps, + use_apex_gn=use_apex_gn, + act=act, + fused_act=True, + amp_mode=amp_mode, + ) self.conv1 = Conv2d( - in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero + in_channels=out_channels, + out_channels=out_channels, + kernel=3, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, + **init_zero, ) self.skip = None if out_channels != in_channels or up or down: kernel = 1 if resample_proj or out_channels != in_channels else 0 + fused_conv_bias = fused_conv_bias if kernel != 0 else False self.skip = Conv2d( in_channels=in_channels, out_channels=out_channels, @@ -451,56 +635,75 @@ def __init__( up=up, down=down, resample_filter=resample_filter, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, **init, ) if self.num_heads: - self.norm2 = GroupNorm(num_channels=out_channels, eps=eps) + self.norm2 = GroupNorm( + num_channels=out_channels, + eps=eps, + use_apex_gn=use_apex_gn, + amp_mode=amp_mode, + ) self.qkv = Conv2d( in_channels=out_channels, out_channels=out_channels * 3, kernel=1, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, **(init_attn if init_attn is not None else init), ) self.proj = Conv2d( in_channels=out_channels, out_channels=out_channels, kernel=1, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, **init_zero, ) def forward(self, x, emb): - torch.cuda.nvtx.range_push("UNetBlock") - orig = x - x = self.conv0(silu(self.norm0(x))) - - params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype) - if self.adaptive_scale: - scale, shift = params.chunk(chunks=2, dim=1) - x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) - else: - x = silu(self.norm1(x.add_(params))) - - x = self.conv1( - torch.nn.functional.dropout(x, p=self.dropout, training=self.training) - ) - x = x.add_(self.skip(orig) if self.skip is not None else orig) - x = x * self.skip_scale - - if self.num_heads: - q, k, v = ( - self.qkv(self.norm2(x)) - .reshape( - x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1 - ) - .unbind(2) + with nvtx.annotate( + message="UNetBlock", color="purple" + ) if self.profile_mode else contextlib.nullcontext(): + orig = x + x = self.conv0(self.norm0(x)) + params = self.affine(emb).unsqueeze(2).unsqueeze(3) + if not self.amp_mode: + if params.dtype != x.dtype: + params = params.to(x.dtype) + + if self.adaptive_scale: + scale, shift = params.chunk(chunks=2, dim=1) + x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) + else: + x = self.norm1(x.add_(params)) + + x = self.conv1( + torch.nn.functional.dropout(x, p=self.dropout, training=self.training) ) - w = AttentionOp.apply(q, k) - a = torch.einsum("nqk,nck->ncq", w, v) - x = self.proj(a.reshape(*x.shape)).add_(x) + x = x.add_(self.skip(orig) if self.skip is not None else orig) x = x * self.skip_scale - torch.cuda.nvtx.range_pop() - return x + + if self.num_heads: + q, k, v = ( + self.qkv(self.norm2(x)) + .reshape( + x.shape[0], self.num_heads, x.shape[1] // self.num_heads, 3, -1 + ) + .unbind(3) + ) + # w = AttentionOp.apply(q, k) + # a = torch.einsum("nqk,nck->ncq", w, v) + # Compute attention in one step + with amp.autocast(enabled=self.amp_mode): + attn = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = self.proj(attn.reshape(*x.shape)).add_(x) + x = x * self.skip_scale + + return x class PositionalEmbedding(torch.nn.Module): @@ -516,16 +719,23 @@ class PositionalEmbedding(torch.nn.Module): Maximum number of positions for the embeddings, by default 10000. endpoint : bool, optional If True, the embedding considers the endpoint. By default False. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( - self, num_channels: int, max_positions: int = 10000, endpoint: bool = False + self, + num_channels: int, + max_positions: int = 10000, + endpoint: bool = False, + amp_mode: bool = False, ): super().__init__() self.num_channels = num_channels self.max_positions = max_positions self.endpoint = endpoint + self.amp_mode = amp_mode def forward(self, x): freqs = torch.arange( @@ -533,7 +743,10 @@ def forward(self, x): ) freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) freqs = (1 / self.max_positions) ** freqs - x = x.ger(freqs.to(x.dtype)) + if not self.amp_mode: + if freqs.dtype != x.dtype: + freqs = freqs.to(x.dtype) + x = x.ger(freqs) x = torch.cat([x.cos(), x.sin()], dim=1) return x @@ -555,13 +768,21 @@ class FourierEmbedding(torch.nn.Module): scale : int, optional A scale factor applied to the random frequencies, controlling their range and thereby the frequency of oscillations in the embedding space. By default 16. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ - def __init__(self, num_channels: int, scale: int = 16): + def __init__(self, num_channels: int, scale: int = 16, amp_mode: bool = False): super().__init__() self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) + self.amp_mode = amp_mode def forward(self, x): - x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) + freqs = self.freqs + if not self.amp_mode: + if x.dtype != self.freqs.dtype: + freqs = self.freqs.to(x.dtype) + + x = x.ger((2 * np.pi * freqs)) x = torch.cat([x.cos(), x.sin()], dim=1) return x diff --git a/physicsnemo/models/diffusion/preconditioning.py b/physicsnemo/models/diffusion/preconditioning.py index cbc04f4f75..18927fc94b 100644 --- a/physicsnemo/models/diffusion/preconditioning.py +++ b/physicsnemo/models/diffusion/preconditioning.py @@ -24,7 +24,6 @@ from typing import List, Literal, Tuple, Union import numpy as np -import nvtx import torch from physicsnemo.models.meta import ModelMetaData @@ -821,7 +820,6 @@ def _scaling_fn( """ return torch.cat([c_in * x, img_lr.to(x.dtype)], dim=1) - @nvtx.annotate(message="EDMPrecondSuperResolution", color="orange") def forward( self, x: torch.Tensor, diff --git a/physicsnemo/models/diffusion/song_unet.py b/physicsnemo/models/diffusion/song_unet.py index f5eeaaf517..d45d18c482 100644 --- a/physicsnemo/models/diffusion/song_unet.py +++ b/physicsnemo/models/diffusion/song_unet.py @@ -19,6 +19,7 @@ Diffusion-Based Generative Models". """ +import contextlib from dataclasses import dataclass from typing import Callable, List, Optional, Union @@ -113,6 +114,15 @@ class SongUNet(Module): additive_pos_embed : bool, optional If True, adds a learned positional embedding after the first convolution layer. Used in StormCast model. By default False. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. Reference ---------- @@ -157,6 +167,10 @@ def __init__( resample_filter: List[int] = [1, 1], checkpoint_level: int = 0, additive_pos_embed: bool = False, + use_apex_gn: bool = False, + act: str = "silu", + profile_mode: bool = False, + amp_mode: bool = False, ): valid_embedding_types = ["fourier", "positional", "zero"] if embedding_type not in valid_embedding_types: @@ -197,7 +211,14 @@ def __init__( init=init, init_zero=init_zero, init_attn=init_attn, + use_apex_gn=use_apex_gn, + act=act, + fused_conv_bias=True, + profile_mode=profile_mode, + amp_mode=amp_mode, ) + self.profile_mode = profile_mode + self.amp_mode = amp_mode # for compatibility with older versions that took only 1 dimension self.img_resolution = img_resolution @@ -221,12 +242,19 @@ def __init__( # Mapping. if self.embedding_type != "zero": self.map_noise = ( - PositionalEmbedding(num_channels=noise_channels, endpoint=True) + PositionalEmbedding( + num_channels=noise_channels, endpoint=True, amp_mode=amp_mode + ) if embedding_type == "positional" - else FourierEmbedding(num_channels=noise_channels) + else FourierEmbedding(num_channels=noise_channels, amp_mode=amp_mode) ) self.map_label = ( - Linear(in_features=label_dim, out_features=noise_channels, **init) + Linear( + in_features=label_dim, + out_features=noise_channels, + amp_mode=amp_mode, + **init, + ) if label_dim else None ) @@ -235,16 +263,23 @@ def __init__( in_features=augment_dim, out_features=noise_channels, bias=False, + amp_mode=amp_mode, **init, ) if augment_dim else None ) self.map_layer0 = Linear( - in_features=noise_channels, out_features=emb_channels, **init + in_features=noise_channels, + out_features=emb_channels, + amp_mode=amp_mode, + **init, ) self.map_layer1 = Linear( - in_features=emb_channels, out_features=emb_channels, **init + in_features=emb_channels, + out_features=emb_channels, + amp_mode=amp_mode, + **init, ) # Encoder. @@ -257,7 +292,12 @@ def __init__( cin = cout cout = model_channels self.enc[f"{res}x{res}_conv"] = Conv2d( - in_channels=cin, out_channels=cout, kernel=3, **init + in_channels=cin, + out_channels=cout, + kernel=3, + fused_conv_bias=True, + amp_mode=amp_mode, + **init, ) else: self.enc[f"{res}x{res}_down"] = UNetBlock( @@ -270,9 +310,15 @@ def __init__( kernel=0, down=True, resample_filter=resample_filter, + amp_mode=amp_mode, ) self.enc[f"{res}x{res}_aux_skip"] = Conv2d( - in_channels=caux, out_channels=cout, kernel=1, **init + in_channels=caux, + out_channels=cout, + kernel=1, + fused_conv_bias=True, + amp_mode=amp_mode, + **init, ) if encoder_type == "residual": self.enc[f"{res}x{res}_aux_residual"] = Conv2d( @@ -282,6 +328,8 @@ def __init__( down=True, resample_filter=resample_filter, fused_resample=True, + fused_conv_bias=True, + amp_mode=amp_mode, **init, ) caux = cout @@ -326,91 +374,112 @@ def __init__( kernel=0, up=True, resample_filter=resample_filter, + amp_mode=amp_mode, ) self.dec[f"{res}x{res}_aux_norm"] = GroupNorm( - num_channels=cout, eps=1e-6 + num_channels=cout, + eps=1e-6, + use_apex_gn=use_apex_gn, + amp_mode=amp_mode, ) self.dec[f"{res}x{res}_aux_conv"] = Conv2d( - in_channels=cout, out_channels=out_channels, kernel=3, **init_zero + in_channels=cout, + out_channels=out_channels, + kernel=3, + fused_conv_bias=True, + amp_mode=amp_mode, + **init_zero, ) - @nvtx.annotate(message="SongUNet", color="blue") def forward(self, x, noise_labels, class_labels, augment_labels=None): - if self.embedding_type != "zero": - # Mapping. - emb = self.map_noise(noise_labels) - emb = ( - emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) - ) # swap sin/cos - if self.map_label is not None: - tmp = class_labels - if self.training and self.label_dropout: - tmp = tmp * ( - torch.rand([x.shape[0], 1], device=x.device) - >= self.label_dropout - ).to(tmp.dtype) - emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features)) - if self.map_augment is not None and augment_labels is not None: - emb = emb + self.map_augment(augment_labels) - emb = silu(self.map_layer0(emb)) - emb = silu(self.map_layer1(emb)) - else: - emb = torch.zeros( - (noise_labels.shape[0], self.emb_channels), device=x.device - ) + with nvtx.annotate( + message="SongUNet", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if self.embedding_type != "zero": + # Mapping. + emb = self.map_noise(noise_labels) + emb = ( + emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) + ) # swap sin/cos + if self.map_label is not None: + tmp = class_labels + if self.training and self.label_dropout: + tmp = tmp * ( + torch.rand([x.shape[0], 1], device=x.device) + >= self.label_dropout + ).to(tmp.dtype) + emb = emb + self.map_label( + tmp * np.sqrt(self.map_label.in_features) + ) + if self.map_augment is not None and augment_labels is not None: + emb = emb + self.map_augment(augment_labels) + emb = silu(self.map_layer0(emb)) + emb = silu(self.map_layer1(emb)) + else: + emb = torch.zeros( + (noise_labels.shape[0], self.emb_channels), device=x.device + ) - # Encoder. - skips = [] - aux = x - for name, block in self.enc.items(): - with nvtx.annotate(f"SongUNet encoder: {name}", color="blue"): - if "aux_down" in name: - aux = block(aux) - elif "aux_skip" in name: - x = skips[-1] = x + block(aux) - elif "aux_residual" in name: - x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) - elif "_conv" in name: - x = block(x) - if self.additive_pos_embed: - x = x + self.spatial_emb.to(dtype=x.dtype) - skips.append(x) - else: - # For UNetBlocks check if we should use gradient checkpointing - if isinstance(block, UNetBlock): - if x.shape[-1] > self.checkpoint_threshold: - x = checkpoint(block, x, emb, use_reentrant=False) - else: - x = block(x, emb) - else: + # Encoder. + skips = [] + aux = x + for name, block in self.enc.items(): + with nvtx.annotate( + f"SongUNet encoder: {name}", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if "aux_down" in name: + aux = block(aux) + elif "aux_skip" in name: + x = skips[-1] = x + block(aux) + elif "aux_residual" in name: + x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) + elif "_conv" in name: x = block(x) - skips.append(x) - - # Decoder. - aux = None - tmp = None - for name, block in self.dec.items(): - with nvtx.annotate(f"SongUNet decoder: {name}", color="blue"): - if "aux_up" in name: - aux = block(aux) - elif "aux_norm" in name: - tmp = block(x) - elif "aux_conv" in name: - tmp = block(silu(tmp)) - aux = tmp if aux is None else tmp + aux - else: - if x.shape[1] != block.in_channels: - x = torch.cat([x, skips.pop()], dim=1) - # check for checkpointing on decoder blocks and up sampling blocks - if ( - x.shape[-1] > self.checkpoint_threshold and "_block" in name - ) or ( - x.shape[-1] > (self.checkpoint_threshold / 2) and "_up" in name - ): - x = checkpoint(block, x, emb, use_reentrant=False) + if self.additive_pos_embed: + x = x + self.spatial_emb.to(dtype=x.dtype) + skips.append(x) + else: + # For UNetBlocks check if we should use gradient checkpointing + if isinstance(block, UNetBlock): + if x.shape[-1] > self.checkpoint_threshold: + # self.checkpoint = checkpoint? + # else: self.checkpoint = lambda(block,x,emb:block(x,emb)) + x = checkpoint(block, x, emb) + else: + # AssertionError: Only support NHWC layout. + x = block(x, emb) + else: + x = block(x) + skips.append(x) + + # Decoder. + aux = None + tmp = None + for name, block in self.dec.items(): + with nvtx.annotate( + f"SongUNet decoder: {name}", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if "aux_up" in name: + aux = block(aux) + elif "aux_norm" in name: + tmp = block(x) + elif "aux_conv" in name: + tmp = block(silu(tmp)) + aux = tmp if aux is None else tmp + aux else: - x = block(x, emb) - return aux + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + # check for checkpointing on decoder blocks and up sampling blocks + if ( + x.shape[-1] > self.checkpoint_threshold and "_block" in name + ) or ( + x.shape[-1] > (self.checkpoint_threshold / 2) + and "_up" in name + ): + x = checkpoint(block, x, emb) + else: + x = block(x, emb) + return aux class SongUNetPosEmbd(SongUNet): @@ -480,6 +549,18 @@ class SongUNetPosEmbd(SongUNet): checkpoint_level : int, optional Number of layers that should use gradient checkpointing (0 disables checkpointing). Higher values trade memory for computation. By default 0. + additive_pos_embed : bool, optional + If True, adds a learned positional embedding after the first convolution layer. + Used in StormCast model. By default False. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. Note ----- @@ -547,6 +628,11 @@ def __init__( gridtype: str = "sinusoidal", N_grid_channels: int = 4, checkpoint_level: int = 0, + additive_pos_embed: bool = False, + use_apex_gn: bool = False, + act: str = "silu", + profile_mode: bool = False, + amp_mode: bool = False, ): super().__init__( img_resolution, @@ -567,13 +653,20 @@ def __init__( decoder_type, resample_filter, checkpoint_level, + additive_pos_embed, + use_apex_gn, + act, + profile_mode, + amp_mode, ) self.gridtype = gridtype self.N_grid_channels = N_grid_channels - self.pos_embd = self._get_positional_embedding() + if self.gridtype == "learnable": + self.pos_embd = self._get_positional_embedding() + else: + self.register_buffer("pos_embd", self._get_positional_embedding().float()) - @nvtx.annotate(message="SongUNet", color="blue") def forward( self, x, @@ -583,26 +676,31 @@ def forward( embedding_selector: Optional[Callable] = None, augment_labels=None, ): - if embedding_selector is not None and global_index is not None: - raise ValueError( - "Cannot provide both embedding_selector and global_index. " - "embedding_selector is the preferred approach for better efficiency." - ) - - # Append positional embedding to input conditioning - if self.pos_embd is not None: - # Select positional embeddings with a selector function - if embedding_selector is not None: - selected_pos_embd = self.positional_embedding_selector( - x, embedding_selector + with nvtx.annotate( + message="SongUNetPosEmbd", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if embedding_selector is not None and global_index is not None: + raise ValueError( + "Cannot provide both embedding_selector and global_index. " + "embedding_selector is the preferred approach for better efficiency." ) - # Select positional embeddings using global indices (selects all - # embeddings if global_index is None) - else: - selected_pos_embd = self.positional_embedding_indexing(x, global_index) - x = torch.cat((x, selected_pos_embd), dim=1) - return super().forward(x, noise_labels, class_labels, augment_labels) + # Append positional embedding to input conditioning + if self.pos_embd is not None: + # Select positional embeddings with a selector function + if embedding_selector is not None: + selected_pos_embd = self.positional_embedding_selector( + x, embedding_selector + ) + # Select positional embeddings using global indices (selects all + # embeddings if global_index is None) + else: + selected_pos_embd = self.positional_embedding_indexing( + x, global_index + ) + x = torch.cat((x, selected_pos_embd), dim=1) + + return super().forward(x, noise_labels, class_labels, augment_labels) def positional_embedding_indexing( self, @@ -657,32 +755,28 @@ def positional_embedding_indexing( """ # If no global indices are provided, select all embeddings and expand # to match the batch size of the input + if x.dtype != self.pos_embd.dtype: + self.pos_embd = self.pos_embd.to(x.dtype) + if global_index is None: - return ( - self.pos_embd.to(x.dtype) - .to(x.device)[None] - .expand((x.shape[0], -1, -1, -1)) + selected_pos_embd = self.pos_embd[None].expand( + (x.shape[0], -1, -1, -1) ) # (B, N_pe, H, W) - B = global_index.shape[0] - H = global_index.shape[2] - W = global_index.shape[3] - global_index = torch.reshape( - torch.permute(global_index, (1, 0, 2, 3)), (2, -1) - ) # (B, 2, H, W) to (2, B*H*W) - # Use advanced indexing to select the positional embeddings based on - # their y-x coordinates - selected_pos_embd = self.pos_embd.to(x.device)[ - :, global_index[0], global_index[1] - ] # (N_pe, B*H*W) - selected_pos_embd = ( - torch.permute( - torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], B, H, W)), + else: + B = global_index.shape[0] + X = global_index.shape[2] + Y = global_index.shape[3] + global_index = torch.reshape( + torch.permute(global_index, (1, 0, 2, 3)), (2, -1) + ) # (B, 2, X, Y) to (2, B*X*Y) + selected_pos_embd = self.pos_embd[ + :, global_index[0], global_index[1] + ] # (N_pe, B*X*Y) + selected_pos_embd = torch.permute( + torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], B, X, Y)), (1, 0, 2, 3), - ) - .to(x.device) - .to(x.dtype) - ) # (B, N_pe, H, W) + ) # (B, N_pe, X, Y) return selected_pos_embd def positional_embedding_selector( @@ -733,9 +827,9 @@ def positional_embedding_selector( :meth:`physicsnemo.utils.patching.BasePatching2D.apply` For the base patching method typically used in embedding_selector. """ - return embedding_selector( - self.pos_embd.to(x.dtype).to(x.device) - ) # (B, N_pe, H, W) + if x.dtype != self.pos_embd.dtype: + self.pos_embd = self.pos_embd.to(x.dtype) + return embedding_selector(self.pos_embd) # (B, N_pe, H, W) def _get_positional_embedding(self): if self.N_grid_channels == 0: diff --git a/physicsnemo/models/module.py b/physicsnemo/models/module.py index 58c7b55949..d95461394c 100644 --- a/physicsnemo/models/module.py +++ b/physicsnemo/models/module.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import importlib import inspect import json @@ -23,12 +24,13 @@ import tempfile import warnings from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import torch import physicsnemo from physicsnemo.models.meta import ModelMetaData +from physicsnemo.models.util_compatibility import convert_ckp_apex from physicsnemo.registry import ModelRegistry from physicsnemo.utils.filesystem import _download_cached, _get_fs @@ -337,7 +339,9 @@ def load( self.load_state_dict(model_dict, strict=strict) @classmethod - def from_checkpoint(cls, file_name: str) -> "Module": + def from_checkpoint( + cls, file_name: str, model_args: Optional[Dict] = None + ) -> "Module": """Simple utility for constructing a model from a checkpoint Parameters @@ -374,14 +378,22 @@ def from_checkpoint(cls, file_name: str) -> "Module": # Load model arguments and instantiate the model with open(local_path.joinpath("args.json"), "r") as f: args = json.load(f) + + ckp_args = copy.deepcopy(args) + + # Merge model_args (adding new keys and updating existing ones) + if model_args is not None: + args["__args__"].update(model_args) + model = cls.instantiate(args) # Load the model weights model_dict = torch.load( local_path.joinpath("model.pt"), map_location=model.device ) - model.load_state_dict(model_dict) + model_dict = convert_ckp_apex(ckp_args, model_args, model_dict) + model.load_state_dict(model_dict, strict=False) return model @staticmethod diff --git a/physicsnemo/models/util_compatibility.py b/physicsnemo/models/util_compatibility.py new file mode 100644 index 0000000000..27edcf9992 --- /dev/null +++ b/physicsnemo/models/util_compatibility.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict + + +def convert_ckp_apex( + ckp_args_dict: Dict[str, Any], + model_args: Dict[str, Any], + model_dict: Dict[str, Any], +) -> Dict[str, Any]: + + """Utility for converting Apex GroupNorm-related keys in a checkpoint. + + This function modifies the checkpoint arguments and model dictionary + to ensure compatibility when switching between Apex-optimized models + and standard PyTorch models. + + Parameters + ---------- + ckp_args_dict : Dict[str, Any] + Dictionary of checkpoint arguments (e.g., configuration parameters saved during training). + model_args : Dict[str, Any] + Dictionary of model initialization arguments that may need updating. + model_dict : Dict[str, Any] + Dictionary containing model state_dict (weights) loaded from checkpoint. + + Returns + ------- + Dict[str, Any] + Updated model_dict with necessary key modifications applied for compatibility. + + Raises + ------ + KeyError + If essential expected keys are missing during the conversion process. + """ + + apex_in_ckp = ("use_apex_gn" in ckp_args_dict["__args__"].keys()) and ( + ckp_args_dict["__args__"]["use_apex_gn"] + ) + apex_in_workflow = ( + (model_args is not None) + and ("use_apex_gn" in model_args.keys()) + and (model_args["use_apex_gn"]) + ) + + filtered_state_dict = {} + # case1: try to use non-optimized ckp in optimized workflow + if (not apex_in_ckp) and apex_in_workflow: + # transfer GN weight & bias to apex GN weight & bias + for key, value in model_dict.items(): + filtered_state_dict[key] = value # Keep the original key + # Duplicate weight/bias for Apex GroupNorm (without removing the original) + for norm_layer in ["norm0", "norm1", "norm2", "aux_norm"]: + if f"{norm_layer}.weight" in key: + new_key = key.replace( + f"{norm_layer}.weight", f"{norm_layer}.gn.weight" + ) + filtered_state_dict[new_key] = value # Duplicate weight + elif f"{norm_layer}.bias" in key: + new_key = key.replace(f"{norm_layer}.bias", f"{norm_layer}.gn.bias") + filtered_state_dict[new_key] = value # Duplicate bias + + # case2: try to use optimized ckp in non-optimized workflow + elif apex_in_ckp and (not apex_in_workflow): + # transfer apex GN weight & bias to GN weight & bias + for key, value in model_dict.items(): + filtered_state_dict[key] = value # Keep the original key + # Duplicate weight/bias for Apex GroupNorm (without removing the original) + for norm_layer in ["norm0", "norm1", "norm2", "aux_norm"]: + if f"{norm_layer}.gn.weight" in key: + new_key = key.replace( + f"{norm_layer}.gn.weight", f"{norm_layer}.weight" + ) + filtered_state_dict[new_key] = value # Duplicate weight + elif f"{norm_layer}.bias" in key: + new_key = key.replace(f"{norm_layer}.gn.bias", f"{norm_layer}.bias") + filtered_state_dict[new_key] = value # Duplicate bias + else: + # no need to convert ckp + return model_dict + + return filtered_state_dict diff --git a/physicsnemo/utils/patching.py b/physicsnemo/utils/patching.py index a3570fd50c..7c9ea302de 100644 --- a/physicsnemo/utils/patching.py +++ b/physicsnemo/utils/patching.py @@ -19,7 +19,7 @@ import random import warnings from abc import ABC, abstractmethod -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from einops import rearrange @@ -112,7 +112,9 @@ def fuse(self, input: Tensor, **kwargs) -> Tensor: """ raise NotImplementedError("'fuse' method must be implemented in subclasses.") - def global_index(self, batch_size: int) -> Tensor: + def global_index( + self, batch_size: int, device: Union[torch.device, str] = "cpu" + ) -> Tensor: """ Returns a tensor containing the global indices for each patch. @@ -125,6 +127,8 @@ def global_index(self, batch_size: int) -> Tensor: ---------- batch_size : int The size of the batch of images to patch. + device : Union[torch.device, str] + Proper device to initialize global_index on. Default to `cpu` Returns ------- @@ -134,12 +138,12 @@ def global_index(self, batch_size: int) -> Tensor: y-coordinate (height), and `global_index[:, 1, :, :]` contains the x-coordinate (width). """ - Ny = torch.arange(self.img_shape[0]).int() - Nx = torch.arange(self.img_shape[1]).int() + Ny = torch.arange(self.img_shape[0], device=device).int() + Nx = torch.arange(self.img_shape[1], device=device).int() grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)[ None, ].expand(batch_size, -1, -1, -1) - global_index = self.apply(grid) + global_index = self.apply(grid).long() return global_index @@ -289,10 +293,16 @@ def apply( self.patch_shape[1], device=input.device, ) + out = out.to( + memory_format=torch.channels_last + if input.is_contiguous(memory_format=torch.channels_last) + else torch.contiguous_format + ) if additional_input is not None: add_input_interp = torch.nn.functional.interpolate( input=additional_input, size=self.patch_shape, mode="bilinear" ) + for i, (py, px) in enumerate(self.patch_indices): if additional_input is not None: out[B * i : B * (i + 1),] = torch.cat( diff --git a/test/metrics/diffusion/test_losses.py b/test/metrics/diffusion/test_losses.py index 0c8bf27cdb..0acab919cd 100644 --- a/test/metrics/diffusion/test_losses.py +++ b/test/metrics/diffusion/test_losses.py @@ -366,6 +366,31 @@ def mock_augment_pipe(imgs): expected_shape = (batch_size * patch_num, channels, patch_shape[0], patch_shape[1]) assert loss_value_with_patching.shape == expected_shape + # Tests with patching accumulation + loss_func.y_mean = None + patch_nums_iter = [4, 4, 4, 2] + patch_shape = (16, 16) + for patch_num in patch_nums_iter: + patching = RandomPatching2D( + img_shape=(32, 32), patch_shape=patch_shape, patch_num=patch_num + ) + loss_value_with_patching = loss_func( + fake_residual_net, + img_clean, + img_lr, + patching=patching, + use_patch_grad_acc=True, + ) + assert isinstance(loss_value_with_patching, torch.Tensor) + # Shape should be (batch_size * patch_num, channels, patch_shape_y, patch_shape_x) + expected_shape = ( + batch_size * patch_num, + channels, + patch_shape[0], + patch_shape[1], + ) + assert loss_value_with_patching.shape == expected_shape + # Test error on invalid patching object with pytest.raises(ValueError): loss_func( diff --git a/test/models/common/checkpoints.py b/test/models/common/checkpoints.py index 8bbb234275..3aecd67475 100644 --- a/test/models/common/checkpoints.py +++ b/test/models/common/checkpoints.py @@ -35,6 +35,7 @@ def validate_checkpoint( in_args: Tuple[Tensor] = (), rtol: float = 1e-5, atol: float = 1e-5, + enable_autocast: bool = False, ) -> bool: """Check network's checkpoint safely saves and loads the state of the model @@ -54,6 +55,8 @@ def validate_checkpoint( Relative tolerance of error allowed, by default 1e-5 atol : float, optional Absolute tolerance of error allowed, by default 1e-5 + enable_autocast: bool, optional + Whether to enable autocast in model forward Returns ------- @@ -72,8 +75,9 @@ def validate_checkpoint( pass # Now test forward passes - output_1 = model_1.forward(*in_args) - output_2 = model_2.forward(*in_args) + with torch.autocast("cuda", enabled=enable_autocast): + output_1 = model_1.forward(*in_args) + output_2 = model_2.forward(*in_args) # Model outputs should initially be different assert not compare_output( @@ -85,12 +89,15 @@ def validate_checkpoint( model_2.load("checkpoint.mdlus") # Forward with loaded checkpoint - output_2 = model_2.forward(*in_args) + with torch.autocast("cuda", enabled=enable_autocast): + output_2 = model_2.forward(*in_args) + loaded_checkpoint = compare_output(output_1, output_2, rtol, atol) # Restore checkpoint with from_checkpoint, checks initialization of model directly from checkpoint model_2 = physicsnemo.Module.from_checkpoint("checkpoint.mdlus").to(model_1.device) - output_2 = model_2.forward(*in_args) + with torch.autocast("cuda", enabled=enable_autocast): + output_2 = model_2.forward(*in_args) restored_checkpoint = compare_output(output_1, output_2, rtol, atol) # Delete checkpoint file (it should exist!) diff --git a/test/models/diffusion/test_song_unet_agn_amp.py b/test/models/diffusion/test_song_unet_agn_amp.py new file mode 100644 index 0000000000..3b8445b885 --- /dev/null +++ b/test/models/diffusion/test_song_unet_agn_amp.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: E402 +import os +import sys + +import pytest +import torch + +script_path = os.path.abspath(__file__) +sys.path.append(os.path.join(os.path.dirname(script_path), "..")) + +import common + +from physicsnemo.models.diffusion import SongUNet as UNet + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_constructor(device): + """Test the Song UNet constructor options""" + + # DDM++ + img_resolution = 16 + in_channels = 2 + out_channels = 2 + model = ( + UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + output_image = model(input_image, noise_labels, class_labels) + assert output_image.shape == (1, out_channels, img_resolution, img_resolution) + + # DDM++ with additive pos embed + model_channels = 64 + model = ( + UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + model_channels=model_channels, + additive_pos_embed=True, + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + output_image = model(input_image, noise_labels, class_labels) + assert model.spatial_emb.shape == ( + 1, + model_channels, + img_resolution, + img_resolution, + ) + + # NCSN++ + model = ( + UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + embedding_type="fourier", + channel_mult_noise=2, + encoder_type="residual", + resample_filter=[1, 3, 3, 1], + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + output_image = model(input_image, noise_labels, class_labels) + assert output_image.shape == (1, out_channels, img_resolution, img_resolution) + + # test rectangular shape + model = ( + UNet( + img_resolution=[img_resolution, img_resolution * 2], + in_channels=in_channels, + out_channels=out_channels, + embedding_type="fourier", + channel_mult_noise=2, + encoder_type="residual", + resample_filter=[1, 3, 3, 1], + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, out_channels, img_resolution, img_resolution * 2]).to( + device + ) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + output_image = model(input_image, noise_labels, class_labels) + assert output_image.shape == (1, out_channels, img_resolution, img_resolution * 2) + + # Also test failure cases + try: + model = ( + UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + embedding_type=None, + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + raise AssertionError("Failed to error for invalid argument") + except ValueError: + pass + + try: + model = ( + UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + encoder_type=None, + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + raise AssertionError("Failed to error for invalid argument") + except ValueError: + pass + + try: + model = ( + UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + decoder_type=None, + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + raise AssertionError("Failed to error for invalid argument") + except ValueError: + pass + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_optims(device): + """Test Song UNet optimizations""" + + def setup_model(): + model = ( + UNet( + img_resolution=16, + in_channels=2, + out_channels=2, + embedding_type="fourier", + channel_mult_noise=2, + encoder_type="residual", + resample_filter=[1, 3, 3, 1], + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + + return model, [input_image, noise_labels, class_labels] + + # Ideally always check graphs first + model, invar = setup_model() + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + assert common.validate_cuda_graphs(model, (*invar,)) + + # Check JIT + model, invar = setup_model() + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + assert common.validate_jit(model, (*invar,)) + # Check AMP + model, invar = setup_model() + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + assert common.validate_amp(model, (*invar,)) + # Check Combo + model, invar = setup_model() + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + assert common.validate_combo_optims(model, (*invar,)) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_song_unet_checkpoint(device): + """Test Song UNet checkpoint save/load""" + # Construct FNO models + model_1 = UNet( + img_resolution=16, + in_channels=2, + out_channels=2, + use_apex_gn=True, + amp_mode=True, + ).to(device) + + model_2 = UNet( + img_resolution=16, + in_channels=2, + out_channels=2, + use_apex_gn=True, + amp_mode=True, + ).to(device) + + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + assert common.validate_checkpoint( + model_1, + model_2, + (*[input_image, noise_labels, class_labels],), + enable_autocast=True, + ) + + +@common.check_ort_version() +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_son_unet_deploy(device): + """Test Song UNet deployment support""" + model = ( + UNet( + img_resolution=16, + in_channels=2, + out_channels=2, + embedding_type="fourier", + channel_mult_noise=2, + encoder_type="residual", + resample_filter=[1, 3, 3, 1], + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + + assert common.validate_onnx_export( + model, (*[input_image, noise_labels, class_labels],) + ) + + assert common.validate_onnx_runtime( + model, (*[input_image, noise_labels, class_labels],) + ) diff --git a/test/models/diffusion/test_song_unet_pos_embd_agn_amp.py b/test/models/diffusion/test_song_unet_pos_embd_agn_amp.py new file mode 100644 index 0000000000..1e5e92a6c1 --- /dev/null +++ b/test/models/diffusion/test_song_unet_pos_embd_agn_amp.py @@ -0,0 +1,305 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: E402 +import os +import sys + +import pytest +import torch + +script_path = os.path.abspath(__file__) +sys.path.append(os.path.join(os.path.dirname(script_path), "..")) + +import common + +from physicsnemo.models.diffusion import SongUNetPosEmbd as UNet + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_global_indexing(device): + torch.manual_seed(0) + N_pos = 2 + batch_shape_x = 32 + batch_shape_y = 64 + # Construct the DDM++ UNet model + + model = ( + UNet( + img_resolution=128, + in_channels=2 + N_pos, + out_channels=2, + gridtype="test", + N_grid_channels=N_pos, + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + input_image = torch.ones([1, 2, batch_shape_x, batch_shape_y]).to(device) + noise_labels = noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + idx_x = torch.arange(45, 45 + batch_shape_x) + idx_y = torch.arange(12, 12 + batch_shape_y) + mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) + global_index = torch.stack((mesh_x, mesh_y), dim=0)[None].to(device) + + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + output_image = model(input_image, noise_labels, class_labels, global_index) + pos_embed = model.positional_embedding_indexing(input_image, global_index) + assert output_image.shape == (1, 2, batch_shape_x, batch_shape_y) + assert torch.equal(pos_embed, global_index) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_constructor(device): + """Test the Song UNet constructor options""" + + # DDM++ + img_resolution = 16 + in_channels = 2 + out_channels = 2 + N_pos = 4 + model = ( + UNet( + img_resolution=img_resolution, + in_channels=in_channels + N_pos, + out_channels=out_channels, + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + output_image = model(input_image, noise_labels, class_labels) + assert output_image.shape == (1, out_channels, img_resolution, img_resolution) + + # test rectangular shape + model = ( + UNet( + img_resolution=[img_resolution, img_resolution * 2], + in_channels=in_channels + N_pos, + out_channels=out_channels, + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, out_channels, img_resolution, img_resolution * 2]).to( + device + ) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + output_image = model(input_image, noise_labels, class_labels) + assert output_image.shape == (1, out_channels, img_resolution, img_resolution * 2) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_position_embedding(device): + # build unet + img_resolution = 16 + in_channels = 2 + out_channels = 2 + # NCSN++ + N_pos = 100 + model = ( + UNet( + img_resolution=img_resolution, + in_channels=in_channels + N_pos, + out_channels=out_channels, + embedding_type="fourier", + channel_mult_noise=2, + encoder_type="residual", + resample_filter=[1, 3, 3, 1], + gridtype="learnable", + N_grid_channels=N_pos, + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + output_image = model(input_image, noise_labels, class_labels) + assert output_image.shape == (1, out_channels, img_resolution, img_resolution) + assert model.pos_embd.shape == (100, img_resolution, img_resolution) + + model = ( + UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + N_grid_channels=40, + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + assert model.pos_embd.shape == (40, img_resolution, img_resolution) + + +def test_fails_if_grid_is_invalid(): + """Test the positional embedding options. "linear" gridtype only support 2 channels, and N_grid_channels in "sinusoidal" should be a factor of 4""" + img_resolution = 16 + in_channels = 2 + out_channels = 2 + + with pytest.raises(ValueError): + UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + gridtype="linear", + N_grid_channels=20, + use_apex_gn=True, + amp_mode=True, + ).to(memory_format=torch.channels_last) + + with pytest.raises(ValueError): + UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + gridtype="sinusoidal", + N_grid_channels=11, + use_apex_gn=True, + amp_mode=True, + ).to(memory_format=torch.channels_last) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_optims(device): + """Test Song UNet optimizations""" + + def setup_model(): + model = ( + UNet( + img_resolution=16, + in_channels=6, + out_channels=2, + embedding_type="fourier", + channel_mult_noise=2, + encoder_type="residual", + resample_filter=[1, 3, 3, 1], + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + + return model, [input_image, noise_labels, class_labels] + + # Ideally always check graphs first + model, invar = setup_model() + assert common.validate_cuda_graphs(model, (*invar,)) + + # Check JIT + model, invar = setup_model() + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + assert common.validate_jit(model, (*invar,)) + # Check AMP + model, invar = setup_model() + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + assert common.validate_amp(model, (*invar,)) + # Check Combo + model, invar = setup_model() + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + assert common.validate_combo_optims(model, (*invar,)) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_checkpoint(device): + """Test Song UNet checkpoint save/load""" + # Construct FNO models + model_1 = ( + UNet( + img_resolution=16, + in_channels=6, + out_channels=2, + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + + model_2 = ( + UNet( + img_resolution=16, + in_channels=6, + out_channels=2, + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + assert common.validate_checkpoint( + model_1, + model_2, + (*[input_image, noise_labels, class_labels],), + enable_autocast=True, + ) + + +@common.check_ort_version() +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_son_unet_deploy(device): + """Test Song UNet deployment support""" + model = ( + UNet( + img_resolution=16, + in_channels=6, + out_channels=2, + embedding_type="fourier", + channel_mult_noise=2, + encoder_type="residual", + resample_filter=[1, 3, 3, 1], + use_apex_gn=True, + amp_mode=True, + ) + .to(device) + .to(memory_format=torch.channels_last) + ) + + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + input_image = torch.ones([1, 2, 16, 16]).to(device) + + assert common.validate_onnx_export( + model, (*[input_image, noise_labels, class_labels],) + ) + assert common.validate_onnx_runtime( + model, (*[input_image, noise_labels, class_labels],) + )