This repository contains the PyTorch code for reproducing the results in the paper Image Generation with a Sphere Encoder.
pip install --no-cache-dir torch torchvision fvcore numpy tqdm wandb git+https://github.com/LTH14/torch-fidelity.gitTo train a model from scratch on CIFAR-10, run the following command:
./scripts/train_cifar10.shA folder named ./workspace/ will be created to store everything related to training, evaluation, and visualization.
The training jobs will be organized in ./workspace/experiments/ as follows:
./workspace/experiments/
|── sphere-base-base-cifar-10-32px
|── ckpt # checkpoints
|── vis # visualization
|── cfg.json # configuration
|── log.txt # training logFor other datasets such as ImageNet, Animal Faces, and Oxford Flowers, first need to prepare dataset list files (train.json, val.json, or test.json).
Place these files in the ./workspace/datasets/imagenet/ directory.
Each entry in the JSON files should follow this format:
{"class_id": 0, "class_name": "imagenet", "image_path": "/absolute/path/to/imagenet.jpg", "is_absolute_path": true}Once the dataset files are prepared, run the following command to begin training:
./scripts/train_imagenet.shThe other datasets can be downloaded from the following links:
| dataset | link1 | link2 |
|---|---|---|
| Animal Faces | kaggle | stargan |
| Oxford Flowers | kaggle | vgg/data |
After preparing the dataset files, run the corresponding training scripts:
./scripts/train_af.sh # for animal faces
./scripts/train_of.sh # for oxford flowers| fid artifacts | 🤗 hf dataset repo |
|---|---|
| data statistic files | fid_stats |
| reference images | fid_refs |
Download the FID artifacts from above links and put them in ./workspace/.
The directory tree should look like this:
./workspace/
├── fid_stats
|── fid_stats_extr_animal-faces_256px.npz
|── fid_stats_extr_cifar-10_32px.npz
|── fid_stats_extr_flowers-102_256px.npz
|── fid_stats_rand-50k_imagenet_256px.npz
├── fid_refs
|── ref_images_imagenet_256px/imagesTo evaluate a trained model, for example, ./workspace/experiments/sphere-base-base-cifar-10-32px created by the CIFAR-10 training script, run the following command:
./run.sh eval.py \
--job_dir sphere-base-base-cifar-10-32px \
--forward_steps 1 4 \
--report_fid rfid gfid \
--use_cfg True \
--cfg_min 1.2 \
--cfg_max 1.2 \
--cfg_position combo \
--rm_folder_after_eval TrueThe evaluation results will be saved in ./workspace/experiments/sphere-base-base-cifar-10-32px/eval/ with a table format.
For ease of comparison and reproducibility, please see the model card in MODEL.md. It details our trained models, evaluation scripts, and the performance results reported in the paper.
# --job_dir can be
# sphere-l-af, sphere-l-of, sphere-l-imagenet, or sphere-xl-imagenet
./run.sh sample.py \
--job_dir sphere-xl-imagenet \
--num_gen_samples 16 \
--class_of_interests 980 985Output images can be found in ./workspace/visualization/, which will look like:
./run.sh lerp.py \
--job_dir sphere-l-of \
--grid_nrow 16Output images can be found in ./workspace/interpolation/, which will look like:
We can also try to interpolate 4 images with bilinear interpolation (blerp):
./run.sh lerp.py \
--job_dir sphere-l-af \
--interp_mode blerp \
--grid_nrow 8 \
--grid_ncol 8 \
--num_trials 25Output images look like:
./run.sh edit.py \
--edit_mode crossover \
--input_image images/dog.jpg \
--extra_image images/cat.jpg \
--job_dir sphere-l-af \
--noise_strength_scaler 0.25 \
--stitch_mode tri_backward \
--stitch_swap TrueOutput images can be found in ./workspace/image_editing/, which will look like:
We can also try different stitching modes and swapping, for example:
--stitch_mode tri_backward \
--stitch_swap FalseOutput images will look like:
./run.sh edit.py \
--edit_mode condition \
--input_image images/wolly_panda.jpg \
--job_dir sphere-l-imagenet \
--noise_strength_scaler .25 \
--num_trials 10 \
--forward_steps 5Output images can be found in ./workspace/image_editing/, which will look like:
Here we use the sphere encoder trained only on Oxford-Flowers to reconstruct the OOD images, which is a more challenging setting. To reconstruct an input image, run the following command:
./run.sh edit.py \
--edit_mode reconstruction \
--job_dir sphere-l-of \
--input_image images/cake.jpg| image A and reconstruction | image B and reconstruction | image C and reconstruction |
|---|---|---|
![]() |
![]() |
![]() |
The code is licensed under the CC-BY-NC 4.0 License.
















