Skip to content

Commit

Permalink
Update einsum layer for Gemma example
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718899289
  • Loading branch information
Flax Team authored and RaghuSpaceRajan committed Jan 24, 2025
1 parent dfa0fe8 commit ac91584
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
11 changes: 6 additions & 5 deletions examples/gemma/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from __future__ import annotations

from typing import Any, Union
from collections.abc import Sequence
from typing import Any, Union

from flax import nnx
import flax.linen as nn
Expand All @@ -31,11 +31,12 @@
class Einsum(nnx.Module):
"""Einsum is a convenience module for parameterized tensor multiplication."""

def __init__(self, shape: Shape, *, rngs: nnx.Rngs):
def __init__(self, einsum_str: str, shape: Shape, *, rngs: nnx.Rngs):
self.einsum_str = einsum_str
self.w = nnx.Param(nn.initializers.normal()(rngs.params(), shape))

def __call__(self, eqn: str, x: ArrayLike) -> Array:
return jnp.einsum(eqn, x, self.w.value)
def __call__(self, x: ArrayLike) -> Array:
return jnp.einsum(self.einsum_str, x, self.w.value)

@property
def shape(self) -> Shape:
Expand All @@ -48,7 +49,7 @@ class RMSNorm(nnx.Module):
def __init__(self, dim: int, *, rngs: nnx.Rngs):
self.scale = nnx.Param(nn.initializers.zeros_init()(rngs.params(), dim))

def __call__(self, x):
def __call__(self, x: Array) -> Array:
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06)))
# normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is
Expand Down
5 changes: 2 additions & 3 deletions examples/gemma/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ class EinsumTest(parameterized.TestCase):
),
)
def test_einsum(self, inputs_shape, params_shape, eqn, expected_shape):
einsum = layers.Einsum(params_shape, rngs=nnx.Rngs(params=0))
einsum = layers.Einsum(eqn, params_shape, rngs=nnx.Rngs(params=0))
output = einsum(
eqn,
jnp.ones(inputs_shape),
)
self.assertEqual(output.shape, expected_shape)
Expand All @@ -54,7 +53,7 @@ def test_einsum(self, inputs_shape, params_shape, eqn, expected_shape):
),
)
def test_shape(self, shape):
einsum = layers.Einsum(shape, rngs=nnx.Rngs(params=0))
einsum = layers.Einsum('ij->ji', shape, rngs=nnx.Rngs(params=0))
self.assertEqual(einsum.shape, shape)


Expand Down
16 changes: 10 additions & 6 deletions examples/gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

from __future__ import annotations

from collections.abc import Sequence
import enum
from typing import Any, Union
from collections.abc import Sequence

from flax import nnx
import flax.linen as nn
import layers
import positional_embeddings
import flax.linen as nn
import jax
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike # pylint: disable=g-importing-member,g-multiple-import
Expand Down Expand Up @@ -94,21 +94,25 @@ def __init__(
self.sliding_window_size = sliding_window_size
self.attn_logits_soft_cap = attn_logits_soft_cap
self.attn_vec_einsum = layers.Einsum(
einsum_str='BTNH,NHD->BTD',
shape=(num_heads, head_dim, features),
rngs=rngs,
)

if num_heads == num_kv_heads:
self.qkv_einsum = layers.Einsum(
einsum_str='BTD,SNDH->SBTNH',
shape=(3, num_heads, features, head_dim),
rngs=rngs,
)
else:
self.q_einsum = layers.Einsum(
einsum_str='BTD,NDH->BTNH',
shape=(num_heads, features, head_dim),
rngs=rngs,
)
self.kv_einsum = layers.Einsum(
einsum_str='BSD,CKDH->CBSKH',
shape=(2, num_kv_heads, features, head_dim),
rngs=rngs,
)
Expand All @@ -123,10 +127,10 @@ def __call__(
seq_len = x.shape[1]

if self.use_qkv_einsum:
query_proj, key_proj, value_proj = self.qkv_einsum('BTD,SNDH->SBTNH', x)
query_proj, key_proj, value_proj = self.qkv_einsum(x)
else:
query_proj = self.q_einsum('BTD,NDH->BTNH', x)
key_proj, value_proj = self.kv_einsum('BSD,CKDH->CBSKH', x)
query_proj = self.q_einsum(x)
key_proj, value_proj = self.kv_einsum(x)

query_proj = positional_embeddings.apply_rope(
query_proj,
Expand Down Expand Up @@ -173,7 +177,7 @@ def __call__(
padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK)
probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype)
encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
attn_output = self.attn_vec_einsum(encoded)

if cache is not None:
new_cache = {
Expand Down

0 comments on commit ac91584

Please sign in to comment.