-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance gap between manual nvfuser definition and thunder.jit
for rmsnorm
#3629
Comments
I have a (mostly) standalone script for nsys profiling here. I'll run a sweep using |
I filed Lightning-AI/lightning-thunder#1582 to also track this in the thunder repository. Looking forward to hearing the results of your analysis, @Priya2698! |
I compared the existing nvfuser definition (nvf_rmsnorm) with that generated from Thunder when using F.rms_norm ( I have been looking at
where Simplifying this to avoid the roundtrip cast:
My next steps will be to isolate common computations in the two definitions and identify which instructions have the maximum difference in performance for the two fusion definitions. |
Thunder definition (nvfuser version:
Function to generate this definition:
nvfuser pre-generated definition: Fuser/benchmarks/python/test_rmsnorm_bwd.py Lines 20 to 77 in 03e7e34
|
I see different register usage in two definitions:
Manual nvfuser definition in current benchmarks:
|
This isn't segmented, right? Which scheduler is used? I suppose it's one of the normalization schedulers. Maybe @liqiangxl can take a look too? |
(For my reference): What I have tried?
|
No, 1 segment.
inner_outer_persistent Can the fusion definition using dynamic shapes (the existent nvfuser definition) vs static shapes (the above Thunder definition) cause different register usage? |
Yes, sometimes significantly. IIUC, the nvFuser version is dynamic shape and uses more registers, which makes sense. What's interesting is that that nvFuser version is actually faster than the static-shape Thunder version, right? |
Yes, that is correct. |
thunder.jit
thunder.jit
for rmsnorm
What do the heuristic parameters look like? Is this the only case like this? Or are you finding similar gaps with other benchmarks too? |
Thunder version
Nvfuser version
The heuristics look the same |
There is a similar trend for |
Did you see the same trend on A100 or H100 too? |
Yes, you can see the results at: http://nv/etT for H100. This is slightly old though (Dec end). I have not done a complete sweep recently. I'll start one, to be sure. |
Thanks. @liqiangxl, can you please look into it? |
I checked A100, one reason is due to:
The fusions from thunder & nvFuser manually defined are different, the nvFuser version has
Thunder:
nvFuser:
|
They are interesting findings. Thanks @liqiangxl. @Priya2698 I thought the only difference is static shape vs dynamic shape as well as some additional cast ops. What Liqiang found seems to indicate they are fundamentally different, like |
Thanks. @liqiangxl for taking a look at this.
Sorry, I did not mean to indicate that this is the only difference.
whereas nvfuser does Where is The inputs to the fusion definition are 3 bf16 (input, weight, grads) and 1 fp32 (rms in nvfuser and 1/rms in thunder). |
The dashboard for the latest run: http://nv/evT. The
I have put up a PR to update this: #3783. There are still examples in Thunder which use the decomposed RMSNorm so I'll check whether this version is still being used by any models or is important to Thunder/nvfuser. |
I am seeing lower performance for
thunder.jit
(with nvfuserex executor) than the manual nvfuser definition existent in the python benchmark suite: http://nv/etb. This came up in testing PR #3394.For
size = (2048, 8192), dtype=torch.bfloat16
(on my local system with Ada card):The above numbers are using rmsnorm composed of primitives:
I recover some of the performance using
torch.nn.functional.rms_norm
(Note that the manual nvfuser definition was generated through Thunder using the abovermsnorm_prims
):The text was updated successfully, but these errors were encountered: