-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsetup.py
More file actions
32 lines (27 loc) · 1.1 KB
/
setup.py
File metadata and controls
32 lines (27 loc) · 1.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from setuptools import find_packages, setup
import os
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
# set cuda arch if cross-compiling and gpu not available (e.g. during docker build)
# https://en.wikipedia.org/wiki/CUDA#GPUs_supported
if not torch.cuda.is_available():
if os.environ.get("TORCH_CUDA_ARCH_LIST") is None:
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.5;8.0;8.6;8.7;8.9+PTX" # turing, ampere, ada
setup(
name="cuda_event_ops",
packages=find_packages(),
ext_modules=[
CUDAExtension(
"iterative_3d_warp_cuda._C",
["cuda_event_ops/iterative_3d_warp/extension.cpp", "cuda_event_ops/iterative_3d_warp/kernel.cu"],
extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3"]},
),
CUDAExtension(
"trilinear_splat_cuda._C",
["cuda_event_ops/trilinear_splat/extension.cpp", "cuda_event_ops/trilinear_splat/kernel.cu"],
extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3"]},
),
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
)