From f4400944892304427b782ab047a5682a5b5430e9 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 14 Nov 2024 17:59:20 +0000 Subject: [PATCH] fixed batching rules to accommodated batched RHS operand for GEMM Signed-off-by: Alp Dener --- .../common/util/pybind_helper.h | 138 ++++++++++-------- transformer_engine/jax/cpp_extensions/gemm.py | 133 ++++++----------- .../jax/csrc/extensions/pybind.cpp | 59 +------- 3 files changed, 123 insertions(+), 207 deletions(-) diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 432ac815ec..a36ff3f0f9 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -8,72 +8,88 @@ #define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ #include +#include #include #include #include #include "cuda_runtime.h" -#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ - pybind11::enum_(m, "DType") \ - .value("kByte", transformer_engine::DType::kByte) \ - .value("kInt32", transformer_engine::DType::kInt32) \ - .value("kFloat32", transformer_engine::DType::kFloat32) \ - .value("kFloat16", transformer_engine::DType::kFloat16) \ - .value("kBFloat16", transformer_engine::DType::kBFloat16) \ - .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ - pybind11::enum_(m, "NVTE_Bias_Type") \ - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ - .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ - pybind11::enum_(m, "NVTE_Mask_Type") \ - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ - pybind11::enum_(m, "NVTE_QKV_Layout") \ - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ - pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ - pybind11::enum_(m, "CommOverlapType") \ - .value("RS", transformer_engine::CommOverlapType::RS) \ - .value("AG", transformer_engine::CommOverlapType::AG); \ - pybind11::enum_(m, "CommOverlapAlgo") \ - .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ - .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ - .value("SPLIT_PIPELINED_AG_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ - .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ - .value("SPLIT_PIPELINED_RS_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ - .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ - .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ - .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ - m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ - py::call_guard(), py::arg("device_id") = -1); \ - m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ - py::call_guard()); +#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ + pybind11::enum_(m, "DType") \ + .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt32", transformer_engine::DType::kInt32) \ + .value("kFloat32", transformer_engine::DType::kFloat32) \ + .value("kFloat16", transformer_engine::DType::kFloat16) \ + .value("kBFloat16", transformer_engine::DType::kBFloat16) \ + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + pybind11::enum_(m, "NVTE_Bias_Type") \ + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ + .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ + pybind11::enum_(m, "NVTE_Mask_Type") \ + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_QKV_Format") \ + .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ + .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ + .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD); \ + pybind11::enum_(m, "NVTE_QKV_Layout") \ + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "NVTE_Activation_Type") \ + .value("GELU", NVTE_Activation_Type::GELU) \ + .value("GEGLU", NVTE_Activation_Type::GEGLU) \ + .value("SILU", NVTE_Activation_Type::SILU) \ + .value("SWIGLU", NVTE_Activation_Type::SWIGLU) \ + .value("RELU", NVTE_Activation_Type::RELU) \ + .value("REGLU", NVTE_Activation_Type::REGLU) \ + .value("QGELU", NVTE_Activation_Type::QGELU) \ + .value("QGEGLU", NVTE_Activation_Type::QGEGLU) \ + .value("SRELU", NVTE_Activation_Type::SRELU) \ + .value("SREGLU", NVTE_Activation_Type::SREGLU); \ + pybind11::enum_(m, "CommOverlapType") \ + .value("RS", transformer_engine::CommOverlapType::RS) \ + .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapAlgo") \ + .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ + .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ + .value("SPLIT_PIPELINED_AG_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ + .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ + .value("SPLIT_PIPELINED_RS_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ + .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ + .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ + pybind11::call_guard(), pybind11::arg("device_id") = -1); \ + m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + pybind11::call_guard()); #endif diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 677fabca59..ceafce46e1 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -24,10 +24,10 @@ jax_dtype_is_fp8, get_padded_spec, is_ffi_enabled, + check_valid_batch_dims, ) from ..sharding import ( global_mesh_resource, - get_mesh_axis_size, lax_paral_op, all_reduce_max_along_all_axes_except_PP, ) @@ -83,9 +83,6 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av and dtypes.canonicalize_dtype(rhs_scale_inv_aval.dtype) == jnp.float32 ), "Missing RHS operand scale inverse in FP8 GEMM." - # Disallow batching for RHS - assert rhs_aval.ndim == 2, "GEMM does not support batching the RHS operand." - # Validate operand layouts lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, @@ -97,12 +94,12 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av ), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}." lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 - rhs_trans = rhs_inner_dim == 1 + rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 assert ( not (lhs_trans and rhs_trans) ), "GEMM does not support transposed LHS and transposed RHS at the same time." if is_fp8: - assert lhs_trans, "FP8 GEMM does not support transposed LHS." + assert not lhs_trans, "FP8 GEMM does not support transposed LHS." assert rhs_trans, "FP8 GEMM requires transposed RHS." # Validate output dtype @@ -124,11 +121,18 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av out_scale_updated_dtype = jnp.float32 # Infer output shape - rhs_outer_dim = 0 if rhs_trans else 1 lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 lhs_bdims = [dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] + lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) + rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 + rhs_bdims = [dim for dim in range(rhs_aval.ndim) + if dim not in [rhs_outer_dim, rhs_inner_dim]] + rhs_batch_size = reduce(operator.mul, rhs_bdims, 1) + assert ( + lhs_batch_size == rhs_batch_size + ), "LHS and RHS operands must have the same batched sizes." out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) # Validate bias/bias_grad shape against inferred output @@ -201,7 +205,7 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ (lhs_aval.ndim, rhs_aval.ndim) ) lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 - rhs_trans = rhs_inner_dim == 1 + rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 operand_output_aliases = { 4: 4, # bias <--> bias_grad @@ -248,12 +252,9 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ ] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - rhs_outer_dim = 0 if rhs_trans else 1 lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 - lhs_bdims = [dim for dim in range(lhs_aval.ndim) - if dim not in [lhs_outer_dim, lhs_inner_dim]] - lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] - m = reduce(operator.mul, lhs_batch_shape, 1) * lhs_aval.shape[lhs_outer_dim] + rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 + m = lhs_aval.shape[lhs_outer_dim] k = rhs_aval.shape[rhs_inner_dim] n = rhs_aval.shape[rhs_outer_dim] workspace_size = get_cublas_workspace_size_bytes() @@ -308,77 +309,32 @@ def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out def batcher(batched_args, batch_dims, *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, use_split_accumulator): assert CollectiveGemmPrimitive.outer_primitive is not None + check_valid_batch_dims(batch_dims) + lhs_bdims, *_, bias_bdims, gelu_input_bdims, out_amax_bdims, out_scale_bdims = batch_dims - lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale = batched_args - assert rhs.ndim == 2, "TE/JAX GEMM custom op does not support batching RHS operands." - - # Get contracting and batch dimensions out - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs.ndim, rhs.ndim) - ) - lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == 1 - lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 - rhs_outer_dim = 0 if rhs_trans else 1 - lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] - - # FP8 GEMM only supports lhs_trans = False and rhs_trans = True so we may need to - # reorder the axes here to match - if jax_dtype_is_fp8(lhs.dtype): - lhs = jnp.transpose(lhs, (*lhs_bdims, lhs_outer_dim, lhs_inner_dim)) - lhs_trans = False - rhs = jnp.transpose(rhs, (rhs_outer_dim, rhs_inner_dim)) - rhs_trans = True - contracting_dims = (1, 1) - - # Collapse all non-contracting dimensions - batch_shape = [lhs.shape[dim] for dim in lhs_bdims] - batch_size = reduce(operator.mul, batch_shape, 1) - lhs_outer_size = lhs.shape[lhs_outer_dim] - lhs_shape_2d = ( - (lhs.shape[lhs_inner_dim], batch_size * lhs_outer_size) - if lhs_trans - else (batch_size * lhs_outer_size, lhs.shape[lhs_inner_dim]) - ) - lhs = jnp.reshape(lhs, lhs_shape_2d) - if fuse_gelu: - gelu_input = jnp.reshape( - gelu_input, (batch_size * lhs_outer_size, rhs.shape[rhs_outer_dim]) - ) - - outputs = CollectiveGemmPrimitive.outer_primitive.bind( - lhs, - lhs_scale_inv, - rhs, - rhs_scale_inv, - bias, - gelu_input, - out_amax, - out_scale, - out_dtype=out_dtype, - contracting_dims=contracting_dims, - fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, - grad=grad, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - ) - - # Reshape output to recover original LHS batch shape - outputs[0] = jnp.reshape( - outputs[0], - (*batch_shape, lhs_outer_size, rhs.shape[rhs_outer_dim]) - ) - gelu_bdims = batch_dims[3] - if fuse_gelu: - outputs[3] = jnp.reshape(outputs[3], outputs[0].shape) - gelu_bdims = lhs_bdims + # FP8 GEMM only supports non-transposed LHS and transposed RHS + lhs, _, rhs, *_ = batched_args + lhs_trans = contracting_dims[0] != lhs.ndim - 1 + rhs_trans = contracting_dims[1] == rhs.ndim - 1 + lhs = jnp.matrix_transpose(lhs) if lhs_trans and jax_dtype_is_fp8(lhs.dtype) else lhs + rhs = jnp.matrix_transpose(rhs) if not rhs_trans and jax_dtype_is_fp8(rhs.dtype) else rhs + contracting_dims = (1, 1) return ( - outputs, - (lhs_bdims, batch_dims[1], batch_dims[2], gelu_bdims, batch_dims[4]) + CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + batched_args[1], + rhs, + *batched_args[3:], + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) ) @staticmethod @@ -400,9 +356,9 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bi + "not already partitioned correctly.") lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == 1 + rhs_trans = rhs_inner_dim == rhs.ndim - 1 lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 - rhs_outer_dim = 0 if rhs_trans else 1 + rhs_outer_dim = rhs.ndim - 2 if rhs_trans else rhs.ndim - 1 lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] rhs_outer_spec = rhs_spec[rhs_outer_dim] @@ -440,9 +396,9 @@ def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulat ) lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == 1 + rhs_trans = rhs_inner_dim == rhs.ndim - 1 lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 - rhs_outer_dim = 0 if rhs_trans else 1 + rhs_outer_dim = rhs.ndim - 2 if rhs_trans else rhs.ndim - 1 lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] rhs_outer_spec = rhs_spec[rhs_outer_dim] @@ -558,7 +514,7 @@ def fp8_gemm_impl( gelu_input = jnp.zeros(0, dtype=bias.dtype) elif gelu_input is None: lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + rhs_outer_dim = rhs.ndim - 2 if contracting_dims[1] == 0 else rhs.ndim - 1 out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) gelu_input = jnp.zeros(out_shape, dtype=bias.dtype) @@ -599,7 +555,7 @@ def gemm_impl( dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + rhs_outer_dim = rhs.ndim - 2 if contracting_dims[1] == 0 else rhs.ndim - 1 out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) if not fuse_bias: @@ -618,9 +574,6 @@ def gemm_impl( gelu_input is not None ), "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." elif gelu_input is None: - lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 - out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) gelu_input = jnp.zeros(out_shape, dtype=lhs.dtypes) out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 7b8ebdcdd2..ddf98d9d78 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include "common/util/pybind_helper.h" #include "extensions.h" namespace transformer_engine { @@ -107,6 +108,8 @@ pybind11::dict Registrations() { } PYBIND11_MODULE(transformer_engine_jax, m) { + NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) + m.def("registrations", &Registrations); m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor, pybind11::arg(), pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0); @@ -129,62 +132,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); - - pybind11::enum_(m, "DType", pybind11::module_local()) - .value("kByte", DType::kByte) - .value("kInt32", DType::kInt32) - .value("kInt64", DType::kInt64) - .value("kFloat32", DType::kFloat32) - .value("kFloat16", DType::kFloat16) - .value("kBFloat16", DType::kBFloat16) - .value("kFloat8E4M3", DType::kFloat8E4M3) - .value("kFloat8E5M2", DType::kFloat8E5M2); - - pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - - pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - - pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); - - pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) - .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) - .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) - .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD); - - pybind11::enum_(m, "NVTE_Activation_Type", pybind11::module_local()) - .value("GELU", NVTE_Activation_Type::GELU) - .value("GEGLU", NVTE_Activation_Type::GEGLU) - .value("SILU", NVTE_Activation_Type::SILU) - .value("SWIGLU", NVTE_Activation_Type::SWIGLU) - .value("RELU", NVTE_Activation_Type::RELU) - .value("REGLU", NVTE_Activation_Type::REGLU) - .value("QGELU", NVTE_Activation_Type::QGELU) - .value("QGEGLU", NVTE_Activation_Type::QGEGLU) - .value("SRELU", NVTE_Activation_Type::SRELU) - .value("SREGLU", NVTE_Activation_Type::SREGLU) - .export_values(); - - pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8); } } // namespace jax