Skip to content

Commit 6b7b43c

Browse files
jeffdailymalfet
authored andcommitted
[ROCm] std::clamp work-around for hip-clang compiler (pytorch#127812)
Fixes pytorch#127666. Other std math functions are replaced with those in the global namespace during hipify. HIP does not claim to support every function in the C++ standard library. std::clamp is not yet supported and we have been relying on the std implementation. For Fedora 40 + gcc 14, a host-side assert is used which is not supported. Work-around this by replacing std::clamp with min and max. Using #ifndef USE_ROCM to differentiate between CUDA using std::clamp and the ROCm replacement broke Windows builds. The replacement generates the same PTX as std::clamp, so using the replacement unconditionally. The replacement generates the same PTX as std::clamp. See https://godbolt.org/z/Wde9KW3v4 for a sample. Original patch comes from @lamikr. Modified to improve efficiency. lamikr/rocm_sdk_builder#37 Co-authored-by: Nikita Shulga <[email protected]> Pull Request resolved: pytorch#127812 Approved by: https://github.com/hongxiayang, https://github.com/malfet
1 parent 09f6bd6 commit 6b7b43c

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

aten/src/ATen/native/cuda/IndexKernel.cu

+8-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,14 @@ void index_put_kernel_quantized_cuda(TensorIterator& iter, const IntArrayRef ind
259259

260260
gpu_index_kernel(iter, index_size, index_stride, [inv_scale, zero_point, qmin, qmax]C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) {
261261
int64_t qvalue = static_cast<int64_t>(zero_point + nearbyintf(*(float*)in_data * inv_scale));
262-
qvalue = std::clamp(qvalue, qmin, qmax);
262+
// See https://github.com/pytorch/pytorch/issues/127666
263+
// and https://github.com/pytorch/pytorch/issues/128253.
264+
// hip-clang std::clamp __glibcxx_assert_fail host function when building on Fedora40/gcc14.
265+
// The following replaces std::clamp(qvalue, qmin, qmax) and is a viable solution for
266+
// both CUDA and ROCm since std::clamp and this replacement generates the same PTX.
267+
// Using #ifdef USE_ROCM to differentiate caused Windows build failures.
268+
// The replacement should generate the same PTX as std::clamp. See https://godbolt.org/z/Wde9KW3v4
269+
qvalue = (qvalue < qmin) ? qmin : (qmax < qvalue) ? qmax : qvalue;
263270
*(scalar_t*)(out_data + offset) = static_cast<scalar_t>(qvalue);
264271
});
265272
});

0 commit comments

Comments
 (0)