Skip to content

Commit

Permalink
PoincaréVAE (#51)
Browse files Browse the repository at this point in the history
* [WIP] work on pvae

* [WIP] fix Hyperbolic geometry and distributions

* [WIP] add tests and work on PVAE

* increase coverage

* [WIP] fix device

* [WIP] work on PVAE

* [WIP] tesing the model

* test

* test

* test

* fix pvae

* work on repro

* change data processing in repro experiment

* Add PoincaréDisk Sampler

* add sampler tests

* fix device issue

* add tutorials and fix device setting

* add PVAE to doc

* minor change in sampler

* minor change

* small change in docs

* work on reproducibility

* add ref to readme

* clean up readme

* add Wrapped PVAE results

* black and isort

* update README

* remove not needed

* update repro

* isort
  • Loading branch information
clementchadebec authored Sep 3, 2022
1 parent c9f0fff commit 32b76a2
Show file tree
Hide file tree
Showing 50 changed files with 3,425 additions and 72 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ examples/notebooks/my_model_with_custom_archi/
examples/notebooks/my_model/
examples/net.py
examples/scripts/configs/*
examples/scripts/reproducibility/reproducibility/*




Expand Down
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ provides the possibility to perform benchmark experiments and comparisons by tra
the models with the same autoencoding neural network architecture. The feature *make your own autoencoder*
allows you to train any of these models with your own data and own Encoder and Decoder neural networks. It integrates an experiment monitoring tool [wandb](https://wandb.ai/) 🧪 and allows model sharing and loading from the [HuggingFace Hub](https://huggingface.co/models) 🤗 in a few lines of code.

## Quick access:
- [Installation](#installation)
- [Implemented models](#available-models) / [Implemented samplers](#available-samplers)
- [Reproducibility statement](#reproducibility) / [Results flavor](#results)
- [Model training](#launching-a-model-training) / [Data generation](#launching-data-generation) / [Custom network architectures](#define-you-own-autoencoder-architecture)
- [Model sharing with 🤗 Hub](#sharing-your-models-with-the-huggingface-hub-) / [Experiment tracking with `wandb`](#monitoring-your-experiments-with-wandb-)
- [Tutorials](#getting-your-hands-on-the-code) / [Documentation](https://pythae.readthedocs.io/en/latest/)
- [Contributing 🚀](#contributing-) / [Issues 🛠️](#dealing-with-issues-%EF%B8%8F)
- [Citing this repository](#citation)

# Installation

Expand Down Expand Up @@ -81,7 +90,8 @@ VAE with Inverse Autoregressive Flows (VAE_IAF) | [![Open In Colab](https://col
| Wasserstein Autoencoder (WAE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/wae_training.ipynb) | [link](https://arxiv.org/abs/1711.01558) | [link](https://github.com/tolstikhin/wae) |
| Info Variational Autoencoder (INFOVAE_MMD) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/info_vae_training.ipynb) | [link](https://arxiv.org/abs/1706.02262) | |
| VAMP Autoencoder (VAMP) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/vamp_training.ipynb) | [link](https://arxiv.org/abs/1705.07120) | [link](https://github.com/jmtomczak/vae_vampprior) |
| Hyperspherical VAE (SVAE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/svae_training.ipynb) | [link](https://arxiv.org/abs/1804.00891) | [link](https://github.com/nicola-decao/s-vae-pytorch) |
| Hyperspherical VAE (SVAE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/svae_training.ipynb) | [link](https://arxiv.org/abs/1804.00891) | [link](https://github.com/nicola-decao/s-vae-pytorch)
| Poincaré Disk VAE (PoincareVAE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/pvae_training.ipynb) | [link](https://arxiv.org/abs/1901.06033) | [link](https://github.com/emilemathieu/pvae) |
| Adversarial Autoencoder (Adversarial_AE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/adversarial_ae_training.ipynb) | [link](https://arxiv.org/abs/1511.05644)
| Variational Autoencoder GAN (VAEGAN) 🥗 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/vaegan_training.ipynb) | [link](https://arxiv.org/abs/1512.09300) | [link](https://github.com/andersbll/autoencoding_beyond_pixels)| [link](https://arxiv.org/abs/1512.09300) | [link](https://github.com/andersbll/autoencoding_beyond_pixels)
| Vector Quantized VAE (VQVAE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/vqvae_training.ipynb) | [link](https://arxiv.org/abs/1711.00937) | [link](https://github.com/deepmind/sonnet/blob/v2/sonnet/)
Expand All @@ -102,6 +112,7 @@ Below is the list of the models currently implemented in the library.
| Gaussian mixture (GaussianMixtureSampler) | all models | [link](https://arxiv.org/abs/1903.12436) | [link](https://github.com/ParthaEth/Regularized_autoencoders-RAE-/tree/master/models/rae) |
| Two stage VAE sampler (TwoStageVAESampler) | all VAE based models| [link](https://openreview.net/pdf?id=B1e0X3C9tQ) | [link](https://github.com/daib13/TwoStageVAE/) |)
| Unit sphere uniform sampler (HypersphereUniformSampler) | SVAE | [link](https://arxiv.org/abs/1804.00891) | [link](https://github.com/nicola-decao/s-vae-pytorch)
| Poincaré Disk sampler (PoincareDiskSampler) | PoincareVAE | [link](https://arxiv.org/abs/1901.06033) | [link](https://github.com/emilemathieu/pvae)
| VAMP prior sampler (VAMPSampler) | VAMP | [link](https://arxiv.org/abs/1705.07120) | [link](https://github.com/jmtomczak/vae_vampprior) |
| Manifold sampler (RHVAESampler) | RHVAE | [link](https://arxiv.org/abs/2105.00026) | [link](https://github.com/clementchadebec/pyraug)|
| Masked Autoregressive Flow Sampler (MAFSampler) | all models | [link](https://arxiv.org/abs/1705.07057v4) | [link](https://github.com/gpapamak/maf) |
Expand Down
2 changes: 2 additions & 0 deletions docs/source/models/autoencoders/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Autoencoders
infovae
vamp
svae
pvae
aae
vaegan
vqvae
Expand Down Expand Up @@ -53,6 +54,7 @@ Available Models
~pythae.models.INFOVAE_MMD
~pythae.models.VAMP
~pythae.models.SVAE
~pythae.models.PoincareVAE
~pythae.models.Adversarial_AE
~pythae.models.VAEGAN
~pythae.models.VQVAE
Expand Down
13 changes: 13 additions & 0 deletions docs/source/models/autoencoders/pvae.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
**********************************
PoincareVAE
**********************************


.. automodule::
pythae.models.pvae

.. autoclass:: pythae.models.PoincareVAEConfig
:members:

.. autoclass:: pythae.models.PoincareVAE
:members:
1 change: 1 addition & 0 deletions docs/source/models/pythae.models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Available Autoencoders
~pythae.models.INFOVAE_MMD
~pythae.models.VAMP
~pythae.models.SVAE
~pythae.models.PoincareVAE
~pythae.models.Adversarial_AE
~pythae.models.VAEGAN
~pythae.models.VQVAE
Expand Down
9 changes: 9 additions & 0 deletions docs/source/samplers/poincare_disk_sampler.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
**********************************
PoincareDiskSampler
**********************************

.. automodule::
pythae.samplers.pvae_sampler

.. autoclass:: pythae.samplers.PoincareDiskSampler
:members:
2 changes: 2 additions & 0 deletions docs/source/samplers/pythae.samplers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Samplers
gmm_sampler
twostage_sampler
unit_sphere_unif_sampler
poincare_disk_sampler
vamp_sampler
rhvae_sampler
maf_sampler
Expand All @@ -28,6 +29,7 @@ Samplers
~pythae.samplers.GaussianMixtureSampler
~pythae.samplers.TwoStageVAESampler
~pythae.samplers.HypersphereUniformSampler
~pythae.samplers.PoincareDiskSampler
~pythae.samplers.VAMPSampler
~pythae.samplers.RHVAESampler
~pythae.samplers.MAFSampler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision.datasets as datasets\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
Expand Down Expand Up @@ -253,7 +256,7 @@
"metadata": {},
"outputs": [],
"source": [
"reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()"
"reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()"
]
},
{
Expand Down Expand Up @@ -301,7 +304,7 @@
"metadata": {},
"outputs": [],
"source": [
"interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()"
"interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()"
]
},
{
Expand Down
7 changes: 5 additions & 2 deletions examples/notebooks/models_training/ae_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision.datasets as datasets\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
Expand Down Expand Up @@ -252,7 +255,7 @@
"metadata": {},
"outputs": [],
"source": [
"reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()"
"reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()"
]
},
{
Expand Down Expand Up @@ -300,7 +303,7 @@
"metadata": {},
"outputs": [],
"source": [
"interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()"
"interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()"
]
},
{
Expand Down
7 changes: 5 additions & 2 deletions examples/notebooks/models_training/beta_tc_vae_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision.datasets as datasets\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
Expand Down Expand Up @@ -256,7 +259,7 @@
"metadata": {},
"outputs": [],
"source": [
"reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()"
"reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()"
]
},
{
Expand Down Expand Up @@ -304,7 +307,7 @@
"metadata": {},
"outputs": [],
"source": [
"interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()"
"interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()"
]
},
{
Expand Down
7 changes: 5 additions & 2 deletions examples/notebooks/models_training/beta_vae_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision.datasets as datasets\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
Expand Down Expand Up @@ -254,7 +257,7 @@
"metadata": {},
"outputs": [],
"source": [
"reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()"
"reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()"
]
},
{
Expand Down Expand Up @@ -302,7 +305,7 @@
"metadata": {},
"outputs": [],
"source": [
"interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()"
"interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision.datasets as datasets\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
Expand Down Expand Up @@ -256,7 +259,7 @@
"metadata": {},
"outputs": [],
"source": [
"reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()"
"reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()"
]
},
{
Expand Down Expand Up @@ -304,7 +307,7 @@
"metadata": {},
"outputs": [],
"source": [
"interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()"
"interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()"
]
},
{
Expand Down
9 changes: 6 additions & 3 deletions examples/notebooks/models_training/factor_vae_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision.datasets as datasets\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
Expand Down Expand Up @@ -254,7 +257,7 @@
"metadata": {},
"outputs": [],
"source": [
"reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()"
"reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()"
]
},
{
Expand Down Expand Up @@ -302,7 +305,7 @@
"metadata": {},
"outputs": [],
"source": [
"interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()"
"interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()"
]
},
{
Expand Down Expand Up @@ -340,7 +343,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.8.12"
},
"orig_nbformat": 4
},
Expand Down
9 changes: 6 additions & 3 deletions examples/notebooks/models_training/hvae_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision.datasets as datasets\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
Expand Down Expand Up @@ -255,7 +258,7 @@
"metadata": {},
"outputs": [],
"source": [
"reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()"
"reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()"
]
},
{
Expand Down Expand Up @@ -303,7 +306,7 @@
"metadata": {},
"outputs": [],
"source": [
"interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()"
"interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()"
]
},
{
Expand Down Expand Up @@ -341,7 +344,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.11"
"version": "3.8.12"
},
"orig_nbformat": 4
},
Expand Down
9 changes: 6 additions & 3 deletions examples/notebooks/models_training/info_vae_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision.datasets as datasets\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
Expand Down Expand Up @@ -263,7 +266,7 @@
"metadata": {},
"outputs": [],
"source": [
"reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()"
"reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()"
]
},
{
Expand Down Expand Up @@ -311,7 +314,7 @@
"metadata": {},
"outputs": [],
"source": [
"interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()"
"interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()"
]
},
{
Expand Down Expand Up @@ -349,7 +352,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.11"
"version": "3.8.12"
},
"orig_nbformat": 4
},
Expand Down
Loading

0 comments on commit 32b76a2

Please sign in to comment.