Skip to content

Commit

Permalink
[JAX] Bug fix for distributed normalization (#1366)
Browse files Browse the repository at this point in the history
* fix ctx.aval_out indexing for workspace
* add cudnn init to prepare phase of norm custom calls
* add thread_local for norm registry instance
---------

Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng authored Dec 12, 2024
1 parent e4c99b0 commit 0e1d9fa
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 14 deletions.
3 changes: 1 addition & 2 deletions transformer_engine/common/normalization/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {

class NormalizationPlanRegistry {
public:
// TODO thread-safe
static NormalizationPlanRegistry& getInstance() {
static NormalizationPlanRegistry instance;
static thread_local NormalizationPlanRegistry instance;
return instance;
}

Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/jax/cpp_extensions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size

wkspace_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-1]

out_types = [
ir.RankedTensorType.get(out_shape, output_type),
Expand Down Expand Up @@ -441,7 +441,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon):

sm_margin = get_backward_sm_margin()

wkspace_aval = ctx.avals_out[-4:]
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
Expand Down Expand Up @@ -650,7 +650,7 @@ def lowering(ctx, x, gamma, *, epsilon):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size

wkspace_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-1]

out_types = [
ir.RankedTensorType.get(out_shape, x_type.element_type),
Expand Down Expand Up @@ -841,7 +841,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon):
hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size

wkspace_aval = ctx.avals_out[-3:]
wkspace_aval = ctx.avals_out[-1]

out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type),
Expand Down Expand Up @@ -1088,7 +1088,7 @@ def lowering(
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size

wkspace_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-1]

out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
Expand Down Expand Up @@ -1394,7 +1394,7 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size

wkspace_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-1]

out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
Expand Down
24 changes: 18 additions & 6 deletions transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,24 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler);

// Normalization
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler);
dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler);
dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler);
dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler);
dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler);
dict["te_layernorm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler));
dict["te_layernorm_forward_fp8_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler));
dict["te_layernorm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler));
dict["te_rmsnorm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler));
dict["te_rmsnorm_forward_fp8_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler));
dict["te_rmsnorm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler));

// Attention
pybind11::dict fused_attn_forward_ffi;
Expand Down

0 comments on commit 0e1d9fa

Please sign in to comment.