Skip to content

Commit db5b5b8

Browse files
committed
add safety docs, make subgroup_quad_broadcast safe
1 parent 8924aa0 commit db5b5b8

File tree

1 file changed

+41
-1
lines changed

1 file changed

+41
-1
lines changed

crates/spirv-std/src/arch/subgroup.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,11 @@ pub fn subgroup_all_equal<T: VectorOrScalar>(value: T) -> bool {
294294
/// The resulting value is undefined if `id` is an inactive invocation, or is greater than or equal to the size of the group.
295295
///
296296
/// Requires Capability `GroupNonUniformBallot`.
297+
///
298+
/// # Safety
299+
/// * `id` must not be dynamically uniform
300+
/// * before 1.5: `id` must be constant
301+
/// * Result is undefined if `id` is an inactive invocation or out of bounds
297302
#[spirv_std_macros::gpu_only]
298303
#[doc(alias = "OpGroupNonUniformBroadcast")]
299304
#[inline]
@@ -396,6 +401,9 @@ pub fn subgroup_ballot(predicate: bool) -> SubgroupMask {
396401
/// `value` is a set of bitfields where the first invocation is represented in the lowest bit of the first vector component and the last (up to the size of the group) is the higher bit number of the last bitmask needed to represent all bits of the group invocations.
397402
///
398403
/// Requires Capability `GroupNonUniformBallot`.
404+
///
405+
/// # Safety
406+
/// * `value` must be the same for all dynamic instances of this instruction
399407
#[spirv_std_macros::gpu_only]
400408
#[doc(alias = "OpGroupNonUniformInverseBallot")]
401409
#[inline]
@@ -434,6 +442,10 @@ pub unsafe fn subgroup_inverse_ballot(value: SubgroupMask) -> bool {
434442
/// The resulting value is undefined if `index` is greater than or equal to the size of the group.
435443
///
436444
/// Requires Capability `GroupNonUniformBallot`.
445+
///
446+
/// # Safety
447+
/// * This function is safe
448+
/// * Result is undefined if `id` is out of bounds
437449
#[spirv_std_macros::gpu_only]
438450
#[doc(alias = "OpGroupNonUniformBallotBitExtract")]
439451
#[inline]
@@ -520,6 +532,10 @@ macro_subgroup_ballot_bit_count!(
520532
/// `value` is a set of bitfields where the first invocation is represented in the lowest bit of the first vector component and the last (up to the size of the group) is the higher bit number of the last bitmask needed to represent all bits of the group invocations.
521533
///
522534
/// Requires Capability `GroupNonUniformBallot`.
535+
///
536+
/// # Safety
537+
/// * This function is safe
538+
/// * Result is undefined if `id` is an inactive invocation or out of bounds
523539
#[spirv_std_macros::gpu_only]
524540
#[doc(alias = "OpGroupNonUniformBallotFindLSB")]
525541
#[inline]
@@ -588,6 +604,10 @@ pub fn subgroup_ballot_find_msb(value: SubgroupMask) -> u32 {
588604
/// The resulting value is undefined if `id` is an inactive invocation, or is greater than or equal to the size of the group.
589605
///
590606
/// Requires Capability `GroupNonUniformShuffle`.
607+
///
608+
/// # Safety
609+
/// * This function is safe
610+
/// * Result is undefined if `id` is an inactive invocation or out of bounds
591611
#[spirv_std_macros::gpu_only]
592612
#[doc(alias = "OpGroupNonUniformShuffle")]
593613
#[inline]
@@ -625,6 +645,10 @@ pub fn subgroup_shuffle<T: VectorOrScalar>(value: T, id: u32) -> T {
625645
/// The resulting value is undefined if current invocation’s id within the group xor’ed with Mask is an inactive invocation, or is greater than or equal to the size of the group.
626646
///
627647
/// Requires Capability `GroupNonUniformShuffle`.
648+
///
649+
/// # Safety
650+
/// * This function is safe
651+
/// * Result is undefined if current invocation’s id within the group xor’ed with `mask` is an inactive invocation or out of bounds
628652
#[spirv_std_macros::gpu_only]
629653
#[doc(alias = "OpGroupNonUniformShuffleXor")]
630654
#[inline]
@@ -662,6 +686,10 @@ pub fn subgroup_shuffle_xor<T: VectorOrScalar>(value: T, mask: u32) -> T {
662686
/// Delta is treated as unsigned and the resulting value is undefined if Delta is greater than the current invocation’s id within the group or if the selected lane is inactive.
663687
///
664688
/// Requires Capability `GroupNonUniformShuffleRelative`.
689+
///
690+
/// # Safety
691+
/// * This function is safe
692+
/// * Result is undefined if `delta` is greater than the current invocation’s id within the group or if the selected lane is inactive
665693
#[spirv_std_macros::gpu_only]
666694
#[doc(alias = "OpGroupNonUniformShuffleUp")]
667695
#[inline]
@@ -699,6 +727,10 @@ pub fn subgroup_shuffle_up<T: VectorOrScalar>(value: T, delta: u32) -> T {
699727
/// Delta is treated as unsigned and the resulting value is undefined if Delta is greater than or equal to the size of the group, or if the current invocation’s id within the group + Delta is either an inactive invocation or greater than or equal to the size of the group.
700728
///
701729
/// Requires Capability `GroupNonUniformShuffleRelative`.
730+
///
731+
/// # Safety
732+
/// * This function is safe
733+
/// * Result is undefined if `delta` is greater than or equal to the size of the group, or if the current invocation’s id within the group + `delta` is either an inactive invocation or greater than or equal to the size of the group.
702734
#[spirv_std_macros::gpu_only]
703735
#[doc(alias = "OpGroupNonUniformShuffleDown")]
704736
#[inline]
@@ -1264,10 +1296,14 @@ Requires Capability `GroupNonUniformArithmetic` and `GroupNonUniformClustered`.
12641296
/// If the value of `index` is greater than or equal to 4, or refers to an inactive invocation, the resulting value is undefined.
12651297
///
12661298
/// Requires Capability `GroupNonUniformQuad`.
1299+
///
1300+
/// # Safety
1301+
/// * This function is safe
1302+
/// * Result is undefined if the value of `index` is greater than or equal to 4, or refers to an inactive invocation
12671303
#[spirv_std_macros::gpu_only]
12681304
#[doc(alias = "OpGroupNonUniformQuadBroadcast")]
12691305
#[inline]
1270-
pub unsafe fn subgroup_quad_broadcast<T: VectorOrScalar>(value: T, index: u32) -> T {
1306+
pub fn subgroup_quad_broadcast<T: VectorOrScalar>(value: T, index: u32) -> T {
12711307
let mut result = T::default();
12721308

12731309
unsafe {
@@ -1343,6 +1379,10 @@ pub enum QuadDirection {
13431379
/// If an active invocation reads `value` from an inactive invocation, the resulting value is undefined.
13441380
///
13451381
/// Requires Capability `GroupNonUniformQuad`.
1382+
///
1383+
/// # Safety
1384+
/// * This function is safe
1385+
/// * Result is undefined if an active invocation reads `value` from an inactive invocation
13461386
#[spirv_std_macros::gpu_only]
13471387
#[doc(alias = "OpGroupNonUniformQuadSwap")]
13481388
#[inline]

0 commit comments

Comments
 (0)