diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 767ad45a1..4ea53d533 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -139,8 +139,7 @@ def quantize(x, q_dtype, scale, compute_dtype): def dequantize(x, dq_dtype, scale): return x.astype(dq_dtype) * jnp.broadcast_to(scale.astype(dq_dtype), x.shape) - -def quantize_dequantize(x, q_dtype, scale, compute_dtype): +def qdq(x, q_dtype, scale, compute_dtype): qx = quantize(x, q_dtype, scale, compute_dtype) return dequantize(qx, x.dtype, scale) @@ -165,8 +164,8 @@ def compute_amax_history(x, amax_history): return new_history -def quantize_and_update( - x, q_dtype, scale, amax_history, compute_dtype, use_direct_quant=False +def update_fp8_meta( + x, q_dtype, scale, amax_history ): is_fmax32 = (scale.dtype == fm32 and amax_history.dtype == fm32) # convert fm32->f32 so we can do math @@ -181,20 +180,20 @@ def quantize_and_update( new_scale = compute_scale(amax_from_history, scale, dtype_max) new_history = compute_amax_history(x, amax_history) - # convert f32->fmax32 so the autodiff system accumulates fp8 meta correctly if is_fmax32: new_history = lax.convert_element_type(new_history, fp32_max_grad) new_scale = lax.convert_element_type(new_scale, fp32_max_grad) - - # Quantize the input - if not use_direct_quant: - qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype) - return qx, new_scale, new_history - return new_scale, new_history +def quantize_dequantize_update(x, q_dtype, scale, amax_history, compute_dtype): + updated_scale, updated_history = update_fp8_meta(x, q_dtype, scale, amax_history) + qdq_x = qdq(x, q_dtype, _fm32_to_float32(updated_scale), compute_dtype) + return qdq_x, updated_scale, updated_history - return qx, new_scale, new_history +def _fm32_to_float32(value): + if value.dtype == fm32: + return lax.convert_element_type(value, jnp.float32) + return value def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, @@ -242,14 +241,14 @@ def dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, @partial(custom_vjp, nondiff_argnums=(0, 1)) def in_qdq(compute_dtype, q_dtype, inp, scale, amax_history): - qin, _, _ = quantize_and_update( + qin, _, _ = quantize_dequantize_update( inp, q_dtype, scale, amax_history, compute_dtype ) return qin def in_qdq_fwd(compute_dtype, q_dtype, inp, scale, amax_history): - qin, new_scale, new_history = quantize_and_update( + qin, new_scale, new_history = quantize_dequantize_update( inp, q_dtype, scale, amax_history, compute_dtype ) return qin, (new_scale, new_history) @@ -275,7 +274,7 @@ def out_qdq_fwd(compute_dtype, q_dtype, out, scale, amax_history): def out_qdq_bwd(compute_dtype, q_dtype, res, g): scale, amax_history = res - q_g, new_scale, new_history = quantize_and_update( + q_g, new_scale, new_history = quantize_dequantize_update( g, q_dtype, scale, amax_history, compute_dtype ) return q_g, new_scale, new_history @@ -284,41 +283,58 @@ def out_qdq_bwd(compute_dtype, q_dtype, res, g): out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) -def q_dot_dq_impl( - lhs, - rhs, - lhs_scale, - rhs_scale, - out_grad_scale, - lhs_amax_history, - rhs_amax_history, - out_grad_amax_history, - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, - is_training -): - new_lhs_scale, new_lhs_amax_history = quantize_and_update( - lhs, - jnp.float8_e4m3fn, - lhs_scale, - lhs_amax_history, - compute_dtype, - use_direct_quant=True - ) - new_rhs_scale, new_rhs_amax_history = quantize_and_update( - rhs, - jnp.float8_e4m3fn, - rhs_scale, - rhs_amax_history, - compute_dtype, - use_direct_quant=True +@partial(custom_vjp, nondiff_argnums=(0, 1)) +def in_q(compute_dtype, q_dtype, inp, scale, amax_history): + new_scale, _ = update_fp8_meta(inp, q_dtype, scale, amax_history) + qin = quantize(inp, q_dtype, _fm32_to_float32(new_scale), compute_dtype) + return qin, new_scale + +def in_q_fwd(compute_dtype, q_dtype, inp, scale, amax_history): + new_scale, new_history = update_fp8_meta(inp, q_dtype, scale, amax_history) + qin = quantize(inp, q_dtype, _fm32_to_float32(new_scale), compute_dtype) + return (qin, new_scale), (new_scale, new_history) + +def in_q_bwd(compute_dtype, q_dtype, res, _): + new_scale, new_history = res + # We don't compute gradients for inp, scale and amax_history, but we pass through scale and history + return None, new_scale, new_history + +in_q.defvjp(in_q_fwd, in_q_bwd) + + +@partial(custom_vjp, nondiff_argnums=(0, )) +def out_dq(dq_type, lhs_scale, rhs_scale, out): + q_out = dequantize( + out, + dq_type, + _fm32_to_float32(lhs_scale) * _fm32_to_float32(rhs_scale) ) + return q_out + +def out_dq_fwd(dq_type, lhs_scale, rhs_scale, out): + return out_dq(dq_type, lhs_scale, rhs_scale, out), None + +def out_dq_bwd(dq_type, _, g): + return None, None, g + +out_dq.defvjp(out_dq_fwd, out_dq_bwd) - q_lhs = quantize(lhs, jnp.float8_e4m3fn, new_lhs_scale, preferred_element_type) - q_rhs = quantize(rhs, jnp.float8_e4m3fn, new_rhs_scale, preferred_element_type) +def quantized_dot_impl( + lhs, + q_lhs, + lhs_scale, # actualy new lhs scale + rhs, + q_rhs, # actualy new rhs scale + rhs_scale, + out_grad_scale, # old out grad scale + out_grad_amax_history, # old out grad amax history + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + is_training +): out = lax.dot_general( q_lhs, q_rhs, @@ -326,49 +342,44 @@ def q_dot_dq_impl( preferred_element_type=preferred_element_type, precision=lax.Precision.DEFAULT, ) - - out = dequantize(out, preferred_element_type, new_lhs_scale * new_rhs_scale) if is_training: res = ( lhs, - rhs, q_lhs, + lhs_scale, + rhs, q_rhs, - new_lhs_scale, - new_rhs_scale, + rhs_scale, out_grad_scale, - new_lhs_amax_history, - new_rhs_amax_history, out_grad_amax_history, ) return out, res else: return out - @partial(custom_vjp, nondiff_argnums=(8, 9, 10, 11)) -def q_dot_dq( +def quantized_dot( lhs, + q_lhs, + lhs_scale, # actualy new lhs scale rhs, - lhs_scale, + q_rhs, rhs_scale, - out_grad_scale, - lhs_amax_history, - rhs_amax_history, - out_grad_amax_history, + out_grad_scale, # old out grad scale + out_grad_amax_history, # old out grad amax history compute_dtype, dimension_numbers, precision=None, preferred_element_type=None ): - return q_dot_dq_impl( + return quantized_dot_impl( lhs, - rhs, + q_lhs, lhs_scale, + rhs, + q_rhs, rhs_scale, out_grad_scale, - lhs_amax_history, - rhs_amax_history, out_grad_amax_history, compute_dtype, dimension_numbers, @@ -377,29 +388,28 @@ def q_dot_dq( is_training=False, ) - -def q_dot_dq_fwd( +def quantized_dot_fwd( lhs, - rhs, + q_lhs, lhs_scale, + rhs, + q_rhs, rhs_scale, out_grad_scale, - lhs_amax_history, - rhs_amax_history, out_grad_amax_history, compute_dtype, dimension_numbers, precision, preferred_element_type, ): - return q_dot_dq_impl( + return quantized_dot_impl( lhs, - rhs, + q_lhs, lhs_scale, + rhs, + q_rhs, rhs_scale, out_grad_scale, - lhs_amax_history, - rhs_amax_history, out_grad_amax_history, compute_dtype, dimension_numbers, @@ -408,8 +418,7 @@ def q_dot_dq_fwd( is_training=True ) - -def q_dot_dq_bwd( +def quantized_dot_bwd( compute_dtype, dimension_numbers, precision, @@ -419,27 +428,23 @@ def q_dot_dq_bwd( ): ( lhs, - rhs, q_lhs, + lhs_scale, + rhs, q_rhs, - new_lhs_scale, - new_rhs_scale, + rhs_scale, out_grad_scale, - new_lhs_amax_history, - new_rhs_amax_history, out_grad_amax_history, ) = res - new_out_grad_scale, new_out_grad_amax_history = quantize_and_update( + new_out_grad_scale, new_out_grad_amax_history = update_fp8_meta( g, jnp.float8_e5m2, out_grad_scale, out_grad_amax_history, - compute_dtype, - use_direct_quant=True ) - q_g = quantize(g, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type) + q_g = quantize(g, jnp.float8_e5m2, _fm32_to_float32(new_out_grad_scale), preferred_element_type) grad_lhs = dot_general_transpose_lhs( q_g, @@ -449,7 +454,11 @@ def q_dot_dq_bwd( precision=lax.Precision.HIGHEST, preferred_element_type=preferred_element_type, ) - grad_lhs = dequantize(grad_lhs, preferred_element_type, new_rhs_scale * new_out_grad_scale) + grad_lhs = dequantize( + grad_lhs, + preferred_element_type, + _fm32_to_float32(rhs_scale) * _fm32_to_float32(new_out_grad_scale) + ) grad_rhs = dot_general_transpose_rhs( q_g, @@ -459,21 +468,67 @@ def q_dot_dq_bwd( precision=lax.Precision.HIGHEST, preferred_element_type=preferred_element_type, ) - grad_rhs = dequantize(grad_rhs, preferred_element_type, new_lhs_scale * new_out_grad_scale) + grad_rhs = dequantize( + grad_rhs, + preferred_element_type, + _fm32_to_float32(lhs_scale) * _fm32_to_float32(new_out_grad_scale) + ) return ( grad_lhs, + None, + None, grad_rhs, - new_lhs_scale, - new_rhs_scale, + None, + None, new_out_grad_scale, - new_lhs_amax_history, - new_rhs_amax_history, new_out_grad_amax_history, ) -q_dot_dq.defvjp(q_dot_dq_fwd, q_dot_dq_bwd) +quantized_dot.defvjp(quantized_dot_fwd, quantized_dot_bwd) +# Convenience wrappers for the quantize-dot-dequantize +def q_dot_dq( + lhs, + rhs, + lhs_scale, + rhs_scale, + out_grad_scale, + lhs_amax_history, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision=None, + preferred_element_type=None +): + q_lhs, new_lhs_scale = in_q( + compute_dtype, jnp.float8_e4m3fn, lhs, lhs_scale, lhs_amax_history + ) + q_rhs, new_rhs_scale = in_q( + compute_dtype, jnp.float8_e4m3fn, rhs, rhs_scale, rhs_amax_history + ) + y = quantized_dot( + lhs, + q_lhs, + new_lhs_scale, + rhs, + q_rhs, + new_rhs_scale, + out_grad_scale, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type + ) + y = out_dq( + dq_type=preferred_element_type, + lhs_scale=new_lhs_scale, + rhs_scale=new_rhs_scale, + out=y + ) + return y # type: ignore @partial(custom_jvp, nondiff_argnums=(2, 3, 4)) def dot_general_with_precision( diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 8e2fd51cd..ff31efdde 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -1263,7 +1263,7 @@ def test_fp8_dot_general_injection(self, fp8_genre, use_direct_quant): # Used to cast the inputs to be representable in FP8, so that the difference # of the results from the original gemm and fp8 gemm is small. cast_to_representable = functools.partial( - fp8_ops.quantize_dequantize, + fp8_ops.qdq, scale=jnp.ones((1,)), compute_dtype=jnp.float32, )