From e165db20a6d04635991afcad63bb70f16384ff9a Mon Sep 17 00:00:00 2001 From: Flax Team Date: Wed, 22 Jan 2025 09:39:04 -0800 Subject: [PATCH] Update einsum layer for Gemma example PiperOrigin-RevId: 718420528 --- examples/gemma/layers.py | 11 ++++++----- examples/gemma/layers_test.py | 5 ++--- examples/gemma/modules.py | 16 ++++++++++------ 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/examples/gemma/layers.py b/examples/gemma/layers.py index 174ec7272c..5dafbd15d0 100644 --- a/examples/gemma/layers.py +++ b/examples/gemma/layers.py @@ -16,8 +16,8 @@ from __future__ import annotations -from typing import Any, Union from collections.abc import Sequence +from typing import Any, Union from flax import nnx import flax.linen as nn @@ -31,11 +31,12 @@ class Einsum(nnx.Module): """Einsum is a convenience module for parameterized tensor multiplication.""" - def __init__(self, shape: Shape, *, rngs: nnx.Rngs): + def __init__(self, einsum_str: str, shape: Shape, *, rngs: nnx.Rngs): + self.einsum_str = einsum_str self.w = nnx.Param(nn.initializers.normal()(rngs.params(), shape)) - def __call__(self, eqn: str, x: ArrayLike) -> Array: - return jnp.einsum(eqn, x, self.w.value) + def __call__(self, x: ArrayLike) -> Array: + return jnp.einsum(self.einsum_str, x, self.w.value) @property def shape(self) -> Shape: @@ -48,7 +49,7 @@ class RMSNorm(nnx.Module): def __init__(self, dim: int, *, rngs: nnx.Rngs): self.scale = nnx.Param(nn.initializers.zeros_init()(rngs.params(), dim)) - def __call__(self, x): + def __call__(self, x: Array) -> Array: var = jnp.mean(jnp.square(x), axis=-1, keepdims=True) normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is diff --git a/examples/gemma/layers_test.py b/examples/gemma/layers_test.py index d71071bdbf..68c8c609b1 100644 --- a/examples/gemma/layers_test.py +++ b/examples/gemma/layers_test.py @@ -38,9 +38,8 @@ class EinsumTest(parameterized.TestCase): ), ) def test_einsum(self, inputs_shape, params_shape, eqn, expected_shape): - einsum = layers.Einsum(params_shape, rngs=nnx.Rngs(params=0)) + einsum = layers.Einsum(eqn, params_shape, rngs=nnx.Rngs(params=0)) output = einsum( - eqn, jnp.ones(inputs_shape), ) self.assertEqual(output.shape, expected_shape) @@ -54,7 +53,7 @@ def test_einsum(self, inputs_shape, params_shape, eqn, expected_shape): ), ) def test_shape(self, shape): - einsum = layers.Einsum(shape, rngs=nnx.Rngs(params=0)) + einsum = layers.Einsum('ij->ji', shape, rngs=nnx.Rngs(params=0)) self.assertEqual(einsum.shape, shape) diff --git a/examples/gemma/modules.py b/examples/gemma/modules.py index d09939491e..38708803c7 100644 --- a/examples/gemma/modules.py +++ b/examples/gemma/modules.py @@ -16,14 +16,14 @@ from __future__ import annotations +from collections.abc import Sequence import enum from typing import Any, Union -from collections.abc import Sequence from flax import nnx -import flax.linen as nn import layers import positional_embeddings +import flax.linen as nn import jax import jax.numpy as jnp from jaxtyping import Array, ArrayLike # pylint: disable=g-importing-member,g-multiple-import @@ -94,21 +94,25 @@ def __init__( self.sliding_window_size = sliding_window_size self.attn_logits_soft_cap = attn_logits_soft_cap self.attn_vec_einsum = layers.Einsum( + einsum_str='BTNH,NHD->BTD', shape=(num_heads, head_dim, features), rngs=rngs, ) if num_heads == num_kv_heads: self.qkv_einsum = layers.Einsum( + einsum_str='BTD,SNDH->SBTNH', shape=(3, num_heads, features, head_dim), rngs=rngs, ) else: self.q_einsum = layers.Einsum( + einsum_str='BTD,NDH->BTNH', shape=(num_heads, features, head_dim), rngs=rngs, ) self.kv_einsum = layers.Einsum( + einsum_str='BSD,CKDH->CBSKH', shape=(2, num_kv_heads, features, head_dim), rngs=rngs, ) @@ -123,10 +127,10 @@ def __call__( seq_len = x.shape[1] if self.use_qkv_einsum: - query_proj, key_proj, value_proj = self.qkv_einsum('BTD,SNDH->SBTNH', x) + query_proj, key_proj, value_proj = self.qkv_einsum(x) else: - query_proj = self.q_einsum('BTD,NDH->BTNH', x) - key_proj, value_proj = self.kv_einsum('BSD,CKDH->CBSKH', x) + query_proj = self.q_einsum(x) + key_proj, value_proj = self.kv_einsum(x) query_proj = positional_embeddings.apply_rope( query_proj, @@ -173,7 +177,7 @@ def __call__( padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK) probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype) encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) - attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded) + attn_output = self.attn_vec_einsum(encoded) if cache is not None: new_cache = {