Skip to content

Commit e99e27f

Browse files
Qwix Developerscopybara-github
authored andcommitted
adds ragged_dot_tpu_test.py
PiperOrigin-RevId: 811867181
1 parent b3143b3 commit e99e27f

File tree

1 file changed

+149
-0
lines changed

1 file changed

+149
-0
lines changed

tests/core/ragged_dot_tpu_test.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
17+
from absl.testing import absltest
18+
from absl.testing import parameterized
19+
import jax
20+
from jax import numpy as jnp
21+
from qwix._src.core import qarray
22+
from qwix._src.core import ragged_dot
23+
24+
25+
def rel_mae(x, y):
26+
assert x.dtype == y.dtype and x.shape == y.shape
27+
return jnp.abs(x - y).mean() / jnp.abs(x).mean()
28+
29+
30+
class RaggedDotTpuTest(parameterized.TestCase):
31+
"""More expensive TPU tests for ragged_dot, mainly on numerics."""
32+
33+
def setUp(self):
34+
super().setUp()
35+
self._random_key = jax.random.key(42)
36+
37+
def _make_array(self, shape, asymmetric=False):
38+
self._random_key, key = jax.random.split(self._random_key)
39+
if asymmetric:
40+
return jax.random.uniform(key, shape, jnp.float32)
41+
return jax.random.normal(key, shape, jnp.float32)
42+
43+
@parameterized.named_parameters(
44+
dict(
45+
testcase_name='int8',
46+
lhs_shape=(128, 256),
47+
lhs_qtype=jnp.int8,
48+
rhs_shape=(4, 256, 64),
49+
rhs_qtype=jnp.int8,
50+
group_sizes=(64, 32, 16, 16),
51+
expected_mae=0.03,
52+
),
53+
dict(
54+
testcase_name='lhs_asymmetric',
55+
lhs_shape=(128, 256),
56+
lhs_qtype=jnp.int8,
57+
lhs_asymmetric=True,
58+
rhs_shape=(4, 256, 64),
59+
rhs_qtype=jnp.int8,
60+
group_sizes=(50, 50, 28, 0),
61+
expected_mae=0.07,
62+
disable_fast_ragged_dot=True,
63+
),
64+
dict(
65+
testcase_name='rhs_group_channelwise',
66+
lhs_shape=(128, 256),
67+
lhs_qtype=jnp.int8,
68+
rhs_shape=(4, 256, 64),
69+
rhs_qtype=jnp.int8,
70+
rhs_channelwise_axes=(0,),
71+
group_sizes=(128, 0, 0, 0),
72+
expected_mae=0.03,
73+
disable_fast_ragged_dot=True,
74+
),
75+
dict(
76+
testcase_name='rhs_contracting_tiled',
77+
lhs_shape=(128, 256),
78+
lhs_qtype=jnp.int8,
79+
rhs_shape=(4, 256, 64),
80+
rhs_qtype=jnp.int8,
81+
rhs_tiled_axes={1: 128},
82+
group_sizes=(10, 20, 30, 68),
83+
expected_mae=0.03,
84+
disable_fast_ragged_dot=True,
85+
),
86+
)
87+
def test_ragged_dot(
88+
self,
89+
*,
90+
lhs_shape: tuple[int, ...],
91+
lhs_qtype: jax.typing.DTypeLike | None,
92+
lhs_asymmetric: bool = False,
93+
rhs_shape: tuple[int, ...],
94+
rhs_qtype: jax.typing.DTypeLike | None,
95+
rhs_channelwise_axes: tuple[int, ...] = (),
96+
rhs_tiled_axes: dict[int, int] | None = None,
97+
group_sizes: tuple[int, ...],
98+
expected_mae: float,
99+
disable_fast_ragged_dot: bool = False,
100+
):
101+
lhs = self._make_array(lhs_shape, lhs_asymmetric)
102+
rhs = self._make_array(rhs_shape, False)
103+
rhs_tiled_axes = rhs_tiled_axes or {}
104+
group_sizes = jnp.array(group_sizes)
105+
106+
if lhs_qtype:
107+
lhs_how = qarray.HowToQuantize(
108+
qtype=lhs_qtype,
109+
channelwise_axes=(),
110+
tiled_axes={},
111+
calibration_method='minmax' if lhs_asymmetric else 'absmax',
112+
)
113+
q_lhs = qarray.quantize(lhs, lhs_how)
114+
else:
115+
q_lhs = lhs
116+
117+
if rhs_qtype:
118+
rhs_how = qarray.HowToQuantize(
119+
qtype=rhs_qtype,
120+
channelwise_axes=rhs_channelwise_axes,
121+
tiled_axes=rhs_tiled_axes,
122+
calibration_method='absmax',
123+
)
124+
q_rhs = qarray.quantize(rhs, rhs_how)
125+
else:
126+
q_rhs = rhs
127+
128+
@jax.jit
129+
def _multi_ragged_dot(lhs, rhs, fp_res):
130+
slow_res = ragged_dot._slow_ragged_dot(lhs, rhs, group_sizes)
131+
if disable_fast_ragged_dot:
132+
fast_res = slow_res
133+
else:
134+
fast_res = ragged_dot._fast_ragged_dot(lhs, rhs, group_sizes)
135+
return (
136+
rel_mae(slow_res, fp_res),
137+
rel_mae(slow_res, fast_res),
138+
)
139+
140+
fp_res = jax.lax.ragged_dot(lhs, rhs, group_sizes)
141+
fp_mae, fast_mae = _multi_ragged_dot(q_lhs, q_rhs, fp_res)
142+
143+
logging.info('fp_mae=%s fast_mae=%s', fp_mae, fast_mae)
144+
self.assertLessEqual(fp_mae, expected_mae)
145+
self.assertLessEqual(fast_mae, 0.003)
146+
147+
148+
if __name__ == '__main__':
149+
absltest.main()

0 commit comments

Comments
 (0)