2727from qwix ._src import qconfig
2828from qwix ._src .core import conv_general_qt
2929from qwix ._src .core import dot_general_qt
30+ from qwix ._src .core import stochastic_rounding
3031
3132
3233@dataclasses .dataclass (frozen = True , kw_only = True )
@@ -52,6 +53,13 @@ class QtRule(qconfig.QuantizationRule):
5253 # residuals for backward pass.
5354 bwd_use_original_residuals : bool = False
5455
56+ # Use stochastic rounding for the gradients. (Only 'uniform' is supported.)
57+ bwd_stochastic_rounding : str | None = None
58+
59+ # Use channelwise noise for stochastic rounding. By default, it will generate
60+ # noise for the 0th dimension and broadcast it over remaining dimensions.
61+ channelwise_noise_axes : Sequence [int ] = (0 ,)
62+
5563 # Override any fields in DotGeneralQtConfig.
5664 additional_qt_config : Mapping [str , Any ] | None = None
5765
@@ -385,6 +393,26 @@ def _create_dot_general_qt_config(
385393 if rhs_is_weight :
386394 drhs_tile_size = rule .bwd_weight_grad_tile_size
387395
396+ if rule .bwd_stochastic_rounding == 'uniform' :
397+ dlhs_stochastic_rounding_noise_fn = functools .partial (
398+ stochastic_rounding .uniform_noise ,
399+ key = flax_util .make_rng ('stochastic_rounding' ),
400+ channelwise_noise_axes = rule .channelwise_noise_axes ,
401+ )
402+ drhs_stochastic_rounding_noise_fn = functools .partial (
403+ stochastic_rounding .uniform_noise ,
404+ key = flax_util .make_rng ('stochastic_rounding' ),
405+ channelwise_noise_axes = rule .channelwise_noise_axes ,
406+ )
407+ elif rule .bwd_stochastic_rounding is not None :
408+ raise ValueError (
409+ 'Stochastic rounding should be "uniform" or None, got:'
410+ f' { rule .bwd_stochastic_rounding } '
411+ )
412+ else :
413+ dlhs_stochastic_rounding_noise_fn = None
414+ drhs_stochastic_rounding_noise_fn = None
415+
388416 qt_config = dot_general_qt .DotGeneralQtConfig (
389417 # fwd configs.
390418 lhs_qtype = lhs_qtype ,
@@ -405,6 +433,8 @@ def _create_dot_general_qt_config(
405433 # misc.
406434 disable_channelwise_axes = rule .disable_channelwise_axes ,
407435 bwd_use_original_residuals = rule .bwd_use_original_residuals ,
436+ dlhs_stochastic_rounding_noise_fn = dlhs_stochastic_rounding_noise_fn ,
437+ drhs_stochastic_rounding_noise_fn = drhs_stochastic_rounding_noise_fn ,
408438 )
409439
410440 if rule .additional_qt_config :
0 commit comments