Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

huge huge memory usage!! #9

Open
eisneim opened this issue Feb 5, 2024 · 8 comments
Open

huge huge memory usage!! #9

eisneim opened this issue Feb 5, 2024 · 8 comments

Comments

@eisneim
Copy link

eisneim commented Feb 5, 2024

i find that the pscan method used in this Mamba implementation use huge amount of memory! any idea how to reduce memory consumption? or replace the pscan method with other implementation??

great thanks!

@alxndrTL
Copy link
Owner

alxndrTL commented Feb 5, 2024

Indeed, it is normal that this version uses a lot of memory as it doesn't use the recomputation technique described in the paper (see this on the README for more information).
As of now, if you can use the default CUDA implementation, go for it (if you have a recent enough NVIDIA GPU), as mamba.py is mostly designed for educational purposes. If not, you can look maybe here to implement the recomputation yourself.
I'm working on a "performance update" and I hope that I will be able to include the recomputation technique.

Here is my benchmark for training a Mamba with d_model=512, n_layers=16, B=16, L=512 on a A100 80GB :

  • CUDA implementation : 3.2GB
  • mamba.py : 33.6GB
    You see how important the recomputation technique is.

@eisneim
Copy link
Author

eisneim commented Feb 6, 2024

@alxndrTL thanks! this is such a great work!
i'm using M3 max macbook so CUDA implementation is not an option unfortunately; i was trying to use your mamba code with method described in paper vision mamba(Vim), it got the CLIP training working using mamba as text encoder, vision mamba as vision encoder, but comparing to Vit, the memory usage increased at least 5 times!
maybe creating a custom operator using Metal would solve this issue? apple doc

looking forward to your "performance update"

@alxndrTL
Copy link
Owner

alxndrTL commented Feb 6, 2024

Yes, using Metal would be an option, but there is also MLX in the place now, I don't yet know if they are exclusive or not. I guess it will not be too long before MLX implements an efficient (kernel optimized, etc) pscan implementation.
I will keep you in touch here for the performance update!

@ali-shihab
Copy link

Hey, any updates on this performance/pscan update? My 16GB m1 pro just crashed from loading mamba-2.8b for fine tuning on a classification task, so I'm looking forward to this!

@alxndrTL
Copy link
Owner

Hello @ali-shihab, the performance update has been pushed ≈2 months ago, but it only enables faster training, not memory usage. I just worked on Jamba so I'm taking some time off but I'm aware that it's the main problem of mamba.py right now (and kind of make it not usable for real scenarios)

@ali-shihab
Copy link

Hey @alxndrTL, thanks for the quick response. If I'm not mistaken, the performance update is for PyTorch, no? Correct me if I'm wrong.

Additionally, I've just scanned over jamba.py very quickly, and it seems everything in there can be implemented in MLX - do you think I could do this, or is there something that stopped you from being able to? If not, I'll see if I can implement a local version of it in MLX & clean it up for a PR once I have some time.

Also, that is completely understandable, enjoy your time off :)

@alxndrTL
Copy link
Owner

alxndrTL commented May 1, 2024

Yes the performance update is for PyTorch only, as the pscan works quite poorly on MLX as of right now (see the comments in https://github.com/alxndrTL/mamba.py/blob/main/mlx/pscan_mlx.py, tested with MLX in January)

For the Jamba implementation in MLX, that would be very welcome! The one sore point I see would be replacing F.scaled_dot_product_attention, I don't know if there is an implementation of FlashAttention in MLX, else you just have to do the attention computations by hand.

@M-I
Copy link

M-I commented Jul 13, 2024

flash attention is not there (yet):
ml-explore/mlx-examples#724 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants