-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathANROT_HELANet.py
36 lines (32 loc) · 1.33 KB
/
ANROT_HELANet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class ANROT_HELANet(nn.Module):
def __init__(self, backbone: nn.Module):
super(ANROT_HELANet, self).__init__()
self.backbone = backbone
def forward(
self,
support_images: torch.Tensor,
support_labels: torch.Tensor,
query_images: torch.Tensor,
) -> torch.Tensor:
"""
Predict query labels using labeled support images.
"""
# Extract the features of support and query images
z_support = self.backbone.forward(support_images)
z_query = self.backbone.forward(query_images)
# Infer the number of different classes from the labels of the support set
n_way = len(torch.unique(support_labels))
# Prototype i is the mean of all instances of features corresponding to labels == i
z_proto = torch.cat(
[
z_support[torch.nonzero(support_labels == label)].mean(0)
for label in range(n_way)
]
)
# Compute the HELLINGER distance from queries to prototypes
dists = Hellinger_dist(z_query, z_proto)
# And here is the super complicated operation to transform those distances into classification scores!
scores = -dists
return scores
convolutional_network_output = resnet12()
model = ANROT_HELANet(convolutional_network_output).cuda()