Skip to content

Commit b40e461

Browse files
committed
Feat: Add tensor reduction
1 parent 23b71e8 commit b40e461

File tree

8 files changed

+399
-1
lines changed

8 files changed

+399
-1
lines changed

crates/cudnn/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ The previous tensor descriptor can be used together with a `i8` device buffer an
6969

7070
Currently this crate does not support `f16` and `bf16` data types.
7171

72-
### Tensor formats
72+
### cuDNN tensor formats
7373

7474
We decided not to check tensor format configurations at compile time, since it is too strong of a requirement. As a consequence, should you mess up, the program will fail at run-time. A proper understanding of the cuDNN API mechanics is thus fundamental to properly use this crate.
7575

crates/cudnn/src/activation/activation_mode.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::sys;
44
///
55
/// cuDNN [docs](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnActivationMode_t)
66
/// may offer additional information about the APi behavior.
7+
#[non_exhaustive]
78
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89
pub enum ActivationMode {
910
/// Selects the sigmoid function.

crates/cudnn/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ mod math_type;
1313
mod nan_propagation;
1414
mod op;
1515
mod pooling;
16+
mod reduction;
1617
mod rnn;
1718
mod softmax;
1819
mod sys;
@@ -31,6 +32,7 @@ pub use math_type::*;
3132
pub use nan_propagation::*;
3233
pub use op::*;
3334
pub use pooling::*;
35+
pub use reduction::*;
3436
pub use rnn::*;
3537
pub use softmax::*;
3638
pub use tensor::*;
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
use crate::sys;
2+
3+
/// Indicates the data type of the indices computed by a reduction operation.
4+
#[non_exhaustive]
5+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6+
pub enum IndicesType {
7+
U8,
8+
U16,
9+
U32,
10+
U64,
11+
}
12+
13+
impl From<IndicesType> for sys::cudnnIndicesType_t {
14+
fn from(mode: IndicesType) -> Self {
15+
match mode {
16+
IndicesType::U8 => Self::CUDNN_8BIT_INDICES,
17+
IndicesType::U16 => Self::CUDNN_16BIT_INDICES,
18+
IndicesType::U32 => Self::CUDNN_32BIT_INDICES,
19+
IndicesType::U64 => Self::CUDNN_64BIT_INDICES,
20+
}
21+
}
22+
}

crates/cudnn/src/reduction/mod.rs

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
mod indices_type;
2+
mod reduce_indices;
3+
mod reduce_op;
4+
mod reduction_descriptor;
5+
6+
pub use indices_type::*;
7+
pub use reduce_indices::*;
8+
pub use reduce_op::*;
9+
pub use reduction_descriptor::*;
10+
11+
use std::mem::MaybeUninit;
12+
13+
use cust::memory::GpuBuffer;
14+
15+
use crate::{
16+
sys, CudnnContext, CudnnError, DataType, IntoResult, ScalingDataType, TensorDescriptor,
17+
};
18+
19+
impl CudnnContext {
20+
/// Returns the minimum size of the workspace to be passed to the reduction given the input and
21+
/// output tensors.
22+
///
23+
/// # Arguments
24+
///
25+
/// * `desc` - reduction descriptor.
26+
///
27+
/// * `a_desc` - input tensor descriptor.
28+
///
29+
/// * `c_desc` - output tensor descriptor.
30+
pub fn get_reduction_workspace_size<T, U, V>(
31+
&self,
32+
desc: &ReductionDescriptor<T>,
33+
a_desc: &TensorDescriptor<U>,
34+
c_desc: &TensorDescriptor<V>,
35+
) -> Result<usize, CudnnError>
36+
where
37+
T: DataType,
38+
U: DataType,
39+
V: DataType,
40+
{
41+
let mut size = MaybeUninit::uninit();
42+
43+
unsafe {
44+
sys::cudnnGetReductionWorkspaceSize(
45+
self.raw,
46+
desc.raw,
47+
a_desc.raw,
48+
c_desc.raw,
49+
size.as_mut_ptr(),
50+
)
51+
.into_result()?;
52+
53+
Ok(size.assume_init())
54+
}
55+
}
56+
57+
/// Returns the minimum size of the index space to be passed to the reduction given the input
58+
/// and output tensors.
59+
///
60+
/// # Arguments
61+
///
62+
/// * `desc` - reduction descriptor.
63+
///
64+
/// * `a_desc` - input tensor descriptor.
65+
///
66+
/// * `c_desc` - output tensor descriptor.
67+
pub fn get_reduction_indices_size<T, U, V>(
68+
&self,
69+
desc: &ReductionDescriptor<T>,
70+
a_desc: &TensorDescriptor<U>,
71+
c_desc: &TensorDescriptor<V>,
72+
) -> Result<usize, CudnnError>
73+
where
74+
T: DataType,
75+
U: DataType,
76+
V: DataType,
77+
{
78+
let mut size = MaybeUninit::uninit();
79+
80+
unsafe {
81+
sys::cudnnGetReductionIndicesSize(
82+
self.raw,
83+
desc.raw,
84+
a_desc.raw,
85+
c_desc.raw,
86+
size.as_mut_ptr(),
87+
)
88+
.into_result()?;
89+
90+
Ok(size.assume_init())
91+
}
92+
}
93+
94+
/// This function reduces tensor `a` by implementing the equation:
95+
///
96+
/// C = alpha * reduce op ( A ) + gamma * C
97+
///
98+
/// given tensors `a` and `c` and scaling factors `alpha` and `gamma`.
99+
/// Each dimension of the output tensor c must match the corresponding dimension of the
100+
/// input tensor a or must be equal to 1.
101+
///
102+
/// The dimensions equal to 1 indicate the dimensions of a to be reduced.
103+
///
104+
/// **Do note** that currently only the 32-bit indices type is supported and that the data types
105+
/// of the tensors A and C must match if of type double. In this case, alpha and gamma and are all
106+
/// assumed to be of type double.
107+
///
108+
/// # Arguments
109+
///
110+
/// * `desc` - tensor reduction descriptor.
111+
///
112+
/// * `indices` - indices buffer in device memory.
113+
///
114+
/// * `workspace` - workspace for the reduction operation.
115+
///
116+
/// * `alpha` - scaling factor for the input tensor.
117+
///
118+
/// * `a_desc` - tensor descriptor for the input tensor.
119+
///
120+
/// * `a` - input tensor in device memory.
121+
///
122+
/// * `gamma` - scaling factor for the output tensor.
123+
///
124+
/// * `c_desc` - tensor descriptor for the output tensor.
125+
///
126+
/// * `c` - output tensor in device memory.
127+
///
128+
/// # Errors
129+
///
130+
/// Returns errors if an unsupported configuration of arguments is detected.
131+
///
132+
/// # Examples
133+
///
134+
/// ```
135+
/// # use std::error::Error;
136+
/// #
137+
/// # fn main() -> Result<(), Box<dyn Error>> {
138+
/// use cudnn::{CudnnContext, NanPropagation, ReduceOp, ReduceIndices, ReductionDescriptor, TensorDescriptor};
139+
/// use cust::memory::DeviceBuffer;
140+
///
141+
/// let ctx = CudnnContext::new()?;
142+
///
143+
/// let op = ReduceOp::Add;
144+
/// let nan_policy = NanPropagation::PropagateNaN;
145+
/// let indices = ReduceIndices::None;
146+
/// let indices_type = None;
147+
///
148+
/// let desc = ReductionDescriptor::<f32>::new(op, nan_policy, indices, indices_type)?;
149+
///
150+
/// let alpha = 1.0;
151+
/// let a_desc = TensorDescriptor::<i8>::new_strides(&[1, 1, 1, 5], &[5, 5, 5, 1])?;
152+
/// let a = DeviceBuffer::<i8>::from_slice(&[4, 4, 4, 4, 4])?;
153+
///
154+
/// let gamma = 1.0;
155+
/// let c_desc = TensorDescriptor::<i8>::new_strides(&[1, 1, 1, 1], &[1, 1, 1, 1])?;
156+
/// let mut c = DeviceBuffer::<i8>::from_slice(&[0])?;
157+
///
158+
/// let workspace_size = ctx.get_reduction_workspace_size(&desc, &a_desc, &c_desc)?;
159+
/// let mut workspace = unsafe { DeviceBuffer::uninitialized(workspace_size)? };
160+
///
161+
/// let indices: Option<&mut DeviceBuffer<u8>> = None;
162+
///
163+
/// ctx.reduce(&desc, indices, &mut workspace, alpha, &a_desc, &a, gamma, &c_desc, &mut c)?;
164+
///
165+
/// let c_host = c.as_host_vec()?;
166+
///
167+
/// assert!(c_host.iter().all(|x| *x == 20));
168+
/// # Ok(())
169+
/// # }
170+
/// ```
171+
pub fn reduce<CompT, U, V>(
172+
&self,
173+
desc: &ReductionDescriptor<CompT>,
174+
indices: Option<&mut impl GpuBuffer<u8>>,
175+
workspace: &mut impl GpuBuffer<u8>,
176+
alpha: CompT,
177+
a_desc: &TensorDescriptor<U>,
178+
a: &impl GpuBuffer<U>,
179+
gamma: CompT,
180+
c_desc: &TensorDescriptor<V>,
181+
c: &mut impl GpuBuffer<V>,
182+
) -> Result<(), CudnnError>
183+
where
184+
CompT: ScalingDataType<U>,
185+
U: DataType,
186+
V: DataType,
187+
{
188+
let (indices_ptr, indices_size) = {
189+
indices.map_or((std::ptr::null_mut(), 0), |indices| {
190+
(indices.as_device_ptr().as_mut_ptr() as _, indices.len())
191+
})
192+
};
193+
194+
let workspace_ptr = workspace.as_device_ptr().as_mut_ptr() as _;
195+
let workspace_size = workspace.len();
196+
197+
let a_data = a.as_device_ptr().as_ptr() as _;
198+
let c_data = c.as_device_ptr().as_mut_ptr() as _;
199+
200+
let alpha = &alpha as *const CompT as _;
201+
let gamma = &gamma as *const CompT as _;
202+
203+
unsafe {
204+
sys::cudnnReduceTensor(
205+
self.raw,
206+
desc.raw,
207+
indices_ptr,
208+
indices_size,
209+
workspace_ptr,
210+
workspace_size,
211+
alpha,
212+
a_desc.raw,
213+
a_data,
214+
gamma,
215+
c_desc.raw,
216+
c_data,
217+
)
218+
.into_result()?;
219+
}
220+
221+
Ok(())
222+
}
223+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
use crate::sys;
2+
3+
/// Indicates whether a reduction operation should compute indices or not.
4+
#[non_exhaustive]
5+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6+
pub enum ReduceIndices {
7+
/// Do not compute indices.
8+
None,
9+
/// Compute indices. The resulting indices are relative to the dimensions being reduced, and
10+
/// flattened.
11+
Flattened,
12+
}
13+
14+
impl From<ReduceIndices> for sys::cudnnReduceTensorIndices_t {
15+
fn from(mode: ReduceIndices) -> Self {
16+
match mode {
17+
ReduceIndices::None => Self::CUDNN_REDUCE_TENSOR_NO_INDICES,
18+
ReduceIndices::Flattened => Self::CUDNN_REDUCE_TENSOR_FLATTENED_INDICES,
19+
}
20+
}
21+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
use crate::sys;
2+
3+
/// Tensor reduction operation.
4+
#[non_exhaustive]
5+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6+
pub enum ReduceOp {
7+
Add,
8+
Mul,
9+
Min,
10+
Max,
11+
Amax,
12+
Avg,
13+
Norm1,
14+
Norm2,
15+
MulNoZeros,
16+
}
17+
18+
impl From<ReduceOp> for sys::cudnnReduceTensorOp_t {
19+
fn from(op: ReduceOp) -> Self {
20+
match op {
21+
ReduceOp::Add => Self::CUDNN_REDUCE_TENSOR_ADD,
22+
ReduceOp::Mul => Self::CUDNN_REDUCE_TENSOR_MUL,
23+
ReduceOp::Min => Self::CUDNN_REDUCE_TENSOR_MIN,
24+
ReduceOp::Max => Self::CUDNN_REDUCE_TENSOR_MAX,
25+
ReduceOp::Amax => Self::CUDNN_REDUCE_TENSOR_AMAX,
26+
ReduceOp::Avg => Self::CUDNN_REDUCE_TENSOR_AVG,
27+
ReduceOp::Norm1 => Self::CUDNN_REDUCE_TENSOR_NORM1,
28+
ReduceOp::Norm2 => Self::CUDNN_REDUCE_TENSOR_NORM2,
29+
ReduceOp::MulNoZeros => Self::CUDNN_REDUCE_TENSOR_MUL_NO_ZEROS,
30+
}
31+
}
32+
}

0 commit comments

Comments
 (0)