Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance slice operation to support more range variation #1989

Merged
merged 5 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions crates/burn-import/src/burn/node/slice.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};
use crate::burn::{Scope, TensorType, ToTokens, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
Expand All @@ -8,8 +8,7 @@ use quote::quote;
pub struct SliceNode {
pub input: TensorType,
pub output: TensorType,
pub starts: Vec<usize>,
pub ends: Vec<usize>,
pub ranges: Vec<Option<(i64, i64)>>,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for SliceNode {
Expand All @@ -22,11 +21,19 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for SliceNode {
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let starts = &self.starts;
let ends = &self.ends;

let ranges = self.ranges.iter().map(|range| match range {
Some((start, end)) => {
let start = start.to_tokens();
let end = end.to_tokens();

quote! { Some((#start, #end))}
}
None => quote! { None },
});

quote! {
let #output = #input.slice([#(#starts..#ends),*]);
let #output = #input.slice([#(#ranges),*]);
}
}
fn into_node(self) -> Node<PS> {
Expand All @@ -51,8 +58,7 @@ mod tests {
graph.register(SliceNode::new(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
vec![0, 0, 0, 0],
vec![1, 1, 1, 1],
vec![Some((0, 1)), Some((0, 1)), Some((0, 1)), Some((0, 1))],
));
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);

Expand All @@ -78,8 +84,7 @@ mod tests {
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.slice([0usize..1usize,0usize..1usize,0usize..1usize,0usize..1usize]);

let tensor2 = tensor1.slice([Some((0, 1)), Some((0, 1)), Some((0, 1)), Some((0, 1))]);
tensor2
}
}
Expand Down
97 changes: 55 additions & 42 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1283,60 +1283,73 @@
(start_dim as usize, end_dim as usize)
}

pub fn slice_config(node: &Node) -> (Vec<usize>, Vec<usize>) {
fn ensure_1d_tensor(node: &Node, index: usize) {
match &node.inputs[index].ty {
ArgType::Tensor(tensor) => assert_eq!(tensor.dim, 1, "Slice: tensor must be 1D"),
_ => panic!("Only tensor input is valid"),
};
}

fn get_input_values(node: &Node, index: usize) -> Vec<usize> {
let tensor_shape = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.shape.as_ref().unwrap(),
_ => panic!("Only tensor input is valid"),
};
pub fn slice_config(node: &Node) -> Vec<Option<(i64, i64)>> {
fn get_input_values(node: &Node, index: usize) -> Vec<i64> {
// If the input is not provided, return an empty vector
if node.inputs.get(index).is_none() {
return Vec::new();
}

match &node.inputs[index].value {
Some(Data::Int64s(shape)) => shape
.iter()
.enumerate()
.map(|(i, x)| {
if x.is_negative() {
tensor_shape[i] - x.wrapping_abs() as usize
} else {
*x as usize
}
})
.collect(),
Some(Data::Int64s(shape)) => shape.clone(),

_ => panic!("Tensor data type must be int64"),
}
}

ensure_1d_tensor(node, 1);
ensure_1d_tensor(node, 2);

let starts = get_input_values(node, 1);
let ends = get_input_values(node, 2);
let mut starts = get_input_values(node, 1);
let mut ends = get_input_values(node, 2);
let mut axes = get_input_values(node, 3);
let mut steps = get_input_values(node, 4);

// https://burn.dev/docs/burn/prelude/struct.Tensor.html#method.slice
// TODO default missing axes ranges to the full range of the corresponding axis
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axes" => {
let mut i = 0;
value.clone().into_i64s().iter().for_each(|x| {
assert_eq!(*x, i, "Slice: axes must be consecutive");
i += 1;
})
}
"steps" => value.clone().into_i64s().into_iter().for_each(|x| {
if x != 1 {
panic!("Slice: steps other than 1 are not supported");
}
}),
"starts" => starts = value.clone().into_i64s(),
"ends" => ends = value.clone().into_i64s(),
"axes" => axes = value.clone().into_i64s(),
"steps" => steps = value.clone().into_i64s(),
_ => {}
}
}

(starts, ends)
if !steps.is_empty() && steps.iter().any(|&x| x != 1) {
panic!("Slice: steps other than 1 are not supported");
}

// Extract the shape of the input tensor
let input_dim = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor.dim,
_ => panic!("Only tensor input is valid"),
};

// If axes is not provided, it defaults to all axes
if axes.is_empty() {
axes = (0..starts.len() as i64).collect();
}

// assert len(starts) == len(ends) == len(axes)
if starts.len() != ends.len() || starts.len() != axes.len() {
panic!("Slice: starts, ends, and axes must have the same length");
}

// If dim is negative, it is counted from the end
// Negative value means counting dimensions from the back.
for axis in &mut axes {
if *axis < 0 {
*axis = *axis + input_dim as i64;

Check failure on line 1341 in crates/burn-import/src/onnx/op_configuration.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu-22.04, stable, std)

[clippy] reported by reviewdog 🐶 error: manual implementation of an assign operation --> crates/burn-import/src/onnx/op_configuration.rs:1341:13 | 1341 | *axis = *axis + input_dim as i64; | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ help: replace it with: `*axis += input_dim as i64` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#assign_op_pattern = note: `-D clippy::assign-op-pattern` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(clippy::assign_op_pattern)]` Raw Output: crates/burn-import/src/onnx/op_configuration.rs:1341:13:e:error: manual implementation of an assign operation --> crates/burn-import/src/onnx/op_configuration.rs:1341:13 | 1341 | *axis = *axis + input_dim as i64; | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ help: replace it with: `*axis += input_dim as i64` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#assign_op_pattern = note: `-D clippy::assign-op-pattern` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(clippy::assign_op_pattern)]` __END__
}
}

// convert starts, ends, and axes to ranges. Use None for missing axes ranges
let mut ranges: Vec<Option<(i64, i64)>> = vec![None; input_dim];
for i in 0..axes.len() {
let axis = axes[i] as usize;
ranges[axis] = Some((starts[i], ends[i]));
}

ranges
}

pub fn transpose_config(curr: &Node) -> Vec<i64> {
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -736,9 +736,9 @@ impl ParsedOnnxGraph {
fn slice_conversion(node: Node) -> SliceNode {
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let (starts, ends) = slice_config(&node);
let ranges = slice_config(&node);

SliceNode::new(input, output, starts, ends)
SliceNode::new(input, output, ranges)
}

fn sum_conversion(node: Node) -> SumNode {
Expand Down
162 changes: 127 additions & 35 deletions crates/burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,50 +579,64 @@ where

/// Returns a tensor containing the elements selected from the given ranges.
///
/// # Arguments
///
/// * `ranges` - A type implementing the `RangesArg` trait, which can be:
/// - An array of `core::ops::Range<usize>`
/// - An array of `Option<(i64, i64)>`
/// - An array of `(i64, i64)` tuples
///
/// # Behavior
///
/// - Supports partial and full slicing in any number of dimensions.
/// - Handles negative indices by wrapping around from the end of the dimension.
/// - Clamps ranges to the tensor's dimensions if they exceed the bounds.
/// - For `Option<(i64, i64)>` ranges, `None` selects the full range of that dimension.
///
/// # Panics
///
/// If a range exceeds the number of elements on a dimension.
/// - If the number of ranges provided doesn't match the tensor's dimensions.
/// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1).
///
/// # Example
/// # Examples
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Shape};
///
/// fn example<B: Backend>() {
/// let device = B::Device::default();
/// // Create a tensor with a single dimension of ints between 0 and 11
/// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..12, &device);
/// // Select elements 0, 1, 2, 3 from the first dimension
/// let tensor_slices = tensor.clone().slice([0..4]);
/// println!("\nexpecting [0,1,2,3] : {:?}", tensor);
/// println!("expecting [4] : {:?}", tensor.dims());
///
/// // Create a Tensor with 3 dimensions
/// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
/// // This slice will select the element 0 on the first dimension,
/// // elements 0,1,2 of the second dimension and element 1 of third dimension
/// let tensor_slices = tensor.slice([0..1, 0..3, 1..2]);
/// println!("expecting [1, 3, 1] : {:?}", tensor_slices.dims());
///
/// // Create a tensor of ints from 0 to 11 and reshape it into three dimensions
/// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..12, &device);
/// let tensor = tensor.reshape([1, 3, 4]);
/// println!("\nexpecting [[[0,1,2,3],[4,5,6,7],[8,9,10,11]]] : {:?}", tensor);
/// println!("expecting [1, 3, 4] : {:?}", tensor.dims());
/// // Select element 0 of first dimension, elements 1,2 of second dimension
/// // and element 1 of third dimension
/// //
/// // This is the equivalent of this pseudo code
/// // let mut v = vec![[[]]];
/// // v[0][0][0] = tensor[0][1][1];
/// // v[0][1][0] = tensor[0][2][1];
/// let tensor_slices = tensor.slice([0..1, 1..3, 1..2]);
/// println!("\nexpecting [1, 2, 1] : {:?}", tensor_slices.dims());
/// println!("expecting [[[5],[9]]] : {:?}", tensor_slices);
///
/// // 1D slicing
/// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
/// let slice = tensor.slice([1..4]);
/// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
///
/// // 2D slicing
/// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 4]), &device);
/// let slice = tensor.slice([1..3, 0..2]);
/// assert_eq!(slice.dims(), [2, 2]);
///
/// // Using negative indices
/// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
/// let slice = tensor.slice([(1, -1)]); // Equivalent to 1..4
/// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
///
/// // Using Option<(i64, i64)>
/// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..12, &device).reshape([3, 4]);
/// let slice = tensor.slice([Some((1, -1)), None]); // Select rows 1 and 2, all columns
/// assert_eq!(slice.dims(), [2, 4]);
/// }
/// ```
pub fn slice<const D2: usize>(self, ranges: [core::ops::Range<usize>; D2]) -> Self {
///
/// # Note
///
/// This function uses the `RangesArg` trait for flexible range specification. The trait
/// handles the conversion of various range formats and applies clamping and negative
/// index handling internally.
pub fn slice<const D2: usize, R: RangesArg<D2>>(self, ranges: R) -> Self {
let ranges = ranges.into_ranges(self.shape());

check!(TensorCheck::slice(&self.shape(), &ranges));
Self::new(K::slice(self.primitive, ranges))
}
Expand Down Expand Up @@ -675,8 +689,8 @@ where
/// Converts the data of the current tensor.
pub fn into_data(self) -> TensorData {
crate::try_read_sync(self.into_data_async()).expect(
"Failed to read tensor data synchronously.
This can happen on platforms that don't support blocking futures like WASM.
"Failed to read tensor data synchronously.
This can happen on platforms that don't support blocking futures like WASM.
If possible, try using into_data_async instead.",
)
}
Expand Down Expand Up @@ -875,7 +889,7 @@ where
/// If the backend fails to read the tensor data synchronously.
pub fn into_scalar(self) -> K::Elem {
crate::try_read_sync(self.into_scalar_async()).expect(
"Failed to read tensor data synchronously. This can happen on platforms
"Failed to read tensor data synchronously. This can happen on platforms
that don't support blocking futures like WASM. Try into_scalar_async instead.",
)
}
Expand Down Expand Up @@ -2090,6 +2104,84 @@ impl MovedimArgs for i32 {
}
}

/// Trait used for slice arguments
pub trait RangesArg<const D2: usize> {
/// Converts into a set of ranges to `[core::ops::Range<usize>; D2]` for the `tensor.slice()` function
fn into_ranges<const D: usize>(self, shape: Shape<D>) -> [core::ops::Range<usize>; D2];

/// Handles negative index values
fn handle_negative_index(start: i64, end: i64, dim: usize) -> (usize, usize) {
let start = if start < 0 {
(dim as i64 + start) as usize
} else {
start as usize
};
let end = if end < 0 {
(dim as i64 + end) as usize
} else {
end as usize
};
(start, end)
}

/// Clamps the range to the shape dimensions
fn clamp_range(start: usize, end: usize, dim: usize) -> (usize, usize) {
let start = start.clamp(0, dim);
let end = end.clamp(0, dim);
(start, end)
}
}

impl<const D2: usize> RangesArg<D2> for [core::ops::Range<usize>; D2] {
fn into_ranges<const D: usize>(self, shape: Shape<D>) -> [core::ops::Range<usize>; D2] {
// clamp the ranges to the shape dimensions
let ranges = self
.iter()
.enumerate()
.map(|(i, range)| {
let (start, end) = Self::clamp_range(range.start, range.end, shape.dims[i]);
start..end
})
.collect::<Vec<_>>();
ranges.try_into().unwrap()
}
}

impl<const D2: usize> RangesArg<D2> for [Option<(i64, i64)>; D2] {
fn into_ranges<const D: usize>(self, shape: Shape<D>) -> [core::ops::Range<usize>; D2] {
let ranges = self
.iter()
.enumerate()
.map(|(i, range)| match range {
Some((start, end)) => {
let (start, end) = Self::handle_negative_index(*start, *end, shape.dims[i]);
let (start, end) = Self::clamp_range(start, end, shape.dims[i]);
start..end
}
None => 0..shape.dims[i], // if None, use the full range
})
.collect::<Vec<_>>();

ranges.try_into().unwrap()
}
}

impl<const D2: usize> RangesArg<D2> for [(i64, i64); D2] {
fn into_ranges<const D: usize>(self, shape: Shape<D>) -> [core::ops::Range<usize>; D2] {
let ranges = self
.iter()
.enumerate()
.map(|(i, &(start, end))| {
let (start, end) = Self::handle_negative_index(start, end, shape.dims[i]);
let (start, end) = Self::clamp_range(start, end, shape.dims[i]);
start..end
})
.collect::<Vec<_>>();

ranges.try_into().unwrap()
}
}

/// Trait used for reshape arguments.
pub trait ReshapeArgs<const D2: usize> {
/// Converts to a shape.
Expand Down
Loading
Loading