Skip to content

Commit d1d7d64

Browse files
liudangyicopybara-github
authored andcommitted
Implement quantized jax.lax.ragged_dot.
This change introduces `qwix._src.core.ragged_dot`, which provides a quantized version of `jax.lax.ragged_dot`. It includes a "fast" path that performs the dot product on quantized values and applies scales later, and a "slow" path that dequantizes before calling the standard `ragged_dot`. The choice between fast and slow paths depends on whether the quantization involves zero points or specific channelwise scale shapes. This is a very basic version, notable limitations are * Only the default dimension_numbers are supported (thus it's ragged_dot rather than ragged_dot_general). * Tiling is not supported. We will address those later. Move the logic about dtype handling into a new function `qarray.get_accumulator_and_result_type` which is shared by multiple ops. PiperOrigin-RevId: 811488890
1 parent c6f5032 commit d1d7d64

File tree

7 files changed

+271
-36
lines changed

7 files changed

+271
-36
lines changed

qwix/_src/core/conv_general.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def conv_general_dilated(
177177
dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None,
178178
feature_group_count: int = 1,
179179
batch_group_count: int = 1,
180+
# TODO(dangyi): Add preferred_element_type.
180181
) -> jax.Array:
181182
"""Dispatches to fast or slow conv_general_dilated depending on the inputs."""
182183
if isinstance(lhs, qarray.QArray) and isinstance(rhs, qarray.QArray):

qwix/_src/core/dot_general.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -203,20 +203,9 @@ def _fast_dot_general(
203203
rhs_scale = qarray.split_axis(rhs_scale, {a: 1 for a in rhs_tiled_ca})
204204
rhs_scale = qarray.transpose_array(rhs_scale, rhs_scale_transpose)
205205

206-
if preferred_element_type is None:
207-
# We want to override the preferred_element_type to int32 for int8 x int8
208-
# dot_general, or bfloat16/float32 for fp8 x fp8 dot_general.
209-
if all('int' in x.dtype.name for x in (lhs_value, rhs_value)):
210-
preferred_element_type = jnp.int32
211-
elif lhs_scale is not None:
212-
preferred_element_type = lhs_scale.dtype
213-
elif rhs_scale is not None:
214-
preferred_element_type = rhs_scale.dtype
215-
else:
216-
if lhs_scale is not None:
217-
lhs_scale = lhs_scale.astype(preferred_element_type)
218-
if rhs_scale is not None:
219-
rhs_scale = rhs_scale.astype(preferred_element_type)
206+
preferred_element_type, result_type = qarray.get_accumulator_and_result_type(
207+
lhs, rhs, preferred_element_type=preferred_element_type
208+
)
220209

221210
res = jax.lax.dot_general(
222211
lhs_value,
@@ -259,7 +248,7 @@ def _fast_dot_general(
259248
res = qarray.call_with_generic_broadcast(jnp.multiply, res, rhs_scale)
260249
if sum_axes:
261250
res = jnp.sum(res, axis=sum_axes)
262-
return res
251+
return res.astype(result_type)
263252

264253

265254
def _slow_dot_general(
@@ -321,15 +310,9 @@ def loop_dot_general(
321310
else:
322311
ca_tile_counts.append(1)
323312

324-
acc_dtype = None
325-
if all('int' in x.dtype.name for x in (lhs_value, rhs_value)):
326-
acc_dtype = jnp.int32
327-
elif preferred_element_type is not None:
328-
acc_dtype = preferred_element_type
329-
elif lhs_scale is not None:
330-
acc_dtype = lhs_scale.dtype
331-
elif rhs_scale is not None:
332-
acc_dtype = rhs_scale.dtype
313+
preferred_element_type, result_type = qarray.get_accumulator_and_result_type(
314+
lhs, rhs, preferred_element_type=preferred_element_type
315+
)
333316

334317
lhs_scale_transpose, rhs_scale_transpose = _get_scale_transpose(
335318
dimension_numbers, (len(lhs_value.shape), len(rhs_value.shape))
@@ -357,7 +340,7 @@ def take_slice(
357340
take_slice(lhs_value, lhs_ca, ca_tile_indices),
358341
take_slice(rhs_value, rhs_ca, ca_tile_indices),
359342
dimension_numbers=dimension_numbers,
360-
preferred_element_type=acc_dtype,
343+
preferred_element_type=preferred_element_type,
361344
**kwargs,
362345
)
363346
if lhs_scale is not None:
@@ -370,9 +353,7 @@ def take_slice(
370353
out = qarray.call_with_generic_broadcast(jnp.multiply, out, scale)
371354
acc = out if acc is None else acc + out
372355
assert acc is not None
373-
if preferred_element_type is not None:
374-
acc = acc.astype(preferred_element_type)
375-
return acc
356+
return acc.astype(result_type)
376357

377358

378359
# If a contracting dimension has a tile size smaller than this threshold, tiled

qwix/_src/core/einsum.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,23 +108,23 @@ def einsum(
108108
Returns:
109109
The result of the einsum, a floating-point jax.Array.
110110
"""
111+
# preferred_element_type has to be set for jnp.einsum so that it won't infer
112+
# the type from qvalue x qvalue.
113+
_, preferred_element_type = qarray.get_accumulator_and_result_type(
114+
*[a for a in args if isinstance(a, qarray.MaybeQArray)],
115+
preferred_element_type=preferred_element_type,
116+
)
117+
111118
# We want to use jnp.einsum with quantized dot_general to avoid duplicating
112119
# the implementation. However, jnp.einsum will check the inputs to be
113120
# jax Arrays. To work around this, we send the qvalue to jnp.einsum and
114121
# restore the actual QArray before actually passing them to dot_general.
115122
args = list(args)
116123
qvalue_to_qarray = {}
117-
118-
# preferred_element_type needs to be set for jnp.einsum so that it won't infer
119-
# the type from qvalue x qvalue.
120-
scale_dtypes = []
121124
for i, arg in enumerate(args):
122125
if isinstance(arg, qarray.QArray):
123126
args[i] = arg.qvalue
124127
qvalue_to_qarray[id(arg.qvalue)] = arg
125-
scale_dtypes.append(arg.scale.dtype)
126-
if preferred_element_type is None and scale_dtypes:
127-
preferred_element_type = jnp.result_type(*scale_dtypes)
128128

129129
def _dot_general(*args, **kwargs):
130130
args = [qvalue_to_qarray.pop(id(a), a) for a in args]

qwix/_src/core/qarray.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,8 @@ def transpose_array(
297297
Returns:
298298
The transposed array.
299299
"""
300+
if any(l > 1 for a, l in enumerate(array.shape) if a not in transpose):
301+
raise ValueError(f'Cannot transpose {array.shape} as {transpose}.')
300302
used_axes = [a for a in transpose if a is not None and array.shape[a] != 1]
301303
# If used_axes is already in order, no actual transpose is needed and we can
302304
# just reshape the array.
@@ -553,3 +555,46 @@ def clip_to_calibration(
553555
else:
554556
raise ValueError(f'Unsupported calibration: {calibration}')
555557
return array.reshape(original_shape)
558+
559+
560+
def get_accumulator_and_result_type(
561+
*args: MaybeQArray,
562+
preferred_element_type: jax.typing.DTypeLike | None,
563+
) -> tuple[jax.typing.DTypeLike, jax.typing.DTypeLike]:
564+
"""jnp.result_type for QArray.
565+
566+
Accumulator type is the dtype used for the dot_general computation.
567+
Result type is the dtype of the final result.
568+
569+
Args:
570+
*args: The arguments to dot_general.
571+
preferred_element_type: The preferred element type for dot_general.
572+
573+
Returns:
574+
A tuple of the accumulator type and the result type.
575+
"""
576+
qvalue_dtypes, dequant_dtypes = [], []
577+
for arg in args:
578+
if isinstance(arg, QArray):
579+
qvalue_dtypes.append(arg.qvalue.dtype) # note qtype can be different.
580+
dequant_dtypes.append(arg.scale.dtype)
581+
else:
582+
qvalue_dtypes.append(arg.dtype)
583+
dequant_dtypes.append(arg.dtype)
584+
585+
# Result type should only depend on dequant_dtype and preferred_element_type.
586+
result_type = preferred_element_type
587+
if result_type is None:
588+
# There's no dtype promotion path for fp8 or lower, and int4 or lower.
589+
# We manually upcast them to bf16 or int32.
590+
for i, t in enumerate(dequant_dtypes):
591+
if t.itemsize <= 1:
592+
dequant_dtypes[i] = jnp.int32 if 'int' in t.name else jnp.bfloat16
593+
result_type = jnp.result_type(*dequant_dtypes)
594+
595+
# Accumulator type should be the same as result type except for int x int.
596+
accumulator_type = result_type
597+
if all('int' in t.name for t in qvalue_dtypes):
598+
accumulator_type = jnp.int32
599+
600+
return accumulator_type, result_type

qwix/_src/core/ragged_dot.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Quantized jax.lax.ragged_dot."""
16+
17+
import jax
18+
from jax import numpy as jnp
19+
from qwix._src.core import qarray
20+
21+
22+
def _fast_ragged_dot(
23+
lhs: qarray.MaybeQArray,
24+
rhs: qarray.MaybeQArray,
25+
group_sizes: jax.Array,
26+
precision: jax.lax.PrecisionLike = None,
27+
preferred_element_type: jax.typing.DTypeLike | None = None,
28+
group_offset: jax.Array | None = None,
29+
) -> jax.Array:
30+
"""Quantized jax.lax.ragged_dot."""
31+
if isinstance(lhs, qarray.QArray):
32+
assert lhs.zero_point is None, 'not supported yet'
33+
lhs_value = lhs.qvalue
34+
lhs_scale = lhs.scale
35+
else:
36+
lhs_value = lhs
37+
lhs_scale = None
38+
if isinstance(rhs, qarray.QArray):
39+
assert rhs.zero_point is None, 'not supported yet'
40+
rhs_value = rhs.qvalue
41+
rhs_scale = rhs.scale
42+
else:
43+
rhs_value = rhs
44+
rhs_scale = None
45+
46+
preferred_element_type, result_type = qarray.get_accumulator_and_result_type(
47+
lhs, rhs, preferred_element_type=preferred_element_type
48+
)
49+
50+
out = jax.lax.ragged_dot(
51+
lhs_value,
52+
rhs_value,
53+
group_sizes,
54+
precision=precision,
55+
preferred_element_type=preferred_element_type,
56+
group_offset=group_offset,
57+
)
58+
59+
# ragged_dot has fixed dimension numbers which makes implementation a lot
60+
# easier, i.e., lhs: [m, k], rhs: [g, k, n], res: [m, n].
61+
# TODO(dangyi): support arbitrary dimension numbers.
62+
if lhs_scale is not None: # [m, 1]
63+
lhs_scale = qarray.transpose_array(lhs_scale, (0, None))
64+
out = qarray.call_with_generic_broadcast(jnp.multiply, out, lhs_scale)
65+
if rhs_scale is not None: # [1, 1, n] or [g, 1, n]
66+
if rhs_scale.shape[0] == 1:
67+
# It's possible to apply the scale to the out directly.
68+
rhs_scale = qarray.transpose_array(rhs_scale, (None, 2))
69+
else:
70+
# We need another ragged_dot to apply the scale to the out.
71+
rhs_scale = jax.lax.ragged_dot(
72+
jnp.ones((out.shape[0], 1), rhs_scale.dtype),
73+
rhs_scale,
74+
group_sizes,
75+
group_offset=group_offset,
76+
)
77+
out = qarray.call_with_generic_broadcast(jnp.multiply, out, rhs_scale)
78+
79+
return out.astype(result_type)
80+
81+
82+
def _slow_ragged_dot(
83+
lhs: qarray.MaybeQArray,
84+
rhs: qarray.MaybeQArray,
85+
group_sizes: jax.Array,
86+
**kwargs,
87+
) -> jax.Array:
88+
"""Quantized jax.lax.ragged_dot which dequantizes first."""
89+
if isinstance(lhs, qarray.QArray):
90+
lhs = qarray.dequantize(lhs)
91+
if isinstance(rhs, qarray.QArray):
92+
rhs = qarray.dequantize(rhs)
93+
return jax.lax.ragged_dot(lhs, rhs, group_sizes, **kwargs)
94+
95+
96+
def ragged_dot(
97+
lhs: qarray.MaybeQArray,
98+
rhs: qarray.MaybeQArray,
99+
group_sizes: jax.Array,
100+
precision: jax.lax.PrecisionLike = None,
101+
preferred_element_type: jax.typing.DTypeLike | None = None,
102+
group_offset: jax.Array | None = None,
103+
) -> jax.Array:
104+
"""Quantized jax.lax.ragged_dot."""
105+
use_fast_ragged_dot = True
106+
107+
# fast_ragged_dot does't support channelwise scales on group axis, or tiled
108+
# scales on contracting axes, or zero_point.
109+
if isinstance(lhs, qarray.QArray): # [m, k]
110+
if lhs.zero_point is not None or lhs.scale.shape[1] > 1:
111+
use_fast_ragged_dot = False
112+
if isinstance(rhs, qarray.QArray): # [g, k, n]
113+
if (
114+
rhs.zero_point is not None
115+
or rhs.scale.shape[0] > 1
116+
or rhs.scale.shape[1] > 1
117+
):
118+
use_fast_ragged_dot = False
119+
120+
if use_fast_ragged_dot:
121+
return _fast_ragged_dot(
122+
lhs,
123+
rhs,
124+
group_sizes,
125+
precision=precision,
126+
preferred_element_type=preferred_element_type,
127+
group_offset=group_offset,
128+
)
129+
else:
130+
return _slow_ragged_dot(
131+
lhs,
132+
rhs,
133+
group_sizes,
134+
precision=precision,
135+
preferred_element_type=preferred_element_type,
136+
group_offset=group_offset,
137+
)

tests/core/einsum_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _einsum(lhs, rhs):
200200
lhs_shape=(10, 256, 16),
201201
rhs_shape=(256, 16, 128),
202202
lhs_asymmetric=True,
203-
expected_rel_mae=0.0130005,
203+
expected_rel_mae=0.0129395,
204204
),
205205
dict(
206206
testcase_name='lhs_asymmetric_subchannel',

tests/core/ragged_dot_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import absltest
16+
from absl.testing import parameterized
17+
import jax
18+
from jax import numpy as jnp
19+
from qwix._src.core import qarray
20+
from qwix._src.core import ragged_dot
21+
22+
23+
def mae(a, b):
24+
assert a.dtype == b.dtype and a.shape == b.shape
25+
return jnp.abs(a - b).mean() / jnp.abs(a).mean()
26+
27+
28+
class RaggedDotTest(parameterized.TestCase):
29+
30+
@parameterized.named_parameters(
31+
dict(
32+
testcase_name='no_channelwise',
33+
lhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[]),
34+
rhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[]),
35+
),
36+
dict(
37+
testcase_name='channelwise',
38+
lhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[0]),
39+
rhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[2]),
40+
),
41+
dict(
42+
testcase_name='more_channelwise',
43+
lhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[0]),
44+
rhs_how=qarray.HowToQuantize(qtype=jnp.int8, channelwise_axes=[0, 2]),
45+
),
46+
)
47+
def test_ragged_dot(
48+
self,
49+
lhs_how,
50+
rhs_how,
51+
disable_fast_path=False,
52+
):
53+
lhs = jax.random.normal(jax.random.key(0), (256, 16), jnp.bfloat16)
54+
rhs = jax.random.normal(jax.random.key(1), (10, 16, 64), jnp.bfloat16)
55+
group_sizes = jnp.array([10, 20, 30, 40, 0, 115, 6, 7, 1, 27], jnp.int32)
56+
57+
fp_res = jax.lax.ragged_dot(lhs, rhs, group_sizes)
58+
59+
qlhs = qarray.quantize(lhs, lhs_how)
60+
qrhs = qarray.quantize(rhs, rhs_how)
61+
62+
slow_res = ragged_dot._slow_ragged_dot(qlhs, qrhs, group_sizes)
63+
self.assertLess(mae(slow_res, fp_res), 0.02)
64+
65+
if not disable_fast_path:
66+
fast_res = ragged_dot._fast_ragged_dot(qlhs, qrhs, group_sizes)
67+
self.assertLess(mae(fast_res, slow_res), 0.005)
68+
69+
70+
if __name__ == '__main__':
71+
absltest.main()

0 commit comments

Comments
 (0)