This software project accompanies the research paper TRACT: Denoising Diffusion Models with Transitive Closure Time-Distillation
@article{berthelot2023tract,
title={TRACT: Denoising Diffusion Models with Transitive Closure Time-Distillation},
author={Berthelot, David and Autef, Arnaud and Lin, Jierui and Yap, Dian Ang and Zhai, Shuangfei and Hu, Siyuan and Zheng, Daniel and Talbott, Walter and Gu, Eric},
journal={arXiv preprint arXiv:2303.04248},
year={2023}
}
Git clone with --recurse-submodules to initialize EDM submodule properly.
Setup environment variables:
export ML_DATA=~/Data/DDPM-Images
export PYTHONPATH=$PYTHONPATH:.Then run
sudo apt install python3.8-dev python3.8-venv python3-dev -ySet up a virtualenv
python3.8 -m venv ~/tract_venv
source ~/tract_venv/bin/activateor via pyenv
pyenv install 3.8.0
pyenv virtualenv 3.8.0 tract_venv
pyenv local tract_venvthen upgrade pip
pip install --upgrade pip
Install pip pkgs
pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
Please follow README to setup datasets.
Please follow README to setup teacher checkpoints.
For running with NVIDIA's Elucidated model (EDM), ensure the edm/
submodule has been initialized properly.
Please follow the Real activation statistics section in README in order to compute and save the real activation statistics to be used in FID evaluation.
The below commands will reproduce results from our paper when run on a cluster of 8 NVIDIA A100 or V100 GPUs.
Example: Run TC distillation on Cifar10 using distillation time schedule: 1024, 32, 1.
python tc_distill.py --dataset=cifar10 --time_schedule=1024,32,1 --fid_len=50000 --report_fid_len=8M --report_img_len=8M --train_len=96MExample: Run TC distillation on Cifar10 using EDM teacher
python tc_distill_edm.py --dataset=cifar10 --time_schedule=40,1 --fid_len=50000 --report_fid_len=8M --report_img_len=8M --train_len=96M --batch=512Getting help
python tc_distill.py --helpTensorboard outputs are generated in a dir like e/DATASET/MODEL/EXP_NAME/tb/. For example, you can start tensorboard to view metrics like
tensorboard --logdir e/cifar10/EluDDIM05TCMultiStepx0\(EluUNet\)/aug_prob@0.0_batch@8_dropout@0.0_ema@0.9996_lr@0.001_lr_warmup@None_res@32_sema@0.1_time_schedule@40,1_timesteps@40/tb/ --bind_allPlease follow README to run FID evaluation.