Skip to content

Commit

Permalink
Make SpatialAttention a non-layer (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
puddly authored Dec 28, 2024
1 parent ac6502b commit 56e4f7d
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions microwakeword/mixednet.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,29 +231,28 @@ def __call__(self, inputs):
return x


class SpatialAttention(tf.keras.layers.Layer):
class SpatialAttention:
"""Spatial Attention Layer based on CBAM: Convolutional Block Attention Module
https://arxiv.org/pdf/1807.06521v2
Args:
object (_type_): _description_
"""

def __init__(self, kernel_size, ring_buffer_size, **kwargs):
super().__init__(**kwargs)

def __init__(self, kernel_size, ring_buffer_size):
self.kernel_size = kernel_size
self.ring_buffer_size = ring_buffer_size

def call(self, inputs):
tranposed = tf.transpose(inputs, perm=[0, 1, 3, 2])
def __call__(self, inputs):
tranposed = tf.keras.ops.transpose(inputs, axes=[0, 1, 3, 2])
channel_avg = tf.keras.layers.AveragePooling2D(
pool_size=(1, tranposed.shape[2]), strides=(1, tranposed.shape[2])
)(tranposed)
channel_max = tf.keras.layers.MaxPooling2D(
pool_size=(1, tranposed.shape[2]), strides=(1, tranposed.shape[2])
)(tranposed)
pooled = tf.keras.layers.Concatenate(axis=-1)([channel_avg, channel_max])

attention = stream.Stream(
cell=tf.keras.layers.Conv2D(
1,
Expand All @@ -275,12 +274,6 @@ def call(self, inputs):

return net * attention

def get_config(self):
return {
"kernel_size": self.kernel_size,
"ring_buffer_size": self.ring_buffer_size,
}


def model(flags, shape, batch_size):
"""MixedNet model.
Expand Down Expand Up @@ -368,7 +361,10 @@ def model(flags, shape, batch_size):

if net.shape[1] > 1:
if flags.spatial_attention:
net = SpatialAttention(4, net.shape[1] - 1)(net)
net = SpatialAttention(
kernel_size=4,
ring_buffer_size=net.shape[1] - 1,
)(net)
else:
net = stream.Stream(
cell=tf.keras.layers.Identity(),
Expand Down

0 comments on commit 56e4f7d

Please sign in to comment.