Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update einsum layer for Gemma example #4498

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions examples/gemma/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from __future__ import annotations

from typing import Any, Union
from collections.abc import Sequence
from typing import Any, Union

from flax import nnx
import flax.linen as nn
Expand All @@ -31,11 +31,12 @@
class Einsum(nnx.Module):
"""Einsum is a convenience module for parameterized tensor multiplication."""

def __init__(self, shape: Shape, *, rngs: nnx.Rngs):
def __init__(self, einsum_str: str, shape: Shape, *, rngs: nnx.Rngs):
self.einsum_str = einsum_str
self.w = nnx.Param(nn.initializers.normal()(rngs.params(), shape))

def __call__(self, eqn: str, x: ArrayLike) -> Array:
return jnp.einsum(eqn, x, self.w.value)
def __call__(self, x: ArrayLike) -> Array:
return jnp.einsum(self.einsum_str, x, self.w.value)

@property
def shape(self) -> Shape:
Expand All @@ -48,7 +49,7 @@ class RMSNorm(nnx.Module):
def __init__(self, dim: int, *, rngs: nnx.Rngs):
self.scale = nnx.Param(nn.initializers.zeros_init()(rngs.params(), dim))

def __call__(self, x):
def __call__(self, x: Array) -> Array:
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06)))
# normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is
Expand Down
5 changes: 2 additions & 3 deletions examples/gemma/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ class EinsumTest(parameterized.TestCase):
),
)
def test_einsum(self, inputs_shape, params_shape, eqn, expected_shape):
einsum = layers.Einsum(params_shape, rngs=nnx.Rngs(params=0))
einsum = layers.Einsum(eqn, params_shape, rngs=nnx.Rngs(params=0))
output = einsum(
eqn,
jnp.ones(inputs_shape),
)
self.assertEqual(output.shape, expected_shape)
Expand All @@ -54,7 +53,7 @@ def test_einsum(self, inputs_shape, params_shape, eqn, expected_shape):
),
)
def test_shape(self, shape):
einsum = layers.Einsum(shape, rngs=nnx.Rngs(params=0))
einsum = layers.Einsum('ij->ji', shape, rngs=nnx.Rngs(params=0))
self.assertEqual(einsum.shape, shape)


Expand Down
16 changes: 10 additions & 6 deletions examples/gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

from __future__ import annotations

from collections.abc import Sequence
import enum
from typing import Any, Union
from collections.abc import Sequence

from flax import nnx
import flax.linen as nn
import layers
import positional_embeddings
import flax.linen as nn
import jax
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike # pylint: disable=g-importing-member,g-multiple-import
Expand Down Expand Up @@ -94,21 +94,25 @@ def __init__(
self.sliding_window_size = sliding_window_size
self.attn_logits_soft_cap = attn_logits_soft_cap
self.attn_vec_einsum = layers.Einsum(
einsum_str='BTNH,NHD->BTD',
shape=(num_heads, head_dim, features),
rngs=rngs,
)

if num_heads == num_kv_heads:
self.qkv_einsum = layers.Einsum(
einsum_str='BTD,SNDH->SBTNH',
shape=(3, num_heads, features, head_dim),
rngs=rngs,
)
else:
self.q_einsum = layers.Einsum(
einsum_str='BTD,NDH->BTNH',
shape=(num_heads, features, head_dim),
rngs=rngs,
)
self.kv_einsum = layers.Einsum(
einsum_str='BSD,CKDH->CBSKH',
shape=(2, num_kv_heads, features, head_dim),
rngs=rngs,
)
Expand All @@ -123,10 +127,10 @@ def __call__(
seq_len = x.shape[1]

if self.use_qkv_einsum:
query_proj, key_proj, value_proj = self.qkv_einsum('BTD,SNDH->SBTNH', x)
query_proj, key_proj, value_proj = self.qkv_einsum(x)
else:
query_proj = self.q_einsum('BTD,NDH->BTNH', x)
key_proj, value_proj = self.kv_einsum('BSD,CKDH->CBSKH', x)
query_proj = self.q_einsum(x)
key_proj, value_proj = self.kv_einsum(x)

query_proj = positional_embeddings.apply_rope(
query_proj,
Expand Down Expand Up @@ -173,7 +177,7 @@ def __call__(
padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK)
probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype)
encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
attn_output = self.attn_vec_einsum(encoded)

if cache is not None:
new_cache = {
Expand Down
Loading