-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
huge huge memory usage!! #9
Comments
Indeed, it is normal that this version uses a lot of memory as it doesn't use the recomputation technique described in the paper (see this on the README for more information). Here is my benchmark for training a Mamba with
|
@alxndrTL thanks! this is such a great work! looking forward to your "performance update" |
Yes, using Metal would be an option, but there is also MLX in the place now, I don't yet know if they are exclusive or not. I guess it will not be too long before MLX implements an efficient (kernel optimized, etc) pscan implementation. |
Hey, any updates on this performance/pscan update? My 16GB m1 pro just crashed from loading mamba-2.8b for fine tuning on a classification task, so I'm looking forward to this! |
Hello @ali-shihab, the performance update has been pushed ≈2 months ago, but it only enables faster training, not memory usage. I just worked on Jamba so I'm taking some time off but I'm aware that it's the main problem of mamba.py right now (and kind of make it not usable for real scenarios) |
Hey @alxndrTL, thanks for the quick response. If I'm not mistaken, the performance update is for PyTorch, no? Correct me if I'm wrong. Additionally, I've just scanned over jamba.py very quickly, and it seems everything in there can be implemented in MLX - do you think I could do this, or is there something that stopped you from being able to? If not, I'll see if I can implement a local version of it in MLX & clean it up for a PR once I have some time. Also, that is completely understandable, enjoy your time off :) |
Yes the performance update is for PyTorch only, as the pscan works quite poorly on MLX as of right now (see the comments in https://github.com/alxndrTL/mamba.py/blob/main/mlx/pscan_mlx.py, tested with MLX in January) For the Jamba implementation in MLX, that would be very welcome! The one sore point I see would be replacing |
flash attention is not there (yet): |
i find that the pscan method used in this Mamba implementation use huge amount of memory! any idea how to reduce memory consumption? or replace the pscan method with other implementation??
great thanks!
The text was updated successfully, but these errors were encountered: