Skip to content

Commit

Permalink
Enhance slice operation to support more range variation (#1989)
Browse files Browse the repository at this point in the history
* Enhance slice operation to support more range variation

* Fix doc clippy

* Fixed doc test

* Fix flipped attribute names

* Fix clippy
  • Loading branch information
antimora authored Jul 8, 2024
1 parent c0211e2 commit e8b915a
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 121 deletions.
25 changes: 15 additions & 10 deletions crates/burn-import/src/burn/node/slice.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};
use crate::burn::{Scope, TensorType, ToTokens, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
Expand All @@ -8,8 +8,7 @@ use quote::quote;
pub struct SliceNode {
pub input: TensorType,
pub output: TensorType,
pub starts: Vec<usize>,
pub ends: Vec<usize>,
pub ranges: Vec<Option<(i64, i64)>>,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for SliceNode {
Expand All @@ -22,11 +21,19 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for SliceNode {
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let starts = &self.starts;
let ends = &self.ends;

let ranges = self.ranges.iter().map(|range| match range {
Some((start, end)) => {
let start = start.to_tokens();
let end = end.to_tokens();

quote! { Some((#start, #end))}
}
None => quote! { None },
});

quote! {
let #output = #input.slice([#(#starts..#ends),*]);
let #output = #input.slice([#(#ranges),*]);
}
}
fn into_node(self) -> Node<PS> {
Expand All @@ -51,8 +58,7 @@ mod tests {
graph.register(SliceNode::new(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
vec![0, 0, 0, 0],
vec![1, 1, 1, 1],
vec![Some((0, 1)), Some((0, 1)), Some((0, 1)), Some((0, 1))],
));
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);

Expand All @@ -78,8 +84,7 @@ mod tests {
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.slice([0usize..1usize,0usize..1usize,0usize..1usize,0usize..1usize]);

let tensor2 = tensor1.slice([Some((0, 1)), Some((0, 1)), Some((0, 1)), Some((0, 1))]);
tensor2
}
}
Expand Down
97 changes: 55 additions & 42 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1283,60 +1283,73 @@ pub fn shape_config(curr: &Node) -> (usize, usize) {
(start_dim as usize, end_dim as usize)
}

pub fn slice_config(node: &Node) -> (Vec<usize>, Vec<usize>) {
fn ensure_1d_tensor(node: &Node, index: usize) {
match &node.inputs[index].ty {
ArgType::Tensor(tensor) => assert_eq!(tensor.dim, 1, "Slice: tensor must be 1D"),
_ => panic!("Only tensor input is valid"),
};
}

fn get_input_values(node: &Node, index: usize) -> Vec<usize> {
let tensor_shape = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.shape.as_ref().unwrap(),
_ => panic!("Only tensor input is valid"),
};
pub fn slice_config(node: &Node) -> Vec<Option<(i64, i64)>> {
fn get_input_values(node: &Node, index: usize) -> Vec<i64> {
// If the input is not provided, return an empty vector
if node.inputs.get(index).is_none() {
return Vec::new();
}

match &node.inputs[index].value {
Some(Data::Int64s(shape)) => shape
.iter()
.enumerate()
.map(|(i, x)| {
if x.is_negative() {
tensor_shape[i] - x.wrapping_abs() as usize
} else {
*x as usize
}
})
.collect(),
Some(Data::Int64s(shape)) => shape.clone(),

_ => panic!("Tensor data type must be int64"),
}
}

ensure_1d_tensor(node, 1);
ensure_1d_tensor(node, 2);

let starts = get_input_values(node, 1);
let ends = get_input_values(node, 2);
let mut starts = get_input_values(node, 1);
let mut ends = get_input_values(node, 2);
let mut axes = get_input_values(node, 3);
let mut steps = get_input_values(node, 4);

// https://burn.dev/docs/burn/prelude/struct.Tensor.html#method.slice
// TODO default missing axes ranges to the full range of the corresponding axis
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axes" => {
let mut i = 0;
value.clone().into_i64s().iter().for_each(|x| {
assert_eq!(*x, i, "Slice: axes must be consecutive");
i += 1;
})
}
"steps" => value.clone().into_i64s().into_iter().for_each(|x| {
if x != 1 {
panic!("Slice: steps other than 1 are not supported");
}
}),
"starts" => starts = value.clone().into_i64s(),
"ends" => ends = value.clone().into_i64s(),
"axes" => axes = value.clone().into_i64s(),
"steps" => steps = value.clone().into_i64s(),
_ => {}
}
}

(starts, ends)
if !steps.is_empty() && steps.iter().any(|&x| x != 1) {
panic!("Slice: steps other than 1 are not supported");
}

// Extract the shape of the input tensor
let input_dim = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor.dim,
_ => panic!("Only tensor input is valid"),
};

// If axes is not provided, it defaults to all axes
if axes.is_empty() {
axes = (0..starts.len() as i64).collect();
}

// assert len(starts) == len(ends) == len(axes)
if starts.len() != ends.len() || starts.len() != axes.len() {
panic!("Slice: starts, ends, and axes must have the same length");
}

// If dim is negative, it is counted from the end
// Negative value means counting dimensions from the back.
for axis in &mut axes {
if *axis < 0 {
*axis = *axis + input_dim as i64;

Check failure on line 1341 in crates/burn-import/src/onnx/op_configuration.rs

View workflow job for this annotation

GitHub Actions / clippy

[clippy] crates/burn-import/src/onnx/op_configuration.rs#L1341

error: manual implementation of an assign operation --> crates/burn-import/src/onnx/op_configuration.rs:1341:13 | 1341 | *axis = *axis + input_dim as i64; | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ help: replace it with: `*axis += input_dim as i64` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#assign_op_pattern = note: `-D clippy::assign-op-pattern` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(clippy::assign_op_pattern)]`
Raw output
crates/burn-import/src/onnx/op_configuration.rs:1341:13:e:error: manual implementation of an assign operation
    --> crates/burn-import/src/onnx/op_configuration.rs:1341:13
     |
1341 |             *axis = *axis + input_dim as i64;
     |             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ help: replace it with: `*axis += input_dim as i64`
     |
     = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#assign_op_pattern
     = note: `-D clippy::assign-op-pattern` implied by `-D warnings`
     = help: to override `-D warnings` add `#[allow(clippy::assign_op_pattern)]`


__END__
}
}

// convert starts, ends, and axes to ranges. Use None for missing axes ranges
let mut ranges: Vec<Option<(i64, i64)>> = vec![None; input_dim];
for i in 0..axes.len() {
let axis = axes[i] as usize;
ranges[axis] = Some((starts[i], ends[i]));
}

ranges
}

pub fn transpose_config(curr: &Node) -> Vec<i64> {
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -736,9 +736,9 @@ impl ParsedOnnxGraph {
fn slice_conversion(node: Node) -> SliceNode {
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let (starts, ends) = slice_config(&node);
let ranges = slice_config(&node);

SliceNode::new(input, output, starts, ends)
SliceNode::new(input, output, ranges)
}

fn sum_conversion(node: Node) -> SumNode {
Expand Down
162 changes: 127 additions & 35 deletions crates/burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,50 +579,64 @@ where

/// Returns a tensor containing the elements selected from the given ranges.
///
/// # Arguments
///
/// * `ranges` - A type implementing the `RangesArg` trait, which can be:
/// - An array of `core::ops::Range<usize>`
/// - An array of `Option<(i64, i64)>`
/// - An array of `(i64, i64)` tuples
///
/// # Behavior
///
/// - Supports partial and full slicing in any number of dimensions.
/// - Handles negative indices by wrapping around from the end of the dimension.
/// - Clamps ranges to the tensor's dimensions if they exceed the bounds.
/// - For `Option<(i64, i64)>` ranges, `None` selects the full range of that dimension.
///
/// # Panics
///
/// If a range exceeds the number of elements on a dimension.
/// - If the number of ranges provided doesn't match the tensor's dimensions.
/// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1).
///
/// # Example
/// # Examples
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Shape};
///
/// fn example<B: Backend>() {
/// let device = B::Device::default();
/// // Create a tensor with a single dimension of ints between 0 and 11
/// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..12, &device);
/// // Select elements 0, 1, 2, 3 from the first dimension
/// let tensor_slices = tensor.clone().slice([0..4]);
/// println!("\nexpecting [0,1,2,3] : {:?}", tensor);
/// println!("expecting [4] : {:?}", tensor.dims());
///
/// // Create a Tensor with 3 dimensions
/// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
/// // This slice will select the element 0 on the first dimension,
/// // elements 0,1,2 of the second dimension and element 1 of third dimension
/// let tensor_slices = tensor.slice([0..1, 0..3, 1..2]);
/// println!("expecting [1, 3, 1] : {:?}", tensor_slices.dims());
///
/// // Create a tensor of ints from 0 to 11 and reshape it into three dimensions
/// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..12, &device);
/// let tensor = tensor.reshape([1, 3, 4]);
/// println!("\nexpecting [[[0,1,2,3],[4,5,6,7],[8,9,10,11]]] : {:?}", tensor);
/// println!("expecting [1, 3, 4] : {:?}", tensor.dims());
/// // Select element 0 of first dimension, elements 1,2 of second dimension
/// // and element 1 of third dimension
/// //
/// // This is the equivalent of this pseudo code
/// // let mut v = vec![[[]]];
/// // v[0][0][0] = tensor[0][1][1];
/// // v[0][1][0] = tensor[0][2][1];
/// let tensor_slices = tensor.slice([0..1, 1..3, 1..2]);
/// println!("\nexpecting [1, 2, 1] : {:?}", tensor_slices.dims());
/// println!("expecting [[[5],[9]]] : {:?}", tensor_slices);
///
/// // 1D slicing
/// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
/// let slice = tensor.slice([1..4]);
/// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
///
/// // 2D slicing
/// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 4]), &device);
/// let slice = tensor.slice([1..3, 0..2]);
/// assert_eq!(slice.dims(), [2, 2]);
///
/// // Using negative indices
/// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
/// let slice = tensor.slice([(1, -1)]); // Equivalent to 1..4
/// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
///
/// // Using Option<(i64, i64)>
/// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..12, &device).reshape([3, 4]);
/// let slice = tensor.slice([Some((1, -1)), None]); // Select rows 1 and 2, all columns
/// assert_eq!(slice.dims(), [2, 4]);
/// }
/// ```
pub fn slice<const D2: usize>(self, ranges: [core::ops::Range<usize>; D2]) -> Self {
///
/// # Note
///
/// This function uses the `RangesArg` trait for flexible range specification. The trait
/// handles the conversion of various range formats and applies clamping and negative
/// index handling internally.
pub fn slice<const D2: usize, R: RangesArg<D2>>(self, ranges: R) -> Self {
let ranges = ranges.into_ranges(self.shape());

check!(TensorCheck::slice(&self.shape(), &ranges));
Self::new(K::slice(self.primitive, ranges))
}
Expand Down Expand Up @@ -675,8 +689,8 @@ where
/// Converts the data of the current tensor.
pub fn into_data(self) -> TensorData {
crate::try_read_sync(self.into_data_async()).expect(
"Failed to read tensor data synchronously.
This can happen on platforms that don't support blocking futures like WASM.
"Failed to read tensor data synchronously.
This can happen on platforms that don't support blocking futures like WASM.
If possible, try using into_data_async instead.",
)
}
Expand Down Expand Up @@ -875,7 +889,7 @@ where
/// If the backend fails to read the tensor data synchronously.
pub fn into_scalar(self) -> K::Elem {
crate::try_read_sync(self.into_scalar_async()).expect(
"Failed to read tensor data synchronously. This can happen on platforms
"Failed to read tensor data synchronously. This can happen on platforms
that don't support blocking futures like WASM. Try into_scalar_async instead.",
)
}
Expand Down Expand Up @@ -2109,6 +2123,84 @@ impl MovedimArgs for i32 {
}
}

/// Trait used for slice arguments
pub trait RangesArg<const D2: usize> {
/// Converts into a set of ranges to `[core::ops::Range<usize>; D2]` for the `tensor.slice()` function
fn into_ranges<const D: usize>(self, shape: Shape<D>) -> [core::ops::Range<usize>; D2];

/// Handles negative index values
fn handle_negative_index(start: i64, end: i64, dim: usize) -> (usize, usize) {
let start = if start < 0 {
(dim as i64 + start) as usize
} else {
start as usize
};
let end = if end < 0 {
(dim as i64 + end) as usize
} else {
end as usize
};
(start, end)
}

/// Clamps the range to the shape dimensions
fn clamp_range(start: usize, end: usize, dim: usize) -> (usize, usize) {
let start = start.clamp(0, dim);
let end = end.clamp(0, dim);
(start, end)
}
}

impl<const D2: usize> RangesArg<D2> for [core::ops::Range<usize>; D2] {
fn into_ranges<const D: usize>(self, shape: Shape<D>) -> [core::ops::Range<usize>; D2] {
// clamp the ranges to the shape dimensions
let ranges = self
.iter()
.enumerate()
.map(|(i, range)| {
let (start, end) = Self::clamp_range(range.start, range.end, shape.dims[i]);
start..end
})
.collect::<Vec<_>>();
ranges.try_into().unwrap()
}
}

impl<const D2: usize> RangesArg<D2> for [Option<(i64, i64)>; D2] {
fn into_ranges<const D: usize>(self, shape: Shape<D>) -> [core::ops::Range<usize>; D2] {
let ranges = self
.iter()
.enumerate()
.map(|(i, range)| match range {
Some((start, end)) => {
let (start, end) = Self::handle_negative_index(*start, *end, shape.dims[i]);
let (start, end) = Self::clamp_range(start, end, shape.dims[i]);
start..end
}
None => 0..shape.dims[i], // if None, use the full range
})
.collect::<Vec<_>>();

ranges.try_into().unwrap()
}
}

impl<const D2: usize> RangesArg<D2> for [(i64, i64); D2] {
fn into_ranges<const D: usize>(self, shape: Shape<D>) -> [core::ops::Range<usize>; D2] {
let ranges = self
.iter()
.enumerate()
.map(|(i, &(start, end))| {
let (start, end) = Self::handle_negative_index(start, end, shape.dims[i]);
let (start, end) = Self::clamp_range(start, end, shape.dims[i]);
start..end
})
.collect::<Vec<_>>();

ranges.try_into().unwrap()
}
}

/// Trait used for reshape arguments.
pub trait ReshapeArgs<const D2: usize> {
/// Converts to a shape.
Expand Down
Loading

0 comments on commit e8b915a

Please sign in to comment.