-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metal Performance Shader (MPS) Integration #51
Comments
Hi @ludsvick! Thank you for trying out cube3D and the ideas to speed it up on MPS. Assuming For
The annoying part is going to be knowing when to switch to slicing (MPS or any backend where
With this, we can use the static kernel when available and switch to good old slicing when it's not. Let me know if you're interested in giving this a spin and making a PR! |
@animan42 Great idea! I'll give this a try and let you know how it goes 👍 |
Hi all, love the project and it being open-sourced!
I tried following along with the setup guide and noticed a bottleneck on my system that was significantly reduced using MPS with PyTorch.
System & Performance
Device - MacBook Pro (2022)
CPU - M2 8-core
GPU - 10 Core w/ Metal 3 support
Memory - 16 GB
Using the command for shape generation, I had an estimated two hours for extracting geometry. After switching to
torch.device("mps")
ingenerate.py
, I got that time down to about four minutes.Problem
I would have made a pull request right off the bat with these changes, but it seems as though there is an operation with the KV caches of the attention transformers that isn't implemented within MPS as of yet (
index_copy_
). It's not too much of a headache to get around, just adding an environment variablePYTORCH_ENABLE_MPS_FALLBACK=1
before running the scripts lets torch defer to CPU for the operation.The only frustration is that this would need to be added in to each script that could reference the operation before torch is imported, or to have users add the environment variable in their .rc files, which could be a pain to manage/keep track of. Since there is both a command line and code-based implementation, I thought I would create an issue first to figure out the best way forward.
(Side note: I did mention the operation in PyTorch's tracker, so feel free to give it a +1 to help get their attention)
The text was updated successfully, but these errors were encountered: