11import os
22import setuptools
33import shutil
4- import subprocess
54import torch
5+ import re
6+ import ast
67from setuptools import find_packages
78from setuptools .command .build_py import build_py
89from torch .utils .cpp_extension import CUDAExtension , CUDA_HOME
10+ from pathlib import Path
11+ import subprocess
12+
13+ SKIP_CUDA_BUILD = os .getenv ("DEEP_GEMM_SKIP_CUDA_BUILD" , "FALSE" ) == "TRUE"
14+
915
1016current_dir = os .path .dirname (os .path .realpath (__file__ ))
11- cxx_flags = ['-std=c++17' , '-O3' , '-fPIC' , '-Wno-psabi' , '-Wno-deprecated-declarations' ,
12- f'-D_GLIBCXX_USE_CXX11_ABI={ int (torch .compiled_with_cxx11_abi ())} ' ]
13- sources = ['csrc/python_api.cpp' ]
14- build_include_dirs = [
15- f'{ CUDA_HOME } /include' ,
16- f'{ CUDA_HOME } /include/cccl' ,
17- 'deep_gemm/include' ,
18- 'third-party/cutlass/include' ,
19- 'third-party/fmt/include' ,
17+ cxx_flags = [
18+ "-std=c++17" ,
19+ "-O3" ,
20+ "-fPIC" ,
21+ "-Wno-psabi" ,
22+ "-Wno-deprecated-declarations" ,
23+ f"-D_GLIBCXX_USE_CXX11_ABI={ int (torch .compiled_with_cxx11_abi ())} " ,
2024]
21- build_libraries = ['cuda' , 'cudart' , 'nvrtc' ]
22- build_library_dirs = [
23- f'{ CUDA_HOME } /lib64' ,
24- f'{ CUDA_HOME } /lib64/stubs'
25+ sources = ["csrc/python_api.cpp" ]
26+ build_include_dirs = [
27+ f"{ CUDA_HOME } /include" ,
28+ f"{ CUDA_HOME } /include/cccl" ,
29+ "deep_gemm/include" ,
30+ "third-party/cutlass/include" ,
31+ "third-party/fmt/include" ,
2532]
33+ build_libraries = ["cuda" , "cudart" , "nvrtc" ]
34+ build_library_dirs = [f"{ CUDA_HOME } /lib64" , f"{ CUDA_HOME } /lib64/stubs" ]
2635third_party_include_dirs = [
27- ' third-party/cutlass/include/cute' ,
28- ' third-party/cutlass/include/cutlass' ,
36+ " third-party/cutlass/include/cute" ,
37+ " third-party/cutlass/include/cutlass" ,
2938]
3039
3140# Use runtime API
32- if int (os .environ .get (' DG_JIT_USE_RUNTIME_API' , '0' )):
33- cxx_flags .append (' -DDG_JIT_USE_RUNTIME_API' )
41+ if int (os .environ .get (" DG_JIT_USE_RUNTIME_API" , "0" )):
42+ cxx_flags .append (" -DDG_JIT_USE_RUNTIME_API" )
3443
3544
3645class CustomBuildPy (build_py ):
@@ -45,22 +54,30 @@ def run(self):
4554 build_py .run (self )
4655
4756 def generate_default_envs (self ):
48- code = '# Pre-installed environment variables\n '
49- code += 'persistent_envs = dict()\n '
50- for name in ('DG_JIT_CACHE_DIR' , 'DG_JIT_PRINT_COMPILER_COMMAND' , 'DG_JIT_CPP_STANDARD' ):
51- code += f"persistent_envs['{ name } '] = '{ os .environ [name ]} '\n " if name in os .environ else ''
57+ code = "# Pre-installed environment variables\n "
58+ code += "persistent_envs = dict()\n "
59+ for name in (
60+ "DG_JIT_CACHE_DIR" ,
61+ "DG_JIT_PRINT_COMPILER_COMMAND" ,
62+ "DG_JIT_CPP_STANDARD" ,
63+ ):
64+ code += (
65+ f"persistent_envs['{ name } '] = '{ os .environ [name ]} '\n "
66+ if name in os .environ
67+ else ""
68+ )
5269
53- with open (os .path .join (self .build_lib , ' deep_gemm' , ' envs.py' ), 'w' ) as f :
70+ with open (os .path .join (self .build_lib , " deep_gemm" , " envs.py" ), "w" ) as f :
5471 f .write (code )
5572
5673 def prepare_includes (self ):
5774 # Create temporary build directory instead of modifying package directory
58- build_include_dir = os .path .join (self .build_lib , ' deep_gemm/include' )
75+ build_include_dir = os .path .join (self .build_lib , " deep_gemm/include" )
5976 os .makedirs (build_include_dir , exist_ok = True )
6077
6178 # Copy third-party includes to the build directory
6279 for d in third_party_include_dirs :
63- dirname = d .split ('/' )[- 1 ]
80+ dirname = d .split ("/" )[- 1 ]
6481 src_dir = os .path .join (current_dir , d )
6582 dst_dir = os .path .join (build_include_dir , dirname )
6683
@@ -72,36 +89,53 @@ def prepare_includes(self):
7289 shutil .copytree (src_dir , dst_dir )
7390
7491
75- if __name__ == '__main__' :
76- # noinspection PyBroadException
77- try :
78- cmd = ['git' , 'rev-parse' , '--short' , 'HEAD' ]
79- revision = '+' + subprocess .check_output (cmd ).decode ('ascii' ).rstrip ()
80- except :
81- revision = ''
92+ if not SKIP_CUDA_BUILD :
93+ ext_modules = [
94+ CUDAExtension (
95+ name = "deep_gemm_cpp" ,
96+ sources = sources ,
97+ include_dirs = build_include_dirs ,
98+ )
99+ ]
100+ else :
101+ ext_modules = []
102+
103+
104+ NO_LOCAL_VERSION = os .getenv ("DEEP_GEMM_NO_LOCAL_VERSION" , "FALSE" ) == "TRUE"
105+
106+
107+ def get_package_version ():
108+ with open (Path (current_dir ) / "deep_gemm" / "__init__.py" , "r" ) as f :
109+ version_match = re .search (r"^__version__\s*=\s*(.*)$" , f .read (), re .MULTILINE )
110+ public_version = ast .literal_eval (version_match .group (1 ))
111+ revision = ""
112+
113+ if not NO_LOCAL_VERSION :
114+ try :
115+ cmd = ["git" , "rev-parse" , "--short" , "HEAD" ]
116+ revision = "+" + subprocess .check_output (cmd ).decode ("ascii" ).rstrip ()
117+ except :
118+ revision = ""
119+
120+ return f"{ public_version } { revision } "
121+
82122
123+ if __name__ == "__main__" :
83124 # noinspection PyTypeChecker
84125 setuptools .setup (
85- name = ' deep_gemm' ,
86- version = '2.0.0' + revision ,
87- packages = find_packages ('.' ),
126+ name = " deep_gemm" ,
127+ version = get_package_version () ,
128+ packages = find_packages ("." ),
88129 package_data = {
89- ' deep_gemm' : [
90- ' include/deep_gemm/**/*' ,
91- ' include/cute/**/*' ,
92- ' include/cutlass/**/*' ,
130+ " deep_gemm" : [
131+ " include/deep_gemm/**/*" ,
132+ " include/cute/**/*" ,
133+ " include/cutlass/**/*" ,
93134 ]
94135 },
95- ext_modules = [
96- CUDAExtension (name = 'deep_gemm_cpp' ,
97- sources = sources ,
98- include_dirs = build_include_dirs ,
99- libraries = build_libraries ,
100- library_dirs = build_library_dirs ,
101- extra_compile_args = cxx_flags )
102- ],
136+ ext_modules = ext_modules ,
103137 zip_safe = False ,
104138 cmdclass = {
105- ' build_py' : CustomBuildPy ,
139+ " build_py" : CustomBuildPy ,
106140 },
107141 )
0 commit comments