Skip to content

Commit 3a8ce6d

Browse files
committed
[ROCm] std::clamp work-around for hip-clang compiler
Fixes pytorch#127666. Other std math functions are replaced with those in the global namespace during hipify. HIP does not claim to support every function in the C++ standard library. std::clamp is not yet supported and we have been relying on the std implementation. For Fedora 40 + gcc 14, a host-side assert is used which is not supported. Work-around this by replacing std::clamp with min and max for USE_ROCM builds. Patch comes from @lamikr. Modified to use #ifndef USE_ROCM. lamikr/rocm_sdk_builder#37
1 parent db9d457 commit 3a8ce6d

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

aten/src/ATen/native/cuda/IndexKernel.cu

+7
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,14 @@ void index_put_kernel_quantized_cuda(TensorIterator& iter, const IntArrayRef ind
249249

250250
gpu_index_kernel(iter, index_size, index_stride, [inv_scale, zero_point, qmin, qmax]C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) {
251251
int64_t qvalue = static_cast<int64_t>(zero_point + nearbyintf(*(float*)in_data * inv_scale));
252+
// See https://github.com/pytorch/pytorch/issues/127666
253+
// hip-clang std::clamp __glibcxx_assert_fail host function when building on Fedora40/gcc14
254+
#ifndef USE_ROCM
252255
qvalue = std::clamp(qvalue, qmin, qmax);
256+
#else
257+
int64_t new_max = std::max<int64_t>(qmin, qvalue);
258+
qvalue = std::min<int64_t>(qmax, new_max);
259+
#endif
253260
*(scalar_t*)(out_data + offset) = static_cast<scalar_t>(qvalue);
254261
});
255262
});

0 commit comments

Comments
 (0)