1+ from calendar import c
12import os
23import setuptools
34import shutil
910from torch .utils .cpp_extension import CUDAExtension , CUDA_HOME
1011from pathlib import Path
1112import subprocess
13+ import sys
14+ import platform
15+ from packaging .version import parse
16+ from wheel .bdist_wheel import bdist_wheel as _bdist_wheel
17+ import urllib
1218
1319SKIP_CUDA_BUILD = os .getenv ("DEEP_GEMM_SKIP_CUDA_BUILD" , "FALSE" ) == "TRUE"
20+ NO_LOCAL_VERSION = os .getenv ("DEEP_GEMM_NO_LOCAL_VERSION" , "FALSE" ) == "TRUE"
21+ FORCE_BUILD = os .getenv ("DEEP_GEMM_FORCE_BUILD" , "FALSE" ) == "TRUE"
1422
15-
23+ BASE_WHEEL_URL = (
24+ "https://github.com/DeepSeek-AI/DeepGEMM/releases/download/{tag_name}/{wheel_name}"
25+ )
26+ PACKAGE_NAME = "deep_gemm"
1627current_dir = os .path .dirname (os .path .realpath (__file__ ))
28+
29+
30+ def get_package_version ():
31+ with open (Path (current_dir ) / "deep_gemm" / "__init__.py" , "r" ) as f :
32+ version_match = re .search (r"^__version__\s*=\s*(.*)$" , f .read (), re .MULTILINE )
33+ public_version = ast .literal_eval (version_match .group (1 ))
34+ revision = ""
35+
36+ if not NO_LOCAL_VERSION :
37+ try :
38+ cmd = ["git" , "rev-parse" , "--short" , "HEAD" ]
39+ revision = "+" + subprocess .check_output (cmd ).decode ("ascii" ).rstrip ()
40+ except :
41+ revision = ""
42+
43+ return f"{ public_version } { revision } "
44+
45+
46+ def get_platform ():
47+ """
48+ Returns the platform name as used in wheel filenames.
49+ """
50+ if sys .platform .startswith ("linux" ):
51+ return f"linux_{ platform .uname ().machine } "
52+ elif sys .platform == "darwin" :
53+ mac_version = "." .join (platform .mac_ver ()[0 ].split ("." )[:2 ])
54+ return f"macosx_{ mac_version } _x86_64"
55+ elif sys .platform == "win32" :
56+ return "win_amd64"
57+ else :
58+ raise ValueError ("Unsupported platform: {}" .format (sys .platform ))
59+
60+
61+ def get_wheel_url ():
62+ torch_version_raw = parse (torch .__version__ )
63+ python_version = f"cp{ sys .version_info .major } { sys .version_info .minor } "
64+ platform_name = get_platform ()
65+ grouped_gemm_version = get_package_version ()
66+ torch_version = f"{ torch_version_raw .major } .{ torch_version_raw .minor } "
67+ cxx11_abi = str (torch ._C ._GLIBCXX_USE_CXX11_ABI ).upper ()
68+
69+ # Determine the version numbers that will be used to determine the correct wheel
70+ # We're using the CUDA version used to build torch, not the one currently installed
71+ # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
72+ torch_cuda_version = parse (torch .version .cuda )
73+ # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
74+ # to save CI time. Minor versions should be compatible.
75+ torch_cuda_version = (
76+ parse ("11.8" ) if torch_cuda_version .major == 11 else parse ("12.3" )
77+ )
78+ # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
79+ cuda_version = f"{ torch_cuda_version .major } "
80+
81+ # Determine wheel URL based on CUDA version, torch version, python version and OS
82+ wheel_filename = f"{ PACKAGE_NAME } -{ grouped_gemm_version } +cu{ cuda_version } torch{ torch_version } cxx11abi{ cxx11_abi } -{ python_version } -{ python_version } -{ platform_name } .whl"
83+
84+ wheel_url = BASE_WHEEL_URL .format (
85+ tag_name = f"v{ grouped_gemm_version } " , wheel_name = wheel_filename
86+ )
87+
88+ return wheel_url , wheel_filename
89+
90+
1791cxx_flags = [
1892 "-std=c++17" ,
1993 "-O3" ,
@@ -101,23 +175,39 @@ def prepare_includes(self):
101175 ext_modules = []
102176
103177
104- NO_LOCAL_VERSION = os .getenv ("DEEP_GEMM_NO_LOCAL_VERSION" , "FALSE" ) == "TRUE"
178+ class CachedWheelsCommand (_bdist_wheel ):
179+ """
180+ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
181+ find an existing wheel (which is currently the case for all grouped gemm installs). We use
182+ the environment parameters to detect whether there is already a pre-built version of a compatible
183+ wheel available and short-circuits the standard full build pipeline.
184+ """
105185
186+ def run (self ):
187+ if FORCE_BUILD :
188+ return super ().run ()
106189
107- def get_package_version ():
108- with open (Path (current_dir ) / "deep_gemm" / "__init__.py" , "r" ) as f :
109- version_match = re .search (r"^__version__\s*=\s*(.*)$" , f .read (), re .MULTILINE )
110- public_version = ast .literal_eval (version_match .group (1 ))
111- revision = ""
112-
113- if not NO_LOCAL_VERSION :
190+ wheel_url , wheel_filename = get_wheel_url ()
191+ print ("Guessing wheel URL: " , wheel_url )
114192 try :
115- cmd = ["git" , "rev-parse" , "--short" , "HEAD" ]
116- revision = "+" + subprocess .check_output (cmd ).decode ("ascii" ).rstrip ()
117- except :
118- revision = ""
193+ urllib .request .urlretrieve (wheel_url , wheel_filename )
119194
120- return f"{ public_version } { revision } "
195+ # Make the archive
196+ # Lifted from the root wheel processing command
197+ # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
198+ if not os .path .exists (self .dist_dir ):
199+ os .makedirs (self .dist_dir )
200+
201+ impl_tag , abi_tag , plat_tag = self .get_tag ()
202+ archive_basename = f"{ self .wheel_dist_name } -{ impl_tag } -{ abi_tag } -{ plat_tag } "
203+
204+ wheel_path = os .path .join (self .dist_dir , archive_basename + ".whl" )
205+ print ("Raw wheel path" , wheel_path )
206+ os .rename (wheel_filename , wheel_path )
207+ except (urllib .error .HTTPError , urllib .error .URLError ):
208+ print ("Precompiled wheel not found. Building from source..." )
209+ # If the wheel could not be downloaded, build from source
210+ super ().run ()
121211
122212
123213if __name__ == "__main__" :
@@ -135,7 +225,5 @@ def get_package_version():
135225 },
136226 ext_modules = ext_modules ,
137227 zip_safe = False ,
138- cmdclass = {
139- "build_py" : CustomBuildPy ,
140- },
228+ cmdclass = {"build_py" : CustomBuildPy , "bdist_wheel" : CachedWheelsCommand },
141229 )
0 commit comments