Skip to content

Commit 9288860

Browse files
committed
build: Add CachedWheel
Signed-off-by: oliver könig <[email protected]>
1 parent 9b1e960 commit 9288860

File tree

2 files changed

+107
-18
lines changed

2 files changed

+107
-18
lines changed

.github/workflows/publish.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ jobs:
8282
pip install torch --index-url https://download.pytorch.org/whl/cpu
8383
- name: Build core package
8484
env:
85-
GROUPED_GEMM_SKIP_CUDA_BUILD: "TRUE"
85+
DEEP_GEMM_NO_LOCAL_VERSION: "TRUE"
86+
DEEP_GEMM_SKIP_CUDA_BUILD: "TRUE"
8687
run: |
8788
python setup.py sdist --dist-dir=dist
8889
- name: Deploy

setup.py

Lines changed: 105 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from calendar import c
12
import os
23
import setuptools
34
import shutil
@@ -9,11 +10,84 @@
910
from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME
1011
from pathlib import Path
1112
import subprocess
13+
import sys
14+
import platform
15+
from packaging.version import parse
16+
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
17+
import urllib
1218

1319
SKIP_CUDA_BUILD = os.getenv("DEEP_GEMM_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
20+
NO_LOCAL_VERSION = os.getenv("DEEP_GEMM_NO_LOCAL_VERSION", "FALSE") == "TRUE"
21+
FORCE_BUILD = os.getenv("DEEP_GEMM_FORCE_BUILD", "FALSE") == "TRUE"
1422

15-
23+
BASE_WHEEL_URL = (
24+
"https://github.com/DeepSeek-AI/DeepGEMM/releases/download/{tag_name}/{wheel_name}"
25+
)
26+
PACKAGE_NAME = "deep_gemm"
1627
current_dir = os.path.dirname(os.path.realpath(__file__))
28+
29+
30+
def get_package_version():
31+
with open(Path(current_dir) / "deep_gemm" / "__init__.py", "r") as f:
32+
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
33+
public_version = ast.literal_eval(version_match.group(1))
34+
revision = ""
35+
36+
if not NO_LOCAL_VERSION:
37+
try:
38+
cmd = ["git", "rev-parse", "--short", "HEAD"]
39+
revision = "+" + subprocess.check_output(cmd).decode("ascii").rstrip()
40+
except:
41+
revision = ""
42+
43+
return f"{public_version}{revision}"
44+
45+
46+
def get_platform():
47+
"""
48+
Returns the platform name as used in wheel filenames.
49+
"""
50+
if sys.platform.startswith("linux"):
51+
return f"linux_{platform.uname().machine}"
52+
elif sys.platform == "darwin":
53+
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
54+
return f"macosx_{mac_version}_x86_64"
55+
elif sys.platform == "win32":
56+
return "win_amd64"
57+
else:
58+
raise ValueError("Unsupported platform: {}".format(sys.platform))
59+
60+
61+
def get_wheel_url():
62+
torch_version_raw = parse(torch.__version__)
63+
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
64+
platform_name = get_platform()
65+
grouped_gemm_version = get_package_version()
66+
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
67+
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
68+
69+
# Determine the version numbers that will be used to determine the correct wheel
70+
# We're using the CUDA version used to build torch, not the one currently installed
71+
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
72+
torch_cuda_version = parse(torch.version.cuda)
73+
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
74+
# to save CI time. Minor versions should be compatible.
75+
torch_cuda_version = (
76+
parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
77+
)
78+
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
79+
cuda_version = f"{torch_cuda_version.major}"
80+
81+
# Determine wheel URL based on CUDA version, torch version, python version and OS
82+
wheel_filename = f"{PACKAGE_NAME}-{grouped_gemm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
83+
84+
wheel_url = BASE_WHEEL_URL.format(
85+
tag_name=f"v{grouped_gemm_version}", wheel_name=wheel_filename
86+
)
87+
88+
return wheel_url, wheel_filename
89+
90+
1791
cxx_flags = [
1892
"-std=c++17",
1993
"-O3",
@@ -101,23 +175,39 @@ def prepare_includes(self):
101175
ext_modules = []
102176

103177

104-
NO_LOCAL_VERSION = os.getenv("DEEP_GEMM_NO_LOCAL_VERSION", "FALSE") == "TRUE"
178+
class CachedWheelsCommand(_bdist_wheel):
179+
"""
180+
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
181+
find an existing wheel (which is currently the case for all grouped gemm installs). We use
182+
the environment parameters to detect whether there is already a pre-built version of a compatible
183+
wheel available and short-circuits the standard full build pipeline.
184+
"""
105185

186+
def run(self):
187+
if FORCE_BUILD:
188+
return super().run()
106189

107-
def get_package_version():
108-
with open(Path(current_dir) / "deep_gemm" / "__init__.py", "r") as f:
109-
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
110-
public_version = ast.literal_eval(version_match.group(1))
111-
revision = ""
112-
113-
if not NO_LOCAL_VERSION:
190+
wheel_url, wheel_filename = get_wheel_url()
191+
print("Guessing wheel URL: ", wheel_url)
114192
try:
115-
cmd = ["git", "rev-parse", "--short", "HEAD"]
116-
revision = "+" + subprocess.check_output(cmd).decode("ascii").rstrip()
117-
except:
118-
revision = ""
193+
urllib.request.urlretrieve(wheel_url, wheel_filename)
119194

120-
return f"{public_version}{revision}"
195+
# Make the archive
196+
# Lifted from the root wheel processing command
197+
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
198+
if not os.path.exists(self.dist_dir):
199+
os.makedirs(self.dist_dir)
200+
201+
impl_tag, abi_tag, plat_tag = self.get_tag()
202+
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
203+
204+
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
205+
print("Raw wheel path", wheel_path)
206+
os.rename(wheel_filename, wheel_path)
207+
except (urllib.error.HTTPError, urllib.error.URLError):
208+
print("Precompiled wheel not found. Building from source...")
209+
# If the wheel could not be downloaded, build from source
210+
super().run()
121211

122212

123213
if __name__ == "__main__":
@@ -135,7 +225,5 @@ def get_package_version():
135225
},
136226
ext_modules=ext_modules,
137227
zip_safe=False,
138-
cmdclass={
139-
"build_py": CustomBuildPy,
140-
},
228+
cmdclass={"build_py": CustomBuildPy, "bdist_wheel": CachedWheelsCommand},
141229
)

0 commit comments

Comments
 (0)