diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 767ad45a1..b6643191d 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -140,8 +140,10 @@ def dequantize(x, dq_dtype, scale): return x.astype(dq_dtype) * jnp.broadcast_to(scale.astype(dq_dtype), x.shape) -def quantize_dequantize(x, q_dtype, scale, compute_dtype): +def quantize_dequantize(x, q_dtype, scale, compute_dtype, quantize_only=False): qx = quantize(x, q_dtype, scale, compute_dtype) + if quantize_only: + return qx return dequantize(qx, x.dtype, scale) @@ -166,7 +168,8 @@ def compute_amax_history(x, amax_history): def quantize_and_update( - x, q_dtype, scale, amax_history, compute_dtype, use_direct_quant=False + x, q_dtype, scale, amax_history, compute_dtype, use_direct_quant=False, + quantize_only=False, ): is_fmax32 = (scale.dtype == fm32 and amax_history.dtype == fm32) # convert fm32->f32 so we can do math @@ -181,21 +184,19 @@ def quantize_and_update( new_scale = compute_scale(amax_from_history, scale, dtype_max) new_history = compute_amax_history(x, amax_history) + if not use_direct_quant: + qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype) + # convert f32->fmax32 so the autodiff system accumulates fp8 meta correctly if is_fmax32: new_history = lax.convert_element_type(new_history, fp32_max_grad) new_scale = lax.convert_element_type(new_scale, fp32_max_grad) - # Quantize the input if not use_direct_quant: - qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype) return qx, new_scale, new_history - return new_scale, new_history - return qx, new_scale, new_history - def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, swap_ans=False): @@ -284,196 +285,25 @@ def out_qdq_bwd(compute_dtype, q_dtype, res, g): out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) -def q_dot_dq_impl( - lhs, - rhs, - lhs_scale, - rhs_scale, - out_grad_scale, - lhs_amax_history, - rhs_amax_history, - out_grad_amax_history, - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, - is_training -): - new_lhs_scale, new_lhs_amax_history = quantize_and_update( - lhs, - jnp.float8_e4m3fn, - lhs_scale, - lhs_amax_history, - compute_dtype, - use_direct_quant=True - ) - new_rhs_scale, new_rhs_amax_history = quantize_and_update( - rhs, - jnp.float8_e4m3fn, - rhs_scale, - rhs_amax_history, - compute_dtype, - use_direct_quant=True - ) - - q_lhs = quantize(lhs, jnp.float8_e4m3fn, new_lhs_scale, preferred_element_type) - q_rhs = quantize(rhs, jnp.float8_e4m3fn, new_rhs_scale, preferred_element_type) - - out = lax.dot_general( - q_lhs, - q_rhs, - dimension_numbers, - preferred_element_type=preferred_element_type, - precision=lax.Precision.DEFAULT, - ) - - out = dequantize(out, preferred_element_type, new_lhs_scale * new_rhs_scale) - if is_training: - res = ( - lhs, - rhs, - q_lhs, - q_rhs, - new_lhs_scale, - new_rhs_scale, - out_grad_scale, - new_lhs_amax_history, - new_rhs_amax_history, - out_grad_amax_history, +@partial(custom_vjp, nondiff_argnums=(0, 1)) +def in_q(compute_dtype, q_dtype, inp, scale, amax_history): + qin, new_scale, new_history = quantize_and_update( + inp, q_dtype, scale, amax_history, compute_dtype, quantize_only=True ) - return out, res - else: - return out - + return qin, new_scale -@partial(custom_vjp, nondiff_argnums=(8, 9, 10, 11)) -def q_dot_dq( - lhs, - rhs, - lhs_scale, - rhs_scale, - out_grad_scale, - lhs_amax_history, - rhs_amax_history, - out_grad_amax_history, - compute_dtype, - dimension_numbers, - precision=None, - preferred_element_type=None -): - return q_dot_dq_impl( - lhs, - rhs, - lhs_scale, - rhs_scale, - out_grad_scale, - lhs_amax_history, - rhs_amax_history, - out_grad_amax_history, - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, - is_training=False, - ) - - -def q_dot_dq_fwd( - lhs, - rhs, - lhs_scale, - rhs_scale, - out_grad_scale, - lhs_amax_history, - rhs_amax_history, - out_grad_amax_history, - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, -): - return q_dot_dq_impl( - lhs, - rhs, - lhs_scale, - rhs_scale, - out_grad_scale, - lhs_amax_history, - rhs_amax_history, - out_grad_amax_history, - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, - is_training=True - ) - - -def q_dot_dq_bwd( - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, - res, - g -): - ( - lhs, - rhs, - q_lhs, - q_rhs, - new_lhs_scale, - new_rhs_scale, - out_grad_scale, - new_lhs_amax_history, - new_rhs_amax_history, - out_grad_amax_history, - ) = res - - new_out_grad_scale, new_out_grad_amax_history = quantize_and_update( - g, - jnp.float8_e5m2, - out_grad_scale, - out_grad_amax_history, - compute_dtype, - use_direct_quant=True - ) - - q_g = quantize(g, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type) - - grad_lhs = dot_general_transpose_lhs( - q_g, - lhs, - q_rhs, - dimension_numbers=dimension_numbers, - precision=lax.Precision.HIGHEST, - preferred_element_type=preferred_element_type, - ) - grad_lhs = dequantize(grad_lhs, preferred_element_type, new_rhs_scale * new_out_grad_scale) - - grad_rhs = dot_general_transpose_rhs( - q_g, - q_lhs, - rhs, - dimension_numbers=dimension_numbers, - precision=lax.Precision.HIGHEST, - preferred_element_type=preferred_element_type, - ) - grad_rhs = dequantize(grad_rhs, preferred_element_type, new_lhs_scale * new_out_grad_scale) - - return ( - grad_lhs, - grad_rhs, - new_lhs_scale, - new_rhs_scale, - new_out_grad_scale, - new_lhs_amax_history, - new_rhs_amax_history, - new_out_grad_amax_history, - ) +def in_q_fwd(compute_dtype, q_dtype, inp, scale, amax_history): + qin, new_scale, new_history = quantize_and_update( + inp, q_dtype, scale, amax_history, compute_dtype, quantize_only=True + ) + return (qin, new_scale), (new_scale, new_history) -q_dot_dq.defvjp(q_dot_dq_fwd, q_dot_dq_bwd) +def in_q_bwd(compute_dtype, q_dtype, res, _): + new_scale, new_history = res + # We don't compute gradients for inp, scale and amax_history, but we pass through scale and history + return None, new_scale, new_history +in_q.defvjp(in_q_fwd, in_q_bwd) @partial(custom_jvp, nondiff_argnums=(2, 3, 4)) def dot_general_with_precision( @@ -520,6 +350,99 @@ def _parse_dot_inputs(*args, **kwargs): x = jnp.asarray(x, comp_dtype) return x, k, dimension_numbers, comp_dtype +def _fm32_to_float32(value): + if value.dtype == fm32: + return lax.convert_element_type(value, jnp.float32) + return value + +# Convenience wrappers for the quantize-dot-dequantize +def q_dot_dq( + lhs, + rhs, + lhs_scale, + rhs_scale, + out_grad_scale, + lhs_amax_history, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision=None, + preferred_element_type=None +): + q_lhs, new_lhs_scale = in_q( + compute_dtype, jnp.float8_e4m3fn, lhs, lhs_scale, lhs_amax_history + ) + + y = one_sided_q_dot_dq( + lhs, + q_lhs, + new_lhs_scale, # actualy new lhs scale + rhs, + rhs_scale, + out_grad_scale, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type + ) + return y # type: ignore + +# This decorator wraps a function to perform quantized dot product and dequantization. +# It prepares the arguments for q_dot_dq. +def q_dot_dq_config( + lhs_scale, rhs_scale, out_grad_scale, + lhs_amax_history, rhs_amax_history, out_grad_amax_history, + compute_dtype +): + def decorator(func): + def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): + return q_dot_dq( + lhs, + rhs, + lhs_scale=lhs_scale, + rhs_scale=rhs_scale, + out_grad_scale=out_grad_scale, + lhs_amax_history=lhs_amax_history, + rhs_amax_history=rhs_amax_history, + out_grad_amax_history=out_grad_amax_history, + compute_dtype=compute_dtype, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + ) + return wrapper + return decorator + +# This decorator wraps a function to perform one-sided quantized dot product and dequantization. +# It prepares the arguments for one_sided_q_dot_dq, including the pre-quantized input, +# scales, and amax histories. This allows for efficient FP8 matrix multiplication while +# managing quantization parameters. +def one_sided_q_dot_dq_config( + compute_dtype, q_x, input_scale, kernel_scale, out_grad_scale, + kernel_amax_history, out_grad_amax_history +): + def decorator(func): + def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): + return one_sided_q_dot_dq( + lhs=lhs, + q_lhs=q_x, + lhs_scale=input_scale, + rhs=rhs, + rhs_scale=kernel_scale, + out_grad_scale=out_grad_scale, + rhs_amax_history=kernel_amax_history, + out_grad_amax_history=out_grad_amax_history, + compute_dtype=compute_dtype, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type + ) + return wrapper + return decorator + class Fp8DotGeneralBase(module.Module): amax_history_length: int = 1024 @@ -589,23 +512,200 @@ def __call__(self, *args, **kwargs): x, k, dimension_numbers, comp_dtype = _parse_dot_inputs( *args, **kwargs ) - y = q_dot_dq( - x, - k, - self.input_scale.value, - self.kernel_scale.value, - self.output_grad_scale.value, - self.input_amax_history.value, - self.kernel_amax_history.value, - self.output_grad_amax_history.value, - comp_dtype, - dimension_numbers, - preferred_element_type=x.dtype + x, + k, + self.input_scale.value, + self.kernel_scale.value, + self.output_grad_scale.value, + self.input_amax_history.value, + self.kernel_amax_history.value, + self.output_grad_amax_history.value, + comp_dtype, + dimension_numbers, + preferred_element_type=x.dtype, ) return y # type: ignore +def one_sided_q_dot_dq_impl( + lhs, + q_lhs, + lhs_scale, # actualy new lhs scale + rhs, + rhs_scale, + out_grad_scale, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + is_training +): + new_rhs_scale, new_rhs_amax_history = quantize_and_update( + rhs, + jnp.float8_e4m3fn, + rhs_scale, + rhs_amax_history, + compute_dtype, + use_direct_quant=True + ) + + q_rhs = quantize(rhs, jnp.float8_e4m3fn, _fm32_to_float32(new_rhs_scale), preferred_element_type) + + out = lax.dot_general( + q_lhs, + q_rhs, + dimension_numbers, + preferred_element_type=preferred_element_type, + precision=lax.Precision.DEFAULT, + ) + + out = dequantize(out, preferred_element_type, _fm32_to_float32(new_rhs_scale) * _fm32_to_float32(lhs_scale)) + if is_training: + res = ( + lhs, + q_lhs, + lhs_scale, + rhs, + q_rhs, + new_rhs_scale, + out_grad_scale, + new_rhs_amax_history, + out_grad_amax_history, + ) + return out, res + else: + return out + +@partial(custom_vjp, nondiff_argnums=(8, 9, 10, 11)) +def one_sided_q_dot_dq( + lhs, + q_lhs, + lhs_scale, # actualy new lhs scale + rhs, + rhs_scale, + out_grad_scale, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision=None, + preferred_element_type=None +): + return one_sided_q_dot_dq_impl( + lhs, + q_lhs, + lhs_scale, + rhs, + rhs_scale, + out_grad_scale, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + is_training=False, + ) + +def one_sided_q_dot_dq_fwd( + lhs, + q_lhs, + lhs_scale, + rhs, + rhs_scale, + out_grad_scale, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, +): + return one_sided_q_dot_dq_impl( + lhs, + q_lhs, + lhs_scale, + rhs, + rhs_scale, + out_grad_scale, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + is_training=True + ) + +def one_sided_q_dot_dq_bwd( + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + res, + g +): + ( + lhs, + q_lhs, + lhs_scale, + rhs, + q_rhs, + new_rhs_scale, + out_grad_scale, + new_rhs_amax_history, + out_grad_amax_history, + ) = res + + new_out_grad_scale, new_out_grad_amax_history = quantize_and_update( + g, + jnp.float8_e5m2, + out_grad_scale, + out_grad_amax_history, + compute_dtype, + use_direct_quant=True + ) + + q_g = quantize(g, jnp.float8_e5m2, _fm32_to_float32(new_out_grad_scale), preferred_element_type) + + grad_lhs = dot_general_transpose_lhs( + q_g, + lhs, + q_rhs, + dimension_numbers=dimension_numbers, + precision=lax.Precision.HIGHEST, + preferred_element_type=preferred_element_type, + ) + grad_lhs = dequantize(grad_lhs, preferred_element_type, _fm32_to_float32(new_rhs_scale) * _fm32_to_float32(new_out_grad_scale)) + + grad_rhs = dot_general_transpose_rhs( + q_g, + q_lhs, + rhs, + dimension_numbers=dimension_numbers, + precision=lax.Precision.HIGHEST, + preferred_element_type=preferred_element_type, + ) + grad_rhs = dequantize(grad_rhs, preferred_element_type, _fm32_to_float32(lhs_scale) * _fm32_to_float32(new_out_grad_scale)) + + return ( + grad_lhs, + None, + None, + grad_rhs, + new_rhs_scale, + new_out_grad_scale, + new_rhs_amax_history, + new_out_grad_amax_history, + ) + +one_sided_q_dot_dq.defvjp(one_sided_q_dot_dq_fwd, one_sided_q_dot_dq_bwd) + + + class NANOOFp8DotGeneralOp(Fp8DotGeneralOp): e4m3_dtype: DType = jnp.float8_e4m3fnuz e5m2_dtype: DType = jnp.float8_e5m2fnuz