Skip to content

daruoktab/diffusion-mnist-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

10 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

MNIST diffusion (DDPM) with PyTorch and DeepInv

Small portfolio project: train a noise-predicting U-Net on MNIST with a linear beta schedule (DDPM-style), save weights, and generate digits interactively in a second marimo app.

Structure

diffusion-mnist-pytorch/
β”œβ”€β”€ trainingdiffusion.py    # Marimo: train DiffUNet, plots, save checkpoint
β”œβ”€β”€ inferencediffusion.py   # Marimo: load weights, DDPM sampling, UI
β”œβ”€β”€ environment.yml         # Conda: Python 3.12
β”œβ”€β”€ requirements.txt        # pip / uv
β”œβ”€β”€ LICENSE                 # MIT
β”œβ”€β”€ README.md
└── data/                   # MNIST (downloaded on first run; gitignored)

Default checkpoint path: trained_diffusion_model.pth (used by both apps).

Environment

Conda for Python 3.12; install dependencies into that env (e.g. with uv pip):

conda env create -f environment.yml
conda activate dmnist
uv pip install -r requirements.txt

Without Conda: uv venv --python 3.12 .venv, activate it, then uv pip install -r requirements.txt.

Run

Training (interactive editor):

marimo edit trainingdiffusion.py

Or one-shot: uv run marimo run trainingdiffusion.py

Inference / sampling:

marimo edit inferencediffusion.py

Model and training

Item Value
Architecture deepinv.models.DiffUNet, 1β†’1 channels
Input size 32Γ—32 grayscale (resize from MNIST)
Schedule Linear Ξ² from 1eβˆ’4 to 0.02, T = 1000
Objective MSE on predicted noise
Optimizer Adam, lr = 1eβˆ’4

VRAM: Increase or decrease mainly via batch_size and image_size in trainingdiffusion.py (see the configuration cell).

Forward noising (training)

noisy_imgs = (
    sqrt_alphas_cumprod[t, None, None, None] * imgs
    + sqrt_one_minus_alphas_cumprod[t, None, None, None] * noise
)

References

License

MIT. You may replace the copyright line in LICENSE with your name if this is a personal repo.

About

🎨 PyTorch diffusion model for MNIST digit generation using DeepInv. Educational project demonstrating DDPM fundamentals with U-Net architecture.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages