Skip to content

Commit

Permalink
added comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrvinod committed Jun 4, 2018
1 parent 9e6dabb commit 01f749f
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 46 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
This is a **PyTorch Tutorial to Image Captioning**.
This is a **[PyTorch](https://pytorch.org) Tutorial to Image Captioning**.

This is the first of a series of tutorials I plan to write about _implementing_ cool models on your own with the amazing [PyTorch](https://pytorch.org) library.
This is the first in a series of tutorials I plan to write about _implementing_ cool models on your own with the amazing PyTorch library.

Basic knowledge of PyTorch, convolutional and recurrent neural networks is assumed.

Expand All @@ -24,9 +24,9 @@ I'm using `PyTorch 0.4` in `Python 3.6`.

**To build a model that can generate a descriptive caption for an image we provide it.**

In the interest of keeping things simple, let's choose to implement the [_Show, Attend, and Tell_](https://arxiv.org/abs/1502.03044) paper. This is by no means the current state-of-the-art, but is still pretty darn amazing.
In the interest of keeping things simple, let's implement the [_Show, Attend, and Tell_](https://arxiv.org/abs/1502.03044) paper. This is by no means the current state-of-the-art, but is still pretty darn amazing.

**This model learns _where_ to look.**
This model learns _where_ to look.

As you generate a caption, word by word, you can see the the model's gaze shifting across the image.

Expand Down Expand Up @@ -78,7 +78,7 @@ There are more examples at the [end of the tutorial](https://github.com/sgrvinod

# Overview

In this section, I will present a broad overview of this model. I don't really get into the _minutiae_ here - feel free to skip to the implementation section and commented code for details.
In this section, I will present a broad overview of this model. If you're already familiar with it, you can skip straight to the implementation section or the commented code.

### Encoder

Expand Down
29 changes: 26 additions & 3 deletions caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@


def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3):
"""
Reads an image and captions it with beam search.
:param encoder: encoder model
:param decoder: decoder model
:param image_path: path to image
:param word_map: word map
:param beam_size: number of sequences to consider at each decode-step
:return: caption, weights for visualization
"""

k = beam_size
vocab_size = len(word_map)

# Read image and process
img = imread(image_path)
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
Expand All @@ -27,9 +42,6 @@ def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=
transform = transforms.Compose([normalize])
image = transform(img) # (3, 256, 256)

k = beam_size
vocab_size = len(word_map)

# Encode
image = image.unsqueeze(0) # (1, 3, 256, 256)
encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim)
Expand Down Expand Up @@ -136,6 +148,17 @@ def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=


def visualize_att(image_path, seq, alphas, rev_word_map, smooth=True):
"""
Visualizes caption with weights at every word.
Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb
:param image_path: path to image that has been captioned
:param seq: caption
:param alphas: weights
:param rev_word_map: reverse word mapping, i.e. ix2word
:param smooth: smooth weights?
"""
image = Image.open(image_path)
image = image.resize([14 * 24, 14 * 24], Image.LANCZOS)

Expand Down
11 changes: 11 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@


class CaptionDataset(Dataset):
"""
A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.
"""

def __init__(self, data_folder, data_name, split, transform=None):
"""
:param data_folder: folder where data files are stored
:param data_name: base name of processed datasets
:param split: split, one of 'TRAIN', 'VAL', or 'TEST'
:param transform: image transform pipeline
"""
self.split = split
assert self.split in {'TRAIN', 'VAL', 'TEST'}

Expand Down
106 changes: 87 additions & 19 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,43 @@


class Encoder(nn.Module):
"""
Encoder.
"""

def __init__(self, encoded_image_size=14):
super(Encoder, self).__init__()
self.enc_image_size = encoded_image_size

resnet = torchvision.models.resnet101(pretrained=True)
resnet = torchvision.models.resnet101(pretrained=True) # pretrained ImageNet ResNet-101

# Remove linear and pool layers (since we're not doing classification)
modules = list(resnet.children())[:-2]
self.resnet = nn.Sequential(*modules)

# Resize image to fixed size to allow input images of variable size
self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

self.fine_tune()

def forward(self, images):
# images.shape = (batch_size, 3, image_size, image_size)
"""
Forward propagation.
:param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
:return: encoded images
"""
out = self.resnet(images) # (batch_size, 2048, image_size/32, image_size/32)
out = self.adaptive_pool(out) # (batch_size, 2048, encoded_image_size, encoded_image_size)
out = out.permute(0, 2, 3, 1) # (batch_size, encoded_image_size, encoded_image_size, 2048)
return out

def fine_tune(self, fine_tune=True):
"""
Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.
:param fine_tune: Allow?
"""
for p in self.resnet.parameters():
p.requires_grad = False
# If fine-tuning, only fine-tune convolutional blocks 2 through 4
Expand All @@ -35,18 +52,31 @@ def fine_tune(self, fine_tune=True):


class Attention(nn.Module):
"""
Attention Network.
"""

def __init__(self, encoder_dim, decoder_dim, attention_dim):
"""
:param encoder_dim: feature size of encoded images
:param decoder_dim: size of decoder's RNN
:param attention_dim: size of the attention network
"""
super(Attention, self).__init__()
self.encoder_att = nn.Linear(encoder_dim, attention_dim)
self.decoder_att = nn.Linear(decoder_dim, attention_dim)
self.full_att = nn.Linear(attention_dim, 1)
self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image
self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output
self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights

def forward(self, encoder_out, decoder_hidden):
# encoder_out.shape = (batch_size, num_pixels, encoder_dim)
# decoder_hidden.shape = (batch_size, decoder_dim)
"""
Forward propagation.
:param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
:param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
:return: attention weighted encoding, weights
"""
att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim)
att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim)
att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels)
Expand All @@ -57,8 +87,21 @@ def forward(self, encoder_out, decoder_hidden):


class DecoderWithAttention(nn.Module):
"""
Decoder.
"""

def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, decoder_layers=1, encoder_dim=2048,
dropout=0.5):
"""
:param attention_dim: size of attention network
:param embed_dim: embedding size
:param decoder_dim: size of decoder's RNN
:param vocab_size: size of vocabulary
:param decoder_layers: number of layers in the decoder
:param encoder_dim: feature size of encoded images
:param dropout: dropout
"""
super(DecoderWithAttention, self).__init__()

self.encoder_dim = encoder_dim
Expand All @@ -69,40 +112,65 @@ def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, decoder_la
self.decoder_layers = decoder_layers
self.dropout = dropout

self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network

self.embedding = nn.Embedding(vocab_size, embed_dim)
self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer
self.dropout = nn.Dropout(p=self.dropout)
self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, decoder_layers)
self.init_h = nn.Linear(encoder_dim, decoder_dim)
self.init_c = nn.Linear(encoder_dim, decoder_dim)
self.f_beta = nn.Linear(decoder_dim, encoder_dim)
self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, decoder_layers) # decoding LSTMCell
self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell
self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell
self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate
self.sigmoid = nn.Sigmoid()
self.fc = nn.Linear(decoder_dim, vocab_size)
self.init_weights()
self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary
self.init_weights() # initialize some layers with the uniform distribution

def init_weights(self):
"""
Initializes some parameters with values from the uniform distribution, for easier convergence.
"""
self.embedding.weight.data.uniform_(-0.1, 0.1)
self.fc.bias.data.fill_(0)
self.fc.weight.data.uniform_(-0.1, 0.1)

def load_pretrained_embeddings(self, embeddings):
"""
Loads embedding layer with pre-trained embeddings.
:param embeddings: pre-trained embeddings
"""
self.embedding.weight = nn.Parameter(embeddings)

def fine_tune_embeddings(self, fine_tune=True):
"""
Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings).
:param fine_tune: Allow?
"""
for p in self.embedding.parameters():
p.requires_grad = fine_tune

def init_hidden_state(self, encoder_out):
"""
Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
:param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
:return: hidden state, cell state
"""
mean_encoder_out = encoder_out.mean(dim=1)
h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim)
c = self.init_c(mean_encoder_out)
return h, c

def forward(self, encoder_out, encoded_captions, caption_lengths):
# encoder_out.shape = (batch_size, image_size, image_size, encoder_dim), image_size being the pixel width/height
# encoded_captions.shape = (batch_size, max_caption_length)
# caption_lengths.shape = (batch_size, 1)
"""
Forward propagation.
:param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
:param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
:param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
:return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
"""

batch_size = encoder_out.size(0)
encoder_dim = encoder_out.size(-1)
vocab_size = self.vocab_size
Expand Down
Loading

0 comments on commit 01f749f

Please sign in to comment.