diff --git a/qtorch/quant/quant_cuda/bit_helper.cu b/qtorch/quant/quant_cuda/bit_helper.cu index 794255f..1df46a2 100644 --- a/qtorch/quant/quant_cuda/bit_helper.cu +++ b/qtorch/quant/quant_cuda/bit_helper.cu @@ -1,6 +1,12 @@ #define FLOAT_TO_BITS(x) (*reinterpret_cast(x)) #define BITS_TO_FLOAT(x) (*reinterpret_cast(x)) +#ifdef __HIP__ +#ifndef __forceinline__ +#define __forceinline__ inline __attribute__((always_inline)) +#endif +#endif + __device__ __forceinline__ unsigned int extract_exponent(float *a) { unsigned int temp = *(reinterpret_cast(a)); temp = (temp << 1 >> 24); // single preciision, 1 sign bit, 23 mantissa bits diff --git a/qtorch/quant/quant_cuda/fixed_point_kernel.cu b/qtorch/quant/quant_cuda/fixed_point_kernel.cu index 99b7727..6158eda 100644 --- a/qtorch/quant/quant_cuda/fixed_point_kernel.cu +++ b/qtorch/quant/quant_cuda/fixed_point_kernel.cu @@ -1,6 +1,11 @@ #include "quant_kernel.h" #include "sim_helper.cu" +#ifdef __HIP__ +#ifndef __forceinline__ +#define __forceinline__ inline __attribute__((always_inline)) +#endif +#endif template __device__ __forceinline__ T clamp_helper(T a, T min, T max) { diff --git a/qtorch/quant/quant_function.py b/qtorch/quant/quant_function.py index 061676c..6555bc6 100644 --- a/qtorch/quant/quant_function.py +++ b/qtorch/quant/quant_function.py @@ -28,6 +28,7 @@ os.path.join(current_path, "quant_cuda/fixed_point_kernel.cu"), os.path.join(current_path, "quant_cuda/quant.cu"), ], + extra_include_paths=[os.path.join(current_path, "quant_cuda")], ) else: quant_cuda = quant_cpu