Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal Performance Shader (MPS) Integration #51

Open
ludsvick opened this issue Mar 21, 2025 · 2 comments
Open

Metal Performance Shader (MPS) Integration #51

ludsvick opened this issue Mar 21, 2025 · 2 comments

Comments

@ludsvick
Copy link

Hi all, love the project and it being open-sourced!

I tried following along with the setup guide and noticed a bottleneck on my system that was significantly reduced using MPS with PyTorch.

System & Performance

Device - MacBook Pro (2022)
CPU - M2 8-core
GPU - 10 Core w/ Metal 3 support
Memory - 16 GB

Using the command for shape generation, I had an estimated two hours for extracting geometry. After switching to torch.device("mps") in generate.py, I got that time down to about four minutes.

Problem

I would have made a pull request right off the bat with these changes, but it seems as though there is an operation with the KV caches of the attention transformers that isn't implemented within MPS as of yet (index_copy_). It's not too much of a headache to get around, just adding an environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 before running the scripts lets torch defer to CPU for the operation.

The only frustration is that this would need to be added in to each script that could reference the operation before torch is imported, or to have users add the environment variable in their .rc files, which could be a pain to manage/keep track of. Since there is both a command line and code-based implementation, I thought I would create an issue first to figure out the best way forward.

(Side note: I did mention the operation in PyTorch's tracker, so feel free to give it a +1 to help get their attention)

@animan42
Copy link
Collaborator

animan42 commented Mar 21, 2025

Hi @ludsvick! Thank you for trying out cube3D and the ideas to speed it up on MPS.

Assuming index_copy_ is the only problematic op: We only need this static kernel op when we are in CUDA Graph mode (EngineFast) which doesn't work on MPS anyway.

For Engine, it is safe to just slice

kv_cache.key_states.[:, :, curr_pos_id:curr_pos_id+1, ...].copy_(k)
kv_cache.value_states.[:, :, curr_pos_id:curr_pos_id+1, ...].copy_(v)

The annoying part is going to be knowing when to switch to slicing (MPS or any backend where index_copy_ is not supported). I propose something like this in cache.py

from dataclasses import dataclass, field
import torch

@dataclass
class Cache:
    key_states: torch.Tensor
    value_states: torch.Tensor
    _supports_index_copy: bool = field(init=False)

    def __post_init__(self):
        self._supports_index_copy = self._check_index_copy_support()

    def _check_index_copy_support(self) -> bool: 
        try:
            device = self.key_states.device
            dummy = torch.tensor([0, 0], device=device)
            dummy.index_copy_(0, torch.tensor([0], device=device), torch.tensor([1], device=device))
            return True
        except RuntimeError:
            return False

    def update(self, curr_pos_id: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> None:
        if self._supports_index_copy:
            self.key_states.index_copy_(2, curr_pos_id, k)
            self.value_states.index_copy_(2, curr_pos_id, v)
        else:
            self.key_states[:, :, curr_pos_id:curr_pos_id +1, ...].copy_(k)
            self.value_states[:, :, curr_pos_id:curr_pos_id +1, ...].copy_(v)

With this, we can use the static kernel when available and switch to good old slicing when it's not. Let me know if you're interested in giving this a spin and making a PR!

@ludsvick
Copy link
Author

@animan42 Great idea! I'll give this a try and let you know how it goes 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants