Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@
<br>

### Overview
Pytorch implementation of Deep Variational Information Bottleneck([paper], [original code])
More Modern Adaptation of the Pytorch implementation of Deep Variational Information Bottleneck([paper], [original code])

Original Pytorch Implementation: ([https://github.com/1Konny/VIB-pytorch])

![ELBO](misc/ELBO.PNG)
![monte_carlo](misc/monte_carlo.PNG)
<br>

### Dependencies
### Setup
1. Download Mini Conda: https://www.anaconda.com/docs/getting-started/miniconda/main
2. Create and activate Conda Environment
```
conda create -n myenv python=3.11
conda activate myenv
```
3. Install the required packages:
```
python 3.6.4
pytorch 0.3.1.post2
tensorboardX(optional)
tensorflow(optional)
pip install -r requirements.txt
```
<br>

Expand All @@ -22,6 +28,10 @@ tensorflow(optional)
```
python main.py --mode train --beta 1e-3 --tensorboard True --env_name [NAME]
```
2. TensorBoard
```
tensorboard --logdir=summary/[NAME]/
```
2. test
```
python main.py --mode test --env_name [NAME] --load_ckpt best_acc.tar
Expand Down
45 changes: 45 additions & 0 deletions data_process/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch, os
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

class UnknownDatasetError(Exception):
def __init__(self, name):
super().__init__(f"Unknown dataset: {name}")

def return_data(args):
name = args.dataset
dset_dir = args.dset_dir
batch_size = args.batch_size
num_workers = os.cpu_count() if torch.cuda.is_available() else 0
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

if name == 'MNIST':
root = os.path.join(dset_dir, 'MNIST')
dset = MNIST
train_kwargs = {'root': root, 'train': True, 'transform': transform, 'download': True}
test_kwargs = {'root': root, 'train': False, 'transform': transform, 'download': False}
else:
raise UnknownDatasetError()

train_loader = DataLoader(dset(**train_kwargs), batch_size=batch_size, shuffle=True,
num_workers=num_workers, drop_last=True)
test_loader = DataLoader(dset(**test_kwargs), batch_size=batch_size, shuffle=False,
num_workers=num_workers, drop_last=False)

return {'train': train_loader, 'test': test_loader}

if __name__ == '__main__' :
import argparse
os.chdir('..')

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='MNIST', type=str)
parser.add_argument('--dset_dir', default='datasets', type=str)
parser.add_argument('--batch_size', default=64, type=int)
args = parser.parse_args()

data_loader = return_data(args)
import ipdb; ipdb.set_trace()
57 changes: 0 additions & 57 deletions datasets/datasets.py

This file was deleted.

84 changes: 40 additions & 44 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from utils import cuda

import time
from numbers import Number

class ToyNet(nn.Module):

Expand All @@ -19,50 +14,51 @@ def __init__(self, K=256):
nn.ReLU(True),
nn.Linear(1024, 1024),
nn.ReLU(True),
nn.Linear(1024, 2*self.K))
nn.Linear(1024, 2*self.K)
)

self.decode = nn.Sequential(
nn.Linear(self.K, 10))

def forward(self, x, num_sample=1):
if x.dim() > 2 : x = x.view(x.size(0),-1)

nn.Linear(self.K, 10)
)

self.apply(self._init_weights) #Not always necessary

def forward(self, x,
num_sample=1, stab_factor = 5,
beta = 1):
# Flatten the input for the MLP
x = x.flatten(start_dim = 1)

# Encode the inpute
statistics = self.encode(x)
mu = statistics[:,:self.K]
std = F.softplus(statistics[:,self.K:]-5,beta=1)

encoding = self.reparametrize_n(mu,std,num_sample)
logit = self.decode(encoding)
# Extract the Distribution Discriptor
mu, raw_std = statistics[:,:self.K], statistics[:,self.K:]
std = F.softplus(raw_std-stab_factor,
beta=beta)

# Encode (Random Sampling) -> Decode
encoding = self.reparametrize_n(mu,std,
num_sample)
logits = self.decode(encoding)

if num_sample == 1 : pass
elif num_sample > 1 : logit = F.softmax(logit, dim=2).mean(0)
if num_sample > 1:
logits = F.softmax(logits, dim=-1).mean(dim=0)

return (mu, std), logit
return (mu, std), logits

def reparametrize_n(self, mu, std, n=1):
# reference :
# http://pytorch.org/docs/0.3.1/_modules/torch/distributions.html#Distribution.sample_n
def expand(v):
if isinstance(v, Number):
return torch.Tensor([v]).expand(n, 1)
else:
return v.expand(n, *v.size())

if n != 1 :
mu = expand(mu)
std = expand(std)

eps = Variable(cuda(std.data.new(std.size()).normal_(), std.is_cuda))

return mu + eps * std

def weight_init(self):
for m in self._modules:
xavier_init(self._modules[m])


def xavier_init(ms):
for m in ms :
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
nn.init.xavier_uniform(m.weight,gain=nn.init.calculate_gain('relu'))
m.bias.data.zero_()
if n == 1:
eps = torch.randn_like(std)
return mu + eps * std
else:
mu_exp = mu.unsqueeze(0).expand(n, *mu.shape)
std_exp = std.unsqueeze(0).expand(n, *std.shape)
eps = torch.randn_like(std_exp)
return mu_exp + eps * std_exp

def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
if m.bias is not None:
nn.init.zeros_(m.bias)
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch>=2.2.1
ipdb==0.13.13
tensorboardX==2.6.2.2
tensorboard==2.19.0
Loading