|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import logging |
| 16 | + |
| 17 | +from absl.testing import absltest |
| 18 | +from absl.testing import parameterized |
| 19 | +import jax |
| 20 | +from jax import numpy as jnp |
| 21 | +from qwix._src.core import qarray |
| 22 | +from qwix._src.core import ragged_dot |
| 23 | + |
| 24 | + |
| 25 | +def rel_mae(x, y): |
| 26 | + assert x.dtype == y.dtype and x.shape == y.shape |
| 27 | + return jnp.abs(x - y).mean() / jnp.abs(x).mean() |
| 28 | + |
| 29 | + |
| 30 | +class RaggedDotTpuTest(parameterized.TestCase): |
| 31 | + """More expensive TPU tests for ragged_dot, mainly on numerics.""" |
| 32 | + |
| 33 | + def setUp(self): |
| 34 | + super().setUp() |
| 35 | + self._random_key = jax.random.key(42) |
| 36 | + |
| 37 | + def _make_array(self, shape, asymmetric=False): |
| 38 | + self._random_key, key = jax.random.split(self._random_key) |
| 39 | + if asymmetric: |
| 40 | + return jax.random.uniform(key, shape, jnp.float32) |
| 41 | + return jax.random.normal(key, shape, jnp.float32) |
| 42 | + |
| 43 | + @parameterized.named_parameters( |
| 44 | + dict( |
| 45 | + testcase_name='int8', |
| 46 | + lhs_shape=(128, 256), |
| 47 | + lhs_qtype=jnp.int8, |
| 48 | + rhs_shape=(4, 256, 64), |
| 49 | + rhs_qtype=jnp.int8, |
| 50 | + group_sizes=(64, 32, 16, 16), |
| 51 | + expected_mae=0.03, |
| 52 | + ), |
| 53 | + dict( |
| 54 | + testcase_name='lhs_asymmetric', |
| 55 | + lhs_shape=(128, 256), |
| 56 | + lhs_qtype=jnp.int8, |
| 57 | + lhs_asymmetric=True, |
| 58 | + rhs_shape=(4, 256, 64), |
| 59 | + rhs_qtype=jnp.int8, |
| 60 | + group_sizes=(50, 50, 28, 0), |
| 61 | + expected_mae=0.07, |
| 62 | + disable_fast_ragged_dot=True, |
| 63 | + ), |
| 64 | + dict( |
| 65 | + testcase_name='rhs_group_channelwise', |
| 66 | + lhs_shape=(128, 256), |
| 67 | + lhs_qtype=jnp.int8, |
| 68 | + rhs_shape=(4, 256, 64), |
| 69 | + rhs_qtype=jnp.int8, |
| 70 | + rhs_channelwise_axes=(0,), |
| 71 | + group_sizes=(128, 0, 0, 0), |
| 72 | + expected_mae=0.03, |
| 73 | + disable_fast_ragged_dot=True, |
| 74 | + ), |
| 75 | + dict( |
| 76 | + testcase_name='rhs_contracting_tiled', |
| 77 | + lhs_shape=(128, 256), |
| 78 | + lhs_qtype=jnp.int8, |
| 79 | + rhs_shape=(4, 256, 64), |
| 80 | + rhs_qtype=jnp.int8, |
| 81 | + rhs_tiled_axes={1: 128}, |
| 82 | + group_sizes=(10, 20, 30, 68), |
| 83 | + expected_mae=0.03, |
| 84 | + disable_fast_ragged_dot=True, |
| 85 | + ), |
| 86 | + ) |
| 87 | + def test_ragged_dot( |
| 88 | + self, |
| 89 | + *, |
| 90 | + lhs_shape: tuple[int, ...], |
| 91 | + lhs_qtype: jax.typing.DTypeLike | None, |
| 92 | + lhs_asymmetric: bool = False, |
| 93 | + rhs_shape: tuple[int, ...], |
| 94 | + rhs_qtype: jax.typing.DTypeLike | None, |
| 95 | + rhs_channelwise_axes: tuple[int, ...] = (), |
| 96 | + rhs_tiled_axes: dict[int, int] | None = None, |
| 97 | + group_sizes: tuple[int, ...], |
| 98 | + expected_mae: float, |
| 99 | + disable_fast_ragged_dot: bool = False, |
| 100 | + ): |
| 101 | + lhs = self._make_array(lhs_shape, lhs_asymmetric) |
| 102 | + rhs = self._make_array(rhs_shape, False) |
| 103 | + rhs_tiled_axes = rhs_tiled_axes or {} |
| 104 | + group_sizes = jnp.array(group_sizes) |
| 105 | + |
| 106 | + if lhs_qtype: |
| 107 | + lhs_how = qarray.HowToQuantize( |
| 108 | + qtype=lhs_qtype, |
| 109 | + channelwise_axes=(), |
| 110 | + tiled_axes={}, |
| 111 | + calibration_method='minmax' if lhs_asymmetric else 'absmax', |
| 112 | + ) |
| 113 | + q_lhs = qarray.quantize(lhs, lhs_how) |
| 114 | + else: |
| 115 | + q_lhs = lhs |
| 116 | + |
| 117 | + if rhs_qtype: |
| 118 | + rhs_how = qarray.HowToQuantize( |
| 119 | + qtype=rhs_qtype, |
| 120 | + channelwise_axes=rhs_channelwise_axes, |
| 121 | + tiled_axes=rhs_tiled_axes, |
| 122 | + calibration_method='absmax', |
| 123 | + ) |
| 124 | + q_rhs = qarray.quantize(rhs, rhs_how) |
| 125 | + else: |
| 126 | + q_rhs = rhs |
| 127 | + |
| 128 | + @jax.jit |
| 129 | + def _multi_ragged_dot(lhs, rhs, fp_res): |
| 130 | + slow_res = ragged_dot._slow_ragged_dot(lhs, rhs, group_sizes) |
| 131 | + if disable_fast_ragged_dot: |
| 132 | + fast_res = slow_res |
| 133 | + else: |
| 134 | + fast_res = ragged_dot._fast_ragged_dot(lhs, rhs, group_sizes) |
| 135 | + return ( |
| 136 | + rel_mae(slow_res, fp_res), |
| 137 | + rel_mae(slow_res, fast_res), |
| 138 | + ) |
| 139 | + |
| 140 | + fp_res = jax.lax.ragged_dot(lhs, rhs, group_sizes) |
| 141 | + fp_mae, fast_mae = _multi_ragged_dot(q_lhs, q_rhs, fp_res) |
| 142 | + |
| 143 | + logging.info('fp_mae=%s fast_mae=%s', fp_mae, fast_mae) |
| 144 | + self.assertLessEqual(fp_mae, expected_mae) |
| 145 | + self.assertLessEqual(fast_mae, 0.003) |
| 146 | + |
| 147 | + |
| 148 | +if __name__ == '__main__': |
| 149 | + absltest.main() |
0 commit comments