Skip to content

Commit 8442cf6

Browse files
committed
update
Signed-off-by: oliver könig <[email protected]>
1 parent 3692017 commit 8442cf6

File tree

1 file changed

+112
-1
lines changed

1 file changed

+112
-1
lines changed

setup.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1+
import ast
12
import os
3+
import re
24
import setuptools
35
import shutil
46
import subprocess
7+
import sys
8+
import urllib
59
import torch
10+
import platform
611
from setuptools import find_packages
712
from setuptools.command.build_py import build_py
813
from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME
14+
from pathlib import Path
15+
from packaging import version as parse
16+
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
917

1018
SKIP_CUDA_BUILD = os.getenv("DEEP_GEMM_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
1119
NO_LOCAL_VERSION = os.getenv("DEEP_GEMM_NO_LOCAL_VERSION", "FALSE") == "TRUE"
@@ -41,6 +49,63 @@
4149
if int(os.environ.get('DG_JIT_USE_RUNTIME_API', '0')):
4250
cxx_flags.append('-DDG_JIT_USE_RUNTIME_API')
4351

52+
def get_package_version():
53+
with open(Path(current_dir) / "deep_gemm" / "__init__.py", "r") as f:
54+
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
55+
public_version = ast.literal_eval(version_match.group(1))
56+
revision = ""
57+
58+
if not NO_LOCAL_VERSION:
59+
try:
60+
cmd = ["git", "rev-parse", "--short", "HEAD"]
61+
revision = "+" + subprocess.check_output(cmd).decode("ascii").rstrip()
62+
except:
63+
revision = ""
64+
65+
return f"{public_version}{revision}"
66+
67+
def get_platform():
68+
"""
69+
Returns the platform name as used in wheel filenames.
70+
"""
71+
if sys.platform.startswith("linux"):
72+
return f"linux_{platform.uname().machine}"
73+
elif sys.platform == "darwin":
74+
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
75+
return f"macosx_{mac_version}_x86_64"
76+
elif sys.platform == "win32":
77+
return "win_amd64"
78+
else:
79+
raise ValueError("Unsupported platform: {}".format(sys.platform))
80+
81+
def get_wheel_url():
82+
torch_version_raw = parse(torch.__version__)
83+
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
84+
platform_name = get_platform()
85+
grouped_gemm_version = get_package_version()
86+
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
87+
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
88+
89+
# Determine the version numbers that will be used to determine the correct wheel
90+
# We're using the CUDA version used to build torch, not the one currently installed
91+
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
92+
torch_cuda_version = parse(torch.version.cuda)
93+
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
94+
# to save CI time. Minor versions should be compatible.
95+
torch_cuda_version = (
96+
parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
97+
)
98+
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
99+
cuda_version = f"{torch_cuda_version.major}"
100+
101+
# Determine wheel URL based on CUDA version, torch version, python version and OS
102+
wheel_filename = f"{PACKAGE_NAME}-{grouped_gemm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
103+
104+
wheel_url = BASE_WHEEL_URL.format(
105+
tag_name=f"v{grouped_gemm_version}", wheel_name=wheel_filename
106+
)
107+
108+
return wheel_url, wheel_filename
44109

45110
class CustomBuildPy(build_py):
46111
def run(self):
@@ -80,6 +145,51 @@ def prepare_includes(self):
80145
# Copy the directory
81146
shutil.copytree(src_dir, dst_dir)
82147

148+
if not SKIP_CUDA_BUILD:
149+
ext_modules = [
150+
CUDAExtension(
151+
name="deep_gemm_cpp",
152+
sources=sources,
153+
include_dirs=build_include_dirs,
154+
)
155+
]
156+
else:
157+
ext_modules = []
158+
159+
class CachedWheelsCommand(_bdist_wheel):
160+
"""
161+
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
162+
find an existing wheel (which is currently the case for all grouped gemm installs). We use
163+
the environment parameters to detect whether there is already a pre-built version of a compatible
164+
wheel available and short-circuits the standard full build pipeline.
165+
"""
166+
167+
def run(self):
168+
if FORCE_BUILD:
169+
return super().run()
170+
171+
wheel_url, wheel_filename = get_wheel_url()
172+
print("Guessing wheel URL: ", wheel_url)
173+
try:
174+
urllib.request.urlretrieve(wheel_url, wheel_filename)
175+
176+
# Make the archive
177+
# Lifted from the root wheel processing command
178+
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
179+
if not os.path.exists(self.dist_dir):
180+
os.makedirs(self.dist_dir)
181+
182+
impl_tag, abi_tag, plat_tag = self.get_tag()
183+
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
184+
185+
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
186+
print("Raw wheel path", wheel_path)
187+
os.rename(wheel_filename, wheel_path)
188+
except (urllib.error.HTTPError, urllib.error.URLError):
189+
print("Precompiled wheel not found. Building from source...")
190+
# If the wheel could not be downloaded, build from source
191+
super().run()
192+
83193

84194
if __name__ == '__main__':
85195
# noinspection PyBroadException
@@ -92,7 +202,7 @@ def prepare_includes(self):
92202
# noinspection PyTypeChecker
93203
setuptools.setup(
94204
name='deep_gemm',
95-
version='2.0.0' + revision,
205+
version=get_package_version(),
96206
packages=find_packages('.'),
97207
package_data={
98208
'deep_gemm': [
@@ -112,5 +222,6 @@ def prepare_includes(self):
112222
zip_safe=False,
113223
cmdclass={
114224
'build_py': CustomBuildPy,
225+
'bdist_wheel': CachedWheelsCommand,
115226
},
116227
)

0 commit comments

Comments
 (0)