Skip to content

Commit 376f23b

Browse files
authored
SWDEV-516595 - Add __shfl functions with __hip_bfloat16 datatype (#42)
Also removes asserts in cooperative groups shfl functions since __hip_bfloat16 shfl is present now Change-Id: I57578b6e68dccc10c2ddcd194e9cc18bc7732ce1
1 parent d9abcdd commit 376f23b

File tree

2 files changed

+55
-8
lines changed

2 files changed

+55
-8
lines changed

hipamd/include/hip/amd_detail/amd_hip_bf16.h

+55-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/**
22
* MIT License
33
*
4-
* Copyright (c) 2019 - 2024 Advanced Micro Devices, Inc. All rights reserved.
4+
* Copyright (c) 2019 - 2025 Advanced Micro Devices, Inc. All rights reserved.
55
*
66
* Permission is hereby granted, free of charge, to any person obtaining a copy
77
* of this software and associated documentation files (the "Software"), to deal
@@ -130,6 +130,13 @@
130130
#define __BF16_DEVICE_STATIC__ __BF16_DEVICE__ static inline
131131
#define __BF16_HOST_DEVICE_STATIC__ __BF16_HOST_DEVICE__ static inline
132132

133+
#pragma push_macro("MAYBE_UNDEF")
134+
#if defined(__has_attribute) && __has_attribute(maybe_undef)
135+
#define MAYBE_UNDEF __attribute__((maybe_undef))
136+
#else
137+
#define MAYBE_UNDEF
138+
#endif
139+
133140
#define HIPRT_ONE_BF16 __ushort_as_bfloat16((unsigned short)0x3F80U)
134141
#define HIPRT_ZERO_BF16 __ushort_as_bfloat16((unsigned short)0x0000U)
135142
#define HIPRT_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U)
@@ -592,6 +599,52 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned s
592599
return u.bf16;
593600
}
594601

602+
/**
603+
* \ingroup HIP_INTRINSIC_BFLOAT16_SHFL
604+
* \brief shfl warp intrinsic for bfloat16
605+
*/
606+
__BF16_DEVICE_STATIC__
607+
__hip_bfloat16 __shfl(MAYBE_UNDEF __hip_bfloat16 var, int src_lane, int width = warpSize) {
608+
union { int i; __hip_bfloat16 f; } tmp; tmp.f = var;
609+
tmp.i = __shfl(tmp.i, src_lane, width);
610+
return tmp.f;
611+
}
612+
613+
/**
614+
* \ingroup HIP_INTRINSIC_BFLOAT16_SHFL
615+
* \brief shfl up warp intrinsic for bfloat16
616+
*/
617+
__BF16_DEVICE_STATIC__
618+
__hip_bfloat16 __shfl_up(MAYBE_UNDEF __hip_bfloat16 var,
619+
unsigned int lane_delta, int width = warpSize) {
620+
union { int i; __hip_bfloat16 f; } tmp; tmp.f = var;
621+
tmp.i = __shfl_up(tmp.i, lane_delta, width);
622+
return tmp.f;
623+
}
624+
625+
/**
626+
* \ingroup HIP_INTRINSIC_BFLOAT16_SHFL
627+
* \brief shfl down warp intrinsic for bfloat16
628+
*/
629+
__BF16_DEVICE_STATIC__
630+
__hip_bfloat16 __shfl_down(MAYBE_UNDEF __hip_bfloat16 var,
631+
unsigned int lane_delta, int width = warpSize) {
632+
union { int i; __hip_bfloat16 f; } tmp; tmp.f = var;
633+
tmp.i = __shfl_down(tmp.i, lane_delta, width);
634+
return tmp.f;
635+
}
636+
637+
/**
638+
* \ingroup HIP_INTRINSIC_BFLOAT16_SHFL
639+
* \brief shfl xor warp intrinsic for bfloat16
640+
*/
641+
__BF16_DEVICE_STATIC__
642+
__hip_bfloat16 __shfl_xor(MAYBE_UNDEF __hip_bfloat16 var, int lane_mask, int width = warpSize) {
643+
union { int i; __hip_bfloat16 f; } tmp; tmp.f = var;
644+
tmp.i = __shfl_xor(tmp.i, lane_mask, width);
645+
return tmp.f;
646+
}
647+
595648
#ifdef HIP_ENABLE_WARP_SYNC_BUILTINS
596649
/**
597650
* \ingroup HIP_INTRINSIC_BFLOAT16_MOVE
@@ -1787,4 +1840,5 @@ __BF16_DEVICE_STATIC__ __hip_bfloat16 unsafeAtomicAdd(__hip_bfloat16 *address,
17871840
return __high2bfloat16(out);
17881841
}
17891842
#endif // defined(__clang__) && defined(__HIP__)
1843+
#pragma pop_macro("MAYBE_UNDEF")
17901844
#endif

hipamd/include/hip/amd_detail/amd_hip_cooperative_groups.h

-7
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,6 @@ class coalesced_group : public thread_group {
462462
*/
463463
template <class T>
464464
__CG_QUALIFIER__ T shfl(T var, int srcRank) const {
465-
static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
466465

467466
srcRank = srcRank % static_cast<int>(size());
468467

@@ -489,7 +488,6 @@ class coalesced_group : public thread_group {
489488
*/
490489
template <class T>
491490
__CG_QUALIFIER__ T shfl_down(T var, unsigned int lane_delta) const {
492-
static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
493491

494492
// Note: The cuda implementation appears to use the remainder of lane_delta
495493
// and WARP_SIZE as the shift value rather than lane_delta itself.
@@ -530,7 +528,6 @@ class coalesced_group : public thread_group {
530528
*/
531529
template <class T>
532530
__CG_QUALIFIER__ T shfl_up(T var, unsigned int lane_delta) const {
533-
static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
534531

535532
// Note: The cuda implementation appears to use the remainder of lane_delta
536533
// and WARP_SIZE as the shift value rather than lane_delta itself.
@@ -838,22 +835,18 @@ template <unsigned int size> class thread_block_tile_base : public tile_base<siz
838835
}
839836

840837
template <class T> __CG_QUALIFIER__ T shfl(T var, int srcRank) const {
841-
static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
842838
return (__shfl(var, srcRank, numThreads));
843839
}
844840

845841
template <class T> __CG_QUALIFIER__ T shfl_down(T var, unsigned int lane_delta) const {
846-
static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
847842
return (__shfl_down(var, lane_delta, numThreads));
848843
}
849844

850845
template <class T> __CG_QUALIFIER__ T shfl_up(T var, unsigned int lane_delta) const {
851-
static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
852846
return (__shfl_up(var, lane_delta, numThreads));
853847
}
854848

855849
template <class T> __CG_QUALIFIER__ T shfl_xor(T var, unsigned int laneMask) const {
856-
static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
857850
return (__shfl_xor(var, laneMask, numThreads));
858851
}
859852

0 commit comments

Comments
 (0)