Skip to content

Commit b3143b3

Browse files
nicolagpcopybara-github
authored andcommitted
Support stochastic rounding in Qwix.
PiperOrigin-RevId: 811569429
1 parent d1d7d64 commit b3143b3

File tree

9 files changed

+217
-3
lines changed

9 files changed

+217
-3
lines changed

qwix/_src/core/dot_general_qt.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ class DotGeneralQtConfig:
5353
disable_channelwise_axes: bool = False
5454
bwd_use_original_residuals: bool = False # what to use as residuals
5555

56+
# Configs for stochastic rounding.
57+
dlhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None
58+
drhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None
59+
5660
# Deprecated. No longer used.
5761
dlhs_lhs_qtype: jax.typing.DTypeLike | None = None # incoming gradient
5862
dlhs_rhs_qtype: jax.typing.DTypeLike | None = None # residual rhs
@@ -227,6 +231,17 @@ def _compute_gradient_for_operand(
227231
)
228232
if config.disable_channelwise_axes:
229233
g_how = dataclasses.replace(g_how, channelwise_axes=[])
234+
235+
if for_dlhs and config.dlhs_stochastic_rounding_noise_fn:
236+
g_how = dataclasses.replace(
237+
g_how,
238+
noise_fn=config.dlhs_stochastic_rounding_noise_fn,
239+
)
240+
if not for_dlhs and config.drhs_stochastic_rounding_noise_fn:
241+
g_how = dataclasses.replace(
242+
g_how,
243+
noise_fn=config.drhs_stochastic_rounding_noise_fn,
244+
)
230245
g = qarray.quantize(g, g_how)
231246

232247
grad_res = dot_general.dot_general(g, y, bwd_dnums)

qwix/_src/core/numerics.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@
1313
# limitations under the License.
1414
"""Numerics for quantization."""
1515

16+
from typing import Callable, Sequence
1617
import jax
1718
from jax import numpy as jnp
1819

20+
# A function that generates noise for stochastic rounding.
21+
# args: shape: The shape of the noise to generate.
22+
# returns: An array of noise with the given shape with channelwise noise axes.
23+
NoiseFn = Callable[[Sequence[int]], jax.Array]
24+
1925

2026
def should_quantize(dtype: jax.typing.DTypeLike) -> bool:
2127
"""Returns True if the dtype should be quantized."""
@@ -64,7 +70,11 @@ def get_symmetric_bound(qtype: jax.typing.DTypeLike) -> float:
6470
return jnp.iinfo(qtype).max + 0.5
6571

6672

67-
def convert_to(x: jax.Array, qtype: jax.typing.DTypeLike) -> jax.Array:
73+
def convert_to(
74+
x: jax.Array,
75+
qtype: jax.typing.DTypeLike,
76+
noise_fn: NoiseFn | None = None,
77+
) -> jax.Array:
6878
"""Rounds and converts x to the given qtype."""
6979
match qtype:
7080
case 'nf4':
@@ -84,8 +94,15 @@ def convert_to(x: jax.Array, qtype: jax.typing.DTypeLike) -> jax.Array:
8494
try:
8595
finfo = jnp.finfo(qtype)
8696
except ValueError:
97+
finfo = None
98+
if finfo is None:
8799
# dtype is an integer type. We need to round manually but clipping can
88100
# be handled by "astype".
101+
if noise_fn is not None:
102+
# Stochastic rounding is done in fp32 to avoid bias from bf16, e.g.
103+
# round(bf16(41)-bf16(0.4)) ~= round(40.5) = 40, rather than
104+
# round(41-0.4) = round(40.6) = 41.
105+
x = x.astype(jnp.float32) + noise_fn(x.shape)
89106
return jnp.round(x).astype(qtype)
90107
# dtype is a floating point type. No rounding needed, but we need to
91108
# clip to the range to avoid inf or nan (e.g. for e4m3fn).

qwix/_src/core/qarray.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,8 @@ class HowToQuantize:
254254
# The calibration method to use. The format is <method>[,<args>], e.g.
255255
# "absmax" or "fixed,-10,10". Check calibrate() for supported methods.
256256
calibration_method: str = 'absmax'
257+
# Noise function to use for stochastic rounding.
258+
noise_fn: numerics.NoiseFn | None = None
257259

258260

259261
ShapeT: TypeAlias = Sequence[int]
@@ -476,6 +478,7 @@ def quantize_with_scale_zero_point(
476478
qtype: jax.typing.DTypeLike,
477479
scale: jax.Array,
478480
zero_point: jax.Array | None,
481+
noise_fn: numerics.NoiseFn | None = None,
479482
) -> QArray:
480483
"""Quantizes an array with the given scale and zero_point.
481484
@@ -484,6 +487,8 @@ def quantize_with_scale_zero_point(
484487
qtype: The logical type used for quantization.
485488
scale: The scale to use.
486489
zero_point: The zero_point to use.
490+
noise_fn: The noise function to add to the quantized array for stochastic
491+
rounding.
487492
488493
Returns:
489494
The quantized array.
@@ -504,7 +509,7 @@ def quantize_with_scale_zero_point(
504509
qvalue = call_with_generic_broadcast(
505510
jnp.add, qvalue, zero_point.astype(qvalue.dtype)
506511
)
507-
qvalue = numerics.convert_to(qvalue, qtype)
512+
qvalue = numerics.convert_to(qvalue, qtype, noise_fn)
508513
return QArray(qvalue, scale, zero_point, qtype)
509514

510515

@@ -515,7 +520,9 @@ def quantize(
515520
"""Quantizes an array using a dynamic range."""
516521
calibration = calibrate(array, how)
517522
scale, zero_point = compute_scale_zero_point(calibration, how.qtype)
518-
return quantize_with_scale_zero_point(array, how.qtype, scale, zero_point)
523+
return quantize_with_scale_zero_point(
524+
array, how.qtype, scale, zero_point, how.noise_fn
525+
)
519526

520527

521528
def dequantize(array: QArray) -> jax.Array:
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Stochastic rounding utilities."""
15+
16+
from typing import Sequence
17+
import jax
18+
19+
20+
def uniform_noise(
21+
shape: tuple[int, ...],
22+
*,
23+
key: jax.Array,
24+
channelwise_noise_axes: Sequence[int] = (0,),
25+
) -> jax.Array:
26+
"""Uniform noise."""
27+
28+
# Keep shape dimensions only for channelwise_noise_axes.
29+
noise_shape = tuple(
30+
dim if axis in channelwise_noise_axes else 1
31+
for axis, dim in enumerate(shape)
32+
)
33+
return jax.random.uniform(key, noise_shape) - 0.5

qwix/_src/flax_util.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,3 +327,20 @@ def _check_shape(value: Any, init_fn: Callable[[], Any]):
327327
abs_init_value = jax.eval_shape(lambda: unbox(init_fn()))
328328
if abs_value != abs_init_value:
329329
raise ValueError(f'{abs_value} != {abs_init_value}')
330+
331+
332+
def make_rng(rng_stream: str) -> jax.Array:
333+
"""Returns a random key from rng stream."""
334+
335+
# Get random key.
336+
module = get_current_module()
337+
if isinstance(module, nn.Module):
338+
key = module.make_rng(rng_stream)
339+
elif isinstance(module, nnx.Module):
340+
if rng_stream != 'stochastic_rounding':
341+
raise ValueError(f'Unsupported nnx rng_stream: {rng_stream}')
342+
key = module.rngs.stochastic_rounding()
343+
else:
344+
raise ValueError('Current module is not known.')
345+
346+
return key

qwix/_src/providers/qt.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from qwix._src import qconfig
2828
from qwix._src.core import conv_general_qt
2929
from qwix._src.core import dot_general_qt
30+
from qwix._src.core import stochastic_rounding
3031

3132

3233
@dataclasses.dataclass(frozen=True, kw_only=True)
@@ -52,6 +53,13 @@ class QtRule(qconfig.QuantizationRule):
5253
# residuals for backward pass.
5354
bwd_use_original_residuals: bool = False
5455

56+
# Use stochastic rounding for the gradients. (Only 'uniform' is supported.)
57+
bwd_stochastic_rounding: str | None = None
58+
59+
# Use channelwise noise for stochastic rounding. By default, it will generate
60+
# noise for the 0th dimension and broadcast it over remaining dimensions.
61+
channelwise_noise_axes: Sequence[int] = (0,)
62+
5563
# Override any fields in DotGeneralQtConfig.
5664
additional_qt_config: Mapping[str, Any] | None = None
5765

@@ -385,6 +393,26 @@ def _create_dot_general_qt_config(
385393
if rhs_is_weight:
386394
drhs_tile_size = rule.bwd_weight_grad_tile_size
387395

396+
if rule.bwd_stochastic_rounding == 'uniform':
397+
dlhs_stochastic_rounding_noise_fn = functools.partial(
398+
stochastic_rounding.uniform_noise,
399+
key=flax_util.make_rng('stochastic_rounding'),
400+
channelwise_noise_axes=rule.channelwise_noise_axes,
401+
)
402+
drhs_stochastic_rounding_noise_fn = functools.partial(
403+
stochastic_rounding.uniform_noise,
404+
key=flax_util.make_rng('stochastic_rounding'),
405+
channelwise_noise_axes=rule.channelwise_noise_axes,
406+
)
407+
elif rule.bwd_stochastic_rounding is not None:
408+
raise ValueError(
409+
'Stochastic rounding should be "uniform" or None, got:'
410+
f' {rule.bwd_stochastic_rounding}'
411+
)
412+
else:
413+
dlhs_stochastic_rounding_noise_fn = None
414+
drhs_stochastic_rounding_noise_fn = None
415+
388416
qt_config = dot_general_qt.DotGeneralQtConfig(
389417
# fwd configs.
390418
lhs_qtype=lhs_qtype,
@@ -405,6 +433,8 @@ def _create_dot_general_qt_config(
405433
# misc.
406434
disable_channelwise_axes=rule.disable_channelwise_axes,
407435
bwd_use_original_residuals=rule.bwd_use_original_residuals,
436+
dlhs_stochastic_rounding_noise_fn=dlhs_stochastic_rounding_noise_fn,
437+
drhs_stochastic_rounding_noise_fn=drhs_stochastic_rounding_noise_fn,
408438
)
409439

410440
if rule.additional_qt_config:

tests/core/numerics_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import functools
1516
from absl.testing import absltest
17+
import jax
1618
from jax import numpy as jnp
1719
from qwix._src.core import numerics
20+
from qwix._src.core import stochastic_rounding
1821

1922

2023
class NumericsTest(absltest.TestCase):
@@ -77,6 +80,24 @@ def test_uint(self):
7780
jnp.array([0, 4, 129, 255], jnp.uint8),
7881
)
7982

83+
def test_stochastic_rounding(self):
84+
key = jax.random.PRNGKey(0)
85+
x = jnp.full((10000,), 0.5)
86+
noise_fn = functools.partial(stochastic_rounding.uniform_noise, key=key)
87+
y = numerics.convert_to(x, jnp.int8, noise_fn=noise_fn)
88+
# Without stochastic rounding, this would be rounded to all zeros based on
89+
# round-half-to-even.
90+
self.assertAlmostEqual(jnp.mean(y), 0.5, delta=0.1)
91+
92+
# Test with negative values.
93+
x = jnp.full((10000,), -0.5)
94+
_, subkey = jax.random.split(key)
95+
noise_fn = functools.partial(stochastic_rounding.uniform_noise, key=subkey)
96+
y = numerics.convert_to(x, jnp.int8, noise_fn=noise_fn)
97+
# Without stochastic rounding, this would be rounded to all zeros based on
98+
# round-half-to-even.
99+
self.assertAlmostEqual(jnp.mean(y), -0.5, delta=0.1)
100+
80101
def test_nf4(self):
81102
self._assert_equal(
82103
numerics.convert_to(jnp.array([-1.0, -0.5, 0.0, 0.8, 1.0]), "nf4"),
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import functools
17+
from absl.testing import absltest
18+
from absl.testing import parameterized
19+
import jax
20+
import jax.numpy as jnp
21+
from qwix._src.core import stochastic_rounding
22+
23+
24+
class StochasticRoundingTest(parameterized.TestCase):
25+
26+
def test_uniform_noise(self):
27+
key = jax.random.PRNGKey(0)
28+
shape = (2, 3)
29+
noise_fn = functools.partial(
30+
stochastic_rounding.uniform_noise, key=key, channelwise_noise_axes=(0,)
31+
)
32+
noise = noise_fn(shape)
33+
self.assertEqual(noise.shape, (2, 1))
34+
noise = jnp.broadcast_to(noise, shape)
35+
# Check that the noise is the same along the shared axis.
36+
self.assertTrue(jnp.all(noise[0, 0] == noise[0, 1]))
37+
self.assertTrue(jnp.all(noise[1, 0] == noise[1, 1]))
38+
# Check that the noise is different along the non-shared axis.
39+
self.assertFalse(jnp.all(noise[0, 0] == noise[1, 0]))
40+
41+
42+
if __name__ == "__main__":
43+
absltest.main()

tests/flax_util_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,37 @@ def test_update_boxed(self):
186186
self.assertIsInstance(updated, nnx.Param)
187187
self.assertEqual(updated.sharding_names, ("b", "a", None))
188188

189+
def test_make_rng_linen(self):
190+
class MyModule(nn.Module):
191+
192+
@nn.compact
193+
def __call__(self, x):
194+
key = flax_util.make_rng("stochastic_rounding")
195+
return key
196+
197+
key = jax.random.PRNGKey(0)
198+
module = MyModule()
199+
variables = module.init(
200+
{"params": key, "stochastic_rounding": key}, jnp.ones((1,))
201+
)
202+
rng_key = module.apply(
203+
variables, jnp.ones((1,)), rngs={"stochastic_rounding": key}
204+
)
205+
self.assertEqual(rng_key.shape, (2,))
206+
207+
def test_make_rng_nnx(self):
208+
class MyModule(nnx.Module):
209+
210+
def __init__(self, *, rngs: nnx.Rngs):
211+
self.rngs = rngs
212+
213+
def __call__(self):
214+
return flax_util.make_rng("stochastic_rounding")
215+
216+
module = MyModule(rngs=nnx.Rngs(stochastic_rounding=0))
217+
key = module()
218+
self.assertEqual(key.shape, ())
219+
189220

190221
if __name__ == "__main__":
191222
absltest.main()

0 commit comments

Comments
 (0)