This is an unofficial PyTorch implementation. Original implementation in jax might be found here: https://github.com/kvfrans/shortcut-models
At the moment there is a very basic implementation with 2 modes: naive(simple flow matching) and shortcut.
Implementation allows to train in multi-gpu mode, thanks to pytorch-lightning
I used celeba-hq dataset from HuggingFace for image generation task https://huggingface.co/datasets/mattymchen/celeba-hq
There is a helpful Dockefile and docker-compose in this repository which install all necessary libraries.
In order to run just write:
python train.py
1 denoising step:
2 denoising steps:
4 denoising steps:
8 denoising steps:
16 denoising steps:
128 denoising steps: