Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Diffusion example for Kármán vortex street dataset #619

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/img/diffusion_karman.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
79 changes: 79 additions & 0 deletions examples/cfd/turbulence_diffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# A Diffusion model for a 2d Kármán vortex street about a fixed cylinder

This example uses a Pytorch implementation of DDPM which can be found [here](https://github.com/lucidrains/denoising-diffusion-pytorch/).

## Problem overview

Turbulent flows are notoriously difficult to model. The structures involved can be found across a wide
range of temporal and spatial scales, and the high degree of non-linearity as well as sensitivity to the initial
conditions make this an especially challenging problem. In particular, even in the presence of statistical
or geometrical symmetries it is not possible to reduce the dimensionality of the problem (e.g. from three- to
two- dimensional) in numerical simulations.
A general analytic solution to the governing (Navier-Stokes) equations
remains elusive, requiring a variety of computational methods to obtain even numerical solutions. Scale-
resolving techniques, such as DNS and LES, strive to resolve the entire spectrum of length and time scales
present in the flow, or at least the energetically most significant part (as is the case for LES), whereas
conventional averaging approaches (such as Reynolds-averaged Navier-Stokes, or RANS) instead attempt to
fit a statistical model. Whilst the former approaches provide the highest possible modelling accuracy, the
latter are affected by inevitable loss of information, with non-generalisable closures, strongly affected by
flow typology and boundary conditions. Even in techniques such as RANS, which considerably reduce the
computational cost of complex simulations, they remain prohibitively high if a large number of computations
must be carried out in relatively short time, such as in rapid prototyping and design optimisation loops. This
makes the use of machine learning, especially generative probabilistic AI extremely promising, particularly
when only statistical distributions resulting from stochastic initial conditions are required. [1]

In this example, we apply a modern DDPM implementation [2] to a 2d Kármán vortex street about a fixed cylinder, and this model
is able to effectively capture the flow distribution.

## Dataset

We perform our experiments on the case of a flow around a cylinder at Reynolds number 3900, which is well studied in the literature.
The flow field is characterized by a Kármán vortex street developing in the wake of the cylinder. The vortex street consists of a characteristic coherent vortex system in which the rotational axes of the individual vortices are aligned with the cylinder axis.

The data set of grayscale images (see image below) was generated by post-processing the transient LES velocity field data using a projection mapping in the sense that the system remains ergodic on a reduced state space. Let $V(\xi,t)=(V_x(\xi,t),V_y(\xi,t),V_z(\xi,t))$ be the velocity field of the fluid. Then for our dataset, the gray scale shows the distribution of the absolute deviation of the local fluctuating velocity magnitude $c(\xi,t)=\sqrt{V_x(\xi,t)^2+V_y(\xi,t)^2+V_z (\xi,t)^2}$ at the location $\xi$ from its time average
$c'(\xi,t) = |c(\xi,t) - \overline{c}(\xi)|$ with $~~\overline{c}(\xi)=\frac{1}{T}\int_0^Tc(\xi,t) \,\mathrm{d}t$.
For the numerical setup of the LES, see [3]. In total, the data set consists of $100,000$ images with a resolution of $1000 \times 600$ pixels.
The full data set is available at LINK.

## Model overview and architecture

The model is an implementation of DDPM combined with a transformer and shadowed by an Exponential Moving Average (EMA) model.
Due to the large number of parameters and slow training times involved, the model is also built with parallelisation across multiple GPUs in mind,
and is supported out of the box.
A UNet with five downsampling layers, interspersed with attention and ResNet blocks, is used to represent the decoder.

![Real image on the left, model on the right](../../../docs/img/diffusion_karman.png)

## Getting Started

The scripts provided include code for training, sampling and preprocessing the dataset.
For training, to view the available command line arguments, run

```bash
python main.py -h
```

and for sampling

```bash
python sample.py -h
```

Arguments can be provided on the command line or in ```config.json```. For training, the only required argument is ```experiment_name```.
For sampling, ```model``` is also required.
Providing this argument to the training script resumes training from where you left off.

The dataset is automatically downloaded from [4], and further processing is applied upon extraction. This can be turned off in ```config.json```.

Depending on your particular setup, you may need to set the environment variables
```bash
export WORLD_SIZE=$N
export CUDA_VISIBLE_DEVICES=$N
```
for some `$N` equivalent to the number of available GPUs.

## References
- [1] [Comparison of Generative Learning Methods for Turbulence Modeling](https://arxiv.org/abs/2411.16417)
- [2] [Denoising Diffusion Probabilistic Models](https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf)
- [3] [Generative Modelling of Turbulence](https://arxiv.org/abs/2112.02548)
- [4] [Flow Around a Cylinder for Generative Learning](https://zenodo.org/records/13820259)
28 changes: 28 additions & 0 deletions examples/cfd/turbulence_diffusion/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"dataset_dir": "data",
"download": true,
"image_size": 512,
"batch_size": 2,
"num_epochs": 1,
"grad_accumulation": 1,
"seed": 42,

"ema_beta": 0.9999,
"ema_power": 0.666666,
"update_ema_every": 10,

"learning_rate": 1e-5,
"adam_betas": [0.9, 0.99],

"gamma": 0.1,
"step_size": 50,

"sample": true,
"sample_size": 4,
"sample_timesteps": 1000,
"model": "",

"save_every": 2,
"num_workers": 6,
"experiment_name": ""
}
123 changes: 123 additions & 0 deletions examples/cfd/turbulence_diffusion/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# ignore_header_test
# coding=utf-8
#
# SPDX-FileCopyrightText: Copyright (c) 2024 - Edmund Ross
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import util
import os
import zipfile
import requests

from io import BytesIO
from PIL import Image
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm


ZENODO_URL = "https://zenodo.org/records/13820259/files/Turbulence_AI.zip?download=1"
DATASET_NAME = "Turbulence_AI"


def initialise_dataset(args):
"""Check if the dataset has been downloaded, and if not, download and apply post-processing"""
dataset_path = util.to_path(args.dataset_dir, DATASET_NAME)
zip_path = dataset_path + ".zip"

os.makedirs(util.to_path(args.dataset_dir), exist_ok=True)

print(f'[{args.experiment_name}] Downloading dataset to {zip_path}...')
download_dataset(zip_path, args)

process_dataset(zip_path, dataset_path, args)


def download_dataset(output_file, args):
"""Download dataset and save to a zip file"""
if os.path.exists(output_file) or not args.download:
print(f'[{args.experiment_name}] {output_file} already exists or download option is false. Skipping download.')
return

response = requests.get(ZENODO_URL, stream=True)
if response.status_code != 200:
raise Exception(f'[{args.experiment_name}] Failed to download dataset: HTTP {response.status_code}')

# Get total file size for progress tracking
total_size = int(response.headers.get('content-length', 0))

with open(output_file, "wb") as file, tqdm(
total=total_size, unit="B", unit_scale=True, desc=f"[{args.experiment_name}] Downloading dataset"
) as progress:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
progress.update(len(chunk))


def process_dataset(zip_file, output_dir, args):
"""Extract zip_file and apply post-processing"""
if os.path.exists(output_dir) or not os.path.exists(zip_file):
print(f"[{args.experiment_name}] {output_dir} already exists, or couldn't find the zip. Skipping processing.")
return

os.makedirs(util.to_path(output_dir), exist_ok=True)

print(f'[{args.experiment_name}] Extracting {zip_file}...')
with zipfile.ZipFile(zip_file) as zip_file_obj:
# Prepare a list of tasks
tasks = []
for file_name in zip_file_obj.namelist():
if not file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
continue

# Read the image bytes
with zip_file_obj.open(file_name) as image_file:
image_bytes = image_file.read()
tasks.append((file_name, image_bytes, output_dir))

print(f'[{args.experiment_name}] Beginning post-processing on {args.num_workers} core(s)...')
with tqdm(total=len(tasks), desc="Processing images", unit="file") as progress_bar:
try:
with ProcessPoolExecutor(max_workers=args.num_workers) as executor:
# 'chunksize' parameter ensures the executor doesn't get overloaded
for _ in executor.map(crop_image_wrapper, tasks, chunksize=10):
progress_bar.update(1)
except Exception as e:
print(f"[{args.experiment_name}] Error: {e}")


def crop_image_wrapper(args):
"""Wrapped for executor for crop_image function"""
crop_image(*args)


def crop_image(file_name, image_bytes, output_dir, crop_left_width=150):
"""Crop an image to remove useless white space to the left of the cylinder """
# Get and open image
try:
with Image.open(BytesIO(image_bytes)) as img:
# Set up the crop. Top left corner is 0,0 and the y-axis is inverted
width, height = img.size
left = crop_left_width
top = 0
right = width
bottom = height

# Execute and save
cropped_img = img.crop((left, top, right, bottom))

output_path = util.to_path(output_dir, os.path.basename(file_name))
cropped_img.save(output_path)
except Exception as e:
print(f"Error processing image: {e}")
89 changes: 89 additions & 0 deletions examples/cfd/turbulence_diffusion/distribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# ignore_header_test
# coding=utf-8
#
# SPDX-FileCopyrightText: Copyright (c) 2024 - Edmund Ross
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import params
import util

import torch
import torch.distributed as dist
import torch.multiprocessing as mult

from itertools import chain


def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12359'

# Initialise the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)


def spawn_processes(fn):
"""Setup hyperparameters, initialise experiment"""
args = params.get_args()
experiment_path = util.initialise_experiment(args)

n_gpus = torch.cuda.device_count()
print(f'[{args.experiment_name}] Found {n_gpus} GPUs. Spawning processes...')

# Spawn the processes
mult.spawn(fn,
args=(n_gpus, args, experiment_path),
nprocs=n_gpus,
join=True)


def cleanup():
"""Clean up distributed session"""
dist.destroy_process_group()


def load(experiment_path, args_model, ddp_diffusion, optimizer=None, ema=None):
"""Load checkpoint from disk"""
model_path = util.to_path(experiment_path, 'checkpoints', args_model)

checkpoint = torch.load(model_path)
epoch_start = checkpoint['epoch'] + 1
ddp_diffusion.load_state_dict(checkpoint['diffusion'])

if ema is not None:
ema.load_state_dict(checkpoint['ema'])
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer'])

# Make sure everyone is loaded before proceeding
dist.barrier()
return epoch_start


def save(experiment_path, epoch, ddp_diffusion, optimizer, ema):
"""Save checkpoint to disk"""
torch.save({
'epoch': epoch,
'diffusion': ddp_diffusion.state_dict(),
'optimizer': optimizer.state_dict(),
'ema': ema.state_dict(),
}, util.to_path(experiment_path, 'checkpoints', f'model_{epoch}.pt'))


def interleave_arrays(*arrays):
"""Collect data from the GPUs, and interleave into a single array, respecting the order"""
interleaved = list(chain.from_iterable(zip(*arrays)))
return interleaved
Loading