diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index b43c644a51..66eea09cb2 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1074,6 +1074,7 @@ def fp8_gemm_impl( bias: Optional[ArrayLike] = None, gelu_input: Optional[ArrayLike] = None, out: Optional[ArrayLike] = None, + extra_out: Optional[ArrayLike] = None, out_amax: Optional[ArrayLike] = None, out_scale: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 59d1045080..1a275ceed7 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -209,27 +209,30 @@ def _gemm_bwd_rule( mirror_dim, (x_inner_dim, kernel_inner_dim), (x.ndim, kernel.ndim) ) + # Recover DGRAD and WGRAD comm+GEMM overlap configs + dgrad_overlap_name = None dgrad_overlap_config = None + wgrad_overlap_name = None wgrad_overlap_config = None - dgrad_pre_rs = None if comm_overlap_config is not None: dgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_dgrad" dgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(dgrad_overlap_name, None) - if ( - dgrad_overlap_config["method"] == "bulk" - and dgrad_overlap_config["comm_type"] == tex.CommOverlapType.AG - ): - # If DGRAD is bulk overlap, copy input X into comm buffer to be all-gathered in - # preparation for WGRAD. - wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" - wgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(wgrad_overlap_name, None) - assert wgrad_overlap_config is not None, "Internal TE error!" - copy_into_overlap_buffer(x, dgrad_overlap_name, True) + wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" + wgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(wgrad_overlap_name, None) + dgrad_pre_rs = None + if dgrad_overlap_config is not None: + if dgrad_overlap_config["method"] == "bulk": # Set DGRAD output buffer to the comm buffer of WGRAD GEMM in order to do the - # bulk RS overlap without an extra memcpy + # bulk RS overlap without an extra memcpy. + assert wgrad_overlap_config is not None, ( + f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!" + ) dgrad_pre_rs = tex.get_overlap_buffer(wgrad_overlap_name, False) + # Copy transposed input into the DGRAD overlap buffer for bulk AG. + copy_into_overlap_buffer(jnp.matrix_transpose(x), dgrad_overlap_name, True) + # FWD MODE: # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) # @@ -246,7 +249,7 @@ def _gemm_bwd_rule( # AG+GEMM w/ DGRAD+RS Overlap: ([B], M, N/P) x (K, N/P)^T ---(RS)---> ([B], M/P, K) # # AG+GEMM w/ Bulk AG Overlap: ([B], M, N/P) x (K, N/P)^T -----> ([B], M, K) (deferred RS) - # ([B], M, K/P) --(Bulk AG)--> ([B], M, K) (needed in WGRAD) + # ([B], M, K/P)^T --(Bulk AG)--> ([B], M, K)^T (needed in WGRAD) # # GEMM+RS: ([B], M/P, N) --(AG)--> ([B], M, N) x (K/P, N)^T ----> ([B], M, K/P) dgrad, dgelu, _, dgrad_extra_out = gemm_impl( @@ -272,13 +275,14 @@ def _gemm_bwd_rule( # Otherwise, if DGRAD overlap is RS overlap, DGRAD output is the extra output tensor dgrad = dgrad_extra_out + # WGRAD w/o Overlap: # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) # # GEMM+AR: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N) ---------> (K, N) # # WGRAD w/ Overlap: - # AG+GEMM w/ DGRAD+RS Overlap: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) + # AG+GEMM w/ DGRAD+RS Overlap: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) # # AG+GEMM w/ Bulk Overlaps: ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) # ([B], M, K) --(Bulk RS)--> ([B], M/P, K) (finalize DGRAD) @@ -299,7 +303,11 @@ def _gemm_bwd_rule( comm_overlap_config=wgrad_overlap_config, ) - if wgrad_overlap_config is not None: + if ( + wgrad_overlap_config is not None + and wgrad_overlap_config["method"] == "bulk" + and wgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): # DGRAD was reduce-scattered during WGRAD GEMM, so set DGRAD to WGRAD extra output here dgrad = wgrad_extra_out @@ -317,6 +325,7 @@ def fp8_gemm( kernel_t: ArrayLike, fp8_meta: FP8MetaPackage, bias: Optional[ArrayLike] = None, + out: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, fuse_gelu: bool = False, accumulate: bool = False, @@ -340,10 +349,12 @@ def fp8_gemm( FP8MetaPackage object carrying amax, scale and scale_inv information for the GEMM operands. bias : Optional[ArrayLike], default = `None` Optional bias term to add onto the (LHS x RHS) result. + out: Optional[ArrayLike], default = `None` + Optional empty buffer for FP8 GEMM output. out_dtype : jnp.dtype, default = `jnp.bfloat16` Data type of the FP8 GEMM output. If chosen as an FP8 dtype (i.e. `jnp.float8_e4m3fn` or `jnp.float8_e5m2`), the `fp8_meta` must also contain amax and scale information for the - GEMM output. + GEMM output. This option is overridden by the data type of the `out` buffer, if given. fuse_gelu : bool, default = `False` Enable the GELU epilogue for GEMM. This applies GELU after the bias-addition if the bias term is not `None`. @@ -389,13 +400,14 @@ def fp8_gemm( ) -@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) +@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10)) def _fp8_gemm( x: ArrayLike, kernel_t: ArrayLike, bias: ArrayLike, amax_list: ArrayLike, scale_list: ArrayLike, + out: ArrayLike, out_dtype: jnp.dtype, fuse_gelu: bool, accumulate: bool, @@ -501,14 +513,14 @@ def _fp8_gemm_fwd_rule( buffer_scale_inv = None if comm_overlap_config is not None: overlap_name = comm_overlap_config["name"] - if comm_overlap_config["method"] != "bulk" and tex.overlap_buffer_is_fp8(overlap_name): - match comm_overlap_config["comm_type"]: - case tex.CommOverlapType.AG: - buffer_scale_inv = x_scale_inv + if comm_overlap_config["comm_type"] == tex.CommOverlapType.AG: + buffer_scale_inv = x_scale_inv - case tex.CommOverlapType.RS: - buffer_scale_inv = jnp.reciprocal(out_scale) + elif comm_overlap_config["comm_type"] == tex.CommOverlapType.RS: + out_dtype = fwd_dtype + out_scale = scale_list[FP8MetaPackage.OUTPUT_IDX][0:1] + buffer_scale_inv = jnp.reciprocal(out_scale) tex.set_overlap_buffer_scale_inverse( overlap_name, @@ -531,9 +543,6 @@ def _fp8_gemm_fwd_rule( use_split_accumulator=use_split_accumulator, comm_overlap_config=comm_overlap_config, ) - if not jax_dtype_is_fp8(out_dtype): - updated_out_amax = None - updated_out_scale = None # Update returned and saved arrays based on comm+GEMM overlap config final_out = out @@ -542,6 +551,10 @@ def _fp8_gemm_fwd_rule( # RS overlap puts the reduce-scattered sharded output into extra_out final_out = extra_out + if not jax_dtype_is_fp8(final_out): + updated_out_amax = None + updated_out_scale = None + ctx = ( casted_x_t, casted_kernel, @@ -583,9 +596,21 @@ def _fp8_gemm_bwd_rule( maybe_fp32_to_fm32, batched_input, ) = ctx - + del out_dtype bwd_dtype = FP8Helper.BWD_DTYPE + # Recover DGRAD and WGRAD comm+GEMM overlap configs + dgrad_overlap_name = None + dgrad_overlap_config = None + wgrad_overlap_name = None + wgrad_overlap_config = None + if comm_overlap_config is not None: + dgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_dgrad" + dgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(dgrad_overlap_name, None) + wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" + wgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(wgrad_overlap_name, None) + + # Cast-transpose grad with potential fusions grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1] grad_scale = scale_list[FP8MetaPackage.GRAD_IDX] grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_ID] @@ -633,28 +658,29 @@ def _fp8_gemm_bwd_rule( ) bgrad = None - # Recover dgrad comm+GEMM overlap config - dgrad_overlap_config = None - if comm_overlap_config is not None: - dgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_dgrad" - dgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(dgrad_overlap_name, None) - # Set scale_inv for comm overlap buffer - dgrad_out_dtype = jnp.bfloat16 dgrad_amax = None dgrad_scale = None - if ( - dgrad_overlap_config is not None - and dgrad_overlap_config["method"] != "bulk" - and tex.overlap_buffer_is_fp8(dgrad_overlap_name) - ): - dgrad_out_dtype = bwd_dtype - dgrad_amax = grad_amax - dgrad_scale = grad_scale - tex.set_overlap_buffer_scale_inverse( - dgrad_overlap_name, - jax.dlpack.to_dlpack(grad_scale_inv), - ) + if dgrad_overlap_config is not None: + if dgrad_overlap_config["method"] == "bulk": + assert wgrad_overlap_config is not None, ( + f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!" + ) + # Set WGRAD buffer as output of DGRAD in order to avoid a memcpy for bulk RS overlap + dgrad_pre_rs = jax.dlpack.from_dlpack( + tex.get_overlap_buffer(wgrad_overlap_name, False) + ) + # Copy input into overlap buffer for all-gather + copy_into_overlap_buffer(casted_x_t, dgrad_overlap_name, True) + + elif tex.overlap_buffer_is_fp8(dgrad_overlap_name): + # Non-bulk RS DGRAD overlap needs output amax and scale if buffer type is FP8 + dgrad_amax = grad_amax + dgrad_scale = grad_scale + tex.set_overlap_buffer_scale_inverse( + dgrad_overlap_name, + jax.dlpack.to_dlpack(grad_scale_inv), + ) # DGRAD: ([B], M, N) x (K, N)^T = ([B], M, K) kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] @@ -663,11 +689,9 @@ def _fp8_gemm_bwd_rule( grad_scale_inv, casted_kernel, kernel_scale_inv, - None, - None, - dgrad_amax, - dgrad_scale, - out_dtype=dgrad_out_dtype, + out=dgrad_pre_rs, + out_amax=dgrad_amax, + out_scale=dgrad_scale, batched_output=batched_input, accumulate=accumulate, use_split_accumulator=use_split_accumulator, @@ -682,65 +706,29 @@ def _fp8_gemm_bwd_rule( ): dgrad = dgrad_extra_out - if fuse_gelu and fuse_bias: - # Fuse bgrad with dGELU. - _, casted_dgelu_t, bgrad, updated_grad_amax = dact_lu_dbias_cast_transpose( - grad, - pre_gelu_out, - grad_amax, - grad_scale, - grad_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - activation_type=("gelu",), - ) - elif fuse_gelu: - # No bias grad to fuse so we just do dGELU. - _, casted_dgelu_t, updated_grad_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) - bgrad = None - - # Recover wgrad config - wgrad_overlap_config = None - if comm_overlap_config is not None: - wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" - wgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(wgrad_overlap_name, None) + # Prepare comm+GEMM overlap for WGRAD + if wgrad_overlap_config is not None: + if wgrad_overlap_config["method"] == "bulk": + # Get all-gathered input from DGRAD bulk overlap + casted_x_t = jax.dlpack.from_dlpack( + tex.get_overlap_buffer(dgrad_overlap_name, False) + ) - # Set scale_inv for comm overlap buffer - wgrad_out_dtype = jnp.bfloat16 - wgrad_amax = None - wgrad_scale = None - if ( - wgrad_overlap_config is not None - and wgrad_overlap_config["method"] != "bulk" - and tex.overlap_buffer_is_fp8(wgrad_overlap_name) - ): - match wgrad_overlap_config["comm_type"]: - case tex.CommOverlapType.AG: - buffer_scale_inv = x_scale_inv - case tex.CommOverlapType.RS: - buffer_scale_inv = grad_scale_inv - wgrad_out_dtype = bwd_dtype - wgrad_amax = grad_amax - wgrad_scale = grad_scale - tex.set_overlap_buffer_scale_inverse( - dgrad_overlap_name, - jax.dlpack.to_dlpack(buffer_scale_inv), - ) + elif tex.overlap_buffer_is_fp8(wgrad_overlap_name): + # Set FP8 scale inverse for non-bulk AG overlap + tex.set_overlap_buffer_scale_inverse( + wgrad_overlap_name, + jax.dlpack.to_dlpack(x_scale_inv) + ) # WGRAD: ([B], N, M) x ([B], K, M)^T = (N, K) - wgrad_rhs_t = casted_dgelu_t if fuse_gelu else casted_grad_t x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] wgrad, *_, wgrad_extra_out = fp8_gemm_impl( casted_x_t, x_scale_inv, - wgrad_rhs_t, + casted_grad_t, grad_scale_inv, - None, - None, - wgrad_amax, - wgrad_scale, - out_dtype=wgrad_out_dtype, + out_dtype=jnp.bfloat16, batched_output=False, accumulate=accumulate, use_split_accumulator=use_split_accumulator, @@ -753,7 +741,7 @@ def _fp8_gemm_bwd_rule( and wgrad_overlap_config["method"] != "bulk" and wgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS ): - wgrad = wgrad_extra_out + dgrad = wgrad_extra_out amax_list[FP8MetaPackage.INPUT_IDX] = ( amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0]) @@ -764,7 +752,7 @@ def _fp8_gemm_bwd_rule( amax_list[FP8MetaPackage.GRAD_IDX] = ( amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0]) ) - if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + if updated_out_amax is not None: amax_list[FP8MetaPackage.OUTPUT_IDX] = ( amax_list[FP8MetaPackage.OUTPUT_IDX].at[0].set(updated_out_amax[0]) ) @@ -782,8 +770,9 @@ def type_safe_gemm( x: ArrayLike, kernel: ArrayLike, bias: Optional[ArrayLike] = None, - fp8_meta: Optional[FP8MetaPackage] = None, + out: Optional[ArrayLike] = None, out_dtype: Optional[jnp.dtype] = None, + fp8_meta: Optional[FP8MetaPackage] = None, contracting_dims: Tuple[int, int] = (-1, -2), fuse_gelu: bool = False, accumulate: bool = False, @@ -802,24 +791,25 @@ def type_safe_gemm( return fp8_gemm( x, kernel, - bias, fp8_meta, - out_dtype, - fuse_gelu, - accumulate, - use_split_accumulator, - comm_overlap_name, + bias=bias, + out=out, + out_dtype=out_dtype, + fuse_gelu=fuse_gelu, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_name=comm_overlap_name, ) else: return gemm( x, kernel, - bias, - contracting_dims, - fuse_gelu, - accumulate, - use_split_accumulator, - comm_overlap_name, + bias=bias, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_name=comm_overlap_name, )