Using flax.nn.scan with modules that have callable containing kwargs #3764
-
I am revising my flax model implementation as the GPU utilisation statistics seem lower than they should/could be. In doing so, I am replacing loops with scan as suggested in the docs. At the moment I am using scan to stack a bunch of encoder layers, I have a partially successful implementation with the following (please note I am using hydra to instantiate layers, these are for the most part analogous to calling nn.<layer_name> with an appropriate set of args/kwargs): class Encoder1DBlock(nn.Module):
"""Transformer encoder layer."""
layer_norm: DictConfig
dropout: DictConfig
self_attention: DictConfig
mlp_block: DictConfig
@nn.compact
def __call__(self, inputs, train=False, mask=None):
# Attention block.
x = instantiate(self.layer_norm)(inputs)
x = instantiate(self.self_attention)(x, mask, not train)
x = instantiate(self.dropout)(x, not train)
# skip connection
x = x + inputs
# MLP block.
y = instantiate(self.layer_norm)(x)
y = instantiate(self.mlp_block, _recursive_=False)(y, train)
return x + y, None
class StackedEncoder1DBlock(nn.Module):
"""Stacking Transformer encoder layers."""
num_blocks: int
encoder_1d_block: DictConfig
@nn.compact
def __call__(self, x, train=False, mask=None):
attention_scan = nn.scan(
Encoder1DBlock,
variable_axes="params",
variable_broadcast=False,
split_rngs={'params': True},
length=self.num_blocks,
)
x, _ = attention_scan(
layer_norm = self.encoder_1d_block["layer_norm"],
dropout = self.encoder_1d_block["dropout"],
self_attention = self.encoder_1d_block["self_attention"],
mlp_block = self.encoder_1d_block["mlp_block"],
)(x, None)
return x But the above implementation doesn't use kwargs I am still debugging and understanding the best way to accomplish my use case with scan, I wished to post here in case (1) someone more experienced knows the solution off hand (2) to share what I learn if/when I solve this. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
My potentially suboptimal solution, any pointers or discussion if others have encountered the same issue would be appreciated: class Encoder1DBlock(nn.Module):
"""Transformer encoder layer."""
layer_norm: DictConfig
dropout: DictConfig
self_attention: DictConfig
mlp_block: DictConfig
train: Optional[bool] = None
mask: Optional[ArrayLike] = None
@nn.compact
def __call__(self, inputs, mask=None, train=None):
train = nn.merge_param('train', self.train, train)
mask = nn.merge_param('mask', self.mask, mask)
# Attention block.
x = instantiate(self.layer_norm)(inputs)
x = instantiate(self.self_attention)(x, mask, not train)
x = instantiate(self.dropout)(x, not train)
# skip connection
x = x + inputs
# MLP block.
y = instantiate(self.layer_norm)(x)
y = instantiate(self.mlp_block, _recursive_=False)(y, train)
return x + y, None
class AddPositionEmbedding(nn.Module):
"""Adds learned positional embeddings to the inputs.
Attributes:
posemb_init: positional embedding initializer.
"""
posemb_init: Callable
@nn.compact
def __call__(self, inputs):
"""Applies the AddPositionEmbs module.
Args:
inputs: Inputs to the layer.
Returns:
Output tensor with shape `(bs, timesteps, in_dim)`.
"""
# inputs.shape is (batch_size, seq_len, emb_dim).
assert inputs.ndim == 3, (
"Number of dimensions should be 3," " but it is: %d" % inputs.ndim
)
pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
pe = self.param("pos_embedding", self.posemb_init, pos_emb_shape)
return inputs + pe
class StackedEncoder1DBlock(nn.Module):
"""Stacking Transformer encoder layers."""
num_blocks: int
encoder_1d_block: DictConfig
@nn.compact
def __call__(self, x, train=False, mask=None):
# apply learnt position embedding
x = AddPositionEmbedding(
posemb_init=nn.initializers.normal(stddev=0.02),
name="posembed_input",
)(x)
# Use scan to iterate over Encoder1DBlock layers
attention_stack = nn.scan(
Encoder1DBlock,
variable_axes={'params': 0},
variable_broadcast=False,
split_rngs={'params': True, 'dropout': True},
length=self.num_blocks,
)
x, _ = attention_stack(layer_norm=self.encoder_1d_block["layer_norm"],
dropout=self.encoder_1d_block["dropout"],
self_attention=self.encoder_1d_block["self_attention"],
mlp_block=self.encoder_1d_block["mlp_block"],
train=train,
mask=mask,
)(x, None)
return x |
Beta Was this translation helpful? Give feedback.
My potentially suboptimal solution, any pointers or discussion if others have encountered the same issue would be appreciated: