Skip to content

Commit 69a4e20

Browse files
committed
Avoid including cuda fp headers in <cuda/std/limits>
1 parent 9521ece commit 69a4e20

File tree

19 files changed

+247
-195
lines changed

19 files changed

+247
-195
lines changed

libcudacxx/include/cuda/std/__floating_point/format.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
# pragma system_header
2222
#endif // no system header
2323

24-
#include <cuda/std/__floating_point/cuda_fp_types.h>
2524
#include <cuda/std/__fwd/fp.h>
2625
#include <cuda/std/__type_traits/is_same.h>
2726
#include <cuda/std/cfloat>

libcudacxx/include/cuda/std/__floating_point/storage.h

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#endif // no system header
2323

2424
#include <cuda/std/__bit/bit_cast.h>
25-
#include <cuda/std/__floating_point/cuda_fp_types.h>
2625
#include <cuda/std/__floating_point/format.h>
2726
#include <cuda/std/__floating_point/traits.h>
2827
#include <cuda/std/__type_traits/always_false.h>
@@ -72,19 +71,11 @@ using __fp_storage_t = decltype(__fp_storage_type_impl<_Fmt>());
7271
template <class _Tp>
7372
using __fp_storage_of_t = __fp_storage_t<__fp_format_of_v<_Tp>>;
7473

75-
#if _CCCL_HAS_NVFP16()
76-
struct __cccl_nvfp16_manip_helper : __half
77-
{
78-
using __half::__x;
79-
};
80-
#endif // _CCCL_HAS_NVFP16()
81-
82-
#if _CCCL_HAS_NVBF16()
83-
struct __cccl_nvbf16_manip_helper : __nv_bfloat16
74+
template <class _Tp>
75+
struct __cccl_nvfp_manip_helper : _Tp
8476
{
85-
using __nv_bfloat16::__x;
77+
using _Tp::__x;
8678
};
87-
#endif // _CCCL_HAS_NVBF16()
8879

8980
template <class _Tp>
9081
[[nodiscard]] _CCCL_API constexpr _Tp __fp_from_storage(__fp_storage_of_t<_Tp> __v) noexcept
@@ -102,39 +93,39 @@ template <class _Tp>
10293
#if _CCCL_HAS_NVFP16()
10394
else if constexpr (is_same_v<_Tp, __half>)
10495
{
105-
__cccl_nvfp16_manip_helper __helper{};
96+
__cccl_nvfp_manip_helper<_Tp> __helper{};
10697
__helper.__x = __v;
10798
return __helper;
10899
}
109100
#endif // _CCCL_HAS_NVFP16()
110101
#if _CCCL_HAS_NVBF16()
111102
else if constexpr (is_same_v<_Tp, __nv_bfloat16>)
112103
{
113-
__cccl_nvbf16_manip_helper __helper{};
104+
__cccl_nvfp_manip_helper<_Tp> __helper{};
114105
__helper.__x = __v;
115106
return __helper;
116107
}
117108
#endif // _CCCL_HAS_NVBF16()
118109
#if _CCCL_HAS_NVFP8_E4M3()
119110
else if constexpr (is_same_v<_Tp, __nv_fp8_e4m3>)
120111
{
121-
__nv_fp8_e4m3 __ret{};
112+
_Tp __ret{};
122113
__ret.__x = __v;
123114
return __ret;
124115
}
125116
#endif // _CCCL_HAS_NVFP8_E4M3()
126117
#if _CCCL_HAS_NVFP8_E5M2()
127118
else if constexpr (is_same_v<_Tp, __nv_fp8_e5m2>)
128119
{
129-
__nv_fp8_e5m2 __ret{};
120+
_Tp __ret{};
130121
__ret.__x = __v;
131122
return __ret;
132123
}
133124
#endif // _CCCL_HAS_NVFP8_E5M2()
134125
#if _CCCL_HAS_NVFP8_E8M0()
135126
else if constexpr (is_same_v<_Tp, __nv_fp8_e8m0>)
136127
{
137-
__nv_fp8_e8m0 __ret{};
128+
_Tp __ret{};
138129
__ret.__x = __v;
139130
return __ret;
140131
}
@@ -143,7 +134,7 @@ template <class _Tp>
143134
else if constexpr (is_same_v<_Tp, __nv_fp6_e2m3>)
144135
{
145136
_CCCL_ASSERT((__v & 0xc0u) == 0u, "Invalid __nv_fp6_e2m3 storage value");
146-
__nv_fp6_e2m3 __ret{};
137+
_Tp __ret{};
147138
__ret.__x = __v;
148139
return __ret;
149140
}
@@ -152,7 +143,7 @@ template <class _Tp>
152143
else if constexpr (is_same_v<_Tp, __nv_fp6_e3m2>)
153144
{
154145
_CCCL_ASSERT((__v & 0xc0u) == 0u, "Invalid __nv_fp6_e3m2 storage value");
155-
__nv_fp6_e3m2 __ret{};
146+
_Tp __ret{};
156147
__ret.__x = __v;
157148
return __ret;
158149
}
@@ -161,7 +152,7 @@ template <class _Tp>
161152
else if constexpr (is_same_v<_Tp, __nv_fp4_e2m1>)
162153
{
163154
_CCCL_ASSERT((__v & 0xf0u) == 0u, "Invalid __nv_fp4_e2m1 storage value");
164-
__nv_fp4_e2m1 __ret{};
155+
_Tp __ret{};
165156
__ret.__x = __v;
166157
return __ret;
167158
}
@@ -190,13 +181,13 @@ template <class _Tp>
190181
#if _CCCL_HAS_NVFP16()
191182
else if constexpr (is_same_v<_Tp, __half>)
192183
{
193-
return __cccl_nvfp16_manip_helper{__v}.__x;
184+
return __cccl_nvfp_manip_helper<_Tp>{__v}.__x;
194185
}
195186
#endif // _CCCL_HAS_NVFP16()
196187
#if _CCCL_HAS_NVBF16()
197188
else if constexpr (is_same_v<_Tp, __nv_bfloat16>)
198189
{
199-
return __cccl_nvbf16_manip_helper{__v}.__x;
190+
return __cccl_nvfp_manip_helper<_Tp>{__v}.__x;
200191
}
201192
#endif // _CCCL_HAS_NVBF16()
202193
#if _CCCL_HAS_NVFP8_E4M3()

libcudacxx/include/cuda/std/__floating_point/traits.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
# pragma system_header
2222
#endif // no system header
2323

24-
#include <cuda/std/__floating_point/cuda_fp_types.h>
2524
#include <cuda/std/__floating_point/properties.h>
2625
#include <cuda/std/__fwd/fp.h>
2726

libcudacxx/include/cuda/std/__limits/numeric_limits.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ enum class __numeric_limits_type
5858
};
5959

6060
template <class _Tp>
61-
_CCCL_API constexpr __numeric_limits_type __make_numeric_limits_type()
61+
[[nodiscard]] _CCCL_API _CCCL_CONSTEVAL __numeric_limits_type __make_numeric_limits_type() noexcept
6262
{
6363
if constexpr (is_same_v<_Tp, bool>)
6464
{
@@ -78,7 +78,16 @@ _CCCL_API constexpr __numeric_limits_type __make_numeric_limits_type()
7878
}
7979
}
8080

81-
template <class _Tp, __numeric_limits_type = __make_numeric_limits_type<_Tp>()>
81+
// To avoid including nvfp headers, we add the _Up type defaulted to _Tp which makes the specialization still be a
82+
// template, which won't be instantiated unless the numeric_limits<_Tp> class is instantiated. The specialization should
83+
// look as:
84+
//
85+
// template <class _Tp>
86+
// class __numeric_limits_impl<__nvfp_type, __numeric_limits_type::__floating_point, _Tp>
87+
// { ... };
88+
//
89+
// and _Tp should be used everywhere instead of __nvfp_type.
90+
template <class _Tp, __numeric_limits_type = __make_numeric_limits_type<_Tp>(), class _Up = _Tp>
8291
class __numeric_limits_impl
8392
{
8493
public:

0 commit comments

Comments
 (0)