Skip to content

Commit

Permalink
Create lka.py
Browse files Browse the repository at this point in the history
  • Loading branch information
MenghaoGuo authored Jun 6, 2022
1 parent 731d700 commit 6c54a42
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions code/channel_spatial_attentions/lka.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Visual Attention Network
import jittor as jt
import jittor.nn as nn


class AttentionModule(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv_spatial = nn.Conv2d(
dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
self.conv1 = nn.Conv2d(dim, dim, 1)

def execute(self, x):
u = x.clone()
attn = self.conv0(x)
attn = self.conv_spatial(attn)
attn = self.conv1(attn)

return u * attn


class SpatialAttention(nn.Module):
def __init__(self, d_model):
super().__init__()

self.proj_1 = nn.Conv2d(d_model, d_model, 1)
self.activation = nn.GELU()
self.spatial_gating_unit = AttentionModule(d_model)
self.proj_2 = nn.Conv2d(d_model, d_model, 1)

def execute(self, x):
shorcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
x = x + shorcut
return x


def main():
attention_block = SpatialAttention(64)
input = jt.rand([4, 64, 32, 32])
output = attention_block(input)
print(input.size(), output.size())


if __name__ == '__main__':
main()

0 comments on commit 6c54a42

Please sign in to comment.