Skip to content

Commit

Permalink
Added a latent transformer modyle with n_layers number of decoder lay…
Browse files Browse the repository at this point in the history
…ers and added a perceiver block consisting of a cross-attention layer and one latent transformer + polished docstrings
  • Loading branch information
lorenzopolicar committed Oct 12, 2023
1 parent 9bfe005 commit b77a133
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
AD_NC
module1.py
100 changes: 97 additions & 3 deletions module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@
import torch.nn as nn

class PositionalImageEmbedding(nn.Module):
def __init__(self, input_channels, embed_dim, bands=4):
"""
"""
Initialise the PositionalImageEmbedding module.
Params:
Args:
input_channels (int): Number of input channels in the image.
embed_dim (int): Dimension of the embedded image representation.
bands (int): Number of Fourier feature bands for positional encoding.
Returns:
(Tensor): Embedded image representation of shape (HEIGHT*WIDTH x BATCH_SIZE x EMBED_DIM).
"""
def __init__(self, input_channels, embed_dim, bands=4):
super().__init__()

# Initialise the Fourier features for positional encoding
Expand Down Expand Up @@ -123,5 +127,95 @@ def forward(self, latent, image):

return out

class LatentTransformer(nn.Module):
"""
Latent Transformer module with multiple decoder layers.
Args:
embed_dim (int): Dimension of the embedded representations.
mlp_dim (int): Dimension of the feedforward network hidden layer.
n_heads (int): Number of attention heads.
dropout (float): Dropout probability.
n_layers (int): Number of decoder layers.
Returns:
Tensor: Transformed latent tensor of shape [LATENT_DIM x BATCH_SIZE x EMBED_DIM].
"""

def __init__(self, embed_dim, mlp_dim, n_heads, dropout, n_layers):
super().__init__()

# Create a list of decoder layers (PerceiverAttention blocks)
self.transformer = nn.ModuleList([
PerceiverAttentionBlock(
embed_dim=embed_dim,
mlp_dim=mlp_dim,
n_heads=n_heads,
dropout=dropout)
for _ in range(n_layers)
])

def forward(self, l):
"""
Forward pass of the LatentTransformer module.
Args:
l (Tensor): Latent tensor of shape [LATENT_DIM x BATCH_SIZE x EMBED_DIM].
Returns:
Tensor: Transformed latent tensor of shape [LATENT_DIM x BATCH_SIZE x EMBED_DIM].
"""
for transform in self.transformer:
l = transform(l, l)
return l

class PerceiverBlock(nn.Module):
"""
Block consisting of one cross-attention layer and one latent transformer.
Args:
embed_dim (int): Dimension of the embedded representations.
attn_mlp_dim (int): Dimension of the cross-attention's feedforward network hidden layer.
trnfr_mlp_dim (int): Dimension of the latent transformer's feedforward network hidden layer.
trnfr_heads (int): Number of attention heads for the latent transformer.
dropout (float): Dropout probability.
trnfr_layers (int): Number of layers in the latent transformer.
"""

def __init__(self, embed_dim, attn_mlp_dim, trnfr_mlp_dim, trnfr_heads, dropout, trnfr_layers):
super().__init()

# Cross-Attention layer
self.cross_attn = PerceiverAttention(
embed_dim, attn_mlp_dim, n_heads=1, dropout=dropout)

# Latent Transformer module
self.latent_transformer = LatentTransformer(
embed_dim, trnfr_mlp_dim, trnfr_heads, dropout, trnfr_layers)

def forward(self, x, l):
"""
Forward pass of the PerceiverBlock module.
Args:
x (Tensor): Input tensor of shape [PIXELS x BATCH_SIZE x EMBED_DIM].
l (Tensor): Latent tensor of shape [LATENT_DIM x BATCH_SIZE x EMBED_DIM].
Returns:
Tensor: Transformed latent tensor of shape [LATENT_DIM x BATCH_SIZE x EMBED_DIM].
"""
# Apply cross-attention on the input and latent tensor
l = self.cross_attn(x, l)

# Apply the latent transformer
l = self.latent_transformer(l)

return l



0 comments on commit b77a133

Please sign in to comment.