Skip to content

Commit c942dbe

Browse files
committed
Optimize cuda::std::saturating_(add|sub) for 120f
1 parent e1a7506 commit c942dbe

File tree

2 files changed

+94
-4
lines changed

2 files changed

+94
-4
lines changed

libcudacxx/include/cuda/std/__numeric/saturating_add.h

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,32 @@ template <class _Tp>
118118
{
119119
if constexpr (is_signed_v<_Tp>)
120120
{
121-
if constexpr (sizeof(_Tp) < sizeof(int32_t))
121+
if constexpr (sizeof(_Tp) == sizeof(int8_t))
122+
{
123+
# if __cccl_ptx_isa >= 920
124+
NV_IF_TARGET(NV_HAS_FEATURE_SM_120f, ({
125+
// Use uint32_t because we want to avoid sign extension.
126+
uint32_t __result;
127+
asm("add.sat.s8x4 %0, %1, %2;"
128+
: "=r"(__result)
129+
: "r"(static_cast<uint32_t>(__x)), "r"(static_cast<uint32_t>(__y)));
130+
return static_cast<_Tp>(__result);
131+
}))
132+
# endif // __cccl_ptx_isa >= 920
133+
return ::cuda::std::saturating_cast<_Tp>(int32_t{__x} + int32_t{__y});
134+
}
135+
else if constexpr (sizeof(_Tp) == sizeof(int16_t))
122136
{
137+
# if __cccl_ptx_isa >= 920
138+
NV_IF_TARGET(NV_HAS_FEATURE_SM_120f, ({
139+
// Use uint32_t because we want to avoid sign extension.
140+
uint32_t __result;
141+
asm("add.sat.s16x2 %0, %1, %2;"
142+
: "=r"(__result)
143+
: "r"(static_cast<uint32_t>(__x)), "r"(static_cast<uint32_t>(__y)));
144+
return static_cast<_Tp>(__result);
145+
}))
146+
# endif // __cccl_ptx_isa >= 920
123147
return ::cuda::std::saturating_cast<_Tp>(int32_t{__x} + int32_t{__y});
124148
}
125149
else if constexpr (sizeof(_Tp) == sizeof(int32_t))
@@ -135,7 +159,43 @@ template <class _Tp>
135159
}
136160
else
137161
{
138-
return ::cuda::saturating_add_overflow(__x, __y).value;
162+
if constexpr (sizeof(_Tp) == sizeof(uint8_t))
163+
{
164+
# if __cccl_ptx_isa >= 920
165+
NV_IF_TARGET(NV_HAS_FEATURE_SM_120f, ({
166+
uint32_t __result;
167+
asm("add.sat.u8x4 %0, %1, %2;" : "=r"(__result) : "r"(uint32_t{__x}), "r"(uint32_t{__y}));
168+
return static_cast<_Tp>(__result);
169+
}))
170+
# endif // __cccl_ptx_isa >= 920
171+
return ::cuda::saturating_add_overflow(__x, __y).value;
172+
}
173+
else if constexpr (sizeof(_Tp) == sizeof(uint16_t))
174+
{
175+
# if __cccl_ptx_isa >= 920
176+
NV_IF_TARGET(NV_HAS_FEATURE_SM_120f, ({
177+
uint32_t __result;
178+
asm("add.sat.u16x2 %0, %1, %2;" : "=r"(__result) : "r"(uint32_t{__x}), "r"(uint32_t{__y}));
179+
return static_cast<_Tp>(__result);
180+
}))
181+
# endif // __cccl_ptx_isa >= 920
182+
return ::cuda::saturating_add_overflow(__x, __y).value;
183+
}
184+
else if constexpr (sizeof(_Tp) == sizeof(uint32_t))
185+
{
186+
# if __cccl_ptx_isa >= 920
187+
NV_IF_TARGET(NV_HAS_FEATURE_SM_120f, ({
188+
uint32_t __result;
189+
asm("add.sat.u32 %0, %1, %2;" : "=r"(__result) : "r"(__x), "r"(__y));
190+
return __result;
191+
}))
192+
# endif // __cccl_ptx_isa >= 920
193+
return ::cuda::saturating_add_overflow(__x, __y).value;
194+
}
195+
else
196+
{
197+
return ::cuda::saturating_add_overflow(__x, __y).value;
198+
}
139199
}
140200
}
141201
#endif // _CCCL_CUDA_COMPILATION()

libcudacxx/include/cuda/std/__numeric/saturating_sub.h

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,23 @@ template <class _Tp>
118118
{
119119
if constexpr (is_signed_v<_Tp>)
120120
{
121-
if constexpr (sizeof(_Tp) < sizeof(int32_t))
121+
if constexpr (sizeof(_Tp) == sizeof(int8_t))
122122
{
123+
# if __cccl_ptx_isa >= 920
124+
NV_IF_TARGET(NV_HAS_FEATURE_SM_120f, ({
125+
// Use uint32_t because we want to avoid sign extension.
126+
uint32_t __result;
127+
asm("sub.sat.s8x4 %0, %1, %2;"
128+
: "=r"(__result)
129+
: "r"(static_cast<uint32_t>(__x)), "r"(static_cast<uint32_t>(__y)));
130+
return static_cast<_Tp>(__result);
131+
}))
132+
# endif // __cccl_ptx_isa >= 920
133+
return ::cuda::std::saturating_cast<_Tp>(int32_t{__x} - int32_t{__y});
134+
}
135+
else if constexpr (sizeof(_Tp) == sizeof(int16_t))
136+
{
137+
// sub.sat.s16x2 doesn't exist for now
123138
return ::cuda::std::saturating_cast<_Tp>(int32_t{__x} - int32_t{__y});
124139
}
125140
// Disabled due to nvbug 5033045
@@ -136,7 +151,22 @@ template <class _Tp>
136151
}
137152
else
138153
{
139-
return ::cuda::saturating_sub_overflow(__x, __y).value;
154+
# if __cccl_ptx_isa >= 920
155+
if constexpr (sizeof(_Tp) == sizeof(uint8_t))
156+
{
157+
NV_IF_TARGET(NV_HAS_FEATURE_SM_120f, ({
158+
uint32_t __result;
159+
asm("sub.sat.u8x4 %0, %1, %2;" : "=r"(__result) : "r"(uint32_t{__x}), "r"(uint32_t{__y}));
160+
return static_cast<_Tp>(__result);
161+
}))
162+
return ::cuda::saturating_sub_overflow(__x, __y).value;
163+
}
164+
else
165+
# endif // __cccl_ptx_isa >= 920
166+
{
167+
// sub.sat.u16x2 doesn't exist for now
168+
return ::cuda::saturating_sub_overflow(__x, __y).value;
169+
}
140170
}
141171
}
142172
#endif // _CCCL_CUDA_COMPILATION()

0 commit comments

Comments
 (0)