Skip to content

Commit

Permalink
corrdiff generate minor bugs fixed (NVIDIA#648)
Browse files Browse the repository at this point in the history
* corrdiff generate minor bugs fixed

* formatting

---------

Co-authored-by: Jay Chen <[email protected]>
Co-authored-by: Mohammad Amin Nabian <[email protected]>
  • Loading branch information
3 people authored Aug 21, 2024
1 parent ee00651 commit d88332b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/generative/corrdiff/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def main(cfg: DictConfig) -> None:
sampler_fn = partial(
deterministic_sampler,
num_steps=cfg.sampler.num_steps,
num_ensembles=cfg.generation.num_ensembles,
# num_ensembles=cfg.generation.num_ensembles,
solver=cfg.sampler.solver,
)
elif cfg.sampler.type == "stochastic":
Expand Down Expand Up @@ -215,7 +215,7 @@ def generate_fn():
).to(memory_format=torch.channels_last),
rank=dist.rank,
device=device,
use_mean_hr=mean_hr,
hr_mean=mean_hr,
)
if cfg.generation.inference_mode == "regression":
image_out = image_reg
Expand Down
2 changes: 2 additions & 0 deletions examples/generative/corrdiff/helpers/generate_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def get_dataset_and_sampler(dataset_cfg, times):
"""
Get a dataset and sampler for generation.
"""
all_time_dataset_cfg = {"train": False, "all_times": True}
dataset_cfg.update(all_time_dataset_cfg)
(dataset, _) = init_dataset_from_config(dataset_cfg, batch_size=1)
plot_times = [
convert_datetime_to_cftime(
Expand Down

0 comments on commit d88332b

Please sign in to comment.