Skip to content

Commit

Permalink
chore(setup): properly package the repository as a Python package
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Feb 24, 2025
1 parent 18e3277 commit 1ad4abe
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 8 deletions.
7 changes: 7 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
recursive-include flash_mla *.pyi
recursive-include flash_mla *.typed
include LICENSE

# Include source files in sdist
include .gitmodules
recursive-include csrc *
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ Currently released:
### Install

```bash
python setup.py install
python3 -m pip install --upgrade pip setuptools
python3 -m pip install torch pybind11 --index-url https://download.pytorch.org/whl/cu126
python3 -m pip install --no-build-isolation --editable .
```

### Benchmark
Expand Down Expand Up @@ -52,7 +54,7 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-

```bibtex
@misc{flashmla2025,
title={FlashMLA: Efficient MLA decoding kernel},
title={FlashMLA: Efficient MLA decoding kernel},
author={Jiashi Li},
year={2025},
publisher = {GitHub},
Expand Down
13 changes: 12 additions & 1 deletion flash_mla/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
__version__ = "1.0.0"
"""
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
"""

from flash_mla.flash_mla_interface import (
get_mla_metadata,
flash_mla_with_kvcache,
)


__all__ = [
"get_mla_metadata",
"flash_mla_with_kvcache",
]


__version__ = "1.0.0"
19 changes: 19 additions & 0 deletions flash_mla/flash_mla_cuda.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch

def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> tuple[torch.Tensor, torch.Tensor]: ...
def fwd_kvcache_mla(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor | None,
head_dim_v: int,
cache_seqlens: torch.Tensor,
block_table: torch.Tensor,
softmax_scale: float,
causal: bool,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ...
2 changes: 1 addition & 1 deletion flash_mla/flash_mla_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

import flash_mla_cuda
from flash_mla import flash_mla_cuda


def get_mla_metadata(
Expand Down
Empty file added flash_mla/py.typed
Empty file.
64 changes: 64 additions & 0 deletions pypyroject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Package ######################################################################

[build-system]
requires = ["setuptools", "pybind11", "torch ~= 2.0"]
build-backend = "setuptools.build_meta"

[project]
name = "flash-mla"
description = "FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving."
readme = "README.md"
requires-python = ">= 3.8"
authors = [
{ name = "FlashMLA Contributors" },
{ name = "Jiashi Li", email = "[email protected]" },
]
license = { text = "MIT" }
keywords = [
"Multi-head Latent Attention",
"MLA",
"Flash MLA",
"Flash Attention",
"CUDA",
"kernel",
]
classifiers = [
"Development Status :: 4 - Beta",
"License :: OSI Approved :: MIT License",
"Programming Language :: C++",
"Programming Language :: CUDA",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Implementation :: CPython",
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX :: Linux",
"Environment :: GPU :: NVIDIA CUDA :: 12",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
]
dependencies = ["torch ~= 2.0"]
dynamic = ["version"]

[project.urls]
Homepage = "https://github.com/deepseek-ai/FlashMLA"
Repository = "https://github.com/deepseek-ai/FlashMLA"
"Bug Report" = "https://github.com/deepseek-ai/FlashMLA/issues"

[project.optional-dependencies]
test = ["triton"]
benchmark = ["triton"]

[tool.setuptools]
include-package-data = true

[tool.setuptools.packages.find]
include = ["flash_mla", "flash_mla.*"]

[tool.setuptools.package-data]
flash_mla = ['*.so', '*.pyd']
7 changes: 3 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
import subprocess

from setuptools import setup, find_packages
from setuptools import setup

from torch.utils.cpp_extension import (
BuildExtension,
Expand Down Expand Up @@ -33,7 +33,7 @@ def append_nvcc_threads(nvcc_extra_args):
ext_modules = []
ext_modules.append(
CUDAExtension(
name="flash_mla_cuda",
name="flash_mla.flash_mla_cuda",
sources=[
"csrc/flash_api.cpp",
"csrc/flash_fwd_mla_bf16_sm90.cu",
Expand Down Expand Up @@ -77,9 +77,8 @@ def append_nvcc_threads(nvcc_extra_args):


setup(
name="flash_mla",
name="flash-mla",
version="1.0.0" + rev,
packages=find_packages(include=['flash_mla']),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
)

0 comments on commit 1ad4abe

Please sign in to comment.