Skip to content

Commit

Permalink
Interpolations (#38)
Browse files Browse the repository at this point in the history
* add interpolate and reconstruct methods

* update doc

* update tests

* update demo

* black & isort

* prepare release

* update README
  • Loading branch information
clementchadebec authored Jul 22, 2022
1 parent 37d4e07 commit 0f5c0cc
Show file tree
Hide file tree
Showing 51 changed files with 2,909 additions and 37 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ And now build the model
You can also find predefined neural network architectures for the most common data sets (*i.e.* MNIST, CIFAR, CELEBA ...) that can be loaded as follows

```python
>>> for pythae.models.nn.benchmark.mnist import (
>>> from pythae.models.nn.benchmark.mnist import (
... Encoder_Conv_AE_MNIST, # For AE based model (only return embeddings)
... Encoder_Conv_VAE_MNIST, # For VAE based model (return embeddings and log_covariances)
... Decoder_Conv_AE_MNIST
Expand Down
6 changes: 6 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ we can conduct benchmark analysis and reproducible research!
Setup
~~~~~~~~~~~~~

To install the latest stable release of this library run the following using ``pip``

.. code-block:: bash
$ pip install pythae
To install the latest version of this library run the following using ``pip``

.. code-block:: bash
Expand Down
10 changes: 10 additions & 0 deletions docs/source/models/autoencoders/auto_model.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
**********************************
AutoModel
**********************************


.. automodule::
pythae.models.auto_model

.. autoclass:: pythae.models.AutoModel
:members:
2 changes: 2 additions & 0 deletions docs/source/models/autoencoders/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Autoencoders
:maxdepth: 1

baseAE
auto_model
ae
vae
betavae
Expand Down Expand Up @@ -37,6 +38,7 @@ Available Models

.. autosummary::
~pythae.models.BaseAE
~pythae.models.AutoModel
~pythae.models.AE
~pythae.models.VAE
~pythae.models.BetaVAE
Expand Down
1 change: 1 addition & 0 deletions docs/source/models/pythae.models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Available Autoencoders

.. autosummary::
~pythae.models.BaseAE
~pythae.models.AutoModel
~pythae.models.AE
~pythae.models.VAE
~pythae.models.BetaVAE
Expand Down
80 changes: 80 additions & 0 deletions examples/notebooks/models_training/adversarial_ae_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,86 @@
"source": [
"## ... the other samplers work the same"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing reconstructions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# show reconstructions\n",
"fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
"\n",
"for i in range(5):\n",
" for j in range(5):\n",
" axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n",
" axes[i][j].axis('off')\n",
"plt.tight_layout(pad=0.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# show the true data\n",
"fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
"\n",
"for i in range(5):\n",
" for j in range(5):\n",
" axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n",
" axes[i][j].axis('off')\n",
"plt.tight_layout(pad=0.)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing interpolations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# show interpolations\n",
"fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n",
"\n",
"for i in range(5):\n",
" for j in range(10):\n",
" axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n",
" axes[i][j].axis('off')\n",
"plt.tight_layout(pad=0.)"
]
}
],
"metadata": {
Expand Down
80 changes: 80 additions & 0 deletions examples/notebooks/models_training/ae_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,86 @@
"source": [
"## ... the other samplers work the same"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing reconstructions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# show reconstructions\n",
"fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
"\n",
"for i in range(5):\n",
" for j in range(5):\n",
" axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n",
" axes[i][j].axis('off')\n",
"plt.tight_layout(pad=0.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# show the true data\n",
"fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
"\n",
"for i in range(5):\n",
" for j in range(5):\n",
" axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n",
" axes[i][j].axis('off')\n",
"plt.tight_layout(pad=0.)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing interpolations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# show interpolations\n",
"fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n",
"\n",
"for i in range(5):\n",
" for j in range(10):\n",
" axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n",
" axes[i][j].axis('off')\n",
"plt.tight_layout(pad=0.)"
]
}
],
"metadata": {
Expand Down
80 changes: 80 additions & 0 deletions examples/notebooks/models_training/beta_tc_vae_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,86 @@
"source": [
"## ... the other samplers work the same"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing reconstructions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# show reconstructions\n",
"fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
"\n",
"for i in range(5):\n",
" for j in range(5):\n",
" axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n",
" axes[i][j].axis('off')\n",
"plt.tight_layout(pad=0.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# show the true data\n",
"fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
"\n",
"for i in range(5):\n",
" for j in range(5):\n",
" axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n",
" axes[i][j].axis('off')\n",
"plt.tight_layout(pad=0.)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing interpolations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# show interpolations\n",
"fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n",
"\n",
"for i in range(5):\n",
" for j in range(10):\n",
" axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n",
" axes[i][j].axis('off')\n",
"plt.tight_layout(pad=0.)"
]
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 0f5c0cc

Please sign in to comment.