-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit d7ef6b1
Showing
4 changed files
with
532 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import os | ||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
from torch.optim import Optimizer | ||
|
||
from utils.modules import Encoder, Decoder | ||
|
||
import torch.optim as optimizer_module | ||
|
||
|
||
# utility function to initialize an optimizer from its name | ||
def init_optimizer(optimizer_name, params): | ||
assert hasattr(optimizer_module, optimizer_name) | ||
OptimizerClass = getattr(optimizer_module, optimizer_name) | ||
return OptimizerClass(params) | ||
|
||
|
||
########################## | ||
# Generic training class # | ||
########################## | ||
class Trainer(nn.Module): | ||
def __init__(self, log_loss_every=10, writer=None): | ||
super(Trainer, self).__init__() | ||
self.iterations = 0 | ||
|
||
self.writer = writer | ||
self.log_loss_every = log_loss_every | ||
|
||
self.loss_items = {} | ||
|
||
def get_device(self): | ||
return list(self.parameters())[0].device | ||
|
||
def train_step(self, data): | ||
# Set all the models in training mode | ||
self.train(True) | ||
|
||
# Log the values in loss_items every log_loss_every iterations | ||
if not (self.writer is None): | ||
if (self.iterations + 1) % self.log_loss_every == 0: | ||
self._log_loss() | ||
|
||
# Move the data to the appropriate device | ||
device = self.get_device() | ||
|
||
for i, item in enumerate(data): | ||
data[i] = item.to(device) | ||
|
||
# Perform the training step and update the iteration count | ||
self._train_step(data) | ||
self.iterations += 1 | ||
|
||
def _add_loss_item(self, name, value): | ||
assert isinstance(name, str) | ||
assert isinstance(value, float) or isinstance(value, int) | ||
|
||
if not (name in self.loss_items): | ||
self.loss_items[name] = [] | ||
|
||
self.loss_items[name].append(value) | ||
|
||
def _log_loss(self): | ||
# Log the expected value of the items in loss_items | ||
for key, values in self.loss_items.items(): | ||
self.writer.add_scalar(tag=key, scalar_value=np.mean(values), global_step=self.iterations) | ||
self.loss_items[key] = [] | ||
|
||
def save(self, model_path): | ||
items_to_save = self._get_items_to_store() | ||
items_to_save['iterations'] = self.iterations | ||
|
||
# Save the model and increment the checkpoint count | ||
torch.save(items_to_save, model_path) | ||
|
||
def load(self, model_path): | ||
items_to_load = torch.load(model_path) | ||
for key, value in items_to_load.items(): | ||
assert hasattr(self, key) | ||
attribute = getattr(self, key) | ||
|
||
# Load the state dictionary for the stored modules and optimizers | ||
if isinstance(attribute, nn.Module) or isinstance(attribute, Optimizer): | ||
attribute.load_state_dict(value) | ||
|
||
# Move the optimizer parameters to the same correct device. | ||
# see https://github.com/pytorch/pytorch/issues/2830 for further details | ||
if isinstance(attribute, Optimizer): | ||
device = list(value['state'].values())[0]['exp_avg'].device # Hack to identify the device | ||
for state in attribute.state.values(): | ||
for k, v in state.items(): | ||
if isinstance(v, torch.Tensor): | ||
state[k] = v.to(device) | ||
|
||
# Otherwise just copy the value | ||
else: | ||
setattr(self, key, value) | ||
|
||
def _get_items_to_store(self): | ||
return dict() | ||
|
||
def _train_step(self, data): | ||
raise NotImplemented() | ||
|
||
|
||
########################## | ||
# Representation Trainer # | ||
########################## | ||
|
||
# Generic class to train an model with a (stochastic) neural network encoder | ||
|
||
class RepresentationTrainer(Trainer): | ||
def __init__(self, z_dim, optimizer_name='Adam', encoder_lr=1e-4, **params): | ||
super(RepresentationTrainer, self).__init__(**params) | ||
|
||
self.z_dim = z_dim | ||
|
||
# Intialization of the encoder | ||
self.encoder = Encoder(z_dim) | ||
|
||
self.opt = init_optimizer(optimizer_name, [ | ||
{'params': self.encoder.parameters(), 'lr': encoder_lr}, | ||
]) | ||
|
||
def _get_items_to_store(self): | ||
items_to_store = super(RepresentationTrainer, self)._get_items_to_store() | ||
|
||
# store the encoder and optimizer parameters | ||
items_to_store['encoder'] = self.encoder.state_dict() | ||
items_to_store['opt'] = self.opt.state_dict() | ||
|
||
return items_to_store | ||
|
||
def _train_step(self, data): | ||
loss = self._compute_loss(data) | ||
|
||
self.opt.zero_grad() | ||
loss.backward() | ||
self.opt.step() | ||
|
||
def _compute_loss(self, data): | ||
raise NotImplemented | ||
|
||
|
||
########################## | ||
# Merge Trainer # | ||
########################## | ||
|
||
class MergeTrainer(Trainer): | ||
def __init__(self, z_dim, optimizer_name='Adam', encoder_lr=1e-4, decoder_lr=1e-4, **params): | ||
super(MergeTrainer, self).__init__(**params) | ||
|
||
self.z_dim = z_dim | ||
|
||
# Intialization of the encoder | ||
self.encoder_x_s = Encoder(z_dim) | ||
self.encoder_x_p = Encoder(z_dim) | ||
self.encoder_y_s = Encoder(z_dim) | ||
self.encoder_y_p = Encoder(z_dim) | ||
|
||
# Intialization of the decoder | ||
self.decoder_x_s = Decoder(z_dim) | ||
self.decoder_x_p = Decoder(z_dim) | ||
self.decoder_y_s = Decoder(z_dim) | ||
self.decoder_y_p = Decoder(z_dim) | ||
|
||
self.opt = init_optimizer(optimizer_name, [ | ||
{'params': self.encoder_x_s.parameters(), 'lr': encoder_lr}, | ||
{'params': self.encoder_x_p.parameters(), 'lr': encoder_lr}, | ||
{'params': self.encoder_y_s.parameters(), 'lr': encoder_lr}, | ||
{'params': self.encoder_y_p.parameters(), 'lr': encoder_lr}, | ||
|
||
{'params': self.decoder_x_s.parameters(), 'lr': decoder_lr}, | ||
{'params': self.decoder_x_p.parameters(), 'lr': decoder_lr}, | ||
{'params': self.decoder_y_s.parameters(), 'lr': decoder_lr}, | ||
{'params': self.decoder_y_p.parameters(), 'lr': decoder_lr}, | ||
]) | ||
|
||
def _get_items_to_store(self): | ||
items_to_store = super(MergeTrainer, self)._get_items_to_store() | ||
|
||
# store the encoder and optimizer parameters | ||
items_to_store['encoder_x_s'] = self.encoder_x_s.state_dict() | ||
items_to_store['encoder_x_p'] = self.encoder_x_p.state_dict() | ||
items_to_store['encoder_y_s'] = self.encoder_y_s.state_dict() | ||
items_to_store['encoder_y_p'] = self.encoder_y_p.state_dict() | ||
items_to_store['opt'] = self.opt.state_dict() | ||
|
||
return items_to_store | ||
|
||
def _train_step(self, data): | ||
loss = self._compute_loss(data) | ||
|
||
self.opt.zero_grad() | ||
loss.backward() | ||
self.opt.step(0) | ||
|
||
def _compute_loss(self, data): | ||
raise NotImplemented |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# from training.multiview_infomax import MVInfoMaxTrainer | ||
from utils.schedulers import ExponentialScheduler | ||
from utils.schedulers import LinearScheduler | ||
from training.base import MergeTrainer | ||
from utils.modules import MIEstimator, Encoder, Decoder, Feature_extractor, Alex_extractor, Alex_extractor_avg, Alex_extractor_fea | ||
import torch | ||
import torch.nn as nn | ||
from torch.distributions import Normal, Independent | ||
from torchvision import models | ||
############### | ||
# DVIB Trainer # | ||
############### | ||
|
||
|
||
class DVIBTrainer(MergeTrainer): | ||
def __init__(self, beta=1, | ||
lambda_start_value=1e-5, lambda_end_value=1, | ||
lambda_n_iterations=100000, lambda_start_iteration=50000, | ||
miest_lr=1e-4,**params): | ||
# The neural networks architectures and initialization procedure is analogous to Multi-View InfoMax | ||
super(DVIBTrainer, self).__init__(**params) | ||
|
||
# Definition of the scheduler to update the value of the regularization coefficient beta over time | ||
self.beta = beta | ||
self.lambda_scheduler = ExponentialScheduler(start_value=lambda_start_value, end_value=lambda_end_value, | ||
n_iterations=lambda_n_iterations, start_iteration=lambda_start_iteration) | ||
#self.lambda_scheduler = LinearScheduler(start_value=lambda_start_value, end_value=lambda_end_value, | ||
# n_iterations=lambda_n_iterations, start_iteration=lambda_start_iteration) | ||
|
||
# Initialization of the mutual information estimation network | ||
self.mi_estimator_x = MIEstimator(self.z_dim, self.z_dim) | ||
self.mi_estimator_y = MIEstimator(self.z_dim, self.z_dim) | ||
|
||
# Adding the parameters of the estimator to the optimizer | ||
|
||
self.opt.add_param_group( | ||
{'params': self.mi_estimator_x.parameters(), 'lr': miest_lr}, | ||
) | ||
self.opt.add_param_group( | ||
{'params': self.mi_estimator_y.parameters(), 'lr': miest_lr}, | ||
) | ||
|
||
# Defining the prior distribution as a factorized normal distribution | ||
self.mu = nn.Parameter(torch.zeros(self.z_dim), requires_grad=False) | ||
self.sigma = nn.Parameter(torch.ones(self.z_dim), requires_grad=False) | ||
self.prior = Normal(loc=self.mu, scale=self.sigma) | ||
self.prior = Independent(self.prior, 1) | ||
|
||
|
||
# if x and y follow the same distribution, encoders for shared representation can share parameters | ||
# self.encoder_y_s = self.encoder_x_s # reuse | ||
self.mu2 = nn.Parameter(torch.full((self.z_dim,),1.), requires_grad=False) | ||
self.mu3 = nn.Parameter(torch.full((self.z_dim,),2.), requires_grad=False) | ||
self.prior2 = Normal(loc=self.mu2, scale=self.sigma) | ||
self.prior2 = Independent(self.prior2, 1) | ||
self.prior3 = Normal(loc=self.mu3, scale=self.sigma) | ||
self.prior3 = Independent(self.prior3, 1) | ||
|
||
#### | ||
#create pretrained resnet18, Alexnet | ||
#### | ||
|
||
self.res = Feature_extractor(output_layer='avgpool') | ||
self.alex = Alex_extractor() | ||
|
||
|
||
|
||
###start lambda | ||
self.labda_start = 0 | ||
def _compute_loss(self, data): | ||
# Read the two views v1 and v2 and ignore the label y | ||
x, y, _, _ = data | ||
x = self.res(x) | ||
y = self.res(y) | ||
#x = self.alex(x) | ||
#y = self.alex(y) | ||
|
||
# Read new dataset, views v1 and v2 | ||
# Encode a batch of data | ||
p_z_xs_given_x = self.encoder_x_s(x) # [z_dim * 2] | ||
p_z_xp_given_x = self.encoder_x_p(x) # [z_dim * 2] | ||
p_z_ys_given_y = self.encoder_y_s(y) # [z_dim * 2] | ||
p_z_yp_given_y = self.encoder_y_p(y) # [z_dim * 2] | ||
|
||
# Sample from the posteriors with reparametrization | ||
z_xs = p_z_xs_given_x.rsample() | ||
z_xp = p_z_xp_given_x.rsample() | ||
z_ys = p_z_ys_given_y.rsample() | ||
z_yp = p_z_yp_given_y.rsample() | ||
|
||
# Reconstruction loss from private and shared latents | ||
q_x_given_z_xs = self.decoder_x_s(z_xs) | ||
q_x_given_z_xp = self.decoder_x_p(z_xp) | ||
q_y_given_z_ys = self.decoder_y_s(z_ys) | ||
q_y_given_z_yp = self.decoder_y_p(z_yp) | ||
|
||
reconstruct_loss_xs = -q_x_given_z_xs.log_prob(x.view(x.shape[0], -1)) | ||
reconstruct_loss_xp = -q_x_given_z_xp.log_prob(x.view(x.shape[0], -1)) | ||
reconstruct_loss_ys = -q_y_given_z_ys.log_prob(y.view(y.shape[0], -1)) | ||
reconstruct_loss_yp = -q_y_given_z_yp.log_prob(y.view(y.shape[0], -1)) | ||
|
||
neg_I_x = reconstruct_loss_xs.mean() + reconstruct_loss_xp.mean() | ||
neg_I_y = reconstruct_loss_ys.mean() + reconstruct_loss_yp.mean() | ||
|
||
# Mutual information estimation between private and shared | ||
# representations from the same view | ||
|
||
|
||
mi_gradient_x, mi_estimation_x = self.mi_estimator_x(z_xs, z_ys) | ||
mi_gradient_x = mi_gradient_x.mean() | ||
mi_estimation_x = mi_estimation_x.mean() | ||
|
||
mi_gradient_y, mi_estimation_y = self.mi_estimator_y(z_ys, z_xs) | ||
|
||
mi_gradient_y = mi_gradient_y.mean() | ||
mi_estimation_y = mi_estimation_y.mean() | ||
|
||
mi_gradient = mi_gradient_x + mi_gradient_y | ||
mi_estimation = mi_estimation_x + mi_estimation_y | ||
|
||
# Upper bound of mutual information between different views | ||
pos_I_y_zxp = p_z_yp_given_y.log_prob(z_yp) - self.prior2.log_prob(z_xp) | ||
pos_I_x_zyp = p_z_xp_given_x.log_prob(z_xp) - self.prior3.log_prob(z_yp) | ||
|
||
pos_beta_I = pos_I_y_zxp.mean() + pos_I_x_zyp.mean() | ||
|
||
# Update the value of beta according to the policy | ||
|
||
labda = self.lambda_scheduler(self.iterations - self.labda_start) | ||
|
||
|
||
# Logging the components | ||
self._add_loss_item('loss/neg_I_x', neg_I_x.item()) | ||
self._add_loss_item('loss/neg_I_y', neg_I_y.item()) | ||
self._add_loss_item('loss/I_z1_z2', mi_estimation.item()) | ||
self._add_loss_item('loss/softplus_I_12', mi_gradient.item()) | ||
#self._add_loss_item('loss/ckl', neg_ckl.item()) | ||
self._add_loss_item('loss/I_beta', pos_beta_I.item()) | ||
self._add_loss_item('loss/beta', self.beta) | ||
self._add_loss_item('loss/lambda', labda) | ||
|
||
# Computing the loss function | ||
loss = neg_I_x + neg_I_y - labda*mi_gradient + beta*pos_beta_I | ||
|
||
return loss |
Oops, something went wrong.