Skip to content

Commit

Permalink
added a classifier for AD or NC classification
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzopolicar committed Oct 12, 2023
1 parent b77a133 commit e81c3ed
Showing 1 changed file with 47 additions and 19 deletions.
66 changes: 47 additions & 19 deletions module.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,11 @@ def forward(self, x):
# Create position encoding with the same shape as the input
enc = self.ff.unsqueeze(0).expand(x.shape[0], -1, -1).type_as(x)

# Concatenate the position encoding along the channel dimension
x = torch.cat([x, enc], dim=1)
# Concatenate the position encoding along the channel dimension and flatten the input
x = torch.cat([x, enc], dim=1).flatten(2)

# Flatten the input
x = x.flatten(2)

# Apply 1D convolution
x = self.conv(x)

# Permute the dimensions for the final output
x = x.permute(2, 0, 1)
# Apply 1D convolution and permute the dimensions for the final output
x = self.conv(x).permute(2, 0, 1)

return x

Expand Down Expand Up @@ -123,7 +117,7 @@ def forward(self, latent, image):
out = self.drop(out)

# Compute the second residual connection
out = out + resid
out += resid

return out

Expand All @@ -140,8 +134,6 @@ class LatentTransformer(nn.Module):
Returns:
Tensor: Transformed latent tensor of shape [LATENT_DIM x BATCH_SIZE x EMBED_DIM].
"""

def __init__(self, embed_dim, mlp_dim, n_heads, dropout, n_layers):
Expand Down Expand Up @@ -185,17 +177,16 @@ class PerceiverBlock(nn.Module):
trnfr_layers (int): Number of layers in the latent transformer.
"""

def __init__(self, embed_dim, attn_mlp_dim, trnfr_mlp_dim, trnfr_heads, dropout, trnfr_layers):
def __init__(self, embed_dim, attention_mlp_dim, transformer_mlp_dim, transformer_heads, dropout, transformer_layers):
super().__init()

# Cross-Attention layer
self.cross_attn = PerceiverAttention(
embed_dim, attn_mlp_dim, n_heads=1, dropout=dropout)
self.cross_attention = PerceiverAttentionBlock(
embed_dim, attention_mlp_dim, n_heads=1, dropout=dropout)

# Latent Transformer module
self.latent_transformer = LatentTransformer(
embed_dim, trnfr_mlp_dim, trnfr_heads, dropout, trnfr_layers)
embed_dim, transformer_mlp_dim, transformer_heads, dropout, transformer_layers)

def forward(self, x, l):
"""
Expand All @@ -210,12 +201,49 @@ def forward(self, x, l):
"""
# Apply cross-attention on the input and latent tensor
l = self.cross_attn(x, l)
l = self.cross_attention(x, l)

# Apply the latent transformer
l = self.latent_transformer(l)

return l

class Classifier(nn.Module):
"""
Classifier for Perceiver model for binary classification (AD or NC)
Args:
embed_dim (int): Dimension of the embedded representations.
n_classes (int): Number of target classes. Default to 2 to classify AD or NC
"""

def __init__(self, embed_dim, n_classes=2):
super().__init__()

# First fully connected layer
self.fc1 = nn.Linear(embed_dim, embed_dim)

# Second fully connected layer for classification
self.fc2 = nn.Linear(embed_dim, n_classes)

def forward(self, x):
"""
Forward pass of the Classifier module.
Args:
x (Tensor): Input tensor of shape [LATENT_DIM x BATCH_SIZE x EMBED_DIM].
Returns:
Tensor: Output tensor for classification of shape [n_classes].
"""

# Apply the first fully connected layer and take the mean along the batch dimension
x = self.fc1(x).mean(dim=0)

# Apply the second fully connected layer for classification
x = self.fc2(x)

return x


0 comments on commit e81c3ed

Please sign in to comment.