Skip to content

Commit d56c0ed

Browse files
committed
subgroup: make VectorOrScalar trait match discussions in EmbarkStudios/rust-gpu#1030
1 parent 589af48 commit d56c0ed

File tree

6 files changed

+66
-112
lines changed

6 files changed

+66
-112
lines changed

crates/spirv-std/src/arch/subgroup.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::float::Float;
44
use crate::integer::{Integer, SignedInteger, UnsignedInteger};
55
#[cfg(target_arch = "spirv")]
66
use crate::memory::{Scope, Semantics};
7-
use crate::scalar::VectorOrScalar;
7+
use crate::vector::VectorOrScalar;
88
#[cfg(target_arch = "spirv")]
99
use core::arch::asm;
1010

crates/spirv-std/src/float.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
//! Traits and helper functions related to floats.
22
3-
use crate::scalar::VectorOrScalar;
43
use crate::vector::Vector;
4+
use crate::vector::{create_dim, VectorOrScalar};
55
#[cfg(target_arch = "spirv")]
66
use core::arch::asm;
7+
use core::num::NonZeroUsize;
78

89
/// Abstract trait representing a SPIR-V floating point type.
910
///
@@ -74,6 +75,7 @@ struct F32x2 {
7475
}
7576
unsafe impl VectorOrScalar for F32x2 {
7677
type Scalar = f32;
78+
const DIM: NonZeroUsize = create_dim(2);
7779
}
7880
unsafe impl Vector<f32, 2> for F32x2 {}
7981

crates/spirv-std/src/scalar.rs

Lines changed: 15 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,7 @@
11
//! Traits related to scalars.
22
3-
/// Abstract trait representing either a vector or a scalar type.
4-
///
5-
/// # Safety
6-
/// Implementing this trait on non-scalar or non-vector types may break assumptions about other
7-
/// unsafe code, and should not be done.
8-
pub unsafe trait VectorOrScalar: Default {
9-
/// Either the scalar component type of the vector or the scalar itself.
10-
type Scalar: Scalar;
11-
}
12-
13-
unsafe impl VectorOrScalar for bool {
14-
type Scalar = bool;
15-
}
16-
unsafe impl VectorOrScalar for f32 {
17-
type Scalar = f32;
18-
}
19-
unsafe impl VectorOrScalar for f64 {
20-
type Scalar = f64;
21-
}
22-
unsafe impl VectorOrScalar for u8 {
23-
type Scalar = u8;
24-
}
25-
unsafe impl VectorOrScalar for u16 {
26-
type Scalar = u16;
27-
}
28-
unsafe impl VectorOrScalar for u32 {
29-
type Scalar = u32;
30-
}
31-
unsafe impl VectorOrScalar for u64 {
32-
type Scalar = u64;
33-
}
34-
unsafe impl VectorOrScalar for i8 {
35-
type Scalar = i8;
36-
}
37-
unsafe impl VectorOrScalar for i16 {
38-
type Scalar = i16;
39-
}
40-
unsafe impl VectorOrScalar for i32 {
41-
type Scalar = i32;
42-
}
43-
unsafe impl VectorOrScalar for i64 {
44-
type Scalar = i64;
45-
}
3+
use crate::vector::{create_dim, VectorOrScalar};
4+
use core::num::NonZeroUsize;
465

476
/// Abstract trait representing a SPIR-V scalar type.
487
///
@@ -54,14 +13,16 @@ pub unsafe trait Scalar:
5413
{
5514
}
5615

57-
unsafe impl Scalar for bool {}
58-
unsafe impl Scalar for f32 {}
59-
unsafe impl Scalar for f64 {}
60-
unsafe impl Scalar for u8 {}
61-
unsafe impl Scalar for u16 {}
62-
unsafe impl Scalar for u32 {}
63-
unsafe impl Scalar for u64 {}
64-
unsafe impl Scalar for i8 {}
65-
unsafe impl Scalar for i16 {}
66-
unsafe impl Scalar for i32 {}
67-
unsafe impl Scalar for i64 {}
16+
macro_rules! impl_scalar {
17+
($($ty:ty),+) => {
18+
$(
19+
unsafe impl VectorOrScalar for $ty {
20+
type Scalar = Self;
21+
const DIM: NonZeroUsize = create_dim(1);
22+
}
23+
unsafe impl Scalar for $ty {}
24+
)+
25+
};
26+
}
27+
28+
impl_scalar!(bool, f32, f64, u8, u16, u32, u64, i8, i16, i32, i64);

crates/spirv-std/src/vector.rs

Lines changed: 35 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,28 @@
11
//! Traits related to vectors.
22
3-
use crate::scalar::{Scalar, VectorOrScalar};
3+
use crate::scalar::Scalar;
4+
use core::num::NonZeroUsize;
45
use glam::{Vec3Swizzles, Vec4Swizzles};
56

6-
unsafe impl VectorOrScalar for glam::Vec2 {
7-
type Scalar = f32;
8-
}
9-
unsafe impl VectorOrScalar for glam::Vec3 {
10-
type Scalar = f32;
11-
}
12-
unsafe impl VectorOrScalar for glam::Vec3A {
13-
type Scalar = f32;
14-
}
15-
unsafe impl VectorOrScalar for glam::Vec4 {
16-
type Scalar = f32;
17-
}
18-
19-
unsafe impl VectorOrScalar for glam::DVec2 {
20-
type Scalar = f64;
21-
}
22-
unsafe impl VectorOrScalar for glam::DVec3 {
23-
type Scalar = f64;
24-
}
25-
unsafe impl VectorOrScalar for glam::DVec4 {
26-
type Scalar = f64;
27-
}
7+
/// Abstract trait representing either a vector or a scalar type.
8+
///
9+
/// # Safety
10+
/// Implementing this trait on non-scalar or non-vector types may break assumptions about other
11+
/// unsafe code, and should not be done.
12+
pub unsafe trait VectorOrScalar: Default {
13+
/// Either the scalar component type of the vector or the scalar itself.
14+
type Scalar: Scalar;
2815

29-
unsafe impl VectorOrScalar for glam::UVec2 {
30-
type Scalar = u32;
31-
}
32-
unsafe impl VectorOrScalar for glam::UVec3 {
33-
type Scalar = u32;
34-
}
35-
unsafe impl VectorOrScalar for glam::UVec4 {
36-
type Scalar = u32;
16+
/// The dimension of the vector, or 1 if it is a scalar
17+
const DIM: NonZeroUsize;
3718
}
3819

39-
unsafe impl VectorOrScalar for glam::IVec2 {
40-
type Scalar = i32;
41-
}
42-
unsafe impl VectorOrScalar for glam::IVec3 {
43-
type Scalar = i32;
44-
}
45-
unsafe impl VectorOrScalar for glam::IVec4 {
46-
type Scalar = i32;
20+
/// replace with `NonZeroUsize::new(n).unwrap()` once `unwrap()` is const stabilized
21+
pub(crate) const fn create_dim(n: usize) -> NonZeroUsize {
22+
match NonZeroUsize::new(n) {
23+
None => panic!("dim must not be 0"),
24+
Some(n) => n,
25+
}
4726
}
4827

4928
/// Abstract trait representing a SPIR-V vector type.
@@ -53,22 +32,24 @@ unsafe impl VectorOrScalar for glam::IVec4 {
5332
/// should not be done.
5433
pub unsafe trait Vector<T: Scalar, const N: usize>: VectorOrScalar<Scalar = T> {}
5534

56-
unsafe impl Vector<f32, 2> for glam::Vec2 {}
57-
unsafe impl Vector<f32, 3> for glam::Vec3 {}
58-
unsafe impl Vector<f32, 3> for glam::Vec3A {}
59-
unsafe impl Vector<f32, 4> for glam::Vec4 {}
60-
61-
unsafe impl Vector<f64, 2> for glam::DVec2 {}
62-
unsafe impl Vector<f64, 3> for glam::DVec3 {}
63-
unsafe impl Vector<f64, 4> for glam::DVec4 {}
64-
65-
unsafe impl Vector<u32, 2> for glam::UVec2 {}
66-
unsafe impl Vector<u32, 3> for glam::UVec3 {}
67-
unsafe impl Vector<u32, 4> for glam::UVec4 {}
35+
macro_rules! impl_vector {
36+
($($scalar:ty: $($vec:ty => $dim:literal),+;)+) => {
37+
$($(
38+
unsafe impl VectorOrScalar for $vec {
39+
type Scalar = $scalar;
40+
const DIM: NonZeroUsize = create_dim($dim);
41+
}
42+
unsafe impl Vector<$scalar, $dim> for $vec {}
43+
)+)+
44+
};
45+
}
6846

69-
unsafe impl Vector<i32, 2> for glam::IVec2 {}
70-
unsafe impl Vector<i32, 3> for glam::IVec3 {}
71-
unsafe impl Vector<i32, 4> for glam::IVec4 {}
47+
impl_vector! {
48+
f32: glam::Vec2 => 2, glam::Vec3 => 3, glam::Vec3A => 3, glam::Vec4 => 4;
49+
f64: glam::DVec2 => 2, glam::DVec3 => 3, glam::DVec4 => 4;
50+
u32: glam::UVec2 => 2, glam::UVec3 => 3, glam::UVec4 => 4;
51+
i32: glam::IVec2 => 2, glam::IVec3 => 3, glam::IVec4 => 4;
52+
}
7253

7354
/// Trait that implements slicing of a vector into a scalar or vector of lower dimensions, by
7455
/// ignoring the higter dimensions

tests/ui/arch/all.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
#![feature(repr_simd)]
44

5+
use core::num::NonZeroUsize;
56
use spirv_std::spirv;
6-
use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector};
7+
use spirv_std::{scalar::Scalar, vector::Vector, vector::VectorOrScalar};
78

89
/// HACK(shesp). Rust doesn't allow us to declare regular (tuple-)structs containing `bool` members
910
/// as `#[repl(simd)]`. But we need this for `spirv_std::arch::any()` and `spirv_std::arch::all()`
@@ -14,6 +15,10 @@ use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector};
1415
struct Vec2<T>(T, T);
1516
unsafe impl<T: Scalar> VectorOrScalar for Vec2<T> {
1617
type Scalar = T;
18+
const DIM: NonZeroUsize = match NonZeroUsize::new(2) {
19+
None => panic!(),
20+
Some(n) => n,
21+
};
1722
}
1823
unsafe impl<T: Scalar> Vector<T, 2> for Vec2<T> {}
1924

tests/ui/arch/any.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
#![feature(repr_simd)]
44

5+
use core::num::NonZeroUsize;
56
use spirv_std::spirv;
6-
use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector};
7+
use spirv_std::{scalar::Scalar, vector::Vector, vector::VectorOrScalar};
78

89
/// HACK(shesp). Rust doesn't allow us to declare regular (tuple-)structs containing `bool` members
910
/// as `#[repl(simd)]`. But we need this for `spirv_std::arch::any()` and `spirv_std::arch::all()`
@@ -14,6 +15,10 @@ use spirv_std::{scalar::Scalar, scalar::VectorOrScalar, vector::Vector};
1415
struct Vec2<T>(T, T);
1516
unsafe impl<T: Scalar> VectorOrScalar for Vec2<T> {
1617
type Scalar = T;
18+
const DIM: NonZeroUsize = match NonZeroUsize::new(2) {
19+
None => panic!(),
20+
Some(n) => n,
21+
};
1722
}
1823
unsafe impl<T: Scalar> Vector<T, 2> for Vec2<T> {}
1924

0 commit comments

Comments
 (0)