Skip to content

Commit

Permalink
refactor: reduce the size of generated code
Browse files Browse the repository at this point in the history
small breaking change is that `run_async` now requires `RunOptions`
  • Loading branch information
decahedron1 committed Mar 7, 2025
1 parent e46689f commit 4d45331
Show file tree
Hide file tree
Showing 22 changed files with 532 additions and 718 deletions.
5 changes: 3 additions & 2 deletions examples/async-gpt2-api/examples/async-gpt2-api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use axum::{
use futures::Stream;
use ort::{
execution_providers::CUDAExecutionProvider,
session::{Session, builder::GraphOptimizationLevel},
session::{RunOptions, Session, builder::GraphOptimizationLevel},
value::TensorRef
};
use rand::Rng;
Expand Down Expand Up @@ -74,7 +74,8 @@ fn generate_stream(
let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?;
let probabilities = {
let mut session = session.lock().await;
let outputs = session.run_async(ort::inputs![input])?.await?;
let options = RunOptions::new()?;
let outputs = session.run_async(ort::inputs![input], &options)?.await?;
let (dim, probabilities) = outputs["output1"].try_extract_raw_tensor()?;

// Collect and sort logits
Expand Down
4 changes: 2 additions & 2 deletions src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ mod tests {
};

#[test]
#[cfg(feature = "std")]
fn test_lora() -> crate::Result<()> {
let model = std::fs::read("tests/data/lora_model.onnx").expect("");
let mut session = Session::builder()?.commit_from_memory(&model)?;
let lora = std::fs::read("tests/data/adapter.orl").expect("");
let lora = Adapter::from_memory(&lora, None)?;
let lora = Adapter::from_file("tests/data/adapter.orl", None)?;

let mut run_options = RunOptions::new()?;
run_options.add_adapter(&lora)?;
Expand Down
6 changes: 0 additions & 6 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,6 @@ impl From<ErrorCode> for ort_sys::OrtErrorCode {
}
}

pub(crate) fn assert_non_null_pointer<T>(ptr: *const T, name: &'static str) -> Result<()> {
(!ptr.is_null())
.then_some(())
.ok_or_else(|| Error::new(format!("Expected pointer `{name}` to not be null")))
}

/// Converts an [`ort_sys::OrtStatus`] to a [`Result`].
///
/// Note that this frees `status`!
Expand Down
32 changes: 27 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub mod tensor;
#[cfg(feature = "training")]
#[cfg_attr(docsrs, doc(cfg(feature = "training")))]
pub mod training;
pub(crate) mod util;
pub mod util;
pub mod value;

#[cfg(feature = "load-dynamic")]
Expand Down Expand Up @@ -218,17 +218,39 @@ macro_rules! ortsys {
(unsafe $method:ident($($n:expr),+ $(,)?).expect($e:expr)) => {
unsafe { $crate::error::status_to_result(($crate::api().$method)($($n),+)) }.expect($e)
};
(unsafe $method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {{
(unsafe $method:ident($($n:expr),+ $(,)?).expect($e:expr); nonNull($($check:ident),+ $(,)?)$(;)?) => {{
unsafe { $crate::error::status_to_result(($crate::api().$method)($($n),+)) }.expect($e);
$(
// TODO: #[cfg(debug_assertions)]?
if ($check).is_null() {
$crate::util::cold();
panic!(concat!("expected `", stringify!($check), "` to not be null"));
}
)+
}};
(unsafe $method:ident($($n:expr),+ $(,)?); nonNull($($check:ident),+ $(,)?)$(;)?) => {{
let _x = unsafe { ($crate::api().$method)($($n),+) };
$($crate::error::assert_non_null_pointer($check, stringify!($method)).unwrap();)+
$(
// TODO: #[cfg(debug_assertions)]?
if ($check).is_null() {
$crate::util::cold();
panic!(concat!("expected `", stringify!($check), "` to not be null"));
}
)+
_x
}};
(unsafe $method:ident($($n:expr),+ $(,)?)?) => {
unsafe { $crate::error::status_to_result(($crate::api().$method)($($n),+)) }?;
};
(unsafe $method:ident($($n:expr),+ $(,)?)?; nonNull($($check:expr),+ $(,)?)$(;)?) => {{
(unsafe $method:ident($($n:expr),+ $(,)?)?; nonNull($($check:ident),+ $(,)?)$(;)?) => {{
unsafe { $crate::error::status_to_result(($crate::api().$method)($($n),+)) }?;
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
$(
// TODO: #[cfg(debug_assertions)]?
if ($check).is_null() {
$crate::util::cold();
return Err($crate::Error::new(concat!("expected `", stringify!($check), "` to not be null")));
}
)+
}};
}

Expand Down
6 changes: 6 additions & 0 deletions src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,12 @@ impl MemoryInfo {
}
}

impl Default for MemoryInfo {
fn default() -> Self {
MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default).expect("failed to create default memory info")
}
}

impl Clone for MemoryInfo {
fn clone(&self) -> Self {
MemoryInfo::new(self.allocation_device(), self.device_id(), self.allocator_type(), self.memory_type()).expect("failed to clone memory info")
Expand Down
55 changes: 14 additions & 41 deletions src/session/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use core::{
ffi::{c_char, c_void},
future::Future,
marker::PhantomData,
ops::Deref,
pin::Pin,
ptr::NonNull,
task::{Context, Poll, Waker}
Expand All @@ -13,8 +12,8 @@ use std::sync::Mutex;

use crate::{
error::Result,
session::{RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner},
value::Value
session::{SessionOutputs, SharedSessionInner, run_options::UntypedRunOptions},
value::{Value, ValueInner}
};

#[derive(Debug)]
Expand Down Expand Up @@ -53,43 +52,17 @@ impl<'r, 's> InferenceFutInner<'r, 's> {
unsafe impl Send for InferenceFutInner<'_, '_> {}
unsafe impl Sync for InferenceFutInner<'_, '_> {}

pub enum RunOptionsRef<'r, O: SelectedOutputMarker> {
Arc(Arc<RunOptions<O>>),
Ref(&'r RunOptions<O>)
}

impl<O: SelectedOutputMarker> From<&Arc<RunOptions<O>>> for RunOptionsRef<'_, O> {
fn from(value: &Arc<RunOptions<O>>) -> Self {
Self::Arc(Arc::clone(value))
}
}

impl<'r, O: SelectedOutputMarker> From<&'r RunOptions<O>> for RunOptionsRef<'r, O> {
fn from(value: &'r RunOptions<O>) -> Self {
Self::Ref(value)
}
}

impl<O: SelectedOutputMarker> Deref for RunOptionsRef<'_, O> {
type Target = RunOptions<O>;

fn deref(&self) -> &Self::Target {
match self {
Self::Arc(r) => r,
Self::Ref(r) => r
}
}
}

pub struct InferenceFut<'s, 'r, 'v, O: SelectedOutputMarker> {
pub struct InferenceFut<'s, 'r, 'v> {
inner: Arc<InferenceFutInner<'r, 's>>,
run_options: RunOptionsRef<'r, O>,
run_options: &'r UntypedRunOptions,
did_receive: bool,
_inputs: PhantomData<&'v ()>
}

impl<'s, 'r, O: SelectedOutputMarker> InferenceFut<'s, 'r, '_, O> {
pub(crate) fn new(inner: Arc<InferenceFutInner<'r, 's>>, run_options: RunOptionsRef<'r, O>) -> Self {
unsafe impl Send for InferenceFut<'_, '_, '_> {}

impl<'s, 'r> InferenceFut<'s, 'r, '_> {
pub(crate) fn new(inner: Arc<InferenceFutInner<'r, 's>>, run_options: &'r UntypedRunOptions) -> Self {
Self {
inner,
run_options,
Expand All @@ -99,7 +72,7 @@ impl<'s, 'r, O: SelectedOutputMarker> InferenceFut<'s, 'r, '_, O> {
}
}

impl<'s, 'r, O: SelectedOutputMarker> Future for InferenceFut<'s, 'r, '_, O> {
impl<'s, 'r> Future for InferenceFut<'s, 'r, '_> {
type Output = Result<SessionOutputs<'r, 's>>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Expand All @@ -115,7 +88,7 @@ impl<'s, 'r, O: SelectedOutputMarker> Future for InferenceFut<'s, 'r, '_, O> {
}
}

impl<O: SelectedOutputMarker> Drop for InferenceFut<'_, '_, '_, O> {
impl Drop for InferenceFut<'_, '_, '_> {
fn drop(&mut self) {
if !self.did_receive {
let _ = self.run_options.terminate();
Expand All @@ -124,19 +97,19 @@ impl<O: SelectedOutputMarker> Drop for InferenceFut<'_, '_, '_, O> {
}
}

pub(crate) struct AsyncInferenceContext<'r, 's, 'v> {
pub(crate) struct AsyncInferenceContext<'r, 's> {
pub(crate) inner: Arc<InferenceFutInner<'r, 's>>,
pub(crate) _input_values: Vec<SessionInputValue<'v>>,
pub(crate) input_ort_values: Vec<*const ort_sys::OrtValue>,
pub(crate) _input_inner_holders: Vec<Arc<ValueInner>>,
pub(crate) input_name_ptrs: Vec<*const c_char>,
pub(crate) output_name_ptrs: Vec<*const c_char>,
pub(crate) session_inner: &'s Arc<SharedSessionInner>,
pub(crate) output_names: Vec<&'s str>,
pub(crate) output_names: Vec<&'r str>,
pub(crate) output_value_ptrs: Vec<*mut ort_sys::OrtValue>
}

pub(crate) extern "system" fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: usize, status: ort_sys::OrtStatusPtr) {
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_, '_, '_>>()) };
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_, '_>>()) };

// Reconvert name ptrs to CString so drop impl is called and memory is freed
for p in ctx.input_name_ptrs {
Expand Down
32 changes: 20 additions & 12 deletions src/session/builder/impl_commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use core::{
};
#[cfg(feature = "std")]
use std::path::Path;
#[cfg(feature = "fetch-models")]
use std::path::PathBuf;

use super::SessionBuilder;
#[cfg(feature = "std")]
Expand All @@ -28,22 +30,27 @@ impl SessionBuilder {
#[cfg(all(feature = "fetch-models", feature = "std"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "fetch-models", feature = "std"))))]
pub fn commit_from_url(self, model_url: impl AsRef<str>) -> Result<Session> {
let downloaded_path = SessionBuilder::download(model_url.as_ref())?;
self.commit_from_file(downloaded_path)
}

#[cfg(all(feature = "fetch-models", feature = "std"))]
fn download(url: &str) -> Result<PathBuf> {
let mut download_dir = ort_sys::internal::dirs::cache_dir()
.expect("could not determine cache directory")
.join("models");
if std::fs::create_dir_all(&download_dir).is_err() {
download_dir = std::env::current_dir().expect("Failed to obtain current working directory");
}

let url = model_url.as_ref();
let model_filename = <sha2::Sha256 as sha2::Digest>::digest(url).into_iter().fold(String::new(), |mut s, b| {
let _ = write!(&mut s, "{:02x}", b);
s
});
let model_filepath = download_dir.join(&model_filename);
let downloaded_path = if model_filepath.exists() {
if model_filepath.exists() {
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
model_filepath
Ok(model_filepath)
} else {
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model");

Expand Down Expand Up @@ -71,29 +78,31 @@ impl SessionBuilder {
drop(writer);

match std::fs::rename(&temp_filepath, &model_filepath) {
Ok(()) => model_filepath,
Ok(()) => Ok(model_filepath),
Err(e) => {
if model_filepath.exists() {
let _ = std::fs::remove_file(temp_filepath);
model_filepath
Ok(model_filepath)
} else {
return Err(Error::new(format!("Failed to download model: {e}")));
Err(Error::new(format!("Failed to download model: {e}")))
}
}
}
};

self.commit_from_file(downloaded_path)
}
}

/// Loads an ONNX model from a file and builds the session.
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub fn commit_from_file<P>(mut self, model_filepath_ref: P) -> Result<Session>
pub fn commit_from_file<P>(self, model_filepath: P) -> Result<Session>
where
P: AsRef<Path>
{
let model_filepath = model_filepath_ref.as_ref();
self.commit_from_file_inner(model_filepath.as_ref())
}

#[cfg(feature = "std")]
fn commit_from_file_inner(mut self, model_filepath: &Path) -> Result<Session> {
if !model_filepath.exists() {
return Err(Error::new_with_code(ErrorCode::NoSuchFile, format!("File at `{}` does not exist", model_filepath.display())));
}
Expand Down Expand Up @@ -166,7 +175,6 @@ impl SessionBuilder {
self.add_config_entry("session.use_ort_model_bytes_for_initializers", "1")?;

let session = self.commit_from_memory(model_bytes)?;

Ok(InMemorySession { session, phantom: PhantomData })
}

Expand Down
16 changes: 6 additions & 10 deletions src/session/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,7 @@ use core::{
ptr::{self, NonNull}
};

use crate::{
AsPointer,
error::{Result, assert_non_null_pointer},
memory::MemoryInfo,
operator::OperatorDomain,
ortsys,
value::DynValue
};
use crate::{AsPointer, error::Result, memory::MemoryInfo, operator::OperatorDomain, ortsys, value::DynValue};

mod impl_commit;
mod impl_config_keys;
Expand Down Expand Up @@ -51,8 +44,11 @@ pub struct SessionBuilder {
impl Clone for SessionBuilder {
fn clone(&self) -> Self {
let mut session_options_ptr = ptr::null_mut();
ortsys![unsafe CloneSessionOptions(self.ptr(), ptr::addr_of_mut!(session_options_ptr)).expect("error cloning session options")];
assert_non_null_pointer(session_options_ptr, "OrtSessionOptions").expect("Cloned session option pointer is null");
ortsys![
unsafe CloneSessionOptions(self.ptr(), ptr::addr_of_mut!(session_options_ptr))
.expect("error cloning session options");
nonNull(session_options_ptr)
];
Self {
session_options_ptr: unsafe { NonNull::new_unchecked(session_options_ptr) },
memory_info: self.memory_info.clone(),
Expand Down
14 changes: 12 additions & 2 deletions src/session/input.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
use alloc::{borrow::Cow, vec::Vec};
use alloc::{borrow::Cow, sync::Arc, vec::Vec};
use core::ops::Deref;

use crate::value::{DynValueTypeMarker, Value, ValueRef, ValueRefMut, ValueTypeMarker};
use crate::value::{DynValueTypeMarker, Value, ValueInner, ValueRef, ValueRefMut, ValueTypeMarker};

pub enum SessionInputValue<'v> {
ViewMut(ValueRefMut<'v, DynValueTypeMarker>),
View(ValueRef<'v, DynValueTypeMarker>),
Owned(Value<DynValueTypeMarker>)
}

impl SessionInputValue<'_> {
pub(crate) fn inner(&self) -> &Arc<ValueInner> {
match self {
Self::ViewMut(v) => v.inner(),
Self::View(v) => v.inner(),
Self::Owned(v) => v.inner()
}
}
}

impl Deref for SessionInputValue<'_> {
type Target = Value;

Expand Down
Loading

0 comments on commit 4d45331

Please sign in to comment.