-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Should the S5 layer be faster than a RNN? #6
Comments
Yes seems like in it's current state it's ~2.3-6x (forward) or ~3.4-9x (forward+backward) slower on CUDA (S5 gets faster with increasing sequence length; tested 1200-144000), on CPU S5 is 1.07-1.44x faster (forward+backward) or 1.03-1.4x slower (forward). There are some new implementations since I did my own (see pytorch thread: pytorch/pytorch#95408), but there are mixed reports on speed; from profiling it does seem like quite a bit of time is spent on stack/interleave functions which don't do any computation. From the paper it seems like an optimized version of S5 would be potentially ~10-60x faster than GRU (which should be similar to LSTM), but the reported figures could be a naive implementation rather than an optimized kernel. Note that benchmarking on CUDA is not reliable due to async calls so I adapted your example to fix that: import os
import torch
from s5 import S5
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
L = 1200
B = 2
x_dim = 128
x = torch.randn(B, L, x_dim).cuda()
import torch.utils.benchmark as benchmark
t0 = benchmark.Timer(
stmt='v, _ = lstm(x)',#'; v.sum().backward(); lstm.zero_grad()',
setup='lstm = torch.nn.LSTM(x_dim, 512).cuda()',
globals={'x': x, 'torch': torch, 'x_dim': x_dim})
t1 = benchmark.Timer(
stmt='model(x)',#'.sum().backward(); model.zero_grad()',
setup='model = S5(x_dim, 512).cuda()',
globals={'x': x, 'S5': S5, 'x_dim': x_dim})
print(t0.timeit(50))
print(t1.timeit(50)) |
I tried running the following script and found that S5 is far slower than PyTorch's LSTM. Is this supposed to be the case? Perhaps the scale at which I'm testing it is too small to realize the benefit?
I would greatly appreciate any comment on this. Thanks in advance, and thanks for the implementation!
The text was updated successfully, but these errors were encountered: