Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance gap between manual nvfuser definition and thunder.jit for rmsnorm #3629

Open
Priya2698 opened this issue Dec 20, 2024 · 21 comments

Comments

@Priya2698
Copy link
Collaborator

Priya2698 commented Dec 20, 2024

I am seeing lower performance for thunder.jit (with nvfuserex executor) than the manual nvfuser definition existent in the python benchmark suite: http://nv/etb. This came up in testing PR #3394.

For size = (2048, 8192), dtype=torch.bfloat16 (on my local system with Ada card):

--------------------------------------------------------------------------------------------------------------------------- benchmark: 4 tests ---------------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                                                      Min                 Max                Mean            StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rmsnorm_bwd_nvf_benchmark[dtype=torch.bfloat16-size=[2048_8192]]                             136.8970 (1.0)      145.0250 (1.0)      140.3881 (1.0)      2.3346 (1.68)     140.0140 (1.0)      3.4200 (1.84)          2;0        7.1231 (1.0)          10           1
test_rmsnorm_bwd_baseline_benchmark[dtype=torch.bfloat16-size=[2048_8192]-executor='thunder']     223.9020 (1.64)     228.9010 (1.58)     226.1649 (1.61)     1.3899 (1.0)      226.0655 (1.61)     1.8540 (1.0)           2;0        4.4216 (0.62)         10           1
test_rmsnorm_bwd_nvf_benchmark[dtype=torch.float32-size=[2048_8192]]                              256.4510 (1.87)     265.5080 (1.83)     260.5773 (1.86)     3.0545 (2.20)     259.8870 (1.86)     4.8270 (2.60)          4;0        3.8376 (0.54)         10           1
test_rmsnorm_bwd_baseline_benchmark[dtype=torch.float32-size=[2048_8192]-executor='thunder']      271.0090 (1.98)     274.9130 (1.90)     273.3845 (1.95)     1.4553 (1.05)     273.9035 (1.96)     2.7580 (1.49)          5;0        3.6579 (0.51)         10           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

The above numbers are using rmsnorm composed of primitives:

def rmsnorm_prims(inputs: list):
    inp, weights = inputs
    squared_mean = (inp**2).mean(1, keepdim=True)
    rms_eps = torch.sqrt(squared_mean + 1e-5)
    output = weights * (inp / rms_eps)
    return output

I recover some of the performance using torch.nn.functional.rms_norm (Note that the manual nvfuser definition was generated through Thunder using the above rmsnorm_prims):

def rmsnorm_func(inputs: list):
    inp, weights = inputs
    output = F.rms_norm(inp, inp.shape[1:], weights, eps=1e-5)
    return output
--------------------------------------------------------------------------------------------------------------------------- benchmark: 4 tests ---------------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                                                      Min                 Max                Mean            StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rmsnorm_bwd_nvf_benchmark[dtype=torch.bfloat16-size=[2048_8192]]                             137.6300 (1.0)      143.1660 (1.0)      140.5177 (1.0)      1.9396 (1.91)     139.8885 (1.0)      3.2640 (2.55)          4;0        7.1165 (1.0)          10           1
test_rmsnorm_bwd_baseline_benchmark[dtype=torch.bfloat16-size=[2048_8192]-executor='thunder']     175.1710 (1.27)     178.3350 (1.25)     176.9573 (1.26)     1.0168 (1.0)      176.9435 (1.26)     1.2810 (1.0)           4;0        5.6511 (0.79)         10           1
test_rmsnorm_bwd_baseline_benchmark[dtype=torch.float32-size=[2048_8192]-executor='thunder']      255.0390 (1.85)     264.3810 (1.85)     258.7758 (1.84)     2.6816 (2.64)     258.5290 (1.85)     2.9120 (2.27)          3;0        3.8643 (0.54)         10           1
test_rmsnorm_bwd_nvf_benchmark[dtype=torch.float32-size=[2048_8192]]                              258.3390 (1.88)     267.1710 (1.87)     261.6898 (1.86)     2.7284 (2.68)     261.1510 (1.87)     2.7240 (2.13)          4;1        3.8213 (0.54)         10           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
@Priya2698
Copy link
Collaborator Author

Priya2698 commented Dec 20, 2024

I have a (mostly) standalone script for nsys profiling here.

I'll run a sweep using F.rms_norm. The existent fusion definition in the python benchmarks was obtained using Thunder but modified to allow for dynamic shapes and dtypes. Some casts and broadcast ops may have been simplified, which may be responsible for the performance gap.
I'll looking at the difference in the operators present.

@Priya2698
Copy link
Collaborator Author

CC: @kevinstephano @mruberry

@mruberry
Copy link

I filed Lightning-AI/lightning-thunder#1582 to also track this in the thunder repository. Looking forward to hearing the results of your analysis, @Priya2698!

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Dec 31, 2024

I compared the existing nvfuser definition (nvf_rmsnorm) with that generated from Thunder when using F.rms_norm (thunder_rmsnorm). I am not using the primitives-based implementation for comparison now since Thunder is now using F.rms_norm. This is also faster.

I have been looking at input size: [2048, 8192], dtype=bfloat16 on my local machine with Ada card.

  1. The launch parameters for both cases are same: BlockDim.x = 16, BlockDim.y = 16, BlockDim.z = 1, GridDim.x = -1, GridDim.y = 142, GridDim.z = -1, Smem Size = 50176

  2. One of the snippets I noted in the thunder_rmsnorm implementation is:

   T29 = fd.ops.cast(T3, dtype=DataType.BFloat16)
   T37 = fd.ops.broadcast_in_dim(T3, shape=[2048, 8192], broadcast_dims=[0, 1])
   T42 = fd.ops.cast(T37, dtype=DataType.Float)

where T3 = fd.define_tensor(shape=[2048, 1], contiguity=[True, None], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])

Simplifying this to avoid the roundtrip cast: T42 = fd.ops.broadcast_in_dim(T3, shape=[2048, 8192], broadcast_dims=[0, 1]) reduces time from 175ms->162ms (input size: [2048, 8192], dtype=bfloat16).

  1. I see similar instructions for memory loads, shared memory access, warp reduce, waits etc in the CUDA kernels for both the definitions.
  2. Here is a link to the benchmarking run: http://nv/etT. The performance seems better than I previously saw which might have been due to infra difference. I will run it again for verification.

My next steps will be to isolate common computations in the two definitions and identify which instructions have the maximum difference in performance for the two fusion definitions.

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Jan 28, 2025

Thunder definition (nvfuser version: '0.2.25+gitfd4a7f1', thunder version: 0.2.0dev:

def rmsnorm_thunder_func_fusion(
  fd, dtype
):
    T0 = fd.define_tensor(shape=[8192], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[2048, 8192], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[2048, 8192], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T3 = fd.define_tensor(shape=[2048, 1], contiguity=[True, None], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T7 = fd.ops.broadcast_in_dim(T0, shape=[2048, 8192], broadcast_dims=[1])
    T8 = fd.ops.cast(T1, dtype=DataType.Float)
    T9 = fd.ops.cast(T7, dtype=DataType.Float)
    T10 = fd.ops.mul(T9, T8)
    T11 = fd.ops.cast(T2, dtype=DataType.Float)
    T12 = fd.ops.mul(T11, T10)
    T13 = fd.ops.sum(T12, dims=[1], keepdim=False, dtype=DataType.Null)
    T14 = fd.ops.cast(T13, dtype=DataType.BFloat16)
    T18 = fd.ops.broadcast_in_dim(T14, shape=[2048, 1], broadcast_dims=[0])
    T19 = fd.ops.cast(T18, dtype=DataType.Float)
    S20 = fd.define_scalar(3.00000, dtype=DataType.Double)
    T21 = fd.ops.pow(T3, S20)
    S22 = fd.define_scalar(-0.500000, dtype=DataType.Double)
    T23 = fd.ops.mul(S22, T19)
    T24 = fd.ops.mul(T23, T21)
    S25 = fd.define_scalar(8192.00, dtype=DataType.Double)
    S26 = fd.ops.reciprocal(S25)
    T27 = fd.ops.mul(T24, S26)
    T28 = fd.ops.sum(T27, dims=[1], keepdim=False, dtype=DataType.Null)
    T29 = fd.ops.cast(T3, dtype=DataType.BFloat16)
    T33 = fd.ops.broadcast_in_dim(T28, shape=[2048, 1], broadcast_dims=[0])
    T37 = fd.ops.broadcast_in_dim(T29, shape=[2048, 8192], broadcast_dims=[0, 1])
    T41 = fd.ops.broadcast_in_dim(T33, shape=[2048, 8192], broadcast_dims=[0, 1])
    T42 = fd.ops.cast(T37, dtype=DataType.Float)
    T43 = fd.ops.mul(T11, T41)
    T44 = fd.ops.mul(T42, T10)
    T45 = fd.ops.mul(T11, T42)
    T46 = fd.ops.add(T44, T43)
    T47 = fd.ops.mul(T45, T8)
    T48 = fd.ops.add(T46, T43)
    T49 = fd.ops.sum(T47, dims=[0], keepdim=False, dtype=DataType.Null)
    T50 = fd.ops.cast(T48, dtype=DataType.BFloat16)
    T51 = fd.ops.cast(T49, dtype=DataType.BFloat16)
    fd.add_output(T51)
    fd.add_output(T50)

Function to generate this definition:

def run_thunder_func():
  size = (2048, 8192)
  dtype =torch.bfloat16
  inputs = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True)
  grads = torch.randn(size, device="cuda", dtype=dtype)
  weights = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True)
  
  def rmsnorm_func(inputs):
    inp, weights = inputs
    output = F.rms_norm(
      inp,
      inp.shape[1:],
      weight=weights,
      eps = 1e-5
    )
    return output
  # Compile the fwd fn for torchcompile
  fwd_fn = thunder.jit(rmsnorm_func, executors=[nvfuserex])
  outputs = fwd_fn([inputs, weights])
  outputs.backward(grads)

nvfuser pre-generated definition:

def rmsnorm_bwd_fusion(
fd: FusionDefinition,
dtype: DataType,
):
T4 = fd.define_tensor(
shape=[-1, -1], contiguity=[True, True], dtype=dtype, is_cpu=False
)
T5 = fd.define_tensor(
shape=[-1, 1], contiguity=[True, None], dtype=DataType.Float, is_cpu=False
)
T6 = fd.define_tensor(
shape=[-1, -1], contiguity=[True, True], dtype=dtype, is_cpu=False
)
T7 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=dtype, is_cpu=False)
S0 = fd.define_scalar(2.0, dtype=DataType.Double)
if dtype in PROMOTE_DTYPES:
T4 = fd.ops.cast(T4, dtype=DataType.Float)
T6 = fd.ops.cast(T6, dtype=DataType.Float)
T7 = fd.ops.cast(T7, dtype=DataType.Float)
T14 = fd.ops.broadcast_in_dim(T5, shape=T4.shape(), broadcast_dims=[0, 1])
T15 = fd.ops.reciprocal(T14)
T16 = fd.ops.mul(T4, T15)
T20 = fd.ops.broadcast_in_dim(T7, shape=T4.shape(), broadcast_dims=[1])
T23 = fd.ops.mul(T6, T16)
T24 = fd.ops.mul(T6, T20)
T25 = fd.ops.sum(T23, dims=[0], keepdim=False, dtype=DataType.Null)
T28 = fd.ops.mul(T24, T15)
T29 = fd.ops.neg(T24)
T30 = fd.ops.mul(T29, T4)
T32 = fd.ops.pow(T14, S0)
T33 = fd.ops.reciprocal(T32)
T34 = fd.ops.mul(T30, T33)
T35 = fd.ops.sum(T34, dims=[1], keepdim=False, dtype=DataType.Null)
V39 = fd.define_vector([T4.size(0), 1], dtype=DataType.Int)
T41 = fd.ops.broadcast_in_dim(T35, shape=V39, broadcast_dims=[0])
T43 = fd.ops.mul(S0, T5)
T44 = fd.ops.reciprocal(T43)
T45 = fd.ops.mul(T41, T44)
S48 = fd.ops.reciprocal(T4.size(1))
T49 = fd.ops.mul(T45, S48)
T50 = fd.ops.sum(T49, dims=[1], keepdim=False, dtype=DataType.Null)
T54 = fd.ops.broadcast_in_dim(T50, shape=V39, broadcast_dims=[0])
T58 = fd.ops.broadcast_in_dim(T54, shape=T4.shape(), broadcast_dims=[0, 1])
T59 = fd.ops.mul(T58, S0)
T62 = fd.ops.mul(T59, T4)
T63 = fd.ops.add(T28, T62)
if dtype in PROMOTE_DTYPES:
T63 = fd.ops.cast(T63, dtype=dtype)
T25 = fd.ops.cast(T25, dtype=dtype)
fd.add_output(T63)
fd.add_output(T25)

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Jan 28, 2025

I see different register usage in two definitions:
The above Thunder definition:

Segments: 1
Scheduler:inner_outer_persistent
Grid: [1, 142, 1]
Block: [16, 16, 1]
Cluster: [0, 0, 0]
Shared Memory [Dyn,Stat]: [50176, 16]
Register: 80

Manual nvfuser definition in current benchmarks:

Segments: 1
Scheduler:inner_outer_persistent
Grid: [1, 142, 1]
Block: [16, 16, 1]
Cluster: [0, 0, 0]
Shared Memory [Dyn,Stat]: [50176, 16]
Register: 96

@naoyam
Copy link
Collaborator

naoyam commented Jan 28, 2025

This isn't segmented, right? Which scheduler is used? I suppose it's one of the normalization schedulers. Maybe @liqiangxl can take a look too?

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Jan 28, 2025

(For my reference): What I have tried?

  1. Measuring thunder.jit performance when computing rmsnorm using primitives and F.rms_norm: F.rms_norm is much faster
  2. Experimenting with the F.rms_norm decomposition in Thunder: The decomposition seems to be equivalent/faster than other approaches (for eg: https://github.com/Lightning-AI/lightning-thunder/blob/3abe4c20c4e0673bb1508ed2614067d756e0775d/examples/llama2.c/model.py#L27-L38)
  3. Updating downcast-broadcast-upcast patterns in the thunder.jit definition above: As summarized in the earlier comment, removing one of these roundtrip cast does improve performance. An approach similar to Merge up-cast, ops, down-cast sequences as minimal units of segments #3699 can be useful here.
  4. Isolate common computation in the two definitions to identify which instructions cause maximum performance difference: This did not give much insight. The math is different between the two so not many exact matches. For instance, the current thunder.jit definition stores 1/rms for backward pass, whereas, the existing definition in benchmarks stores rms.

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Jan 28, 2025

This isn't segmented, right?

No, 1 segment.

Which scheduler is used? I suppose it's one of the normalization schedulers.

inner_outer_persistent

Can the fusion definition using dynamic shapes (the existent nvfuser definition) vs static shapes (the above Thunder definition) cause different register usage?

@naoyam
Copy link
Collaborator

naoyam commented Jan 28, 2025

Can the fusion definition using dynamic shapes (the existent nvfuser definition) vs static shapes (the above Thunder definition) cause different register usage?

Yes, sometimes significantly.

IIUC, the nvFuser version is dynamic shape and uses more registers, which makes sense. What's interesting is that that nvFuser version is actually faster than the static-shape Thunder version, right?

@Priya2698
Copy link
Collaborator Author

Can the fusion definition using dynamic shapes (the existent nvfuser definition) vs static shapes (the above Thunder definition) cause different register usage?

Yes, sometimes significantly.

IIUC, the nvFuser version is dynamic shape and uses more registers, which makes sense. What's interesting is that that nvFuser version is actually faster than the static-shape Thunder version, right?

Yes, that is correct.

@Priya2698 Priya2698 changed the title Performance gap between manual nvfuser definition and thunder.jit Performance gap between manual nvfuser definition and thunder.jit for rmsnorm Jan 28, 2025
@naoyam
Copy link
Collaborator

naoyam commented Jan 28, 2025

What do the heuristic parameters look like?

Is this the only case like this? Or are you finding similar gaps with other benchmarks too?

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Jan 28, 2025

What do the heuristic parameters look like?

Thunder version

===== Combined InnerOuter Reduction Stats ========
outer_dim_numel: 2048
inner_dim_numel: 8192
regs_buffer_size: 32768
smem_buffer_size: 69120
smem_overhead: 3072
vectorize_factor_input: 8
vectorization_factor_tmp_gmem_write: 4
vectorization_factor_outer: 4
multiple_reds_per_blk: 0
warps_per_sm: 8
gdimy: 142
block(16, 16, 1)
===== Reduction Parameters ========
Tag: InnerOuter Register Persistent Heuristic.

Red On Fastest Dim
Persistent Kernel
Project Persistent Buffers
Batches per block: 4

Iteration Domain: blockIdx.y / split grid dimension outer / 
Inner Reduction Domain: persistent batch - 4 / vectorize / factor 8
Launch Parameters: BlockDim.x = 16, BlockDim.y = 16, BlockDim.z = -1, GridDim.x = -1, GridDim.y = 142, GridDim.z = -1, Smem Size = 0
Compile Parameters: index_type = int, maxrregcount = 256, enable_magic_zero = 1, enable_ptxas_verbose = 0

Nvfuser version

===== Combined InnerOuter Reduction Stats ========
outer_dim_numel: 2048
inner_dim_numel: 8192
regs_buffer_size: 32768
smem_buffer_size: 69120
smem_overhead: 3072
vectorize_factor_input: 8
vectorization_factor_tmp_gmem_write: 4
vectorization_factor_outer: 4
multiple_reds_per_blk: 0
warps_per_sm: 8
gdimy: 142
block(16, 16, 1)
===== Reduction Parameters ========
Tag: InnerOuter Register Persistent Heuristic.

Red On Fastest Dim
Persistent Kernel
Project Persistent Buffers
Batches per block: 4

Iteration Domain: blockIdx.y / split grid dimension outer / 
Inner Reduction Domain: persistent batch - 4 / vectorize / factor 8
Launch Parameters: BlockDim.x = 16, BlockDim.y = 16, BlockDim.z = -1, GridDim.x = -1, GridDim.y = 142, GridDim.z = -1, Smem Size = 0
Compile Parameters: index_type = int, maxrregcount = 256, enable_magic_zero = 1, enable_ptxas_verbose = 0
====================================

The heuristics look the same

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Jan 28, 2025

Is this the only case like this? Or are you finding similar gaps with other benchmarks too?

There is a similar trend for layernorm_bwd, dropout_layernorm/rmsnorm_bwd too for bfloat16 as per my last benchmark run. I started with rms_norm since the gap was the widest here, but the gap is smaller now when using F.rms_norm for thunder.jit.

@naoyam
Copy link
Collaborator

naoyam commented Jan 28, 2025

Did you see the same trend on A100 or H100 too?

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Jan 28, 2025

Did you see the same trend on A100 or H100 too?

Yes, you can see the results at: http://nv/etT for H100. This is slightly old though (Dec end). I have not done a complete sweep recently. I'll start one, to be sure.

@naoyam
Copy link
Collaborator

naoyam commented Jan 28, 2025

Thanks. @liqiangxl, can you please look into it?

@liqiangxl
Copy link
Collaborator

I checked A100, one reason is due to:

          T43[0] = pow(T60[0], (float) 1.00000000000000000e+00);

genPowerWithMul() only handles case with factor of 2 & 3

The fusions from thunder & nvFuser manually defined are different, the nvFuser version has y = pow(x,2) and lowered to y = x * x, the thunder version has y = pow(x, 1) and lowered to y = pow(x, 1.0). If I change to y = x, the performance is increased from 17% SOL to 50% SOL. The nvFuser version is 60% SOL. There are still 10% difference, this may becuase of the following reasons:

(1) The fusions are different. Thunder version has `5 bf16` input tensors while nvFuser version has `3 bf16 and 1 fp32` input tensors.
(2) Unroll of a for-loop with const number of iterations.

Thunder:

Inputs:
  T0_g___bfloat[iS0{2048}, bS1{1}]
  T1_g___bfloat[iS2{8192}]
  T2_g___bfloat[iS3{2048}, iS4{8192}]
  T3_g___bfloat[iS5{2048}, iS6{8192}]
  T4_g___bfloat[iS7{2048}, bS8{1}]
Outputs:
  T53_g___bfloat[iS103{8192}]
  T52_g___bfloat[iS101{2048}, iS102{8192}]

nvFuser:

Inputs:
  T0_g___bfloat[iS0{i0}, iS1{i1}]
  T1_g_float[iS83{i0}, bS3{1}]
  T2_g___bfloat[iS74{i0}, iS75{i1}]
  T3_g___bfloat[iS78{i1}]
Outputs:
  T37_g___bfloat[iS100{i0}, iS101{i1}]
  T38_g___bfloat[iS106{i1}]

@naoyam
Copy link
Collaborator

naoyam commented Jan 28, 2025

They are interesting findings. Thanks @liqiangxl.

@Priya2698 I thought the only difference is static shape vs dynamic shape as well as some additional cast ops. What Liqiang found seems to indicate they are fundamentally different, like pow(x, 1) vs pow(x, 2) as well as the input differences. Can you confirm Liqiang's findings?

@Priya2698
Copy link
Collaborator Author

Thanks. @liqiangxl for taking a look at this.

@Priya2698 I thought the only difference is static shape vs dynamic shape as well as some additional cast ops. What Liqiang found seems to indicate they are fundamentally different, like pow(x, 1) vs pow(x, 2) as well as the input differences. Can you confirm Liqiang's findings?

Sorry, I did not mean to indicate that this is the only difference.
The math is different between the two as well. For instance, the current thunder.jit definition stores 1/rms for backward pass, whereas, the existing definition in benchmarks stores rms.
The nvfuser version also comes from thunder.jit but is older. The pow thing can be seen in the Thunder definition:

S20 = fd.define_scalar(3.00000, dtype=DataType.Double)
T21 = fd.ops.pow(T3, S20)

whereas nvfuser does pow(x, 2) as @liqiangxl mentioned.

Where is pow(x, 1) coming from @liqiangxl?

The inputs to the fusion definition are 3 bf16 (input, weight, grads) and 1 fp32 (rms in nvfuser and 1/rms in thunder).
@liqiangxl how did you print out the difference in inputs?

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Jan 29, 2025

The dashboard for the latest run: http://nv/evT.
The latest run shows thunder/nvfuser geomean for RMSNorm bwd to be 0.9 aggregated over all batch sizes. This is using F.rms_norm. There seemed to be some failures so I am re-running to be sure.

The main branch currently has the decomposed rmsnorm (as below) which @liqiangxl has used and this in my experiments yields lower performance than using F.rms_norm:

def rmsnorm_prims(inputs: list):
    inp, weights = inputs
    squared_mean = (inp**2).mean(1, keepdim=True)
    rms_eps = torch.sqrt(squared_mean + 1e-5)
    output = weights * (inp / rms_eps)
    return output

I have put up a PR to update this: #3783.

There are still examples in Thunder which use the decomposed RMSNorm so I'll check whether this version is still being used by any models or is important to Thunder/nvfuser.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants