-
Notifications
You must be signed in to change notification settings - Fork 756
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore(setup): properly package the repository as a Python package
- Loading branch information
Showing
8 changed files
with
110 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
recursive-include flash_mla *.pyi | ||
recursive-include flash_mla *.typed | ||
include LICENSE | ||
|
||
# Include source files in sdist | ||
include .gitmodules | ||
recursive-include csrc * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,17 @@ | ||
__version__ = "1.0.0" | ||
""" | ||
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving. | ||
""" | ||
|
||
from flash_mla.flash_mla_interface import ( | ||
get_mla_metadata, | ||
flash_mla_with_kvcache, | ||
) | ||
|
||
|
||
__all__ = [ | ||
"get_mla_metadata", | ||
"flash_mla_with_kvcache", | ||
] | ||
|
||
|
||
__version__ = "1.0.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch | ||
|
||
def get_mla_metadata( | ||
cache_seqlens: torch.Tensor, | ||
num_heads_per_head_k: int, | ||
num_heads_k: int, | ||
) -> tuple[torch.Tensor, torch.Tensor]: ... | ||
def fwd_kvcache_mla( | ||
q: torch.Tensor, | ||
k_cache: torch.Tensor, | ||
v_cache: torch.Tensor | None, | ||
head_dim_v: int, | ||
cache_seqlens: torch.Tensor, | ||
block_table: torch.Tensor, | ||
softmax_scale: float, | ||
causal: bool, | ||
tile_scheduler_metadata: torch.Tensor, | ||
num_splits: torch.Tensor, | ||
) -> tuple[torch.Tensor, torch.Tensor]: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
import torch | ||
|
||
import flash_mla_cuda | ||
from flash_mla import flash_mla_cuda | ||
|
||
|
||
def get_mla_metadata( | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Package ###################################################################### | ||
|
||
[build-system] | ||
requires = ["setuptools", "pybind11", "torch ~= 2.0"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[project] | ||
name = "flash-mla" | ||
description = "FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving." | ||
readme = "README.md" | ||
requires-python = ">= 3.8" | ||
authors = [ | ||
{ name = "FlashMLA Contributors" }, | ||
{ name = "Jiashi Li", email = "[email protected]" }, | ||
] | ||
license = { text = "MIT" } | ||
keywords = [ | ||
"Multi-head Latent Attention", | ||
"MLA", | ||
"Flash MLA", | ||
"Flash Attention", | ||
"CUDA", | ||
"kernel", | ||
] | ||
classifiers = [ | ||
"Development Status :: 4 - Beta", | ||
"License :: OSI Approved :: MIT License", | ||
"Programming Language :: C++", | ||
"Programming Language :: CUDA", | ||
"Programming Language :: Python :: 3", | ||
"Programming Language :: Python :: 3.8", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12", | ||
"Programming Language :: Python :: 3.13", | ||
"Programming Language :: Python :: Implementation :: CPython", | ||
"Operating System :: Microsoft :: Windows", | ||
"Operating System :: POSIX :: Linux", | ||
"Environment :: GPU :: NVIDIA CUDA :: 12", | ||
"Intended Audience :: Developers", | ||
"Intended Audience :: Education", | ||
"Intended Audience :: Science/Research", | ||
] | ||
dependencies = ["torch ~= 2.0"] | ||
dynamic = ["version"] | ||
|
||
[project.urls] | ||
Homepage = "https://github.com/deepseek-ai/FlashMLA" | ||
Repository = "https://github.com/deepseek-ai/FlashMLA" | ||
"Bug Report" = "https://github.com/deepseek-ai/FlashMLA/issues" | ||
|
||
[project.optional-dependencies] | ||
test = ["triton"] | ||
benchmark = ["triton"] | ||
|
||
[tool.setuptools] | ||
include-package-data = true | ||
|
||
[tool.setuptools.packages.find] | ||
include = ["flash_mla", "flash_mla.*"] | ||
|
||
[tool.setuptools.package-data] | ||
flash_mla = ['*.so', '*.pyd'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters