-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HF Transformers ViT slower than torch.compile
and raw pytorch
#1502
Comments
Hi, thank you for the detailed report!
|
torch.compile
and raw pytorch on two examples. torch.compile
and raw pytorch
I ran: import timeit
import thunder
import torch
from transformers import AutoModel
model = AutoModel.from_pretrained("WinKawaks/vit-tiny-patch16-224").cuda()
jmodel = thunder.jit(model)
jmodel(torch.randn((10, 3, 224, 224), device='cuda:0'))
cmodel = torch.compile(model)
cmodel(torch.randn((10, 3, 224, 224), device='cuda:0'))
mdl = timeit.timeit(model(torch.randn((10, 3, 224, 224), device='cuda:0')))
tc = timeit.timeit(cmodel(torch.randn((10, 3, 224, 224), device='cuda:0')))
th = timeit.timeit(jmodel(torch.randn((10, 3, 224, 224), device='cuda:0')))
print(f"timings: {mdl=}, {tc=}, {th=}")
nvtx.push_range("eager")
model(torch.randn(10, 3, 224, 224).cuda())
nvtx.pop_range()
nvtx.push_range("torch.compile alone")
cmodel(torch.randn(10, 3, 224, 224).cuda())
nvtx.pop_range()
nvtx.push_range("thunder")
jmodel(torch.randn(10, 3, 224, 224).cuda())
nvtx.pop_range() On my ada6k I see: Inside nsys, Thunder shows up as quite a bit slower: In this particular case we're hit by significant host latency, as evidenced by the white gaps in the The following issues detail fixes that are more granular and direct but probably relevant to this issue:
With #1467 (or #981?) being the most relevant. Assigning to Melissa so we discuss and figure out next steps--do we want to keep this open as a tracking issue or should we just defer to the related issues? To the original poster:
|
readme
+
toy exampleThe first example is in the README.md,
I tried
and got
it seems it is faster using the raw pytorch.
The forward trace of thunder is
ViT Example
The toy example may be too simple, so overhead comes. So I tried a practical example.
and I got
and the compile time is also very slow, it took 1m42.8s to do the first run for jmodel, and it only took 23.5s for torch.compile.
and the trace of thunder is
The text was updated successfully, but these errors were encountered: