Skip to content

Commit 4bd5fe2

Browse files
committed
build: Minor tweeks for wheel build
Signed-off-by: oliver könig <[email protected]>
1 parent ea9c5d9 commit 4bd5fe2

File tree

2 files changed

+105
-59
lines changed

2 files changed

+105
-59
lines changed

deep_gemm/__init__.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
try:
66
# noinspection PyUnresolvedReferences
77
from .envs import persistent_envs
8+
89
for key, value in persistent_envs.items():
910
if key not in os.environ:
1011
os.environ[key] = value
@@ -23,19 +24,23 @@
2324
# Kernels
2425
from deep_gemm_cpp import (
2526
# FP8 GEMMs
26-
fp8_gemm_nt, fp8_gemm_nn,
27-
fp8_gemm_tn, fp8_gemm_tt,
27+
fp8_gemm_nt,
28+
fp8_gemm_nn,
29+
fp8_gemm_tn,
30+
fp8_gemm_tt,
2831
m_grouped_fp8_gemm_nt_contiguous,
2932
m_grouped_fp8_gemm_nn_contiguous,
3033
m_grouped_fp8_gemm_nt_masked,
3134
k_grouped_fp8_gemm_tn_contiguous,
3235
# BF16 GEMMs
33-
bf16_gemm_nt, bf16_gemm_nn,
34-
bf16_gemm_tn, bf16_gemm_tt,
36+
bf16_gemm_nt,
37+
bf16_gemm_nn,
38+
bf16_gemm_tn,
39+
bf16_gemm_tt,
3540
m_grouped_bf16_gemm_nt_contiguous,
3641
m_grouped_bf16_gemm_nt_masked,
3742
# Layout kernels
38-
transform_sf_into_required_layout
43+
transform_sf_into_required_layout,
3944
)
4045

4146
# Some alias for legacy supports
@@ -53,22 +58,29 @@
5358
def _find_cuda_home() -> str:
5459
# TODO: reuse PyTorch API later
5560
# For some PyTorch versions, the original `_find_cuda_home` will initialize CUDA, which is incompatible with process forks
56-
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
61+
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
5762
if cuda_home is None:
5863
# noinspection PyBroadException
5964
try:
60-
with open(os.devnull, 'w') as devnull:
61-
nvcc = subprocess.check_output(['which', 'nvcc'], stderr=devnull).decode().rstrip('\r\n')
65+
with open(os.devnull, "w") as devnull:
66+
nvcc = (
67+
subprocess.check_output(["which", "nvcc"], stderr=devnull)
68+
.decode()
69+
.rstrip("\r\n")
70+
)
6271
cuda_home = os.path.dirname(os.path.dirname(nvcc))
6372
except Exception:
64-
cuda_home = '/usr/local/cuda'
73+
cuda_home = "/usr/local/cuda"
6574
if not os.path.exists(cuda_home):
6675
cuda_home = None
6776
assert cuda_home is not None
6877
return cuda_home
6978

7079

7180
deep_gemm_cpp.init(
72-
os.path.dirname(os.path.abspath(__file__)), # Library root directory path
73-
_find_cuda_home() # CUDA home
81+
os.path.dirname(os.path.abspath(__file__)), # Library root directory path
82+
_find_cuda_home(), # CUDA home
7483
)
84+
85+
86+
__version__ = "2.0.0"

setup.py

Lines changed: 82 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,45 @@
11
import os
22
import setuptools
33
import shutil
4-
import subprocess
54
import torch
5+
import re
6+
import ast
67
from setuptools import find_packages
78
from setuptools.command.build_py import build_py
89
from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME
10+
from pathlib import Path
11+
import subprocess
12+
13+
SKIP_CUDA_BUILD = os.getenv("DEEP_GEMM_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
14+
915

1016
current_dir = os.path.dirname(os.path.realpath(__file__))
11-
cxx_flags = ['-std=c++17', '-O3', '-fPIC', '-Wno-psabi', '-Wno-deprecated-declarations',
12-
f'-D_GLIBCXX_USE_CXX11_ABI={int(torch.compiled_with_cxx11_abi())}']
13-
sources = ['csrc/python_api.cpp']
14-
build_include_dirs = [
15-
f'{CUDA_HOME}/include',
16-
f'{CUDA_HOME}/include/cccl',
17-
'deep_gemm/include',
18-
'third-party/cutlass/include',
19-
'third-party/fmt/include',
17+
cxx_flags = [
18+
"-std=c++17",
19+
"-O3",
20+
"-fPIC",
21+
"-Wno-psabi",
22+
"-Wno-deprecated-declarations",
23+
f"-D_GLIBCXX_USE_CXX11_ABI={int(torch.compiled_with_cxx11_abi())}",
2024
]
21-
build_libraries = ['cuda', 'cudart', 'nvrtc']
22-
build_library_dirs = [
23-
f'{CUDA_HOME}/lib64',
24-
f'{CUDA_HOME}/lib64/stubs'
25+
sources = ["csrc/python_api.cpp"]
26+
build_include_dirs = [
27+
f"{CUDA_HOME}/include",
28+
f"{CUDA_HOME}/include/cccl",
29+
"deep_gemm/include",
30+
"third-party/cutlass/include",
31+
"third-party/fmt/include",
2532
]
33+
build_libraries = ["cuda", "cudart", "nvrtc"]
34+
build_library_dirs = [f"{CUDA_HOME}/lib64", f"{CUDA_HOME}/lib64/stubs"]
2635
third_party_include_dirs = [
27-
'third-party/cutlass/include/cute',
28-
'third-party/cutlass/include/cutlass',
36+
"third-party/cutlass/include/cute",
37+
"third-party/cutlass/include/cutlass",
2938
]
3039

3140
# Use runtime API
32-
if int(os.environ.get('DG_JIT_USE_RUNTIME_API', '0')):
33-
cxx_flags.append('-DDG_JIT_USE_RUNTIME_API')
41+
if int(os.environ.get("DG_JIT_USE_RUNTIME_API", "0")):
42+
cxx_flags.append("-DDG_JIT_USE_RUNTIME_API")
3443

3544

3645
class CustomBuildPy(build_py):
@@ -45,22 +54,30 @@ def run(self):
4554
build_py.run(self)
4655

4756
def generate_default_envs(self):
48-
code = '# Pre-installed environment variables\n'
49-
code += 'persistent_envs = dict()\n'
50-
for name in ('DG_JIT_CACHE_DIR', 'DG_JIT_PRINT_COMPILER_COMMAND', 'DG_JIT_CPP_STANDARD'):
51-
code += f"persistent_envs['{name}'] = '{os.environ[name]}'\n" if name in os.environ else ''
57+
code = "# Pre-installed environment variables\n"
58+
code += "persistent_envs = dict()\n"
59+
for name in (
60+
"DG_JIT_CACHE_DIR",
61+
"DG_JIT_PRINT_COMPILER_COMMAND",
62+
"DG_JIT_CPP_STANDARD",
63+
):
64+
code += (
65+
f"persistent_envs['{name}'] = '{os.environ[name]}'\n"
66+
if name in os.environ
67+
else ""
68+
)
5269

53-
with open(os.path.join(self.build_lib, 'deep_gemm', 'envs.py'), 'w') as f:
70+
with open(os.path.join(self.build_lib, "deep_gemm", "envs.py"), "w") as f:
5471
f.write(code)
5572

5673
def prepare_includes(self):
5774
# Create temporary build directory instead of modifying package directory
58-
build_include_dir = os.path.join(self.build_lib, 'deep_gemm/include')
75+
build_include_dir = os.path.join(self.build_lib, "deep_gemm/include")
5976
os.makedirs(build_include_dir, exist_ok=True)
6077

6178
# Copy third-party includes to the build directory
6279
for d in third_party_include_dirs:
63-
dirname = d.split('/')[-1]
80+
dirname = d.split("/")[-1]
6481
src_dir = os.path.join(current_dir, d)
6582
dst_dir = os.path.join(build_include_dir, dirname)
6683

@@ -72,36 +89,53 @@ def prepare_includes(self):
7289
shutil.copytree(src_dir, dst_dir)
7390

7491

75-
if __name__ == '__main__':
76-
# noinspection PyBroadException
77-
try:
78-
cmd = ['git', 'rev-parse', '--short', 'HEAD']
79-
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
80-
except:
81-
revision = ''
92+
if not SKIP_CUDA_BUILD:
93+
ext_modules = [
94+
CUDAExtension(
95+
name="deep_gemm_cpp",
96+
sources=sources,
97+
include_dirs=build_include_dirs,
98+
)
99+
]
100+
else:
101+
ext_modules = []
102+
103+
104+
NO_LOCAL_VERSION = os.getenv("DEEP_GEMM_NO_LOCAL_VERSION", "FALSE") == "TRUE"
105+
106+
107+
def get_package_version():
108+
with open(Path(current_dir) / "deep_gemm" / "__init__.py", "r") as f:
109+
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
110+
public_version = ast.literal_eval(version_match.group(1))
111+
revision = ""
112+
113+
if not NO_LOCAL_VERSION:
114+
try:
115+
cmd = ["git", "rev-parse", "--short", "HEAD"]
116+
revision = "+" + subprocess.check_output(cmd).decode("ascii").rstrip()
117+
except:
118+
revision = ""
119+
120+
return f"{public_version}{revision}"
121+
82122

123+
if __name__ == "__main__":
83124
# noinspection PyTypeChecker
84125
setuptools.setup(
85-
name='deep_gemm',
86-
version='2.0.0' + revision,
87-
packages=find_packages('.'),
126+
name="deep_gemm",
127+
version=get_package_version(),
128+
packages=find_packages("."),
88129
package_data={
89-
'deep_gemm': [
90-
'include/deep_gemm/**/*',
91-
'include/cute/**/*',
92-
'include/cutlass/**/*',
130+
"deep_gemm": [
131+
"include/deep_gemm/**/*",
132+
"include/cute/**/*",
133+
"include/cutlass/**/*",
93134
]
94135
},
95-
ext_modules=[
96-
CUDAExtension(name='deep_gemm_cpp',
97-
sources=sources,
98-
include_dirs=build_include_dirs,
99-
libraries=build_libraries,
100-
library_dirs=build_library_dirs,
101-
extra_compile_args=cxx_flags)
102-
],
136+
ext_modules=ext_modules,
103137
zip_safe=False,
104138
cmdclass={
105-
'build_py': CustomBuildPy,
139+
"build_py": CustomBuildPy,
106140
},
107141
)

0 commit comments

Comments
 (0)