From a4d0265562157602e586f2fe03de62e2e24b2918 Mon Sep 17 00:00:00 2001 From: shuw Date: Mon, 23 Sep 2024 08:52:26 -0700 Subject: [PATCH 1/7] design 1 --- flax/linen/__init__.py | 1 + flax/linen/fp8_ops.py | 406 ++++++++++++++++++++++++++--------------- 2 files changed, 258 insertions(+), 149 deletions(-) diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 24a33d873..13e4a480e 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -89,6 +89,7 @@ from .fp8_ops import ( Fp8DotGeneralOp as Fp8DotGeneralOp, Fp8DirectDotGeneralOp as Fp8DirectDotGeneralOp, + Fp8DirectDoubleDotGeneralOp as Fp8DirectDoubleDotGeneralOp, NANOOFp8DotGeneralOp as NANOOFp8DotGeneralOp, ) from .initializers import ( diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 767ad45a1..d1ab52302 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -283,197 +283,259 @@ def out_qdq_bwd(compute_dtype, q_dtype, res, g): out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) - -def q_dot_dq_impl( +@partial(custom_vjp, nondiff_argnums=(3, 4, 5, 6)) +def q_multi_dot_dq( lhs, - rhs, lhs_scale, - rhs_scale, - out_grad_scale, lhs_amax_history, - rhs_amax_history, - out_grad_amax_history, compute_dtype, dimension_numbers, precision, preferred_element_type, - is_training + *rhs_and_scales, ): - new_lhs_scale, new_lhs_amax_history = quantize_and_update( - lhs, - jnp.float8_e4m3fn, - lhs_scale, - lhs_amax_history, - compute_dtype, - use_direct_quant=True - ) - new_rhs_scale, new_rhs_amax_history = quantize_and_update( - rhs, - jnp.float8_e4m3fn, - rhs_scale, - rhs_amax_history, - compute_dtype, - use_direct_quant=True - ) - - q_lhs = quantize(lhs, jnp.float8_e4m3fn, new_lhs_scale, preferred_element_type) - q_rhs = quantize(rhs, jnp.float8_e4m3fn, new_rhs_scale, preferred_element_type) - - out = lax.dot_general( - q_lhs, - q_rhs, - dimension_numbers, - preferred_element_type=preferred_element_type, - precision=lax.Precision.DEFAULT, - ) - - out = dequantize(out, preferred_element_type, new_lhs_scale * new_rhs_scale) - if is_training: - res = ( - lhs, - rhs, - q_lhs, - q_rhs, - new_lhs_scale, - new_rhs_scale, - out_grad_scale, - new_lhs_amax_history, - new_rhs_amax_history, - out_grad_amax_history, + return q_multi_dot_dq_impl( + lhs, + lhs_scale, + lhs_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + *rhs_and_scales, + is_training=False ) - return out, res - else: - return out - -@partial(custom_vjp, nondiff_argnums=(8, 9, 10, 11)) -def q_dot_dq( - lhs, - rhs, - lhs_scale, - rhs_scale, - out_grad_scale, - lhs_amax_history, - rhs_amax_history, - out_grad_amax_history, - compute_dtype, - dimension_numbers, - precision=None, - preferred_element_type=None -): - return q_dot_dq_impl( +def q_multi_dot_dq_impl( lhs, - rhs, lhs_scale, - rhs_scale, - out_grad_scale, lhs_amax_history, - rhs_amax_history, - out_grad_amax_history, compute_dtype, dimension_numbers, precision, preferred_element_type, - is_training=False, - ) - - -def q_dot_dq_fwd( - lhs, - rhs, - lhs_scale, - rhs_scale, - out_grad_scale, - lhs_amax_history, - rhs_amax_history, - out_grad_amax_history, - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, + *rhs_and_scales, + is_training, ): - return q_dot_dq_impl( + # Quantize lhs (A) once + new_lhs_scale, new_lhs_amax_history = quantize_and_update( + lhs, jnp.float8_e4m3fn, lhs_scale, lhs_amax_history, + compute_dtype, use_direct_quant=True + ) + q_lhs = quantize(lhs, jnp.float8_e4m3fn, new_lhs_scale, preferred_element_type) + + outputs = [] + rhs_list = [] + q_rhs_list = [] + new_rhs_scales = [] + new_rhs_amax_histories = [] + out_grad_scales = [] + out_grad_amax_histories = [] + + # We use 5 as the grouping factor because each RHS input requires 5 pieces of data: + # 1. rhs: the right-hand side matrix + # 2. rhs_scale: the scale for quantizing the RHS + # 3. rhs_amax_history: the amax history for the RHS + # 4. out_grad_scale: the scale for quantizing the output gradient + # 5. out_grad_amax_history: the amax history for the output gradient + assert len(rhs_and_scales) % 5 == 0, f"Expected rhs_and_scales length to be divisible by 5, but got {len(rhs_and_scales)}" + num_rhs = len(rhs_and_scales) // 5 + + for i in range(num_rhs): + rhs, rhs_scale, rhs_amax_history, out_grad_scale, out_grad_amax_history = ( + rhs_and_scales[i*5:(i+1)*5] + ) + new_rhs_scale, new_rhs_amax_history = quantize_and_update( + rhs, jnp.float8_e4m3fn, rhs_scale, rhs_amax_history, + compute_dtype, use_direct_quant=True + ) + q_rhs = quantize(rhs, jnp.float8_e4m3fn, new_rhs_scale, preferred_element_type) + + out = lax.dot_general( + q_lhs, q_rhs, dimension_numbers, + preferred_element_type=preferred_element_type, + precision=lax.Precision.DEFAULT, + ) + out = dequantize(out, preferred_element_type, new_lhs_scale * new_rhs_scale) + + outputs.append(out) + rhs_list.append(rhs) + q_rhs_list.append(q_rhs) + new_rhs_scales.append(new_rhs_scale) + new_rhs_amax_histories.append(new_rhs_amax_history) + out_grad_scales.append(out_grad_scale) + out_grad_amax_histories.append(out_grad_amax_history) + + if is_training: + res = ( + lhs, + q_lhs, + new_lhs_scale, + new_lhs_amax_history, + rhs_list, + q_rhs_list, + new_rhs_scales, + new_rhs_amax_histories, + out_grad_scales, + out_grad_amax_histories, + ) + return tuple(outputs), res + else: + return tuple(outputs) + +def q_multi_dot_dq_fwd( lhs, - rhs, lhs_scale, - rhs_scale, - out_grad_scale, lhs_amax_history, - rhs_amax_history, - out_grad_amax_history, compute_dtype, dimension_numbers, precision, preferred_element_type, - is_training=True - ) - + *rhs_and_scales +): + return q_multi_dot_dq_impl( + lhs, + lhs_scale, + lhs_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + *rhs_and_scales, + is_training=True + ) -def q_dot_dq_bwd( - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, - res, - g +def q_multi_dot_dq_bwd( + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + res, + g ): ( lhs, - rhs, q_lhs, - q_rhs, new_lhs_scale, - new_rhs_scale, - out_grad_scale, new_lhs_amax_history, - new_rhs_amax_history, - out_grad_amax_history, + rhs_list, + q_rhs_list, + new_rhs_scales, + new_rhs_amax_histories, + out_grad_scales, + out_grad_amax_histories, ) = res - new_out_grad_scale, new_out_grad_amax_history = quantize_and_update( - g, - jnp.float8_e5m2, - out_grad_scale, - out_grad_amax_history, - compute_dtype, - use_direct_quant=True - ) - - q_g = quantize(g, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type) - - grad_lhs = dot_general_transpose_lhs( - q_g, - lhs, - q_rhs, - dimension_numbers=dimension_numbers, - precision=lax.Precision.HIGHEST, - preferred_element_type=preferred_element_type, - ) - grad_lhs = dequantize(grad_lhs, preferred_element_type, new_rhs_scale * new_out_grad_scale) - - grad_rhs = dot_general_transpose_rhs( - q_g, - q_lhs, - rhs, - dimension_numbers=dimension_numbers, - precision=lax.Precision.HIGHEST, - preferred_element_type=preferred_element_type, - ) - grad_rhs = dequantize(grad_rhs, preferred_element_type, new_lhs_scale * new_out_grad_scale) + grad_lhs = 0 + grad_rhs_and_scales = [] + + for i, (rhs, q_rhs, new_rhs_scale, out_grad_scale, out_grad_amax_history, g_i) in enumerate(zip(rhs_list, q_rhs_list, new_rhs_scales, out_grad_scales, + out_grad_amax_histories, g)): + new_out_grad_scale, new_out_grad_amax_history = quantize_and_update( + g_i, jnp.float8_e5m2, out_grad_scale, out_grad_amax_history, + compute_dtype, use_direct_quant=True + ) + q_g = quantize(g_i, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type) + + grad_lhs_i = dot_general_transpose_lhs( + q_g, lhs, q_rhs, dimension_numbers=dimension_numbers, + precision=lax.Precision.HIGHEST, preferred_element_type=preferred_element_type, + ) + grad_lhs_i = dequantize(grad_lhs_i, preferred_element_type, new_rhs_scales[i] * new_out_grad_scale) + # lhs is used multiple times + grad_lhs += grad_lhs_i + + grad_rhs = dot_general_transpose_rhs( + q_g, q_lhs, rhs, dimension_numbers=dimension_numbers, + precision=lax.Precision.HIGHEST, preferred_element_type=preferred_element_type, + ) + grad_rhs = dequantize(grad_rhs, preferred_element_type, new_lhs_scale * new_out_grad_scale) + grad_rhs_and_scales.extend([ + grad_rhs, + new_rhs_scales[i], + new_rhs_amax_histories[i], + new_out_grad_scale, + new_out_grad_amax_history + ]) return ( grad_lhs, - grad_rhs, new_lhs_scale, - new_rhs_scale, - new_out_grad_scale, new_lhs_amax_history, - new_rhs_amax_history, - new_out_grad_amax_history, + *grad_rhs_and_scales ) -q_dot_dq.defvjp(q_dot_dq_fwd, q_dot_dq_bwd) +q_multi_dot_dq.defvjp(q_multi_dot_dq_fwd, q_multi_dot_dq_bwd) + +# Convenience wrappers for the common cases +def q_dot_dq( + lhs, + rhs, + lhs_scale, + rhs_scale, + out_grad_scale, + lhs_amax_history, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision=None, + preferred_element_type=None +): + return q_multi_dot_dq( + lhs, + lhs_scale, + lhs_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + rhs, + rhs_scale, + rhs_amax_history, + out_grad_scale, + out_grad_amax_history + )[0] + +def q_double_dot_dq( + lhs, + rhs1, + rhs2, + lhs_scale, + rhs1_scale, + rhs2_scale, + out1_grad_scale, + out2_grad_scale, + lhs_amax_history, + rhs1_amax_history, + rhs2_amax_history, + out1_grad_amax_history, + out2_grad_amax_history, + compute_dtype, + dimension_numbers, + precision=None, + preferred_element_type=None +): + return q_multi_dot_dq( + lhs, + lhs_scale, + lhs_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + rhs1, + rhs1_scale, + rhs1_amax_history, + out1_grad_scale, + out1_grad_amax_history, + rhs2, + rhs2_scale, + rhs2_amax_history, + out2_grad_scale, + out2_grad_amax_history + ) @partial(custom_jvp, nondiff_argnums=(2, 3, 4)) def dot_general_with_precision( @@ -606,6 +668,52 @@ def __call__(self, *args, **kwargs): return y # type: ignore +class Fp8DirectDoubleDotGeneralOp(Fp8DotGeneralBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Initialize additional scales and amax histories for the second kernel + self.kernel2_scale = self.variable('quantization', 'kernel2_scale', jnp.array, 1.0) + self.kernel2_amax_history = self.variable('quantization', 'kernel2_amax_history', jnp.array, jnp.zeros(self.amax_history_len)) + self.output2_grad_scale = self.variable('quantization', 'output2_grad_scale', jnp.array, 1.0) + self.output2_grad_amax_history = self.variable('quantization', 'output2_grad_amax_history', jnp.array, jnp.zeros(self.amax_history_len)) + + def __call__(self, *args, **kwargs): + x, k1, k2, dimension_numbers, comp_dtype = _parse_double_dot_inputs(*args, **kwargs) + + y1, y2 = q_double_dot_dq( + x, + k1, + k2, + self.input_scale.value, + self.kernel_scale.value, + self.kernel2_scale.value, + self.output_grad_scale.value, + self.output2_grad_scale.value, + self.input_amax_history.value, + self.kernel_amax_history.value, + self.kernel2_amax_history.value, + self.output_grad_amax_history.value, + self.output2_grad_amax_history.value, + comp_dtype, + dimension_numbers, + preferred_element_type=x.dtype + ) + + return y1, y2 # type: ignore + +def _parse_double_dot_inputs(*args, **kwargs): + if len(args) == 4: + x, k1, k2, dimension_numbers = args + comp_dtype = kwargs.get('precision', None) + elif len(args) == 3: + x, k1, k2 = args + dimension_numbers = kwargs.get('dimension_numbers', (((x.ndim - 1,), (0,)), ((), ()))) + comp_dtype = kwargs.get('precision', None) + else: + raise ValueError(f'Unexpected arguments: {args}, {kwargs}') + + return x, k1, k2, dimension_numbers, comp_dtype + class NANOOFp8DotGeneralOp(Fp8DotGeneralOp): e4m3_dtype: DType = jnp.float8_e4m3fnuz e5m2_dtype: DType = jnp.float8_e5m2fnuz From 60a9a91ba3be0e162c100445a3c36d9de9da7f04 Mon Sep 17 00:00:00 2001 From: shuw Date: Mon, 23 Sep 2024 14:12:24 -0700 Subject: [PATCH 2/7] oneside --- flax/linen/fp8_ops.py | 536 ++++++++++++++++++------------------------ 1 file changed, 225 insertions(+), 311 deletions(-) diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index d1ab52302..9276f5c01 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -140,8 +140,10 @@ def dequantize(x, dq_dtype, scale): return x.astype(dq_dtype) * jnp.broadcast_to(scale.astype(dq_dtype), x.shape) -def quantize_dequantize(x, q_dtype, scale, compute_dtype): +def quantize_dequantize(x, q_dtype, scale, compute_dtype, quanitze_only=False): qx = quantize(x, q_dtype, scale, compute_dtype) + if quanitze_only: + return qx return dequantize(qx, x.dtype, scale) @@ -166,7 +168,8 @@ def compute_amax_history(x, amax_history): def quantize_and_update( - x, q_dtype, scale, amax_history, compute_dtype, use_direct_quant=False + x, q_dtype, scale, amax_history, compute_dtype, use_direct_quant=False, + quanitze_only=False, ): is_fmax32 = (scale.dtype == fm32 and amax_history.dtype == fm32) # convert fm32->f32 so we can do math @@ -188,14 +191,12 @@ def quantize_and_update( # Quantize the input if not use_direct_quant: - qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype) + qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype, quanitze_only=quanitze_only) return qx, new_scale, new_history return new_scale, new_history - return qx, new_scale, new_history - def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, swap_ans=False): @@ -283,259 +284,26 @@ def out_qdq_bwd(compute_dtype, q_dtype, res, g): out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) -@partial(custom_vjp, nondiff_argnums=(3, 4, 5, 6)) -def q_multi_dot_dq( - lhs, - lhs_scale, - lhs_amax_history, - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, - *rhs_and_scales, -): - return q_multi_dot_dq_impl( - lhs, - lhs_scale, - lhs_amax_history, - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, - *rhs_and_scales, - is_training=False - ) -def q_multi_dot_dq_impl( - lhs, - lhs_scale, - lhs_amax_history, - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, - *rhs_and_scales, - is_training, -): - # Quantize lhs (A) once - new_lhs_scale, new_lhs_amax_history = quantize_and_update( - lhs, jnp.float8_e4m3fn, lhs_scale, lhs_amax_history, - compute_dtype, use_direct_quant=True - ) - q_lhs = quantize(lhs, jnp.float8_e4m3fn, new_lhs_scale, preferred_element_type) - - outputs = [] - rhs_list = [] - q_rhs_list = [] - new_rhs_scales = [] - new_rhs_amax_histories = [] - out_grad_scales = [] - out_grad_amax_histories = [] - - # We use 5 as the grouping factor because each RHS input requires 5 pieces of data: - # 1. rhs: the right-hand side matrix - # 2. rhs_scale: the scale for quantizing the RHS - # 3. rhs_amax_history: the amax history for the RHS - # 4. out_grad_scale: the scale for quantizing the output gradient - # 5. out_grad_amax_history: the amax history for the output gradient - assert len(rhs_and_scales) % 5 == 0, f"Expected rhs_and_scales length to be divisible by 5, but got {len(rhs_and_scales)}" - num_rhs = len(rhs_and_scales) // 5 - - for i in range(num_rhs): - rhs, rhs_scale, rhs_amax_history, out_grad_scale, out_grad_amax_history = ( - rhs_and_scales[i*5:(i+1)*5] - ) - new_rhs_scale, new_rhs_amax_history = quantize_and_update( - rhs, jnp.float8_e4m3fn, rhs_scale, rhs_amax_history, - compute_dtype, use_direct_quant=True - ) - q_rhs = quantize(rhs, jnp.float8_e4m3fn, new_rhs_scale, preferred_element_type) - - out = lax.dot_general( - q_lhs, q_rhs, dimension_numbers, - preferred_element_type=preferred_element_type, - precision=lax.Precision.DEFAULT, - ) - out = dequantize(out, preferred_element_type, new_lhs_scale * new_rhs_scale) - - outputs.append(out) - rhs_list.append(rhs) - q_rhs_list.append(q_rhs) - new_rhs_scales.append(new_rhs_scale) - new_rhs_amax_histories.append(new_rhs_amax_history) - out_grad_scales.append(out_grad_scale) - out_grad_amax_histories.append(out_grad_amax_history) - - if is_training: - res = ( - lhs, - q_lhs, - new_lhs_scale, - new_lhs_amax_history, - rhs_list, - q_rhs_list, - new_rhs_scales, - new_rhs_amax_histories, - out_grad_scales, - out_grad_amax_histories, - ) - return tuple(outputs), res - else: - return tuple(outputs) - -def q_multi_dot_dq_fwd( - lhs, - lhs_scale, - lhs_amax_history, - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, - *rhs_and_scales -): - return q_multi_dot_dq_impl( - lhs, - lhs_scale, - lhs_amax_history, - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, - *rhs_and_scales, - is_training=True +@partial(custom_vjp, nondiff_argnums=(0, 1, 2)) +def in_q(compute_dtype, q_dtype, inp, scale, amax_history): + qin, new_scale, new_history = quantize_and_update( + inp, q_dtype, scale, amax_history, compute_dtype, quantize_only=True ) + return qin, new_scale -def q_multi_dot_dq_bwd( - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, - res, - g -): - ( - lhs, - q_lhs, - new_lhs_scale, - new_lhs_amax_history, - rhs_list, - q_rhs_list, - new_rhs_scales, - new_rhs_amax_histories, - out_grad_scales, - out_grad_amax_histories, - ) = res - - grad_lhs = 0 - grad_rhs_and_scales = [] - - for i, (rhs, q_rhs, new_rhs_scale, out_grad_scale, out_grad_amax_history, g_i) in enumerate(zip(rhs_list, q_rhs_list, new_rhs_scales, out_grad_scales, - out_grad_amax_histories, g)): - new_out_grad_scale, new_out_grad_amax_history = quantize_and_update( - g_i, jnp.float8_e5m2, out_grad_scale, out_grad_amax_history, - compute_dtype, use_direct_quant=True - ) - q_g = quantize(g_i, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type) - - grad_lhs_i = dot_general_transpose_lhs( - q_g, lhs, q_rhs, dimension_numbers=dimension_numbers, - precision=lax.Precision.HIGHEST, preferred_element_type=preferred_element_type, - ) - grad_lhs_i = dequantize(grad_lhs_i, preferred_element_type, new_rhs_scales[i] * new_out_grad_scale) - # lhs is used multiple times - grad_lhs += grad_lhs_i - - grad_rhs = dot_general_transpose_rhs( - q_g, q_lhs, rhs, dimension_numbers=dimension_numbers, - precision=lax.Precision.HIGHEST, preferred_element_type=preferred_element_type, - ) - grad_rhs = dequantize(grad_rhs, preferred_element_type, new_lhs_scale * new_out_grad_scale) - grad_rhs_and_scales.extend([ - grad_rhs, - new_rhs_scales[i], - new_rhs_amax_histories[i], - new_out_grad_scale, - new_out_grad_amax_history - ]) - - return ( - grad_lhs, - new_lhs_scale, - new_lhs_amax_history, - *grad_rhs_and_scales - ) - +def in_q_fwd(compute_dtype, q_dtype, inp, scale, amax_history): + qin, new_scale, new_history = quantize_and_update( + inp, q_dtype, scale, amax_history, compute_dtype, quantize_only=True + ) + return (qin, new_scale), (new_scale, new_history) -q_multi_dot_dq.defvjp(q_multi_dot_dq_fwd, q_multi_dot_dq_bwd) +def in_q_bwd(compute_dtype, q_dtype, inp, res, g): + new_scale, new_history = res + # We don't compute gradients for inp, scale and amax_history, but we pass through scale and history + return new_scale, new_history -# Convenience wrappers for the common cases -def q_dot_dq( - lhs, - rhs, - lhs_scale, - rhs_scale, - out_grad_scale, - lhs_amax_history, - rhs_amax_history, - out_grad_amax_history, - compute_dtype, - dimension_numbers, - precision=None, - preferred_element_type=None -): - return q_multi_dot_dq( - lhs, - lhs_scale, - lhs_amax_history, - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, - rhs, - rhs_scale, - rhs_amax_history, - out_grad_scale, - out_grad_amax_history - )[0] - -def q_double_dot_dq( - lhs, - rhs1, - rhs2, - lhs_scale, - rhs1_scale, - rhs2_scale, - out1_grad_scale, - out2_grad_scale, - lhs_amax_history, - rhs1_amax_history, - rhs2_amax_history, - out1_grad_amax_history, - out2_grad_amax_history, - compute_dtype, - dimension_numbers, - precision=None, - preferred_element_type=None -): - return q_multi_dot_dq( - lhs, - lhs_scale, - lhs_amax_history, - compute_dtype, - dimension_numbers, - precision, - preferred_element_type, - rhs1, - rhs1_scale, - rhs1_amax_history, - out1_grad_scale, - out1_grad_amax_history, - rhs2, - rhs2_scale, - rhs2_amax_history, - out2_grad_scale, - out2_grad_amax_history - ) +in_q.defvjp(in_q_fwd, in_q_bwd) @partial(custom_jvp, nondiff_argnums=(2, 3, 4)) def dot_general_with_precision( @@ -652,67 +420,213 @@ def __call__(self, *args, **kwargs): *args, **kwargs ) - y = q_dot_dq( - x, - k, - self.input_scale.value, - self.kernel_scale.value, - self.output_grad_scale.value, - self.input_amax_history.value, - self.kernel_amax_history.value, - self.output_grad_amax_history.value, - comp_dtype, - dimension_numbers, - preferred_element_type=x.dtype - ) - - return y # type: ignore - -class Fp8DirectDoubleDotGeneralOp(Fp8DotGeneralBase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Initialize additional scales and amax histories for the second kernel - self.kernel2_scale = self.variable('quantization', 'kernel2_scale', jnp.array, 1.0) - self.kernel2_amax_history = self.variable('quantization', 'kernel2_amax_history', jnp.array, jnp.zeros(self.amax_history_len)) - self.output2_grad_scale = self.variable('quantization', 'output2_grad_scale', jnp.array, 1.0) - self.output2_grad_amax_history = self.variable('quantization', 'output2_grad_amax_history', jnp.array, jnp.zeros(self.amax_history_len)) - - def __call__(self, *args, **kwargs): - x, k1, k2, dimension_numbers, comp_dtype = _parse_double_dot_inputs(*args, **kwargs) - - y1, y2 = q_double_dot_dq( - x, - k1, - k2, - self.input_scale.value, - self.kernel_scale.value, - self.kernel2_scale.value, - self.output_grad_scale.value, - self.output2_grad_scale.value, - self.input_amax_history.value, - self.kernel_amax_history.value, - self.kernel2_amax_history.value, - self.output_grad_amax_history.value, - self.output2_grad_amax_history.value, - comp_dtype, - dimension_numbers, - preferred_element_type=x.dtype + q_x, new_input_scale = in_q( + comp_dtype, jnp.float8_e4m3fn, x, self.input_scale.value, self.input_amax_history) + + y = one_sided_q_dot_dq( + x, + q_x, + new_input_scale, # actualy new lhs scale + k, + self.kernel_scale.value, + self.output_grad_scale.value, + self.kernel_amax_history.value, + self.output_grad_amax_history.value, + comp_dtype, + dimension_numbers, + preferred_element_type=x.dtype ) + return y + + # y = q_dot_dq( + # x, + # k, + # self.input_scale.value, + # self.kernel_scale.value, + # self.output_grad_scale.value, + # self.input_amax_history.value, + # self.kernel_amax_history.value, + # self.output_grad_amax_history.value, + # comp_dtype, + # dimension_numbers, + # preferred_element_type=x.dtype + # ) + + # return y # type: ignore +def one_sided_q_dot_dq_impl( + lhs, + q_lhs, + lhs_scale, # actualy new lhs scale + rhs, + rhs_scale, + out_grad_scale, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + is_training +): + new_rhs_scale, new_rhs_amax_history = quantize_and_update( + rhs, + jnp.float8_e4m3fn, + rhs_scale, + rhs_amax_history, + compute_dtype, + use_direct_quant=True + ) + + q_rhs = quantize(rhs, jnp.float8_e4m3fn, new_rhs_scale, preferred_element_type) + + out = lax.dot_general( + q_lhs, + q_rhs, + dimension_numbers, + preferred_element_type=preferred_element_type, + precision=lax.Precision.DEFAULT, + ) + + out = dequantize(out, preferred_element_type, new_rhs_scale * lhs_scale) + if is_training: + res = ( + lhs, + q_lhs, + rhs, + q_rhs, + new_rhs_scale, + out_grad_scale, + new_rhs_amax_history, + out_grad_amax_history, + ) + return out, res + else: + return out + +@partial(custom_vjp, nondiff_argnums=(1,2,8, 9, 10,11)) +def one_sided_q_dot_dq( + lhs, + q_lhs, + lhs_scale, # actualy new lhs scale + rhs, + rhs_scale, + out_grad_scale, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision=None, + preferred_element_type=None +): + return one_sided_q_dot_dq_impl( + lhs, + q_lhs, + lhs_scale, + rhs, + rhs_scale, + out_grad_scale, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + is_training=False, + ) + +def one_sided_q_dot_dq_fwd( + lhs, + q_lhs, + lhs_scale, + rhs, + rhs_scale, + out_grad_scale, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, +): + return one_sided_q_dot_dq_impl( + lhs, + q_lhs, + rhs, + rhs_scale, + out_grad_scale, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + is_training=True + ) + +def one_sided_q_dot_dq_bwd( + q_lhs, + lhs_scale, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type, + res, + g +): + ( + lhs, + q_lhs, + rhs, + q_rhs, + new_rhs_scale, + out_grad_scale, + new_rhs_amax_history, + out_grad_amax_history, + ) = res + + new_out_grad_scale, new_out_grad_amax_history = quantize_and_update( + g, + jnp.float8_e5m2, + out_grad_scale, + out_grad_amax_history, + compute_dtype, + use_direct_quant=True + ) + + q_g = quantize(g, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type) + + grad_lhs = dot_general_transpose_lhs( + q_g, + lhs, + q_rhs, + dimension_numbers=dimension_numbers, + precision=lax.Precision.HIGHEST, + preferred_element_type=preferred_element_type, + ) + grad_lhs = dequantize(grad_lhs, preferred_element_type, new_rhs_scale * new_out_grad_scale) + + grad_rhs = dot_general_transpose_rhs( + q_g, + q_lhs, + rhs, + dimension_numbers=dimension_numbers, + precision=lax.Precision.HIGHEST, + preferred_element_type=preferred_element_type, + ) + grad_rhs = dequantize(grad_rhs, preferred_element_type, lhs_scale * new_out_grad_scale) + + return ( + grad_lhs, + grad_rhs, + new_rhs_scale, + new_out_grad_scale, + new_rhs_amax_history, + new_out_grad_amax_history, + ) + +one_sided_q_dot_dq.defvjp(one_sided_q_dot_dq_fwd, one_sided_q_dot_dq_bwd) - return y1, y2 # type: ignore - -def _parse_double_dot_inputs(*args, **kwargs): - if len(args) == 4: - x, k1, k2, dimension_numbers = args - comp_dtype = kwargs.get('precision', None) - elif len(args) == 3: - x, k1, k2 = args - dimension_numbers = kwargs.get('dimension_numbers', (((x.ndim - 1,), (0,)), ((), ()))) - comp_dtype = kwargs.get('precision', None) - else: - raise ValueError(f'Unexpected arguments: {args}, {kwargs}') - return x, k1, k2, dimension_numbers, comp_dtype class NANOOFp8DotGeneralOp(Fp8DotGeneralOp): e4m3_dtype: DType = jnp.float8_e4m3fnuz From 920099e92a5654c19bb6a4fbc4548654ed41f244 Mon Sep 17 00:00:00 2001 From: shuw Date: Mon, 23 Sep 2024 15:04:46 -0700 Subject: [PATCH 3/7] pass test --- flax/linen/__init__.py | 1 - flax/linen/fp8_ops.py | 44 +++++++++++++++--------------------------- 2 files changed, 16 insertions(+), 29 deletions(-) diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 13e4a480e..24a33d873 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -89,7 +89,6 @@ from .fp8_ops import ( Fp8DotGeneralOp as Fp8DotGeneralOp, Fp8DirectDotGeneralOp as Fp8DirectDotGeneralOp, - Fp8DirectDoubleDotGeneralOp as Fp8DirectDoubleDotGeneralOp, NANOOFp8DotGeneralOp as NANOOFp8DotGeneralOp, ) from .initializers import ( diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 9276f5c01..324603b20 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -140,9 +140,9 @@ def dequantize(x, dq_dtype, scale): return x.astype(dq_dtype) * jnp.broadcast_to(scale.astype(dq_dtype), x.shape) -def quantize_dequantize(x, q_dtype, scale, compute_dtype, quanitze_only=False): +def quantize_dequantize(x, q_dtype, scale, compute_dtype, quantize_only=False): qx = quantize(x, q_dtype, scale, compute_dtype) - if quanitze_only: + if quantize_only: return qx return dequantize(qx, x.dtype, scale) @@ -169,7 +169,7 @@ def compute_amax_history(x, amax_history): def quantize_and_update( x, q_dtype, scale, amax_history, compute_dtype, use_direct_quant=False, - quanitze_only=False, + quantize_only=False, ): is_fmax32 = (scale.dtype == fm32 and amax_history.dtype == fm32) # convert fm32->f32 so we can do math @@ -191,7 +191,7 @@ def quantize_and_update( # Quantize the input if not use_direct_quant: - qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype, quanitze_only=quanitze_only) + qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype, quantize_only=quantize_only) return qx, new_scale, new_history return new_scale, new_history @@ -285,7 +285,7 @@ def out_qdq_bwd(compute_dtype, q_dtype, res, g): out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) -@partial(custom_vjp, nondiff_argnums=(0, 1, 2)) +@partial(custom_vjp, nondiff_argnums=(0, 1)) def in_q(compute_dtype, q_dtype, inp, scale, amax_history): qin, new_scale, new_history = quantize_and_update( inp, q_dtype, scale, amax_history, compute_dtype, quantize_only=True @@ -298,10 +298,10 @@ def in_q_fwd(compute_dtype, q_dtype, inp, scale, amax_history): ) return (qin, new_scale), (new_scale, new_history) -def in_q_bwd(compute_dtype, q_dtype, inp, res, g): +def in_q_bwd(compute_dtype, q_dtype, res, _): new_scale, new_history = res # We don't compute gradients for inp, scale and amax_history, but we pass through scale and history - return new_scale, new_history + return None, new_scale, new_history in_q.defvjp(in_q_fwd, in_q_bwd) @@ -421,7 +421,7 @@ def __call__(self, *args, **kwargs): ) q_x, new_input_scale = in_q( - comp_dtype, jnp.float8_e4m3fn, x, self.input_scale.value, self.input_amax_history) + comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value) y = one_sided_q_dot_dq( x, @@ -436,23 +436,8 @@ def __call__(self, *args, **kwargs): dimension_numbers, preferred_element_type=x.dtype ) - return y - - # y = q_dot_dq( - # x, - # k, - # self.input_scale.value, - # self.kernel_scale.value, - # self.output_grad_scale.value, - # self.input_amax_history.value, - # self.kernel_amax_history.value, - # self.output_grad_amax_history.value, - # comp_dtype, - # dimension_numbers, - # preferred_element_type=x.dtype - # ) - - # return y # type: ignore + return y # type: ignore + def one_sided_q_dot_dq_impl( lhs, q_lhs, @@ -492,6 +477,7 @@ def one_sided_q_dot_dq_impl( res = ( lhs, q_lhs, + lhs_scale, rhs, q_rhs, new_rhs_scale, @@ -503,7 +489,7 @@ def one_sided_q_dot_dq_impl( else: return out -@partial(custom_vjp, nondiff_argnums=(1,2,8, 9, 10,11)) +@partial(custom_vjp, nondiff_argnums=(8, 9, 10, 11)) def one_sided_q_dot_dq( lhs, q_lhs, @@ -551,6 +537,7 @@ def one_sided_q_dot_dq_fwd( return one_sided_q_dot_dq_impl( lhs, q_lhs, + lhs_scale, rhs, rhs_scale, out_grad_scale, @@ -564,8 +551,6 @@ def one_sided_q_dot_dq_fwd( ) def one_sided_q_dot_dq_bwd( - q_lhs, - lhs_scale, compute_dtype, dimension_numbers, precision, @@ -576,6 +561,7 @@ def one_sided_q_dot_dq_bwd( ( lhs, q_lhs, + lhs_scale, rhs, q_rhs, new_rhs_scale, @@ -617,6 +603,8 @@ def one_sided_q_dot_dq_bwd( return ( grad_lhs, + None, + None, grad_rhs, new_rhs_scale, new_out_grad_scale, From c9b0fc5be1249b95f7c04687a9908e0eb62056f7 Mon Sep 17 00:00:00 2001 From: shuw Date: Mon, 23 Sep 2024 15:26:58 -0700 Subject: [PATCH 4/7] pass test --- flax/linen/fp8_ops.py | 48 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 324603b20..a28ba7034 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -350,6 +350,41 @@ def _parse_dot_inputs(*args, **kwargs): x = jnp.asarray(x, comp_dtype) return x, k, dimension_numbers, comp_dtype +# Convenience wrappers for the quantize-dot-dequantize +def q_dot_dq( + lhs, + rhs, + lhs_scale, + rhs_scale, + out_grad_scale, + lhs_amax_history, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision=None, + preferred_element_type=None +): + q_lhs, new_lhs_scale = in_q( + compute_dtype, jnp.float8_e4m3fn, lhs, lhs_scale, lhs_amax_history + ) + + y = one_sided_q_dot_dq( + lhs, + q_lhs, + new_lhs_scale, # actualy new lhs scale + rhs, + rhs_scale, + out_grad_scale, + rhs_amax_history, + out_grad_amax_history, + compute_dtype, + dimension_numbers, + precision, + preferred_element_type + ) + return y # type: ignore + class Fp8DotGeneralBase(module.Module): amax_history_length: int = 1024 @@ -419,23 +454,20 @@ def __call__(self, *args, **kwargs): x, k, dimension_numbers, comp_dtype = _parse_dot_inputs( *args, **kwargs ) - - q_x, new_input_scale = in_q( - comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value) - - y = one_sided_q_dot_dq( + y = q_dot_dq( x, - q_x, - new_input_scale, # actualy new lhs scale k, + self.input_scale.value, self.kernel_scale.value, self.output_grad_scale.value, + self.input_amax_history.value, self.kernel_amax_history.value, self.output_grad_amax_history.value, comp_dtype, dimension_numbers, - preferred_element_type=x.dtype + preferred_element_type=x.dtype, ) + return y # type: ignore def one_sided_q_dot_dq_impl( From 9d0f844651fc9ae38820b127caf0ff86f81f4497 Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 24 Sep 2024 10:24:31 -0700 Subject: [PATCH 5/7] fix quantize with scale type --- flax/linen/fp8_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index a28ba7034..cf6ee671d 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -184,16 +184,16 @@ def quantize_and_update( new_scale = compute_scale(amax_from_history, scale, dtype_max) new_history = compute_amax_history(x, amax_history) + if not use_direct_quant: + qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype) + # convert f32->fmax32 so the autodiff system accumulates fp8 meta correctly if is_fmax32: new_history = lax.convert_element_type(new_history, fp32_max_grad) new_scale = lax.convert_element_type(new_scale, fp32_max_grad) - # Quantize the input if not use_direct_quant: - qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype, quantize_only=quantize_only) return qx, new_scale, new_history - return new_scale, new_history From de8d158084454dedf5513c58eaa4978ed8e8ba42 Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 24 Sep 2024 11:52:16 -0700 Subject: [PATCH 6/7] fix scale dtype and add decorator --- flax/linen/fp8_ops.py | 38 +++++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index cf6ee671d..49b892106 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -350,6 +350,11 @@ def _parse_dot_inputs(*args, **kwargs): x = jnp.asarray(x, comp_dtype) return x, k, dimension_numbers, comp_dtype +def _fm32_to_float32(value): + if value.dtype == fm32: + return lax.convert_element_type(value, jnp.float32) + return value + # Convenience wrappers for the quantize-dot-dequantize def q_dot_dq( lhs, @@ -385,6 +390,29 @@ def q_dot_dq( ) return y # type: ignore +# This decorator wraps a function to perform one-sided quantized dot product and dequantization. +# It prepares the arguments for fp8_ops.one_sided_q_dot_dq, including the pre-quantized input, +# scales, and amax histories. This allows for efficient FP8 matrix multiplication while +# managing quantization parameters. +def one_sided_q_dot_dq_config(comp_dtype, q_x, input_scale, kernel_scale, out_grad_scale, kernel_amax_history, out_grad_amax_history): + def decorator(func): + def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): + return one_sided_q_dot_dq( + lhs=lhs, + q_lhs=q_x, + lhs_scale=input_scale, + rhs=rhs, + rhs_scale=kernel_scale, + out_grad_scale=out_grad_scale, + rhs_amax_history=kernel_amax_history, + out_grad_amax_history=out_grad_amax_history, + compute_dtype=comp_dtype, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type + ) + return wrapper + return decorator class Fp8DotGeneralBase(module.Module): amax_history_length: int = 1024 @@ -494,7 +522,7 @@ def one_sided_q_dot_dq_impl( use_direct_quant=True ) - q_rhs = quantize(rhs, jnp.float8_e4m3fn, new_rhs_scale, preferred_element_type) + q_rhs = quantize(rhs, jnp.float8_e4m3fn, _fm32_to_float32(new_rhs_scale), preferred_element_type) out = lax.dot_general( q_lhs, @@ -504,7 +532,7 @@ def one_sided_q_dot_dq_impl( precision=lax.Precision.DEFAULT, ) - out = dequantize(out, preferred_element_type, new_rhs_scale * lhs_scale) + out = dequantize(out, preferred_element_type, _fm32_to_float32(new_rhs_scale) * _fm32_to_float32(lhs_scale)) if is_training: res = ( lhs, @@ -611,7 +639,7 @@ def one_sided_q_dot_dq_bwd( use_direct_quant=True ) - q_g = quantize(g, jnp.float8_e5m2, new_out_grad_scale, preferred_element_type) + q_g = quantize(g, jnp.float8_e5m2, _fm32_to_float32(new_out_grad_scale), preferred_element_type) grad_lhs = dot_general_transpose_lhs( q_g, @@ -621,7 +649,7 @@ def one_sided_q_dot_dq_bwd( precision=lax.Precision.HIGHEST, preferred_element_type=preferred_element_type, ) - grad_lhs = dequantize(grad_lhs, preferred_element_type, new_rhs_scale * new_out_grad_scale) + grad_lhs = dequantize(grad_lhs, preferred_element_type, _fm32_to_float32(new_rhs_scale) * _fm32_to_float32(new_out_grad_scale)) grad_rhs = dot_general_transpose_rhs( q_g, @@ -631,7 +659,7 @@ def one_sided_q_dot_dq_bwd( precision=lax.Precision.HIGHEST, preferred_element_type=preferred_element_type, ) - grad_rhs = dequantize(grad_rhs, preferred_element_type, lhs_scale * new_out_grad_scale) + grad_rhs = dequantize(grad_rhs, preferred_element_type, _fm32_to_float32(lhs_scale) * _fm32_to_float32(new_out_grad_scale)) return ( grad_lhs, From 2bf769480c3281a9b7393b0f21b7ea67c492bcae Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 24 Sep 2024 12:17:31 -0700 Subject: [PATCH 7/7] all in decorator --- flax/linen/fp8_ops.py | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 49b892106..b6643191d 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -390,11 +390,40 @@ def q_dot_dq( ) return y # type: ignore +# This decorator wraps a function to perform quantized dot product and dequantization. +# It prepares the arguments for q_dot_dq. +def q_dot_dq_config( + lhs_scale, rhs_scale, out_grad_scale, + lhs_amax_history, rhs_amax_history, out_grad_amax_history, + compute_dtype +): + def decorator(func): + def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): + return q_dot_dq( + lhs, + rhs, + lhs_scale=lhs_scale, + rhs_scale=rhs_scale, + out_grad_scale=out_grad_scale, + lhs_amax_history=lhs_amax_history, + rhs_amax_history=rhs_amax_history, + out_grad_amax_history=out_grad_amax_history, + compute_dtype=compute_dtype, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + ) + return wrapper + return decorator + # This decorator wraps a function to perform one-sided quantized dot product and dequantization. -# It prepares the arguments for fp8_ops.one_sided_q_dot_dq, including the pre-quantized input, +# It prepares the arguments for one_sided_q_dot_dq, including the pre-quantized input, # scales, and amax histories. This allows for efficient FP8 matrix multiplication while # managing quantization parameters. -def one_sided_q_dot_dq_config(comp_dtype, q_x, input_scale, kernel_scale, out_grad_scale, kernel_amax_history, out_grad_amax_history): +def one_sided_q_dot_dq_config( + compute_dtype, q_x, input_scale, kernel_scale, out_grad_scale, + kernel_amax_history, out_grad_amax_history +): def decorator(func): def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None): return one_sided_q_dot_dq( @@ -406,7 +435,7 @@ def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type= out_grad_scale=out_grad_scale, rhs_amax_history=kernel_amax_history, out_grad_amax_history=out_grad_amax_history, - compute_dtype=comp_dtype, + compute_dtype=compute_dtype, dimension_numbers=dimension_numbers, precision=precision, preferred_element_type=preferred_element_type @@ -414,6 +443,7 @@ def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type= return wrapper return decorator + class Fp8DotGeneralBase(module.Module): amax_history_length: int = 1024 e4m3_dtype: DType = jnp.float8_e4m3fn