Skip to content

Commit

Permalink
fix train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 23, 2023
1 parent abe66c6 commit d995657
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion flax/experimental/nnx/examples/lm1b/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def predict_step(
for _path, m in module.iter_submodules():
if isinstance(m, nnx.HasCacheInitializer):
input_shape = (inputs.shape[0], max_decode_len, config.emb_dim)
m.init_cache(input_shape)
m.init_cache(input_shape, dtype=config.dtype)

cache = module.extract(nnx.Cache)

Expand Down
2 changes: 1 addition & 1 deletion flax/experimental/nnx/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def __call__(

def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32):
"""Initializes cache for fast autoregressive decoding."""
cache_shape = (*input_shape[:-1], self.num_heads, self.features_out)
cache_shape = (*input_shape[:-1], self.num_heads, self.head_dim)
self.cached_key = nnx.Cache(jnp.zeros(cache_shape, dtype))
self.cached_value = nnx.Cache(jnp.zeros(cache_shape, dtype))
self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32))
Expand Down

0 comments on commit d995657

Please sign in to comment.