Skip to content

Commit

Permalink
refactor!: allow zero-copy from_array for array views with TensorRef
Browse files Browse the repository at this point in the history
This has all sorts of fun breaking changes:
- `ort::inputs!` no longer yields an `ort::Result<...>` (thank God)
- `Tensor::from_array` now only accepts owned data.
- Introduce `TensorRef::from_array_view` and `TensorRefMut::from_array_view_mut`.
- `TryFrom<A>` is no longer implemented for `Tensor<T>` for any variants.

This opens the door to new optimizations on top of fixing a few unsoundness issues.

TODO: update docs
  • Loading branch information
decahedron1 committed Dec 21, 2024
1 parent d7d4493 commit 9ea18d8
Show file tree
Hide file tree
Showing 23 changed files with 379 additions and 381 deletions.
13 changes: 10 additions & 3 deletions benches/squeezenet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use std::{path::Path, sync::Arc};
use glassbench::{Bench, pretend_used};
use image::{ImageBuffer, Pixel, Rgb, imageops::FilterType};
use ndarray::{Array4, s};
use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::{
session::{Session, builder::GraphOptimizationLevel},
value::TensorRef
};

fn load_squeezenet_data() -> ort::Result<(Session, Array4<f32>)> {
const IMAGE_TO_LOAD: &str = "mushroom.png";
Expand Down Expand Up @@ -44,15 +47,19 @@ fn bench_squeezenet(bench: &mut Bench) {
let (session, data) = load_squeezenet_data().unwrap();
bench.task("ArrayView", |task| {
task.iter(|| {
pretend_used(session.run(ort::inputs![data.view()].unwrap()).unwrap());
pretend_used(session.run(ort::inputs![TensorRef::from_array_view(&data).unwrap()]).unwrap());
})
});

let raw = Arc::new(data.as_standard_layout().as_slice().unwrap().to_owned().into_boxed_slice());
let shape: Vec<i64> = data.shape().iter().map(|c| *c as _).collect();
bench.task("Raw data", |task| {
task.iter(|| {
pretend_used(session.run(ort::inputs![(shape.clone(), Arc::clone(&raw))].unwrap()).unwrap());
pretend_used(
session
.run(ort::inputs![TensorRef::from_array_view((shape.clone(), Arc::clone(&raw))).unwrap()])
.unwrap()
);
})
});
}
Expand Down
2 changes: 1 addition & 1 deletion examples/async-gpt2-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ tokio = { version = "1.36", features = [ "full" ] }
tokio-stream = "0.1"
tower-http = { version = "0.5", features = ["fs", "trace"] }
anyhow = "1.0"
async-stream = "0.3"
async-stream-lite = "0.2"

[features]
load-dynamic = [ "ort/load-dynamic" ]
Expand Down
37 changes: 15 additions & 22 deletions examples/async-gpt2-api/examples/async-gpt2-api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ use axum::{
routing::post
};
use futures::Stream;
use ndarray::{Array1, ArrayViewD, Axis, array, concatenate, s};
use ort::{
execution_providers::CUDAExecutionProvider,
inputs,
session::{Session, builder::GraphOptimizationLevel}
session::{Session, builder::GraphOptimizationLevel},
value::TensorRef
};
use rand::Rng;
use tokenizers::Tokenizer;
Expand Down Expand Up @@ -64,37 +63,31 @@ struct AppState {
tokenizer: Arc<Tokenizer>
}

fn generate_stream(tokenizer: Arc<Tokenizer>, session: Arc<Session>, tokens: Vec<i64>, gen_tokens: usize) -> impl Stream<Item = ort::Result<Event>> + Send {
async_stream::try_stream! {
let mut tokens = Array1::from_iter(tokens.iter().cloned());

fn generate_stream(tokenizer: Arc<Tokenizer>, session: Arc<Session>, mut tokens: Vec<i64>, gen_tokens: usize) -> impl Stream<Item = ort::Result<Event>> + Send {
async_stream_lite::try_async_stream(|yielder| async move {
for _ in 0..gen_tokens {
let array = tokens.view().insert_axis(Axis(0)).insert_axis(Axis(1));
let outputs = session.run_async(inputs![array]?)?.await?;
let generated_tokens: ArrayViewD<f32> = outputs["output1"].try_extract_tensor()?;
let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?;
let outputs = session.run_async(ort::inputs![input])?.await?;
let (dim, probabilities) = outputs["output1"].try_extract_raw_tensor()?;

// Collect and sort logits
let probabilities = &mut generated_tokens
.slice(s![0, 0, -1, ..])
.insert_axis(Axis(0))
.to_owned()
.iter()
.cloned()
.enumerate()
.collect::<Vec<_>>();
let (seq_len, vocab_size) = (dim[2] as usize, dim[3] as usize);
let mut probabilities: Vec<(usize, f32)> = probabilities[(seq_len - 1) * vocab_size..].iter().copied().enumerate().collect();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));

// Sample using top-k sampling
let token = {
let mut rng = rand::thread_rng();
probabilities[rng.gen_range(0..=5)].0
probabilities[rng.gen_range(0..=5)].0 as i64
};
tokens = concatenate![Axis(0), tokens, array![token.try_into().unwrap()]];
tokens.push(token);

let token_str = tokenizer.decode(&[token as _], true).unwrap();
yield Event::default().data(token_str);
yielder.r#yield(Event::default().data(token_str)).await;
}
}

Ok(())
})
}

impl FromRef<AppState> for Arc<Session> {
Expand Down
2 changes: 1 addition & 1 deletion examples/cudarc/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ fn main() -> anyhow::Result<()> {
)
.unwrap()
};
let outputs = model.run([tensor.into()])?;
let outputs = model.run(ort::inputs![tensor])?;

let output = outputs["output"].try_extract_tensor::<f32>()?;

Expand Down
15 changes: 12 additions & 3 deletions examples/custom-ops/examples/custom-ops.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use ndarray::Array2;
use ort::{
operator::{
Operator, OperatorDomain,
io::{OperatorInput, OperatorOutput},
kernel::{Kernel, KernelAttributes, KernelContext}
},
session::Session,
tensor::TensorElementType
tensor::TensorElementType,
value::Tensor
};

struct CustomOpOne;
Expand Down Expand Up @@ -78,7 +78,16 @@ fn main() -> ort::Result<()> {
.with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)?
.commit_from_file("tests/data/custom_op_test.onnx")?;

let values = session.run(ort::inputs![Array2::<f32>::zeros((3, 5)), Array2::<f32>::ones((3, 5))]?)?;
let allocator = session.allocator();
let value1 = Tensor::<f32>::new(allocator, [3, 5])?;
let mut value2 = Tensor::<f32>::new(allocator, [3, 5])?;
{
let (_, data) = value2.extract_raw_tensor_mut();
for datum in data {
*datum = 1.;
}
}
let values = session.run(ort::inputs![&value1, &value2])?;
println!("{:?}", values[0].try_extract_tensor::<i32>()?);

Ok(())
Expand Down
18 changes: 8 additions & 10 deletions examples/gpt2/examples/gpt2-no-ndarray.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::{
io::{self, Write},
path::Path,
sync::Arc
path::Path
};

use ort::{
execution_providers::CUDAExecutionProvider,
inputs,
session::{Session, builder::GraphOptimizationLevel}
session::{Session, builder::GraphOptimizationLevel},
value::TensorRef
};
use rand::Rng;
use tokenizers::Tokenizer;
Expand All @@ -33,7 +33,7 @@ fn main() -> ort::Result<()> {
.with_execution_providers([CUDAExecutionProvider::default().build()])
.commit()?;

let mut stdout = io::stdout();
let mut stdout: io::Stdout = io::stdout();
let mut rng = rand::thread_rng();

// Load our model
Expand All @@ -45,16 +45,16 @@ fn main() -> ort::Result<()> {
// Load the tokenizer and encode the prompt into a sequence of tokens.
let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap();
let tokens = tokenizer.encode(PROMPT, false).unwrap();
let mut tokens = Arc::new(tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>().into_boxed_slice());
let mut tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();

print!("{PROMPT}");
stdout.flush().unwrap();

for _ in 0..GEN_TOKENS {
// Raw tensor construction takes a tuple of (dimensions, data).
// The model expects our input to have shape [B, _, S]
let input = (vec![1, 1, tokens.len() as i64], Arc::clone(&tokens));
let outputs = session.run(inputs![input]?)?;
let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?;
let outputs = session.run(inputs![input])?;
let (dim, mut probabilities) = outputs["output1"].try_extract_raw_tensor()?;

// The output tensor will have shape [B, _, S + 1, V]
Expand All @@ -70,9 +70,7 @@ fn main() -> ort::Result<()> {
let token = probabilities[rng.gen_range(0..=TOP_K)].0 as i64;

// Add our generated token to the input sequence
let mut vec = tokens.to_vec();
vec.push(token);
*Arc::make_mut(&mut tokens) = vec.into_boxed_slice();
tokens.push(token);

let token_str = tokenizer.decode(&[token as u32], true).unwrap();
print!("{}", token_str);
Expand Down
5 changes: 3 additions & 2 deletions examples/gpt2/examples/gpt2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use ndarray::{Array1, ArrayViewD, Axis, array, concatenate, s};
use ort::{
execution_providers::CUDAExecutionProvider,
inputs,
session::{Session, builder::GraphOptimizationLevel}
session::{Session, builder::GraphOptimizationLevel},
value::TensorRef
};
use rand::Rng;
use tokenizers::Tokenizer;
Expand Down Expand Up @@ -54,7 +55,7 @@ fn main() -> ort::Result<()> {

for _ in 0..GEN_TOKENS {
let array = tokens.view().insert_axis(Axis(0)).insert_axis(Axis(1));
let outputs = session.run(inputs![array]?)?;
let outputs = session.run(inputs![TensorRef::from_array_view(array)?])?;
let generated_tokens: ArrayViewD<f32> = outputs["output1"].try_extract_tensor()?;

// Collect and sort logits
Expand Down
4 changes: 2 additions & 2 deletions examples/modnet/examples/modnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{ops::Mul, path::Path};

use image::{GenericImageView, ImageBuffer, Rgba, imageops::FilterType};
use ndarray::Array;
use ort::{execution_providers::CUDAExecutionProvider, inputs, session::Session};
use ort::{execution_providers::CUDAExecutionProvider, inputs, session::Session, value::TensorRef};
use show_image::{AsImageView, WindowOptions, event};

#[show_image::main]
Expand All @@ -31,7 +31,7 @@ fn main() -> ort::Result<()> {
input[[0, 2, y, x]] = (b as f32 - 127.5) / 127.5;
}

let outputs = model.run(inputs!["input" => input.view()]?)?;
let outputs = model.run(inputs!["input" => TensorRef::from_array_view(input.view())?])?;

let output = outputs["output"].try_extract_tensor::<f32>()?;

Expand Down
31 changes: 16 additions & 15 deletions examples/phi-3-vision/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use std::{path::Path, time::Instant};
use anyhow::Result;
use image::DynamicImage;
use ndarray::{Array, Array2, Array3, Array4, ArrayView, Ix3, Ix4, s};
use ort::{session::Session, value::Tensor};
use ort::{
session::Session,
value::{Tensor, TensorRef}
};
use tokenizers::Tokenizer;

const VISION_MODEL_NAME: &str = "phi-3-v-128k-instruct-vision.onnx";
Expand All @@ -31,11 +34,10 @@ fn get_image_embedding(vision_model: &Session, img: &Option<DynamicImage>) -> Re
pixel_values = result.pixel_values.shape(),
image_sizes = result.image_sizes.shape(),
);
let model_inputs = ort::inputs![
"pixel_values" => result.pixel_values,
"image_sizes" => result.image_sizes,
]?;
let outputs = vision_model.run(model_inputs)?;
let outputs = vision_model.run(ort::inputs![
"pixel_values" => Tensor::from_array(result.pixel_values)?,
"image_sizes" => Tensor::from_array(result.image_sizes)?,
])?;
let predictions_view: ArrayView<f32, _> = outputs["visual_features"].try_extract_tensor::<f32>()?;
predictions_view.into_dimensionality::<Ix3>()?.to_owned()
} else {
Expand All @@ -45,10 +47,9 @@ fn get_image_embedding(vision_model: &Session, img: &Option<DynamicImage>) -> Re
}

fn get_text_embedding(text_embedding_model: &Session, input_ids: &Array2<i64>) -> Result<Array3<f32>> {
let model_inputs = ort::inputs![
"input_ids" => input_ids.to_owned(),
]?;
let outputs = text_embedding_model.run(model_inputs)?;
let outputs = text_embedding_model.run(ort::inputs![
"input_ids" => TensorRef::from_array_view(input_ids)?,
])?;
let inputs_embeds_view: ArrayView<f32, _> = outputs["inputs_embeds"].try_extract_tensor::<f32>()?;
let inputs_embeds = inputs_embeds_view.into_dimensionality::<Ix3>()?.to_owned();
Ok(inputs_embeds)
Expand Down Expand Up @@ -144,12 +145,12 @@ pub async fn generate_text(
// Prepare model inputs
let model_inputs = {
let mut model_inputs = ort::inputs![
"inputs_embeds" => next_inputs_embeds.clone(),
"attention_mask" => attention_mask.clone(),
]?;
"inputs_embeds" => TensorRef::from_array_view(&next_inputs_embeds)?,
"attention_mask" => TensorRef::from_array_view(&attention_mask)?,
];
for i in 0..32 {
model_inputs.push((format!("past_key_values.{}.key", i).into(), Tensor::from_array(past_key_values[i * 2].view())?.into()));
model_inputs.push((format!("past_key_values.{}.value", i).into(), Tensor::from_array(past_key_values[i * 2 + 1].view())?.into()));
model_inputs.push((format!("past_key_values.{}.key", i).into(), TensorRef::from_array_view(&past_key_values[i * 2])?.into()));
model_inputs.push((format!("past_key_values.{}.value", i).into(), TensorRef::from_array_view(&past_key_values[i * 2 + 1])?.into()));
}
model_inputs
};
Expand Down
11 changes: 6 additions & 5 deletions examples/sentence-transformers/examples/semantic-similarity.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::path::Path;

use ndarray::{Array2, Axis, Ix2};
use ndarray::{Axis, Ix2};
use ort::{
Error,
execution_providers::CUDAExecutionProvider,
session::{Session, builder::GraphOptimizationLevel}
session::{Session, builder::GraphOptimizationLevel},
value::TensorRef
};
use tokenizers::Tokenizer;

Expand Down Expand Up @@ -45,11 +46,11 @@ fn main() -> ort::Result<()> {
let mask: Vec<i64> = encodings.iter().flat_map(|e| e.get_attention_mask().iter().map(|i| *i as i64)).collect();

// Convert our flattened arrays into 2-dimensional tensors of shape [N, L].
let a_ids = Array2::from_shape_vec([inputs.len(), padded_token_length], ids).unwrap();
let a_mask = Array2::from_shape_vec([inputs.len(), padded_token_length], mask).unwrap();
let a_ids = TensorRef::from_array_view(([inputs.len(), padded_token_length], &*ids))?;
let a_mask = TensorRef::from_array_view(([inputs.len(), padded_token_length], &*mask))?;

// Run the model.
let outputs = session.run(ort::inputs![a_ids, a_mask]?)?;
let outputs = session.run(ort::inputs![a_ids, a_mask])?;

// Extract our embeddings tensor and convert it to a strongly-typed 2-dimensional array.
let embeddings = outputs[1].try_extract_tensor::<f32>()?.into_dimensionality::<Ix2>().unwrap();
Expand Down
37 changes: 15 additions & 22 deletions examples/training/examples/train-clm-simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ use std::{
};

use kdam::BarExt;
use ndarray::{Array1, Array2, ArrayViewD, Axis, concatenate, s};
use ort::{
execution_providers::CUDAExecutionProvider,
memory::Allocator,
session::{Session, builder::SessionBuilder},
training::{CheckpointStrategy, Trainer, TrainerCallbacks, TrainerControl, TrainerState, TrainingArguments}
training::{CheckpointStrategy, Trainer, TrainerCallbacks, TrainerControl, TrainerState, TrainingArguments},
value::{Tensor, TensorRef}
};
use rand::RngCore;
use tokenizers::Tokenizer;
Expand Down Expand Up @@ -94,10 +94,10 @@ fn main() -> ort::Result<()> {
.unwrap();
}

Ok((
ort::inputs![Array2::<i64>::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap()]?,
ort::inputs![Array1::<i64>::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap()]?
))
let inputs = Tensor::from_array(([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect::<Vec<i64>>()))?;
let labels = Tensor::from_array(([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect::<Vec<i64>>()))?;

Ok((ort::inputs![inputs], ort::inputs![labels]))
};

trainer.train(
Expand All @@ -115,26 +115,19 @@ fn main() -> ort::Result<()> {
let mut stdout = std::io::stdout();

let tokens = tokenizer.encode("<|endoftext|>", false).unwrap();
let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();

let mut tokens = Array1::from_iter(tokens.iter().cloned());
let mut tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();

for _ in 0..50 {
let array = tokens.view().insert_axis(Axis(0));
let outputs = session.run(ort::inputs![array]?)?;
let generated_tokens: ArrayViewD<f32> = outputs["probs"].try_extract_tensor()?;

let probabilities = &mut generated_tokens
.slice(s![-1, ..])
.to_owned()
.iter()
.cloned()
.enumerate()
.collect::<Vec<_>>();
let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?;
let outputs = session.run(ort::inputs![input])?;
let (dim, probabilities) = outputs["probs"].try_extract_raw_tensor()?;

let (seq_len, vocab_size) = (dim[2] as usize, dim[3] as usize);
let mut probabilities: Vec<(usize, f32)> = probabilities[(seq_len - 1) * vocab_size..].iter().copied().enumerate().collect();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));

let token = probabilities[0].0;
tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]];
let token = probabilities[0].0 as i64;
tokens.push(token);

let token_str = tokenizer.decode(&[token as _], false).unwrap();
print!("{}", token_str);
Expand Down
Loading

0 comments on commit 9ea18d8

Please sign in to comment.