Skip to content

Commit

Permalink
decode test p1
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 27, 2023
1 parent d760141 commit 320e0ee
Showing 1 changed file with 139 additions and 1 deletion.
140 changes: 139 additions & 1 deletion flax/experimental/nnx/examples/lm1b/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,92 @@ def copy_var(nnx_name, linen_name):
copy_var('decoder/logitdense/kernel', 'decoder/logitdense/kernel')
copy_var('decoder/logitdense/bias', 'decoder/logitdense/bias')

def test_forward(self):
def transfer_cache(
self,
config: TransformerConfig,
cache_nnx: nnx.State,
cache_linen: dict[str, Any],
):
rules = dataclasses.asdict(config.rules)
flat_cache_nnx = cache_nnx.flat_state()
flat_cache_linen = traverse_util.flatten_dict(cache_linen, sep='/')

def apply_rules(names: tuple[str, ...]):
return tuple(rules[name] for name in names)

def copy_var(nnx_name, linen_name):
assert (
flat_cache_nnx[nnx_name].value.shape
== flat_cache_linen[linen_name].shape
)
flat_cache_nnx[nnx_name].value = flat_cache_linen[linen_name]

# cache nnx
# {
# 'decoder/encoderdecoderblock_0/attention/cache_index': Cache(value=()),
# 'decoder/encoderdecoderblock_0/attention/cached_key': Cache(value=(2, 2048, 8, 64)),
# 'decoder/encoderdecoderblock_0/attention/cached_value': Cache(value=(2, 2048, 8, 64)),
# 'decoder/encoderdecoderblock_1/attention/cache_index': Cache(value=()),
# 'decoder/encoderdecoderblock_1/attention/cached_key': Cache(value=(2, 2048, 8, 64)),
# 'decoder/encoderdecoderblock_1/attention/cached_value': Cache(value=(2, 2048, 8, 64)),
# 'decoder/encoderdecoderblock_2/attention/cache_index': Cache(value=()),
# 'decoder/encoderdecoderblock_2/attention/cached_key': Cache(value=(2, 2048, 8, 64)),
# 'decoder/encoderdecoderblock_2/attention/cached_value': Cache(value=(2, 2048, 8, 64)),
# 'decoder/encoderdecoderblock_3/attention/cache_index': Cache(value=()),
# 'decoder/encoderdecoderblock_3/attention/cached_key': Cache(value=(2, 2048, 8, 64)),
# 'decoder/encoderdecoderblock_3/attention/cached_value': Cache(value=(2, 2048, 8, 64)),
# 'decoder/encoderdecoderblock_4/attention/cache_index': Cache(value=()),
# 'decoder/encoderdecoderblock_4/attention/cached_key': Cache(value=(2, 2048, 8, 64)),
# 'decoder/encoderdecoderblock_4/attention/cached_value': Cache(value=(2, 2048, 8, 64)),
# 'decoder/encoderdecoderblock_5/attention/cache_index': Cache(value=()),
# 'decoder/encoderdecoderblock_5/attention/cached_key': Cache(value=(2, 2048, 8, 64)),
# 'decoder/encoderdecoderblock_5/attention/cached_value': Cache(value=(2, 2048, 8, 64)),
# 'decoder/posembed_output/cache_index': Cache(value=())
# }

# cache linen
# {
# 'decoder/encoderdecoderblock_0/MultiHeadDotProductAttention_0/cache_index': (),
# 'decoder/encoderdecoderblock_0/MultiHeadDotProductAttention_0/cached_key': (1, 3, 8, 64),
# 'decoder/encoderdecoderblock_0/MultiHeadDotProductAttention_0/cached_value': (1, 3, 8, 64),
# 'decoder/encoderdecoderblock_1/MultiHeadDotProductAttention_0/cache_index': (),
# 'decoder/encoderdecoderblock_1/MultiHeadDotProductAttention_0/cached_key': (1, 3, 8, 64),
# 'decoder/encoderdecoderblock_1/MultiHeadDotProductAttention_0/cached_value': (1, 3, 8, 64),
# 'decoder/encoderdecoderblock_2/MultiHeadDotProductAttention_0/cache_index': (),
# 'decoder/encoderdecoderblock_2/MultiHeadDotProductAttention_0/cached_key': (1, 3, 8, 64),
# 'decoder/encoderdecoderblock_2/MultiHeadDotProductAttention_0/cached_value': (1, 3, 8, 64),
# 'decoder/encoderdecoderblock_3/MultiHeadDotProductAttention_0/cache_index': (),
# 'decoder/encoderdecoderblock_3/MultiHeadDotProductAttention_0/cached_key': (1, 3, 8, 64),
# 'decoder/encoderdecoderblock_3/MultiHeadDotProductAttention_0/cached_value': (1, 3, 8, 64),
# 'decoder/encoderdecoderblock_4/MultiHeadDotProductAttention_0/cache_index': (),
# 'decoder/encoderdecoderblock_4/MultiHeadDotProductAttention_0/cached_key': (1, 3, 8, 64),
# 'decoder/encoderdecoderblock_4/MultiHeadDotProductAttention_0/cached_value': (1, 3, 8, 64),
# 'decoder/encoderdecoderblock_5/MultiHeadDotProductAttention_0/cache_index': (),
# 'decoder/encoderdecoderblock_5/MultiHeadDotProductAttention_0/cached_key': (1, 3, 8, 64),
# 'decoder/encoderdecoderblock_5/MultiHeadDotProductAttention_0/cached_value': (1, 3, 8, 64),
# 'decoder/posembed_output/cache_index': ()
# }

for idx in range(config.num_layers):
copy_var(
f'decoder/encoderdecoderblock_{idx}/attention/cache_index',
f'decoder/encoderdecoderblock_{idx}/MultiHeadDotProductAttention_0/cache_index',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/attention/cached_key',
f'decoder/encoderdecoderblock_{idx}/MultiHeadDotProductAttention_0/cached_key',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/attention/cached_value',
f'decoder/encoderdecoderblock_{idx}/MultiHeadDotProductAttention_0/cached_value',
)

copy_var(
'decoder/posembed_output/cache_index',
'decoder/posembed_output/cache_index',
)

def test_forward_eval(self):
config = CompatTransformerConfig(
vocab_size=20,
output_vocab_size=20,
Expand Down Expand Up @@ -167,6 +252,59 @@ def test_forward(self):

assert jnp.allclose(output_nnx, output_linen, atol=1e-5)

def test_forward_decode(self):
batch_size = 2

config = CompatTransformerConfig(
vocab_size=20,
output_vocab_size=20,
max_len=3,
emb_dim=16,
qkv_dim=16,
num_heads=2,
deterministic=True,
decode=True,
rules=MeshRules(
embed='model',
mlp='data',
kv=None,
vocab=None,
),
)

model_nnx = TransformerLM.create_abstract(config, rngs=nnx.Rngs(0))
for _path, m in model_nnx.iter_submodules():
if isinstance(m, nnx.HasCacheInitializer):
input_shape = (batch_size, config.max_len, config.emb_dim)
m.init_cache(input_shape, dtype=config.dtype)

params_nnx, cache_nnx, _ = model_nnx.split(nnx.Param, nnx.Cache)

model_linen = TransformerLinen(config)

sample_inputs = random.randint(
random.PRNGKey(0), (batch_size, config.max_len), 0, config.vocab_size
)
variables = model_linen.init(random.key(0), sample_inputs)
params_linen = variables['params']
cache_linen = variables['cache']

self.transfer_params(config, params_nnx, params_linen)
self.transfer_cache(config, cache_nnx, cache_linen)
model_nnx.update(params_nnx, cache_nnx)

with nnx.flags(deterministic=True, decode=True):
output_nnx = model_nnx(sample_inputs)

output_linen: jax.Array
output_linen, updates = model_linen.apply(
{'params': params_linen, 'cache': cache_linen},
sample_inputs,
mutable=['cache'],
)

assert jnp.allclose(output_nnx, output_linen, atol=1e-5)


if __name__ == '__main__':
absltest.main()

0 comments on commit 320e0ee

Please sign in to comment.