1+ import ast
12import os
3+ import re
24import setuptools
35import shutil
46import subprocess
7+ import sys
8+ import urllib
59import torch
10+ import platform
611from setuptools import find_packages
712from setuptools .command .build_py import build_py
813from torch .utils .cpp_extension import CUDAExtension , CUDA_HOME
14+ from pathlib import Path
15+ from packaging import version as parse
16+ from wheel .bdist_wheel import bdist_wheel as _bdist_wheel
917
1018SKIP_CUDA_BUILD = os .getenv ("DEEP_GEMM_SKIP_CUDA_BUILD" , "FALSE" ) == "TRUE"
1119NO_LOCAL_VERSION = os .getenv ("DEEP_GEMM_NO_LOCAL_VERSION" , "FALSE" ) == "TRUE"
4149if int (os .environ .get ('DG_JIT_USE_RUNTIME_API' , '0' )):
4250 cxx_flags .append ('-DDG_JIT_USE_RUNTIME_API' )
4351
52+ def get_package_version ():
53+ with open (Path (current_dir ) / "deep_gemm" / "__init__.py" , "r" ) as f :
54+ version_match = re .search (r"^__version__\s*=\s*(.*)$" , f .read (), re .MULTILINE )
55+ public_version = ast .literal_eval (version_match .group (1 ))
56+ revision = ""
57+
58+ if not NO_LOCAL_VERSION :
59+ try :
60+ cmd = ["git" , "rev-parse" , "--short" , "HEAD" ]
61+ revision = "+" + subprocess .check_output (cmd ).decode ("ascii" ).rstrip ()
62+ except :
63+ revision = ""
64+
65+ return f"{ public_version } { revision } "
66+
67+ def get_platform ():
68+ """
69+ Returns the platform name as used in wheel filenames.
70+ """
71+ if sys .platform .startswith ("linux" ):
72+ return f"linux_{ platform .uname ().machine } "
73+ elif sys .platform == "darwin" :
74+ mac_version = "." .join (platform .mac_ver ()[0 ].split ("." )[:2 ])
75+ return f"macosx_{ mac_version } _x86_64"
76+ elif sys .platform == "win32" :
77+ return "win_amd64"
78+ else :
79+ raise ValueError ("Unsupported platform: {}" .format (sys .platform ))
80+
81+ def get_wheel_url ():
82+ torch_version_raw = parse (torch .__version__ )
83+ python_version = f"cp{ sys .version_info .major } { sys .version_info .minor } "
84+ platform_name = get_platform ()
85+ grouped_gemm_version = get_package_version ()
86+ torch_version = f"{ torch_version_raw .major } .{ torch_version_raw .minor } "
87+ cxx11_abi = str (torch ._C ._GLIBCXX_USE_CXX11_ABI ).upper ()
88+
89+ # Determine the version numbers that will be used to determine the correct wheel
90+ # We're using the CUDA version used to build torch, not the one currently installed
91+ # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
92+ torch_cuda_version = parse (torch .version .cuda )
93+ # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
94+ # to save CI time. Minor versions should be compatible.
95+ torch_cuda_version = (
96+ parse ("11.8" ) if torch_cuda_version .major == 11 else parse ("12.3" )
97+ )
98+ # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
99+ cuda_version = f"{ torch_cuda_version .major } "
100+
101+ # Determine wheel URL based on CUDA version, torch version, python version and OS
102+ wheel_filename = f"{ PACKAGE_NAME } -{ grouped_gemm_version } +cu{ cuda_version } torch{ torch_version } cxx11abi{ cxx11_abi } -{ python_version } -{ python_version } -{ platform_name } .whl"
103+
104+ wheel_url = BASE_WHEEL_URL .format (
105+ tag_name = f"v{ grouped_gemm_version } " , wheel_name = wheel_filename
106+ )
107+
108+ return wheel_url , wheel_filename
44109
45110class CustomBuildPy (build_py ):
46111 def run (self ):
@@ -80,6 +145,51 @@ def prepare_includes(self):
80145 # Copy the directory
81146 shutil .copytree (src_dir , dst_dir )
82147
148+ if not SKIP_CUDA_BUILD :
149+ ext_modules = [
150+ CUDAExtension (
151+ name = "deep_gemm_cpp" ,
152+ sources = sources ,
153+ include_dirs = build_include_dirs ,
154+ )
155+ ]
156+ else :
157+ ext_modules = []
158+
159+ class CachedWheelsCommand (_bdist_wheel ):
160+ """
161+ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
162+ find an existing wheel (which is currently the case for all grouped gemm installs). We use
163+ the environment parameters to detect whether there is already a pre-built version of a compatible
164+ wheel available and short-circuits the standard full build pipeline.
165+ """
166+
167+ def run (self ):
168+ if FORCE_BUILD :
169+ return super ().run ()
170+
171+ wheel_url , wheel_filename = get_wheel_url ()
172+ print ("Guessing wheel URL: " , wheel_url )
173+ try :
174+ urllib .request .urlretrieve (wheel_url , wheel_filename )
175+
176+ # Make the archive
177+ # Lifted from the root wheel processing command
178+ # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
179+ if not os .path .exists (self .dist_dir ):
180+ os .makedirs (self .dist_dir )
181+
182+ impl_tag , abi_tag , plat_tag = self .get_tag ()
183+ archive_basename = f"{ self .wheel_dist_name } -{ impl_tag } -{ abi_tag } -{ plat_tag } "
184+
185+ wheel_path = os .path .join (self .dist_dir , archive_basename + ".whl" )
186+ print ("Raw wheel path" , wheel_path )
187+ os .rename (wheel_filename , wheel_path )
188+ except (urllib .error .HTTPError , urllib .error .URLError ):
189+ print ("Precompiled wheel not found. Building from source..." )
190+ # If the wheel could not be downloaded, build from source
191+ super ().run ()
192+
83193
84194if __name__ == '__main__' :
85195 # noinspection PyBroadException
@@ -92,7 +202,7 @@ def prepare_includes(self):
92202 # noinspection PyTypeChecker
93203 setuptools .setup (
94204 name = 'deep_gemm' ,
95- version = '2.0.0' + revision ,
205+ version = get_package_version () ,
96206 packages = find_packages ('.' ),
97207 package_data = {
98208 'deep_gemm' : [
@@ -112,5 +222,6 @@ def prepare_includes(self):
112222 zip_safe = False ,
113223 cmdclass = {
114224 'build_py' : CustomBuildPy ,
225+ 'bdist_wheel' : CachedWheelsCommand ,
115226 },
116227 )
0 commit comments