From e81c3edfaf6b3399eb5a57470d80d73b00654a62 Mon Sep 17 00:00:00 2001 From: Lorenzo Policar Date: Thu, 12 Oct 2023 13:33:33 +1000 Subject: [PATCH] added a classifier for AD or NC classification --- module.py | 66 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 19 deletions(-) diff --git a/module.py b/module.py index 25e0895419..9aa9921be7 100644 --- a/module.py +++ b/module.py @@ -36,17 +36,11 @@ def forward(self, x): # Create position encoding with the same shape as the input enc = self.ff.unsqueeze(0).expand(x.shape[0], -1, -1).type_as(x) - # Concatenate the position encoding along the channel dimension - x = torch.cat([x, enc], dim=1) + # Concatenate the position encoding along the channel dimension and flatten the input + x = torch.cat([x, enc], dim=1).flatten(2) - # Flatten the input - x = x.flatten(2) - - # Apply 1D convolution - x = self.conv(x) - - # Permute the dimensions for the final output - x = x.permute(2, 0, 1) + # Apply 1D convolution and permute the dimensions for the final output + x = self.conv(x).permute(2, 0, 1) return x @@ -123,7 +117,7 @@ def forward(self, latent, image): out = self.drop(out) # Compute the second residual connection - out = out + resid + out += resid return out @@ -140,8 +134,6 @@ class LatentTransformer(nn.Module): Returns: Tensor: Transformed latent tensor of shape [LATENT_DIM x BATCH_SIZE x EMBED_DIM]. - - """ def __init__(self, embed_dim, mlp_dim, n_heads, dropout, n_layers): @@ -185,17 +177,16 @@ class PerceiverBlock(nn.Module): trnfr_layers (int): Number of layers in the latent transformer. """ - - def __init__(self, embed_dim, attn_mlp_dim, trnfr_mlp_dim, trnfr_heads, dropout, trnfr_layers): + def __init__(self, embed_dim, attention_mlp_dim, transformer_mlp_dim, transformer_heads, dropout, transformer_layers): super().__init() # Cross-Attention layer - self.cross_attn = PerceiverAttention( - embed_dim, attn_mlp_dim, n_heads=1, dropout=dropout) + self.cross_attention = PerceiverAttentionBlock( + embed_dim, attention_mlp_dim, n_heads=1, dropout=dropout) # Latent Transformer module self.latent_transformer = LatentTransformer( - embed_dim, trnfr_mlp_dim, trnfr_heads, dropout, trnfr_layers) + embed_dim, transformer_mlp_dim, transformer_heads, dropout, transformer_layers) def forward(self, x, l): """ @@ -210,12 +201,49 @@ def forward(self, x, l): """ # Apply cross-attention on the input and latent tensor - l = self.cross_attn(x, l) + l = self.cross_attention(x, l) # Apply the latent transformer l = self.latent_transformer(l) return l +class Classifier(nn.Module): + """ + Classifier for Perceiver model for binary classification (AD or NC) + + Args: + embed_dim (int): Dimension of the embedded representations. + n_classes (int): Number of target classes. Default to 2 to classify AD or NC + """ + + def __init__(self, embed_dim, n_classes=2): + super().__init__() + + # First fully connected layer + self.fc1 = nn.Linear(embed_dim, embed_dim) + + # Second fully connected layer for classification + self.fc2 = nn.Linear(embed_dim, n_classes) + + def forward(self, x): + """ + Forward pass of the Classifier module. + + Args: + x (Tensor): Input tensor of shape [LATENT_DIM x BATCH_SIZE x EMBED_DIM]. + + Returns: + Tensor: Output tensor for classification of shape [n_classes]. + + """ + + # Apply the first fully connected layer and take the mean along the batch dimension + x = self.fc1(x).mean(dim=0) + + # Apply the second fully connected layer for classification + x = self.fc2(x) + + return x