diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h
index 8a8df63ba4..d1d56d5cc9 100644
--- a/transformer_engine/common/normalization/common.h
+++ b/transformer_engine/common/normalization/common.h
@@ -287,9 +287,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
 
 class NormalizationPlanRegistry {
  public:
-  // TODO thread-safe
   static NormalizationPlanRegistry& getInstance() {
-    static NormalizationPlanRegistry instance;
+    static thread_local NormalizationPlanRegistry instance;
     return instance;
   }
 
diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py
index 0b7df0b5a8..69d7962b62 100644
--- a/transformer_engine/jax/cpp_extensions/normalization.py
+++ b/transformer_engine/jax/cpp_extensions/normalization.py
@@ -147,7 +147,7 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon):
             batch_shape = out_shape[:-1]
             batch_size = reduce(operator.mul, x_shape) // hidden_size
 
-            wkspace_aval = ctx.avals_out[-2:]
+            wkspace_aval = ctx.avals_out[-1]
 
             out_types = [
                 ir.RankedTensorType.get(out_shape, output_type),
@@ -441,7 +441,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon):
 
             sm_margin = get_backward_sm_margin()
 
-            wkspace_aval = ctx.avals_out[-4:]
+            wkspace_aval = ctx.avals_out[-1]
             opaque = transformer_engine_jax.pack_norm_descriptor(
                 batch_size,
                 hidden_size,
@@ -650,7 +650,7 @@ def lowering(ctx, x, gamma, *, epsilon):
             batch_shape = out_shape[:-1]
             batch_size = reduce(operator.mul, x_shape) // hidden_size
 
-            wkspace_aval = ctx.avals_out[-2:]
+            wkspace_aval = ctx.avals_out[-1]
 
             out_types = [
                 ir.RankedTensorType.get(out_shape, x_type.element_type),
@@ -841,7 +841,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon):
             hidden_size = reduce(operator.mul, g_shape)
             batch_size = reduce(operator.mul, x_shape) // hidden_size
 
-            wkspace_aval = ctx.avals_out[-3:]
+            wkspace_aval = ctx.avals_out[-1]
 
             out_types = [
                 ir.RankedTensorType.get(x_shape, x_type.element_type),
@@ -1088,7 +1088,7 @@ def lowering(
             batch_shape = out_shape[:-1]
             batch_size = reduce(operator.mul, x_shape) // hidden_size
 
-            wkspace_aval = ctx.avals_out[-2:]
+            wkspace_aval = ctx.avals_out[-1]
 
             out_types = [
                 ir.RankedTensorType.get(out_shape, ir_out_dtype),
@@ -1394,7 +1394,7 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon):
             batch_shape = out_shape[:-1]
             batch_size = reduce(operator.mul, x_shape) // hidden_size
 
-            wkspace_aval = ctx.avals_out[-2:]
+            wkspace_aval = ctx.avals_out[-1]
 
             out_types = [
                 ir.RankedTensorType.get(out_shape, ir_out_dtype),
diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp
index 9b5c156e5d..a319b74d76 100644
--- a/transformer_engine/jax/csrc/extensions/pybind.cpp
+++ b/transformer_engine/jax/csrc/extensions/pybind.cpp
@@ -83,12 +83,24 @@ pybind11::dict Registrations() {
       EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
 
   // Normalization
-  dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
-  dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler);
-  dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler);
-  dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler);
-  dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler);
-  dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler);
+  dict["te_layernorm_forward_ffi"] =
+      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
+                     pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler));
+  dict["te_layernorm_forward_fp8_ffi"] =
+      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
+                     pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler));
+  dict["te_layernorm_backward_ffi"] =
+      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
+                     pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler));
+  dict["te_rmsnorm_forward_ffi"] =
+      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
+                     pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler));
+  dict["te_rmsnorm_forward_fp8_ffi"] =
+      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
+                     pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler));
+  dict["te_rmsnorm_backward_ffi"] =
+      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
+                     pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler));
 
   // Attention
   pybind11::dict fused_attn_forward_ffi;