How to extract intermediate values from "underneath" vmap
?
#1934
-
I have a model (a simplified version below) which uses I believe a way forward would be to fork from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn
class CrossAttention(nn.Module):
nhead: int
dropout: float = 0.1
norm_fn: nn.Module = nn.LayerNorm
@partial(
nn.vmap,
in_axes=(0, 0, None),
out_axes=0,
variable_axes={'params': None},
split_rngs={'params': False, 'dropout': True},
)
@nn.compact
def __call__(
self,
queries: 'Q C',
memory: 'M C',
deterministic: bool,
) -> 'Q C':
return nn.MultiHeadDotProductAttention(
self.nhead,
dropout_rate=self.dropout,
)(
self.norm_fn(name='mha-norm')(queries),
memory,
deterministic=deterministic,
)
class ParallelDecoder(nn.Module):
n_query: int = 8
depth: int = 4
@nn.compact
def __call__(self, memory: '2 M C') -> '2 Q C':
_2, M, C = memory.shape
queries = self.param('q', nn.initializers.normal(stddev=1.), (2, self.n_query, C))
for _ in range(self.depth):
queries = CrossAttention(nhead=8)(queries, memory, True)
# in a more complete example I would apply an MLP on queries
# to enable mixing between the two items in the leading dimension
return queries
## Example usage
model = ParallelDecoder()
input = np.random.randn(2, 16, 128)
state = model.init(jax.random.PRNGKey(42), input)
output, interms = model.apply(state, input, capture_intermediates=True, mutable=['intermediates']) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You should also include the We should probably document this in our HOWTO. |
Beta Was this translation helpful? Give feedback.
You should also include the
intermediates
variable collection in thevariable_axes
in the liftedvmap
call. So replacevariable_axes={'params': None}
withvariable_axes={'params': None, 'intermediates': 0}
.We should probably document this in our HOWTO.
@jheek