diff --git a/benches/squeezenet.rs b/benches/squeezenet.rs index ee4e61eb..6dd3920c 100644 --- a/benches/squeezenet.rs +++ b/benches/squeezenet.rs @@ -3,7 +3,10 @@ use std::{path::Path, sync::Arc}; use glassbench::{Bench, pretend_used}; use image::{ImageBuffer, Pixel, Rgb, imageops::FilterType}; use ndarray::{Array4, s}; -use ort::session::{Session, builder::GraphOptimizationLevel}; +use ort::{ + session::{Session, builder::GraphOptimizationLevel}, + value::TensorRef +}; fn load_squeezenet_data() -> ort::Result<(Session, Array4)> { const IMAGE_TO_LOAD: &str = "mushroom.png"; @@ -44,7 +47,7 @@ fn bench_squeezenet(bench: &mut Bench) { let (session, data) = load_squeezenet_data().unwrap(); bench.task("ArrayView", |task| { task.iter(|| { - pretend_used(session.run(ort::inputs![data.view()].unwrap()).unwrap()); + pretend_used(session.run(ort::inputs![TensorRef::from_array_view(&data).unwrap()]).unwrap()); }) }); @@ -52,7 +55,11 @@ fn bench_squeezenet(bench: &mut Bench) { let shape: Vec = data.shape().iter().map(|c| *c as _).collect(); bench.task("Raw data", |task| { task.iter(|| { - pretend_used(session.run(ort::inputs![(shape.clone(), Arc::clone(&raw))].unwrap()).unwrap()); + pretend_used( + session + .run(ort::inputs![TensorRef::from_array_view((shape.clone(), Arc::clone(&raw))).unwrap()]) + .unwrap() + ); }) }); } diff --git a/examples/async-gpt2-api/Cargo.toml b/examples/async-gpt2-api/Cargo.toml index 9fa82d5b..7038e833 100644 --- a/examples/async-gpt2-api/Cargo.toml +++ b/examples/async-gpt2-api/Cargo.toml @@ -18,7 +18,7 @@ tokio = { version = "1.36", features = [ "full" ] } tokio-stream = "0.1" tower-http = { version = "0.5", features = ["fs", "trace"] } anyhow = "1.0" -async-stream = "0.3" +async-stream-lite = "0.2" [features] load-dynamic = [ "ort/load-dynamic" ] diff --git a/examples/async-gpt2-api/examples/async-gpt2-api.rs b/examples/async-gpt2-api/examples/async-gpt2-api.rs index 165a9746..2e2104d8 100644 --- a/examples/async-gpt2-api/examples/async-gpt2-api.rs +++ b/examples/async-gpt2-api/examples/async-gpt2-api.rs @@ -10,11 +10,10 @@ use axum::{ routing::post }; use futures::Stream; -use ndarray::{Array1, ArrayViewD, Axis, array, concatenate, s}; use ort::{ execution_providers::CUDAExecutionProvider, - inputs, - session::{Session, builder::GraphOptimizationLevel} + session::{Session, builder::GraphOptimizationLevel}, + value::TensorRef }; use rand::Rng; use tokenizers::Tokenizer; @@ -64,37 +63,31 @@ struct AppState { tokenizer: Arc } -fn generate_stream(tokenizer: Arc, session: Arc, tokens: Vec, gen_tokens: usize) -> impl Stream> + Send { - async_stream::try_stream! { - let mut tokens = Array1::from_iter(tokens.iter().cloned()); - +fn generate_stream(tokenizer: Arc, session: Arc, mut tokens: Vec, gen_tokens: usize) -> impl Stream> + Send { + async_stream_lite::try_async_stream(|yielder| async move { for _ in 0..gen_tokens { - let array = tokens.view().insert_axis(Axis(0)).insert_axis(Axis(1)); - let outputs = session.run_async(inputs![array]?)?.await?; - let generated_tokens: ArrayViewD = outputs["output1"].try_extract_tensor()?; + let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?; + let outputs = session.run_async(ort::inputs![input])?.await?; + let (dim, probabilities) = outputs["output1"].try_extract_raw_tensor()?; // Collect and sort logits - let probabilities = &mut generated_tokens - .slice(s![0, 0, -1, ..]) - .insert_axis(Axis(0)) - .to_owned() - .iter() - .cloned() - .enumerate() - .collect::>(); + let (seq_len, vocab_size) = (dim[2] as usize, dim[3] as usize); + let mut probabilities: Vec<(usize, f32)> = probabilities[(seq_len - 1) * vocab_size..].iter().copied().enumerate().collect(); probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less)); // Sample using top-k sampling let token = { let mut rng = rand::thread_rng(); - probabilities[rng.gen_range(0..=5)].0 + probabilities[rng.gen_range(0..=5)].0 as i64 }; - tokens = concatenate![Axis(0), tokens, array![token.try_into().unwrap()]]; + tokens.push(token); let token_str = tokenizer.decode(&[token as _], true).unwrap(); - yield Event::default().data(token_str); + yielder.r#yield(Event::default().data(token_str)).await; } - } + + Ok(()) + }) } impl FromRef for Arc { diff --git a/examples/cudarc/src/main.rs b/examples/cudarc/src/main.rs index e65fc332..14c333c4 100644 --- a/examples/cudarc/src/main.rs +++ b/examples/cudarc/src/main.rs @@ -45,7 +45,7 @@ fn main() -> anyhow::Result<()> { ) .unwrap() }; - let outputs = model.run([tensor.into()])?; + let outputs = model.run(ort::inputs![tensor])?; let output = outputs["output"].try_extract_tensor::()?; diff --git a/examples/custom-ops/examples/custom-ops.rs b/examples/custom-ops/examples/custom-ops.rs index 61a3cf2b..2f7c4cb2 100644 --- a/examples/custom-ops/examples/custom-ops.rs +++ b/examples/custom-ops/examples/custom-ops.rs @@ -1,4 +1,3 @@ -use ndarray::Array2; use ort::{ operator::{ Operator, OperatorDomain, @@ -6,7 +5,8 @@ use ort::{ kernel::{Kernel, KernelAttributes, KernelContext} }, session::Session, - tensor::TensorElementType + tensor::TensorElementType, + value::Tensor }; struct CustomOpOne; @@ -78,7 +78,16 @@ fn main() -> ort::Result<()> { .with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)? .commit_from_file("tests/data/custom_op_test.onnx")?; - let values = session.run(ort::inputs![Array2::::zeros((3, 5)), Array2::::ones((3, 5))]?)?; + let allocator = session.allocator(); + let value1 = Tensor::::new(allocator, [3, 5])?; + let mut value2 = Tensor::::new(allocator, [3, 5])?; + { + let (_, data) = value2.extract_raw_tensor_mut(); + for datum in data { + *datum = 1.; + } + } + let values = session.run(ort::inputs![&value1, &value2])?; println!("{:?}", values[0].try_extract_tensor::()?); Ok(()) diff --git a/examples/gpt2/examples/gpt2-no-ndarray.rs b/examples/gpt2/examples/gpt2-no-ndarray.rs index 4ce95d7a..09078859 100644 --- a/examples/gpt2/examples/gpt2-no-ndarray.rs +++ b/examples/gpt2/examples/gpt2-no-ndarray.rs @@ -1,13 +1,13 @@ use std::{ io::{self, Write}, - path::Path, - sync::Arc + path::Path }; use ort::{ execution_providers::CUDAExecutionProvider, inputs, - session::{Session, builder::GraphOptimizationLevel} + session::{Session, builder::GraphOptimizationLevel}, + value::TensorRef }; use rand::Rng; use tokenizers::Tokenizer; @@ -33,7 +33,7 @@ fn main() -> ort::Result<()> { .with_execution_providers([CUDAExecutionProvider::default().build()]) .commit()?; - let mut stdout = io::stdout(); + let mut stdout: io::Stdout = io::stdout(); let mut rng = rand::thread_rng(); // Load our model @@ -45,7 +45,7 @@ fn main() -> ort::Result<()> { // Load the tokenizer and encode the prompt into a sequence of tokens. let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap(); let tokens = tokenizer.encode(PROMPT, false).unwrap(); - let mut tokens = Arc::new(tokens.get_ids().iter().map(|i| *i as i64).collect::>().into_boxed_slice()); + let mut tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); print!("{PROMPT}"); stdout.flush().unwrap(); @@ -53,8 +53,8 @@ fn main() -> ort::Result<()> { for _ in 0..GEN_TOKENS { // Raw tensor construction takes a tuple of (dimensions, data). // The model expects our input to have shape [B, _, S] - let input = (vec![1, 1, tokens.len() as i64], Arc::clone(&tokens)); - let outputs = session.run(inputs![input]?)?; + let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?; + let outputs = session.run(inputs![input])?; let (dim, mut probabilities) = outputs["output1"].try_extract_raw_tensor()?; // The output tensor will have shape [B, _, S + 1, V] @@ -70,9 +70,7 @@ fn main() -> ort::Result<()> { let token = probabilities[rng.gen_range(0..=TOP_K)].0 as i64; // Add our generated token to the input sequence - let mut vec = tokens.to_vec(); - vec.push(token); - *Arc::make_mut(&mut tokens) = vec.into_boxed_slice(); + tokens.push(token); let token_str = tokenizer.decode(&[token as u32], true).unwrap(); print!("{}", token_str); diff --git a/examples/gpt2/examples/gpt2.rs b/examples/gpt2/examples/gpt2.rs index 9bebff03..381faa08 100644 --- a/examples/gpt2/examples/gpt2.rs +++ b/examples/gpt2/examples/gpt2.rs @@ -7,7 +7,8 @@ use ndarray::{Array1, ArrayViewD, Axis, array, concatenate, s}; use ort::{ execution_providers::CUDAExecutionProvider, inputs, - session::{Session, builder::GraphOptimizationLevel} + session::{Session, builder::GraphOptimizationLevel}, + value::TensorRef }; use rand::Rng; use tokenizers::Tokenizer; @@ -54,7 +55,7 @@ fn main() -> ort::Result<()> { for _ in 0..GEN_TOKENS { let array = tokens.view().insert_axis(Axis(0)).insert_axis(Axis(1)); - let outputs = session.run(inputs![array]?)?; + let outputs = session.run(inputs![TensorRef::from_array_view(array)?])?; let generated_tokens: ArrayViewD = outputs["output1"].try_extract_tensor()?; // Collect and sort logits diff --git a/examples/modnet/examples/modnet.rs b/examples/modnet/examples/modnet.rs index 6c02a26b..5ce39a93 100644 --- a/examples/modnet/examples/modnet.rs +++ b/examples/modnet/examples/modnet.rs @@ -4,7 +4,7 @@ use std::{ops::Mul, path::Path}; use image::{GenericImageView, ImageBuffer, Rgba, imageops::FilterType}; use ndarray::Array; -use ort::{execution_providers::CUDAExecutionProvider, inputs, session::Session}; +use ort::{execution_providers::CUDAExecutionProvider, inputs, session::Session, value::TensorRef}; use show_image::{AsImageView, WindowOptions, event}; #[show_image::main] @@ -31,7 +31,7 @@ fn main() -> ort::Result<()> { input[[0, 2, y, x]] = (b as f32 - 127.5) / 127.5; } - let outputs = model.run(inputs!["input" => input.view()]?)?; + let outputs = model.run(inputs!["input" => TensorRef::from_array_view(input.view())?])?; let output = outputs["output"].try_extract_tensor::()?; diff --git a/examples/phi-3-vision/src/main.rs b/examples/phi-3-vision/src/main.rs index 3849a74f..ffd8cf31 100644 --- a/examples/phi-3-vision/src/main.rs +++ b/examples/phi-3-vision/src/main.rs @@ -4,7 +4,10 @@ use std::{path::Path, time::Instant}; use anyhow::Result; use image::DynamicImage; use ndarray::{Array, Array2, Array3, Array4, ArrayView, Ix3, Ix4, s}; -use ort::{session::Session, value::Tensor}; +use ort::{ + session::Session, + value::{Tensor, TensorRef} +}; use tokenizers::Tokenizer; const VISION_MODEL_NAME: &str = "phi-3-v-128k-instruct-vision.onnx"; @@ -31,11 +34,10 @@ fn get_image_embedding(vision_model: &Session, img: &Option) -> Re pixel_values = result.pixel_values.shape(), image_sizes = result.image_sizes.shape(), ); - let model_inputs = ort::inputs![ - "pixel_values" => result.pixel_values, - "image_sizes" => result.image_sizes, - ]?; - let outputs = vision_model.run(model_inputs)?; + let outputs = vision_model.run(ort::inputs![ + "pixel_values" => Tensor::from_array(result.pixel_values)?, + "image_sizes" => Tensor::from_array(result.image_sizes)?, + ])?; let predictions_view: ArrayView = outputs["visual_features"].try_extract_tensor::()?; predictions_view.into_dimensionality::()?.to_owned() } else { @@ -45,10 +47,9 @@ fn get_image_embedding(vision_model: &Session, img: &Option) -> Re } fn get_text_embedding(text_embedding_model: &Session, input_ids: &Array2) -> Result> { - let model_inputs = ort::inputs![ - "input_ids" => input_ids.to_owned(), - ]?; - let outputs = text_embedding_model.run(model_inputs)?; + let outputs = text_embedding_model.run(ort::inputs![ + "input_ids" => TensorRef::from_array_view(input_ids)?, + ])?; let inputs_embeds_view: ArrayView = outputs["inputs_embeds"].try_extract_tensor::()?; let inputs_embeds = inputs_embeds_view.into_dimensionality::()?.to_owned(); Ok(inputs_embeds) @@ -144,12 +145,12 @@ pub async fn generate_text( // Prepare model inputs let model_inputs = { let mut model_inputs = ort::inputs![ - "inputs_embeds" => next_inputs_embeds.clone(), - "attention_mask" => attention_mask.clone(), - ]?; + "inputs_embeds" => TensorRef::from_array_view(&next_inputs_embeds)?, + "attention_mask" => TensorRef::from_array_view(&attention_mask)?, + ]; for i in 0..32 { - model_inputs.push((format!("past_key_values.{}.key", i).into(), Tensor::from_array(past_key_values[i * 2].view())?.into())); - model_inputs.push((format!("past_key_values.{}.value", i).into(), Tensor::from_array(past_key_values[i * 2 + 1].view())?.into())); + model_inputs.push((format!("past_key_values.{}.key", i).into(), TensorRef::from_array_view(&past_key_values[i * 2])?.into())); + model_inputs.push((format!("past_key_values.{}.value", i).into(), TensorRef::from_array_view(&past_key_values[i * 2 + 1])?.into())); } model_inputs }; diff --git a/examples/sentence-transformers/examples/semantic-similarity.rs b/examples/sentence-transformers/examples/semantic-similarity.rs index 96901cd3..6aed6e63 100644 --- a/examples/sentence-transformers/examples/semantic-similarity.rs +++ b/examples/sentence-transformers/examples/semantic-similarity.rs @@ -1,10 +1,11 @@ use std::path::Path; -use ndarray::{Array2, Axis, Ix2}; +use ndarray::{Axis, Ix2}; use ort::{ Error, execution_providers::CUDAExecutionProvider, - session::{Session, builder::GraphOptimizationLevel} + session::{Session, builder::GraphOptimizationLevel}, + value::TensorRef }; use tokenizers::Tokenizer; @@ -45,11 +46,11 @@ fn main() -> ort::Result<()> { let mask: Vec = encodings.iter().flat_map(|e| e.get_attention_mask().iter().map(|i| *i as i64)).collect(); // Convert our flattened arrays into 2-dimensional tensors of shape [N, L]. - let a_ids = Array2::from_shape_vec([inputs.len(), padded_token_length], ids).unwrap(); - let a_mask = Array2::from_shape_vec([inputs.len(), padded_token_length], mask).unwrap(); + let a_ids = TensorRef::from_array_view(([inputs.len(), padded_token_length], &*ids))?; + let a_mask = TensorRef::from_array_view(([inputs.len(), padded_token_length], &*mask))?; // Run the model. - let outputs = session.run(ort::inputs![a_ids, a_mask]?)?; + let outputs = session.run(ort::inputs![a_ids, a_mask])?; // Extract our embeddings tensor and convert it to a strongly-typed 2-dimensional array. let embeddings = outputs[1].try_extract_tensor::()?.into_dimensionality::().unwrap(); diff --git a/examples/training/examples/train-clm-simple.rs b/examples/training/examples/train-clm-simple.rs index 647dbe7c..187cf827 100644 --- a/examples/training/examples/train-clm-simple.rs +++ b/examples/training/examples/train-clm-simple.rs @@ -5,12 +5,12 @@ use std::{ }; use kdam::BarExt; -use ndarray::{Array1, Array2, ArrayViewD, Axis, concatenate, s}; use ort::{ execution_providers::CUDAExecutionProvider, memory::Allocator, session::{Session, builder::SessionBuilder}, - training::{CheckpointStrategy, Trainer, TrainerCallbacks, TrainerControl, TrainerState, TrainingArguments} + training::{CheckpointStrategy, Trainer, TrainerCallbacks, TrainerControl, TrainerState, TrainingArguments}, + value::{Tensor, TensorRef} }; use rand::RngCore; use tokenizers::Tokenizer; @@ -94,10 +94,10 @@ fn main() -> ort::Result<()> { .unwrap(); } - Ok(( - ort::inputs![Array2::::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap()]?, - ort::inputs![Array1::::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap()]? - )) + let inputs = Tensor::from_array(([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect::>()))?; + let labels = Tensor::from_array(([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect::>()))?; + + Ok((ort::inputs![inputs], ort::inputs![labels])) }; trainer.train( @@ -115,26 +115,19 @@ fn main() -> ort::Result<()> { let mut stdout = std::io::stdout(); let tokens = tokenizer.encode("<|endoftext|>", false).unwrap(); - let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); - - let mut tokens = Array1::from_iter(tokens.iter().cloned()); + let mut tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); for _ in 0..50 { - let array = tokens.view().insert_axis(Axis(0)); - let outputs = session.run(ort::inputs![array]?)?; - let generated_tokens: ArrayViewD = outputs["probs"].try_extract_tensor()?; - - let probabilities = &mut generated_tokens - .slice(s![-1, ..]) - .to_owned() - .iter() - .cloned() - .enumerate() - .collect::>(); + let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?; + let outputs = session.run(ort::inputs![input])?; + let (dim, probabilities) = outputs["probs"].try_extract_raw_tensor()?; + + let (seq_len, vocab_size) = (dim[2] as usize, dim[3] as usize); + let mut probabilities: Vec<(usize, f32)> = probabilities[(seq_len - 1) * vocab_size..].iter().copied().enumerate().collect(); probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less)); - let token = probabilities[0].0; - tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]]; + let token = probabilities[0].0 as i64; + tokens.push(token); let token_str = tokenizer.decode(&[token as _], false).unwrap(); print!("{}", token_str); diff --git a/examples/training/examples/train-clm.rs b/examples/training/examples/train-clm.rs index b854bbb2..2c29c69d 100644 --- a/examples/training/examples/train-clm.rs +++ b/examples/training/examples/train-clm.rs @@ -5,12 +5,12 @@ use std::{ }; use kdam::BarExt; -use ndarray::{Array1, Array2, ArrayViewD, Axis, concatenate, s}; use ort::{ execution_providers::CUDAExecutionProvider, memory::Allocator, session::{Session, builder::SessionBuilder}, - training::{Checkpoint, Trainer} + training::{Checkpoint, Trainer}, + value::{Tensor, TensorRef} }; use rand::RngCore; use tokenizers::Tokenizer; @@ -83,10 +83,10 @@ fn main() -> ort::Result<()> { .unwrap(); } - let inputs = Array2::::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap(); - let labels = Array1::::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap(); + let inputs = Tensor::from_array(([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect::>()))?; + let labels = Tensor::from_array(([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect::>()))?; - let outputs = trainer.step(ort::inputs![inputs.view()]?, ort::inputs![labels.view()]?)?; + let outputs = trainer.step(ort::inputs![inputs], ort::inputs![labels])?; let loss = outputs[0].try_extract_scalar::()?; pb.set_postfix(format!("loss={loss:.3}")); pb.update(1).unwrap(); @@ -107,26 +107,19 @@ fn main() -> ort::Result<()> { let mut stdout = std::io::stdout(); let tokens = tokenizer.encode("<|endoftext|>", false).unwrap(); - let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); - - let mut tokens = Array1::from_iter(tokens.iter().cloned()); + let mut tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); for _ in 0..50 { - let array = tokens.view().insert_axis(Axis(0)); - let outputs = session.run(ort::inputs![array]?)?; - let generated_tokens: ArrayViewD = outputs["probs"].try_extract_tensor()?; - - let probabilities = &mut generated_tokens - .slice(s![-1, ..]) - .to_owned() - .iter() - .cloned() - .enumerate() - .collect::>(); + let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?; + let outputs = session.run(ort::inputs![input])?; + let (dim, probabilities) = outputs["probs"].try_extract_raw_tensor()?; + + let (seq_len, vocab_size) = (dim[2] as usize, dim[3] as usize); + let mut probabilities: Vec<(usize, f32)> = probabilities[(seq_len - 1) * vocab_size..].iter().copied().enumerate().collect(); probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less)); - let token = probabilities[0].0; - tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]]; + let token = probabilities[0].0 as i64; + tokens.push(token); let token_str = tokenizer.decode(&[token as _], false).unwrap(); print!("{}", token_str); diff --git a/examples/yolov8/examples/yolov8.rs b/examples/yolov8/examples/yolov8.rs index f1f90c37..7c24ddf5 100644 --- a/examples/yolov8/examples/yolov8.rs +++ b/examples/yolov8/examples/yolov8.rs @@ -7,7 +7,8 @@ use ndarray::{Array, Axis, s}; use ort::{ execution_providers::CUDAExecutionProvider, inputs, - session::{Session, SessionOutputs} + session::{Session, SessionOutputs}, + value::TensorRef }; use raqote::{DrawOptions, DrawTarget, LineJoin, PathBuilder, SolidSource, Source, StrokeStyle}; use show_image::{AsImageView, WindowOptions, event}; @@ -66,7 +67,7 @@ fn main() -> ort::Result<()> { let model = Session::builder()?.commit_from_url(YOLOV8M_URL)?; // Run YOLOv8 inference - let outputs: SessionOutputs = model.run(inputs!["images" => input.view()]?)?; + let outputs: SessionOutputs = model.run(inputs!["images" => TensorRef::from_array_view(&input)?])?; let output = outputs["output0"].try_extract_tensor::()?.t().into_owned(); let mut boxes = Vec::new(); diff --git a/src/adapter.rs b/src/adapter.rs index 5906430f..03698fab 100644 --- a/src/adapter.rs +++ b/src/adapter.rs @@ -173,7 +173,7 @@ mod tests { run_options.add_adapter(&lora)?; let output: Tensor = model - .run_with_options(crate::inputs![Tensor::::from_array(([4, 4], vec![1.0; 16]))?]?, &run_options)? + .run_with_options(crate::inputs![Tensor::::from_array(([4, 4], vec![1.0; 16]))?], &run_options)? .remove("output") .expect("") .downcast()?; @@ -198,7 +198,7 @@ mod tests { run_options.add_adapter(&lora)?; let output: Tensor = model - .run_with_options(crate::inputs![Tensor::::from_array(([4, 4], vec![1.0; 16]))?]?, &run_options)? + .run_with_options(crate::inputs![Tensor::::from_array(([4, 4], vec![1.0; 16]))?], &run_options)? .remove("output") .expect("") .downcast()?; diff --git a/src/operator/tests.rs b/src/operator/tests.rs index 8dd2d5e3..88d2e100 100644 --- a/src/operator/tests.rs +++ b/src/operator/tests.rs @@ -1,5 +1,3 @@ -use ndarray::{Array2, arr2}; - use crate::{ Result, operator::{ @@ -8,7 +6,8 @@ use crate::{ kernel::{Kernel, KernelAttributes, KernelContext} }, session::Session, - tensor::TensorElementType + tensor::TensorElementType, + value::Tensor }; struct CustomOpOne; @@ -82,8 +81,17 @@ fn test_custom_ops() -> crate::Result<()> { .with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)? .commit_from_file("tests/data/custom_op_test.onnx")?; - let values = session.run(crate::inputs![Array2::::zeros((3, 5)), Array2::::ones((3, 5))]?)?; - assert_eq!(values[0].try_extract_tensor::()?, arr2(&[[0, 1, 0, 3, 0], [5, 0, 7, 0, 9], [0, 11, 0, 13, 0]]).view().into_dyn()); + let allocator = session.allocator(); + let value1 = Tensor::::new(allocator, [3, 5])?; + let mut value2 = Tensor::::new(allocator, [3, 5])?; + { + let (_, data) = value2.extract_raw_tensor_mut(); + for datum in data { + *datum = 1.; + } + } + let values = session.run(crate::inputs![&value1, &value2])?; + assert_eq!(values[0].try_extract_raw_tensor::()?.1, [0, 1, 0, 3, 0, 5, 0, 7, 0, 9, 0, 11, 0, 13, 0]); Ok(()) } diff --git a/src/session/input.rs b/src/session/input.rs index 88658346..edec9116 100644 --- a/src/session/input.rs +++ b/src/session/input.rs @@ -35,6 +35,11 @@ impl From> for SessionInputValue<'_> { SessionInputValue::Owned(value.into_dyn()) } } +impl<'v, T: ValueTypeMarker + ?Sized> From<&'v Value> for SessionInputValue<'v> { + fn from(value: &'v Value) -> Self { + SessionInputValue::View(value.view().into_dyn()) + } +} /// The inputs to a [`Session::run`] call. /// @@ -123,35 +128,27 @@ impl<'v, const N: usize> From<[SessionInputValue<'v>; N]> for SessionInputs<'_, #[macro_export] macro_rules! inputs { ($($v:expr),+ $(,)?) => ( - (|| -> $crate::Result<_> { - Ok([$(::std::convert::Into::<$crate::session::SessionInputValue<'_>>::into(::std::convert::TryInto::<$crate::value::DynValue>::try_into($v).map_err($crate::Error::from)?)),+]) - })() + [$(::std::convert::Into::<$crate::session::SessionInputValue<'_>>::into($v)),+] ); ($($n:expr => $v:expr),+ $(,)?) => ( - (|| -> $crate::Result<_> { - Ok(vec![$( - ::std::convert::TryInto::<$crate::value::DynValue>::try_into($v) - .map_err($crate::Error::from) - .map(|v| (::std::borrow::Cow::::from($n), $crate::session::SessionInputValue::from(v)))?,)+]) - })() + vec![$((::std::borrow::Cow::::from($n), $crate::session::SessionInputValue::<'_>::from($v)),)+] ); } #[cfg(test)] mod tests { - use std::{collections::HashMap, sync::Arc}; + use std::collections::HashMap; use super::SessionInputs; - use crate::value::DynTensor; + use crate::value::{DynTensor, Tensor}; #[test] fn test_hashmap_static_keys() -> crate::Result<()> { let v: Vec = vec![1., 2., 3., 4., 5.]; - let arc = Arc::new(v.clone().into_boxed_slice()); let shape = vec![v.len() as i64]; let mut inputs: HashMap<&str, DynTensor> = HashMap::new(); - inputs.insert("test", (shape, arc).try_into()?); + inputs.insert("test", Tensor::from_array((shape, v))?.upcast()); let _ = SessionInputs::from(inputs); Ok(()) @@ -160,11 +157,10 @@ mod tests { #[test] fn test_hashmap_string_keys() -> crate::Result<()> { let v: Vec = vec![1., 2., 3., 4., 5.]; - let arc = Arc::new(v.clone().into_boxed_slice()); let shape = vec![v.len() as i64]; let mut inputs: HashMap = HashMap::new(); - inputs.insert("test".to_string(), (shape, arc).try_into()?); + inputs.insert("test".to_string(), Tensor::from_array((shape, v))?.upcast()); let _ = SessionInputs::from(inputs); Ok(()) diff --git a/src/value/impl_tensor/create.rs b/src/value/impl_tensor/create.rs index a11b654a..6409397a 100644 --- a/src/value/impl_tensor/create.rs +++ b/src/value/impl_tensor/create.rs @@ -7,17 +7,18 @@ use std::{ sync::Arc }; +use ndarray::ArrayViewMut; #[cfg(feature = "ndarray")] use ndarray::{ArcArray, Array, ArrayView, CowArray, Dimension}; -use super::{DynTensor, Tensor, TensorRefMut, calculate_tensor_size}; +use super::{Tensor, TensorRef, TensorRefMut, calculate_tensor_size}; use crate::{ AsPointer, error::{Error, ErrorCode, Result, assert_non_null_pointer}, memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType}, ortsys, tensor::{PrimitiveTensorElementType, TensorElementType, Utf8Data}, - value::{DynValue, Value, ValueInner, ValueType} + value::{Value, ValueInner, ValueType} }; impl Tensor { @@ -48,7 +49,7 @@ impl Tensor { /// ``` /// /// Note that string data will *always* be copied, no matter what form the data is provided in. - pub fn from_string_array(input: impl IntoValueTensor) -> Result> { + pub fn from_string_array(input: impl TensorArrayData) -> Result> { let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); let (shape, data) = input.ref_parts()?; @@ -180,13 +181,13 @@ impl Tensor { /// /// Raw data provided as a `Arc>`, `Box<[T]>`, or `Vec` will never be copied. Raw data is expected to be /// in standard, contigous layout. - pub fn from_array(input: impl IntoValueTensor) -> Result> { + pub fn from_array(input: impl OwnedTensorArrayData) -> Result> { let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::CPUInput)?; let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); // f16 and bf16 are repr(transparent) to u16, so memory layout should be identical to onnxruntime - let (shape, ptr, ptr_len, guard) = input.into_parts()?; + let TensorArrayDataParts { shape, ptr, num_elements, guard } = input.into_parts()?; let shape_ptr: *const i64 = shape.as_ptr(); let shape_len = shape.len(); @@ -197,7 +198,7 @@ impl Tensor { unsafe CreateTensorWithDataAsOrtValue( memory_info.ptr(), tensor_values_ptr, - ptr_len * std::mem::size_of::(), + num_elements * std::mem::size_of::(), shape_ptr, shape_len, T::into_tensor_element_type().into(), @@ -223,7 +224,99 @@ impl Tensor { } } +impl<'a, T: PrimitiveTensorElementType + Debug> TensorRef<'a, T> { + pub fn from_array_view(input: impl TensorArrayData) -> Result> { + let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::CPUInput)?; + + let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); + + // f16 and bf16 are repr(transparent) to u16, so memory layout should be identical to onnxruntime + let (shape, data) = input.ref_parts()?; + let num_elements = calculate_tensor_size(&shape); + let shape_ptr: *const i64 = shape.as_ptr(); + let shape_len = shape.len(); + + let tensor_values_ptr: *mut std::ffi::c_void = data.as_ptr() as *mut _; + assert_non_null_pointer(tensor_values_ptr, "TensorValues")?; + + ortsys![ + unsafe CreateTensorWithDataAsOrtValue( + memory_info.ptr(), + tensor_values_ptr, + num_elements * std::mem::size_of::(), + shape_ptr, + shape_len, + T::into_tensor_element_type().into(), + &mut value_ptr + )?; + nonNull(value_ptr) + ]; + + let mut tensor = TensorRef::new(Tensor { + inner: Arc::new(ValueInner { + ptr: unsafe { NonNull::new_unchecked(value_ptr) }, + dtype: ValueType::Tensor { + ty: T::into_tensor_element_type(), + dimensions: shape, + dimension_symbols: vec![None; shape_len] + }, + drop: true, + memory_info: Some(memory_info), + _backing: None + }), + _markers: PhantomData + }); + tensor.upgradable = false; + Ok(tensor) + } +} + impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { + pub fn from_array_view_mut(mut input: impl TensorArrayDataMut) -> Result> { + let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::CPUInput)?; + + let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); + + // f16 and bf16 are repr(transparent) to u16, so memory layout should be identical to onnxruntime + let (shape, data) = input.ref_parts_mut()?; + let num_elements = calculate_tensor_size(&shape); + let shape_ptr: *const i64 = shape.as_ptr(); + let shape_len = shape.len(); + + let tensor_values_ptr: *mut std::ffi::c_void = data.as_ptr() as *mut _; + assert_non_null_pointer(tensor_values_ptr, "TensorValues")?; + + ortsys![ + unsafe CreateTensorWithDataAsOrtValue( + memory_info.ptr(), + tensor_values_ptr, + num_elements * std::mem::size_of::(), + shape_ptr, + shape_len, + T::into_tensor_element_type().into(), + &mut value_ptr + )?; + nonNull(value_ptr) + ]; + + let mut tensor = TensorRefMut::new(Tensor { + inner: Arc::new(ValueInner { + ptr: unsafe { NonNull::new_unchecked(value_ptr) }, + dtype: ValueType::Tensor { + ty: T::into_tensor_element_type(), + dimensions: shape, + dimension_symbols: vec![None; shape_len] + }, + drop: true, + memory_info: Some(memory_info), + _backing: None + }), + _markers: PhantomData + }); + tensor.upgradable = false; + Ok(tensor) + } + /// Create a mutable tensor view from a raw pointer and shape. /// /// The length of data is determined by `T` and the given shape, so the given buffer must be at least @@ -269,7 +362,7 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { nonNull(value_ptr) ]; - Ok(TensorRefMut::new(Value { + let mut tensor = TensorRefMut::new(Value { inner: Arc::new(ValueInner { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, dtype: ValueType::Tensor { @@ -282,16 +375,29 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { _backing: None }), _markers: PhantomData - })) + }); + tensor.upgradable = false; + Ok(tensor) } } -pub trait IntoValueTensor { - type Item; +pub trait TensorArrayData { + fn ref_parts(&self) -> Result<(Vec, &[I])>; +} + +pub trait TensorArrayDataMut: TensorArrayData { + fn ref_parts_mut(&mut self) -> Result<(Vec, &mut [I])>; +} + +pub trait OwnedTensorArrayData: TensorArrayData { + fn into_parts(self) -> Result>; +} - fn ref_parts(&self) -> Result<(Vec, &[Self::Item])>; - #[allow(clippy::type_complexity)] - fn into_parts(self) -> Result<(Vec, *mut Self::Item, usize, Box)>; +pub struct TensorArrayDataParts { + pub shape: Vec, + pub ptr: *mut I, + pub num_elements: usize, + pub guard: Box } pub trait ToDimensions { @@ -355,302 +461,180 @@ impl_to_dimensions!( for [usize; N], for [i32; N], for [i64; N]); #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'i, 'v, T: Clone + 'static, D: Dimension + 'static> IntoValueTensor for &'i CowArray<'v, T, D> -where - 'v: 'i -{ - type Item = T; - - fn ref_parts(&self) -> Result<(Vec, &[Self::Item])> { +impl TensorArrayData for &CowArray<'_, T, D> { + fn ref_parts(&self) -> Result<(Vec, &[T])> { let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); let data = self .as_slice() .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; Ok((shape, data)) } - - fn into_parts(self) -> Result<(Vec, *mut Self::Item, usize, Box)> { - // This will result in a copy in either form of the CowArray - let mut contiguous_array = self.as_standard_layout().into_owned(); - let shape: Vec = contiguous_array.shape().iter().map(|d| *d as i64).collect(); - let ptr = contiguous_array.as_mut_ptr(); - let ptr_len = contiguous_array.len(); - let guard = Box::new(contiguous_array); - Ok((shape, ptr, ptr_len, guard)) - } } #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl IntoValueTensor for &mut ArcArray { - type Item = T; - - fn ref_parts(&self) -> Result<(Vec, &[Self::Item])> { +impl TensorArrayData for ArcArray { + fn ref_parts(&self) -> Result<(Vec, &[T])> { let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); let data = self .as_slice() .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; Ok((shape, data)) } +} - fn into_parts(self) -> Result<(Vec, *mut Self::Item, usize, Box)> { - if self.is_standard_layout() { - // We can avoid the copy here and use the data as is - let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); - let ptr = self.as_mut_ptr(); - let ptr_len = self.len(); - let guard = Box::new(self.clone()); - Ok((shape, ptr, ptr_len, guard)) - } else { - // Need to do a copy here to get data in to standard layout - let mut contiguous_array = self.as_standard_layout().into_owned(); - let shape: Vec = contiguous_array.shape().iter().map(|d| *d as i64).collect(); - let ptr = contiguous_array.as_mut_ptr(); - let ptr_len: usize = contiguous_array.len(); - let guard = Box::new(contiguous_array); - Ok((shape, ptr, ptr_len, guard)) - } +#[cfg(feature = "ndarray")] +#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] +impl TensorArrayData for Array { + fn ref_parts(&self) -> Result<(Vec, &[T])> { + let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); + let data = self + .as_slice() + .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; + Ok((shape, data)) } } #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl IntoValueTensor for Array { - type Item = T; - - fn ref_parts(&self) -> Result<(Vec, &[Self::Item])> { +impl TensorArrayData for &Array { + fn ref_parts(&self) -> Result<(Vec, &[T])> { let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); let data = self .as_slice() .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; Ok((shape, data)) } +} - fn into_parts(self) -> Result<(Vec, *mut Self::Item, usize, Box)> { +#[cfg(feature = "ndarray")] +#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] +impl OwnedTensorArrayData for Array { + fn into_parts(self) -> Result> { if self.is_standard_layout() { // We can avoid the copy here and use the data as is let mut guard = Box::new(self); let shape: Vec = guard.shape().iter().map(|d| *d as i64).collect(); let ptr = guard.as_mut_ptr(); - let ptr_len = guard.len(); - Ok((shape, ptr, ptr_len, guard)) + let num_elements = guard.len(); + Ok(TensorArrayDataParts { shape, ptr, num_elements, guard }) } else { // Need to do a copy here to get data in to standard layout let mut contiguous_array = self.as_standard_layout().into_owned(); let shape: Vec = contiguous_array.shape().iter().map(|d| *d as i64).collect(); let ptr = contiguous_array.as_mut_ptr(); - let ptr_len: usize = contiguous_array.len(); + let num_elements: usize = contiguous_array.len(); let guard = Box::new(contiguous_array); - Ok((shape, ptr, ptr_len, guard)) + Ok(TensorArrayDataParts { shape, ptr, num_elements, guard }) } } } #[cfg(feature = "ndarray")] -impl IntoValueTensor for ArrayView<'_, T, D> { - type Item = T; - - fn ref_parts(&self) -> Result<(Vec, &[Self::Item])> { +#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] +impl TensorArrayData for ArrayView<'_, T, D> { + fn ref_parts(&self) -> Result<(Vec, &[T])> { let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); let data = self .as_slice() .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; Ok((shape, data)) } +} - fn into_parts(self) -> Result<(Vec, *mut Self::Item, usize, Box)> { - // This will result in a copy in either form of the ArrayView - let mut contiguous_array = self.as_standard_layout().into_owned(); - let shape: Vec = contiguous_array.shape().iter().map(|d| *d as i64).collect(); - let ptr = contiguous_array.as_mut_ptr(); - let ptr_len = contiguous_array.len(); - let guard = Box::new(contiguous_array); - Ok((shape, ptr, ptr_len, guard)) +#[cfg(feature = "ndarray")] +#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] +impl TensorArrayData for ArrayViewMut<'_, T, D> { + fn ref_parts(&self) -> Result<(Vec, &[T])> { + let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); + let data = self + .as_slice() + .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; + Ok((shape, data)) } } -impl IntoValueTensor for (D, &[T]) { - type Item = T; +#[cfg(feature = "ndarray")] +#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] +impl TensorArrayDataMut for ArrayViewMut<'_, T, D> { + fn ref_parts_mut(&mut self) -> Result<(Vec, &mut [T])> { + let shape: Vec = self.shape().iter().map(|d| *d as i64).collect(); + let data = self + .as_slice_mut() + .ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?; + Ok((shape, data)) + } +} - fn ref_parts(&self) -> Result<(Vec, &[Self::Item])> { +impl TensorArrayData for (D, &[T]) { + fn ref_parts(&self) -> Result<(Vec, &[T])> { let shape = self.0.to_dimensions(Some(self.1.len()))?; Ok((shape, self.1)) } +} - fn into_parts(self) -> Result<(Vec, *mut Self::Item, usize, Box)> { +impl TensorArrayData for (D, &mut [T]) { + fn ref_parts(&self) -> Result<(Vec, &[T])> { let shape = self.0.to_dimensions(Some(self.1.len()))?; - let mut data = self.1.to_vec(); - let ptr = data.as_mut_ptr(); - let ptr_len: usize = data.len(); - Ok((shape, ptr, ptr_len, Box::new(data))) + Ok((shape, self.1)) } } -impl IntoValueTensor for (D, Vec) { - type Item = T; +impl TensorArrayDataMut for (D, &mut [T]) { + fn ref_parts_mut(&mut self) -> Result<(Vec, &mut [T])> { + let shape = self.0.to_dimensions(Some(self.1.len()))?; + Ok((shape, self.1)) + } +} - fn ref_parts(&self) -> Result<(Vec, &[Self::Item])> { +impl TensorArrayData for (D, Vec) { + fn ref_parts(&self) -> Result<(Vec, &[T])> { let shape = self.0.to_dimensions(Some(self.1.len()))?; let data = &*self.1; Ok((shape, data)) } +} - fn into_parts(mut self) -> Result<(Vec, *mut Self::Item, usize, Box)> { +impl OwnedTensorArrayData for (D, Vec) { + fn into_parts(mut self) -> Result> { let shape = self.0.to_dimensions(Some(self.1.len()))?; let ptr = self.1.as_mut_ptr(); - let ptr_len: usize = self.1.len(); - Ok((shape, ptr, ptr_len, Box::new(self.1))) + let num_elements: usize = self.1.len(); + Ok(TensorArrayDataParts { + shape, + ptr, + num_elements, + guard: Box::new(self.1) + }) } } -impl IntoValueTensor for (D, Box<[T]>) { - type Item = T; - - fn ref_parts(&self) -> Result<(Vec, &[Self::Item])> { +impl TensorArrayData for (D, Box<[T]>) { + fn ref_parts(&self) -> Result<(Vec, &[T])> { let shape = self.0.to_dimensions(Some(self.1.len()))?; let data = &*self.1; Ok((shape, data)) } +} - fn into_parts(mut self) -> Result<(Vec, *mut Self::Item, usize, Box)> { +impl OwnedTensorArrayData for (D, Box<[T]>) { + fn into_parts(mut self) -> Result> { let shape = self.0.to_dimensions(Some(self.1.len()))?; let ptr = self.1.as_mut_ptr(); - let ptr_len: usize = self.1.len(); - Ok((shape, ptr, ptr_len, Box::new(self.1))) + let num_elements: usize = self.1.len(); + Ok(TensorArrayDataParts { + shape, + ptr, + num_elements, + guard: Box::new(self.1) + }) } } -impl IntoValueTensor for (D, Arc>) { - type Item = T; - - fn ref_parts(&self) -> Result<(Vec, &[Self::Item])> { +impl TensorArrayData for (D, Arc>) { + fn ref_parts(&self) -> Result<(Vec, &[T])> { let shape = self.0.to_dimensions(Some(self.1.len()))?; let data = &*self.1; Ok((shape, data)) } - - fn into_parts(mut self) -> Result<(Vec, *mut Self::Item, usize, Box)> { - let shape = self.0.to_dimensions(Some(self.1.len()))?; - let ptr = std::sync::Arc::>::make_mut(&mut self.1).as_mut_ptr(); - let ptr_len: usize = self.1.len(); - let guard = Box::new(Arc::clone(&self.1)); - Ok((shape, ptr, ptr_len, guard)) - } -} - -#[cfg(feature = "ndarray")] -#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for Tensor -where - 'v: 'i -{ - type Error = Error; - fn try_from(arr: &'i CowArray<'v, T, D>) -> Result { - Tensor::from_array(arr) - } -} - -#[cfg(feature = "ndarray")] -#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for Tensor { - type Error = Error; - fn try_from(arr: ArrayView<'v, T, D>) -> Result { - Tensor::from_array(arr) - } -} - -#[cfg(feature = "ndarray")] -#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynTensor -where - 'v: 'i -{ - type Error = Error; - fn try_from(arr: &'i CowArray<'v, T, D>) -> Result { - Tensor::from_array(arr).map(|c| c.upcast()) - } -} - -#[cfg(feature = "ndarray")] -#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynTensor { - type Error = Error; - fn try_from(arr: ArrayView<'v, T, D>) -> Result { - Tensor::from_array(arr).map(|c| c.upcast()) - } -} - -#[cfg(feature = "ndarray")] -#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynValue -where - 'v: 'i -{ - type Error = Error; - fn try_from(arr: &'i CowArray<'v, T, D>) -> Result { - Tensor::from_array(arr).map(|c| c.into_dyn()) - } } - -#[cfg(feature = "ndarray")] -#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynValue { - type Error = Error; - fn try_from(arr: ArrayView<'v, T, D>) -> Result { - Tensor::from_array(arr).map(|c| c.into_dyn()) - } -} - -macro_rules! impl_try_from { - (@T,I $($t:ty),+) => { - $( - impl TryFrom<$t> for Tensor { - type Error = Error; - fn try_from(value: $t) -> Result { - Tensor::from_array(value) - } - } - impl TryFrom<$t> for DynTensor { - type Error = Error; - fn try_from(value: $t) -> Result { - Tensor::from_array(value).map(|c| c.upcast()) - } - } - impl TryFrom<$t> for crate::value::DynValue { - type Error = Error; - fn try_from(value: $t) -> Result { - Tensor::from_array(value).map(|c| c.into_dyn()) - } - } - )+ - }; - (@T,D $($t:ty),+) => { - $( - #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - impl TryFrom<$t> for Tensor { - type Error = Error; - fn try_from(value: $t) -> Result { - Tensor::from_array(value) - } - } - #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - impl TryFrom<$t> for DynTensor { - type Error = Error; - fn try_from(value: $t) -> Result { - Tensor::from_array(value).map(|c| c.upcast()) - } - } - #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - impl TryFrom<$t> for crate::value::DynValue { - type Error = Error; - fn try_from(value: $t) -> Result { - Tensor::from_array(value).map(|c| c.into_dyn()) - } - } - )+ - }; -} - -#[cfg(feature = "ndarray")] -impl_try_from!(@T,D &mut ArcArray, Array); -impl_try_from!(@T,I (I, Arc>), (I, Vec), (I, Box<[T]>), (I, &[T])); diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index 34375b0d..ebdc709c 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -271,7 +271,11 @@ mod tests { use ndarray::{ArcArray1, Array1, CowArray}; use super::Tensor; - use crate::{memory::Allocator, tensor::TensorElementType, value::ValueType}; + use crate::{ + memory::Allocator, + tensor::TensorElementType, + value::{TensorRef, ValueType} + }; #[test] #[cfg(feature = "ndarray")] @@ -298,18 +302,18 @@ mod tests { let v: Vec = vec![1., 2., 3., 4., 5.]; let arc1 = ArcArray1::from_vec(v.clone()); - let mut arc2 = ArcArray1::clone(&arc1); - let value = Tensor::from_array(&mut arc2)?; + let arc2 = ArcArray1::clone(&arc1); + let value = TensorRef::from_array_view(arc2.clone())?; drop((arc1, arc2)); assert_eq!(value.extract_raw_tensor().1, &v); let cow = CowArray::from(Array1::from_vec(v.clone())); - let value = Tensor::from_array(&cow)?; + let value = TensorRef::from_array_view(&cow)?; assert_eq!(value.extract_raw_tensor().1, &v); let owned = Array1::from_vec(v.clone()); - let value = Tensor::from_array(owned.view())?; + let value = TensorRef::from_array_view(owned.view())?; drop(owned); assert_eq!(value.extract_raw_tensor().1, &v); @@ -322,7 +326,7 @@ mod tests { let arc = Arc::new(v.clone().into_boxed_slice()); let shape = vec![v.len() as i64]; - let value = Tensor::from_array((shape, Arc::clone(&arc)))?; + let value = TensorRef::from_array_view((shape, Arc::clone(&arc)))?; drop(arc); assert_eq!(value.try_extract_raw_tensor::()?.1, &v); @@ -358,10 +362,10 @@ mod tests { let v: Vec = vec![1., 2., 3., 4., 5.]; let shape = [v.len()]; - let value_arc_box = Tensor::from_array((shape, Arc::new(v.clone().into_boxed_slice())))?; + let value_arc_box = TensorRef::from_array_view((shape, Arc::new(v.clone().into_boxed_slice())))?; let value_box = Tensor::from_array((shape, v.clone().into_boxed_slice()))?; let value_vec = Tensor::from_array((shape, v.clone()))?; - let value_slice = Tensor::from_array((shape, &v[..]))?; + let value_slice = TensorRef::from_array_view((shape, &v[..]))?; assert_eq!(value_arc_box.extract_raw_tensor().1, &v); assert_eq!(value_box.extract_raw_tensor().1, &v); diff --git a/src/value/mod.rs b/src/value/mod.rs index 8616592b..54142043 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -78,12 +78,19 @@ impl Drop for ValueInner { #[derive(Debug)] pub struct ValueRef<'v, Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> { inner: Value, + pub(crate) upgradable: bool, lifetime: PhantomData<&'v ()> } impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> { pub(crate) fn new(inner: Value) -> Self { - ValueRef { inner, lifetime: PhantomData } + ValueRef { + // We cannot upgade a value which we cannot drop, i.e. `ValueRef`s used in operator kernels. Those only last for the + // duration of the kernel, allowing an upgrade would allow a UAF. + upgradable: inner.inner.drop, + inner, + lifetime: PhantomData + } } /// Attempts to downcast a temporary dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed @@ -100,9 +107,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> { /// Attempts to upgrade this `ValueRef` to an owned [`Value`] holding the same data. pub fn try_upgrade(self) -> Result, Self> { - // We cannot upgade a value which we cannot drop, i.e. `ValueRef`s used in operator kernels. Those only last for the - // duration of the kernel, allowing an upgrade would allow a UAF. - if !self.inner.inner.drop { + if !self.upgradable { return Err(self); } @@ -126,12 +131,19 @@ impl Deref for ValueRef<'_, Type> { #[derive(Debug)] pub struct ValueRefMut<'v, Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> { inner: Value, + pub(crate) upgradable: bool, lifetime: PhantomData<&'v ()> } impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> { pub(crate) fn new(inner: Value) -> Self { - ValueRefMut { inner, lifetime: PhantomData } + ValueRefMut { + // We cannot upgade a value which we cannot drop, i.e. `ValueRef`s used in operator kernels. Those only last for the + // duration of the kernel, allowing an upgrade would allow a UAF. + upgradable: inner.inner.drop, + inner, + lifetime: PhantomData + } } /// Attempts to downcast a temporary mutable dynamic value (like [`DynValue`] or [`DynTensor`]) to a more @@ -148,9 +160,7 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> { /// Attempts to upgrade this `ValueRefMut` to an owned [`Value`] holding the same data. pub fn try_upgrade(self) -> Result, Self> { - // We cannot upgade a value which we cannot drop, i.e. `ValueRef`s used in operator kernels. Those only last for the - // duration of the kernel, allowing an upgrade would allow a UAF. - if !self.inner.inner.drop { + if !self.upgradable { return Err(self); } diff --git a/tests/mnist.rs b/tests/mnist.rs index 3e21484d..49b1f6dd 100644 --- a/tests/mnist.rs +++ b/tests/mnist.rs @@ -4,7 +4,8 @@ use image::{ImageBuffer, Luma, Pixel, imageops::FilterType}; use ort::{ inputs, session::{Session, builder::GraphOptimizationLevel}, - tensor::ArrayExtensions + tensor::ArrayExtensions, + value::TensorRef }; use test_log::test; @@ -45,7 +46,7 @@ fn mnist_5() -> ort::Result<()> { }); // Perform the inference - let outputs = session.run(inputs![array]?)?; + let outputs = session.run(inputs![TensorRef::from_array_view(&array)?])?; let mut probabilities: Vec<(usize, f32)> = outputs[0] .try_extract_tensor()? diff --git a/tests/squeezenet.rs b/tests/squeezenet.rs index de93f11d..fd4fac05 100644 --- a/tests/squeezenet.rs +++ b/tests/squeezenet.rs @@ -10,7 +10,8 @@ use ndarray::s; use ort::{ Error, inputs, session::{Session, builder::GraphOptimizationLevel}, - tensor::ArrayExtensions + tensor::ArrayExtensions, + value::TensorRef }; use test_log::test; @@ -70,7 +71,7 @@ fn squeezenet_mushroom() -> ort::Result<()> { } // Perform the inference - let outputs = session.run(inputs![array]?)?; + let outputs = session.run(inputs![TensorRef::from_array_view(&array)?])?; // Downloaded model does not have a softmax as final layer; call softmax on second axis // and iterate on resulting probabilities, creating an index to later access labels. diff --git a/tests/upsample.rs b/tests/upsample.rs index 747f4c29..47d12d76 100644 --- a/tests/upsample.rs +++ b/tests/upsample.rs @@ -4,7 +4,8 @@ use image::RgbImage; use ndarray::{Array, ArrayViewD, CowArray, Ix4}; use ort::{ inputs, - session::{Session, builder::GraphOptimizationLevel} + session::{Session, builder::GraphOptimizationLevel}, + value::TensorRef }; use test_log::test; @@ -69,7 +70,7 @@ fn upsample() -> ort::Result<()> { let array = convert_image_to_cow_array(&image_buffer); // Perform the inference - let outputs = session.run(inputs![&array]?)?; + let outputs = session.run(inputs![TensorRef::from_array_view(&array)?])?; assert_eq!(outputs.len(), 1); let output: ArrayViewD = outputs[0].try_extract_tensor()?; @@ -106,7 +107,7 @@ fn upsample_with_ort_model() -> ort::Result<()> { let array = convert_image_to_cow_array(&image_buffer); // Perform the inference - let outputs = session.run(inputs![&array]?)?; + let outputs = session.run(inputs![TensorRef::from_array_view(&array)?])?; assert_eq!(outputs.len(), 1); let output: ArrayViewD = outputs[0].try_extract_tensor()?; diff --git a/tests/vectorizer.rs b/tests/vectorizer.rs index 0528d3bc..00810857 100644 --- a/tests/vectorizer.rs +++ b/tests/vectorizer.rs @@ -26,11 +26,7 @@ fn vectorizer() -> ort::Result<()> { let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap()); - // Just one input - let input_tensor_values = inputs![Tensor::from_string_array(&array)?]?; - - // Perform the inference - let outputs = session.run(input_tensor_values)?; + let outputs = session.run(inputs![Tensor::from_string_array(&array)?])?; assert_eq!(outputs[0].try_extract_tensor::()?, ArrayD::from_shape_vec(IxDyn(&[1, 9]), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).unwrap()); Ok(())