diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..373e352 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,7 @@ +recursive-include flash_mla *.pyi +recursive-include flash_mla *.typed +include LICENSE + +# Include source files in sdist +include .gitmodules +recursive-include csrc * diff --git a/README.md b/README.md index 4027334..16f3e8d 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,9 @@ Currently released: ### Install ```bash -python setup.py install +python3 -m pip install --upgrade pip setuptools +python3 -m pip install torch pybind11 --index-url https://download.pytorch.org/whl/cu126 +python3 -m pip install --no-build-isolation --editable . ``` ### Benchmark @@ -52,7 +54,7 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash- ```bibtex @misc{flashmla2025, - title={FlashMLA: Efficient MLA decoding kernel}, + title={FlashMLA: Efficient MLA decoding kernel}, author={Jiashi Li}, year={2025}, publisher = {GitHub}, diff --git a/flash_mla/__init__.py b/flash_mla/__init__.py index 51b8600..97059b6 100644 --- a/flash_mla/__init__.py +++ b/flash_mla/__init__.py @@ -1,6 +1,15 @@ -__version__ = "1.0.0" +"""FlashMLA: An efficient MLA decoding kernel for Hopper GPUs.""" from flash_mla.flash_mla_interface import ( get_mla_metadata, flash_mla_with_kvcache, ) + + +__all__ = [ + "get_mla_metadata", + "flash_mla_with_kvcache", +] + + +__version__ = "1.0.0" diff --git a/flash_mla/flash_mla_cuda.pyi b/flash_mla/flash_mla_cuda.pyi new file mode 100644 index 0000000..8715ca0 --- /dev/null +++ b/flash_mla/flash_mla_cuda.pyi @@ -0,0 +1,19 @@ +import torch + +def get_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> tuple[torch.Tensor, torch.Tensor]: ... +def fwd_kvcache_mla( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor | None, + head_dim_v: int, + cache_seqlens: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + causal: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: ... diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index b2922af..2766cec 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -2,7 +2,7 @@ import torch -import flash_mla_cuda +from flash_mla import flash_mla_cuda def get_mla_metadata( diff --git a/flash_mla/py.typed b/flash_mla/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..5382bed --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,64 @@ +# Package ###################################################################### + +[build-system] +requires = ["setuptools", "pybind11", "torch ~= 2.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "flash-mla" +description = "FlashMLA: An efficient MLA decoding kernel for Hopper GPUs." +readme = "README.md" +requires-python = ">= 3.8" +authors = [ + { name = "FlashMLA Contributors" }, + { name = "Jiashi Li", email = "450993438@qq.com" }, +] +license = { text = "MIT" } +keywords = [ + "Multi-head Latent Attention", + "MLA", + "Flash MLA", + "Flash Attention", + "CUDA", + "kernel", +] +classifiers = [ + "Development Status :: 4 - Beta", + "License :: OSI Approved :: MIT License", + "Programming Language :: C++", + "Programming Language :: CUDA", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: Implementation :: CPython", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX :: Linux", + "Environment :: GPU :: NVIDIA CUDA :: 12", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", +] +dependencies = ["torch ~= 2.0"] +dynamic = ["version"] + +[project.urls] +Homepage = "https://github.com/deepseek-ai/FlashMLA" +Repository = "https://github.com/deepseek-ai/FlashMLA" +"Bug Report" = "https://github.com/deepseek-ai/FlashMLA/issues" + +[project.optional-dependencies] +test = ["triton"] +benchmark = ["triton"] + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +include = ["flash_mla", "flash_mla.*"] + +[tool.setuptools.package-data] +flash_mla = ['*.so', '*.pyd'] diff --git a/setup.py b/setup.py index 6377b1e..96ef2a4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from datetime import datetime import subprocess -from setuptools import setup, find_packages +from setuptools import setup from torch.utils.cpp_extension import ( BuildExtension, @@ -51,7 +51,7 @@ def get_features_args(): ext_modules = [] ext_modules.append( CUDAExtension( - name="flash_mla_cuda", + name="flash_mla.flash_mla_cuda", sources=get_sources(), extra_compile_args={ "cxx": cxx_args + get_features_args(), @@ -92,9 +92,8 @@ def get_features_args(): setup( - name="flash_mla", + name="flash-mla", version="1.0.0" + rev, - packages=find_packages(include=['flash_mla']), ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, )