Skip to content

Commit

Permalink
feat(training): 100% API coverage, support input value maps
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Feb 23, 2025
1 parent d738b17 commit 59cc310
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 72 deletions.
4 changes: 2 additions & 2 deletions ort-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,9 +514,9 @@ pub struct OrtTrainingApi {
pub TrainingSessionGetTrainingModelOutputCount: unsafe extern "system" fn(sess: *const OrtTrainingSession, out: *mut usize) -> OrtStatusPtr,
pub TrainingSessionGetEvalModelOutputCount: unsafe extern "system" fn(sess: *const OrtTrainingSession, out: *mut usize) -> OrtStatusPtr,
pub TrainingSessionGetTrainingModelOutputName:
unsafe extern "system" fn(sess: *const OrtTrainingSession, index: usize, allocator: *mut OrtAllocator, output: *mut *mut c_char) -> OrtStatusPtr,
unsafe extern "system" fn(sess: *const OrtTrainingSession, index: usize, allocator: *mut OrtAllocator, output: *mut *const c_char) -> OrtStatusPtr,
pub TrainingSessionGetEvalModelOutputName:
unsafe extern "system" fn(sess: *const OrtTrainingSession, index: usize, allocator: *mut OrtAllocator, output: *mut *mut c_char) -> OrtStatusPtr,
unsafe extern "system" fn(sess: *const OrtTrainingSession, index: usize, allocator: *mut OrtAllocator, output: *mut *const c_char) -> OrtStatusPtr,
pub LazyResetGrad: unsafe extern "system" fn(session: *mut OrtTrainingSession) -> OrtStatusPtr,
pub TrainStep: unsafe extern "system" fn(
session: *mut OrtTrainingSession,
Expand Down
37 changes: 27 additions & 10 deletions src/training/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
memory::Allocator,
ortsys,
session::{NoSelectedOutputs, RunOptions},
value::DynTensor
value::{DynTensor, Value, ValueType, ValueTypeMarker, r#type::extract_data_type_from_tensor_info}
};

mod simple;
Expand All @@ -36,14 +36,14 @@ pub use self::{
/// May panic if:
/// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime.
/// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled.
pub fn training_api() -> Result<NonNull<ort_sys::OrtTrainingApi>> {
pub fn training_api() -> Result<&'static ort_sys::OrtTrainingApi> {
struct TrainingApiPointer(*const ort_sys::OrtTrainingApi);
unsafe impl Send for TrainingApiPointer {}
unsafe impl Sync for TrainingApiPointer {}

static TRAINING_API: OnceLock<TrainingApiPointer> = OnceLock::new();

NonNull::new(
let ptr = NonNull::new(
TRAINING_API
.get_or_init(|| {
let training_api = ortsys![unsafe GetTrainingApi(ort_sys::ORT_API_VERSION)];
Expand All @@ -52,7 +52,8 @@ pub fn training_api() -> Result<NonNull<ort_sys::OrtTrainingApi>> {
.0
.cast_mut()
)
.ok_or_else(|| Error::new("Training is not enbled in this build of ONNX Runtime."))
.ok_or_else(|| Error::new("Training is not enbled in this build of ONNX Runtime."))?;
Ok(unsafe { ptr.as_ref() })
}

/// Sets the seed used for RNG when training.
Expand All @@ -63,24 +64,24 @@ pub fn set_seed(seed: i64) -> Result<()> {

macro_rules! trainsys {
($method:ident) => {
($crate::training::training_api().unwrap().as_ref().$method)
($crate::training::training_api().unwrap().$method)
};
(unsafe $method:ident($($n:expr),+ $(,)?)) => {
unsafe { ($crate::training::training_api().unwrap().as_ref().$method)($($n),+) }
unsafe { ($crate::training::training_api().unwrap().$method)($($n),+) }
};
(unsafe $method:ident($($n:expr),+ $(,)?).expect($e:expr)) => {
unsafe { $crate::error::status_to_result(($crate::training::training_api().unwrap().as_ref().$method)($($n),+)) }.expect($e)
unsafe { $crate::error::status_to_result(($crate::training::training_api().unwrap().$method)($($n),+)) }.expect($e)
};
(unsafe $method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {{
let _x = unsafe { ($crate::training::training_api().unwrap().as_ref().$method)($($n),+) };
let _x = unsafe { ($crate::training::training_api().unwrap().$method)($($n),+) };
$($crate::error::assert_non_null_pointer($check, stringify!($method)).unwrap();)+
_x
}};
(unsafe $method:ident($($n:expr),+ $(,)?)?) => {
unsafe { $crate::error::status_to_result(($crate::training::training_api()?.as_ref().$method)($($n),+)) }?
unsafe { $crate::error::status_to_result(($crate::training::training_api()?.$method)($($n),+)) }?
};
(unsafe $method:ident($($n:expr),+ $(,)?)?; nonNull($($check:expr),+ $(,)?)$(;)?) => {{
unsafe { $crate::error::status_to_result(($crate::training::training_api()?.as_ref().$method)($($n),+)) }?;
unsafe { $crate::error::status_to_result(($crate::training::training_api()?.$method)($($n),+)) }?;
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
}};
}
Expand Down Expand Up @@ -165,6 +166,22 @@ impl Checkpoint {
trainsys![unsafe GetParameter(self.ptr.as_ptr(), name.as_ptr(), allocator.ptr().cast_mut(), &mut value_ptr)?; nonNull(value_ptr)];
Ok(unsafe { DynTensor::from_ptr(NonNull::new_unchecked(value_ptr), None) })
}

pub fn update_parameter<T: ValueTypeMarker>(&mut self, name: impl AsRef<str>, value: &Value<T>) -> Result<()> {
let name = CString::new(name.as_ref())?;
trainsys![unsafe UpdateParameter(self.ptr.as_ptr(), name.as_ptr(), value.ptr().cast_mut())?];
Ok(())
}

pub fn get_parameter_type(&self, name: impl AsRef<str>) -> Result<ValueType> {
let name = CString::new(name.as_ref())?;

let mut shape_info = ptr::null_mut();
trainsys![unsafe GetParameterTypeAndShape(self.ptr.as_ptr(), name.as_ptr(), &mut shape_info)?; nonNull(shape_info)];
let value_type = unsafe { extract_data_type_from_tensor_info(shape_info) };
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(shape_info)];
Ok(value_type)
}
}

#[derive(Debug, Clone, PartialEq)]
Expand Down
194 changes: 134 additions & 60 deletions src/training/trainer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use alloc::ffi::CString;
use core::ptr::{self, NonNull};
use alloc::{borrow::Cow, ffi::CString};
use core::{
fmt,
ptr::{self, NonNull}
};
use std::path::Path;

use ort_sys::c_char;
Expand All @@ -10,14 +13,17 @@ use crate::{
error::{Result, assert_non_null_pointer, status_to_result},
memory::Allocator,
session::{RunOptions, SessionInputValue, SessionInputs, SessionOutputs, builder::SessionBuilder},
value::Value
tensor::IntoTensorElementType,
value::{Tensor, Value}
};

#[derive(Debug)]
pub struct Trainer {
ptr: NonNull<ort_sys::OrtTrainingSession>,
train_output_names: Vec<String>,
eval_output_names: Vec<String>,
train_input_names: Vec<String>,
eval_input_names: Vec<String>,
ckpt: Checkpoint,
_allocator: Allocator
}
Expand Down Expand Up @@ -98,47 +104,23 @@ impl Trainer {
}

fn new_inner(ptr: NonNull<ort_sys::OrtTrainingSession>, allocator: Allocator, ckpt: Checkpoint) -> Result<Self> {
let mut train_output_len = 0;
trainsys![unsafe TrainingSessionGetTrainingModelOutputCount(ptr.as_ptr(), &mut train_output_len)?];
let train_output_names = (0..train_output_len)
.map(|i| {
let mut name_bytes: *mut c_char = std::ptr::null_mut();
trainsys![unsafe TrainingSessionGetTrainingModelOutputName(ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_bytes)?];
let name = match char_p_to_string(name_bytes) {
Ok(name) => name,
Err(e) => {
unsafe { allocator.free(name_bytes) };
return Err(e);
}
};
unsafe { allocator.free(name_bytes) };
Ok(name)
})
.collect::<Result<Vec<String>>>()?;

let mut eval_output_len = 0;
trainsys![unsafe TrainingSessionGetEvalModelOutputCount(ptr.as_ptr(), &mut eval_output_len)?];
let eval_output_names = (0..eval_output_len)
.map(|i| {
let mut name_bytes: *mut c_char = std::ptr::null_mut();
trainsys![unsafe TrainingSessionGetEvalModelOutputName(ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_bytes)?];
let name = match char_p_to_string(name_bytes) {
Ok(name) => name,
Err(e) => {
unsafe { allocator.free(name_bytes) };
return Err(e);
}
};
unsafe { allocator.free(name_bytes) };
Ok(name)
})
.collect::<Result<Vec<String>>>()?;
let train_output_names =
extract_io_names(ptr, &allocator, trainsys![TrainingSessionGetTrainingModelOutputCount], trainsys![TrainingSessionGetTrainingModelOutputName])?;
let eval_output_names =
extract_io_names(ptr, &allocator, trainsys![TrainingSessionGetEvalModelOutputCount], trainsys![TrainingSessionGetEvalModelOutputName])?;

let train_input_names =
extract_io_names(ptr, &allocator, trainsys![TrainingSessionGetTrainingModelInputCount], trainsys![TrainingSessionGetTrainingModelInputName])?;
let eval_input_names =
extract_io_names(ptr, &allocator, trainsys![TrainingSessionGetEvalModelInputCount], trainsys![TrainingSessionGetEvalModelInputName])?;

Ok(Self {
ptr,
_allocator: allocator,
train_output_names,
train_input_names,
eval_output_names,
eval_input_names,
ckpt
})
}
Expand All @@ -150,29 +132,45 @@ impl Trainer {
) -> Result<SessionOutputs<'s, 's>> {
match inputs.into() {
SessionInputs::ValueSlice(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None),
SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None),
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels).map(Some), None),
SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()).map(Some), None),
SessionInputs::ValueMap(labels) => {
let labels = mapped_inputs(&self.train_input_names, &labels);
self.step_inner(input_values.iter().map(Some).chain(labels), None)
}
},
SessionInputs::ValueArray(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None),
SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None),
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels).map(Some), None),
SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()).map(Some), None),
SessionInputs::ValueMap(labels) => {
let labels = mapped_inputs(&self.train_input_names, &labels);
self.step_inner(input_values.iter().map(Some).chain(labels), None)
}
},
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
SessionInputs::ValueMap(input_values) => {
let input_values = mapped_inputs(&self.train_input_names, &input_values);
match labels.into() {
SessionInputs::ValueSlice(labels) => self.step_inner(input_values.into_iter().chain(labels.iter().map(Some)), None),
SessionInputs::ValueArray(labels) => self.step_inner(input_values.into_iter().chain(labels.iter().map(Some)), None),
SessionInputs::ValueMap(labels) => {
let labels = mapped_inputs(&self.train_input_names, &labels);
self.step_inner(input_values.into_iter().chain(labels), None)
}
}
}
}
}

fn step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>(
&'s self,
input_values: impl Iterator<Item = &'i1 SessionInputValue<'v1>>,
input_values: impl Iterator<Item = Option<&'i1 SessionInputValue<'v1>>>,
run_options: Option<&'r RunOptions>
) -> Result<SessionOutputs<'r, 's>> {
let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()];
let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![ptr::null_mut(); self.train_output_names.len()];

let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr()).collect();
let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|v| v.map_or(ptr::null(), |v| v.ptr())).collect();

let run_options_ptr = if let Some(run_options) = &run_options { run_options.ptr() } else { std::ptr::null() };
let run_options_ptr = if let Some(run_options) = &run_options { run_options.ptr() } else { ptr::null() };

trainsys![unsafe TrainStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr())?];

Expand All @@ -195,29 +193,45 @@ impl Trainer {
) -> Result<SessionOutputs<'s, 's>> {
match inputs.into() {
SessionInputs::ValueSlice(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None),
SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()), None),
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels).map(Some), None),
SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()).map(Some), None),
SessionInputs::ValueMap(labels) => {
let labels = mapped_inputs(&self.eval_input_names, &labels);
self.eval_step_inner(input_values.iter().map(Some).chain(labels), None)
}
},
SessionInputs::ValueArray(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None),
SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()), None),
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels).map(Some), None),
SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()).map(Some), None),
SessionInputs::ValueMap(labels) => {
let labels = mapped_inputs(&self.eval_input_names, &labels);
self.eval_step_inner(input_values.iter().map(Some).chain(labels), None)
}
},
SessionInputs::ValueMap(_) => unimplemented!("named values not supported?")
SessionInputs::ValueMap(input_values) => {
let input_values = mapped_inputs(&self.eval_input_names, &input_values);
match labels.into() {
SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.into_iter().chain(labels.iter().map(Some)), None),
SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.into_iter().chain(labels.iter().map(Some)), None),
SessionInputs::ValueMap(labels) => {
let labels = mapped_inputs(&self.eval_input_names, &labels);
self.eval_step_inner(input_values.into_iter().chain(labels), None)
}
}
}
}
}

fn eval_step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>(
&'s self,
input_values: impl Iterator<Item = &'i1 SessionInputValue<'v1>>,
input_values: impl Iterator<Item = Option<&'i1 SessionInputValue<'v1>>>,
run_options: Option<&'r RunOptions>
) -> Result<SessionOutputs<'r, 's>> {
let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.eval_output_names.len()];
let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![ptr::null_mut(); self.eval_output_names.len()];

let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr()).collect();
let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|v| v.map_or(ptr::null(), |v| v.ptr())).collect();

let run_options_ptr = if let Some(run_options) = &run_options { run_options.ptr() } else { std::ptr::null() };
let run_options_ptr = if let Some(run_options) = &run_options { run_options.ptr() } else { ptr::null() };

trainsys![unsafe EvalStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr())?];

Expand Down Expand Up @@ -261,6 +275,22 @@ impl Trainer {
Ok(())
}

pub fn num_params(&self, trainable_only: bool) -> Result<usize> {
let mut out = 0;
trainsys![unsafe GetParametersSize(self.ptr.as_ptr(), &mut out, trainable_only)?];
Ok(out)
}

pub fn copy_parameters_to<T: IntoTensorElementType + fmt::Debug>(&self, value: &mut Tensor<T>, trainable_only: bool) -> Result<()> {
trainsys![unsafe CopyParametersToBuffer(self.ptr.as_ptr(), value.ptr_mut(), trainable_only)?];
Ok(())
}

pub fn copy_parameters_from<T: IntoTensorElementType + fmt::Debug>(&mut self, value: &Tensor<T>, trainable_only: bool) -> Result<()> {
trainsys![unsafe CopyBufferToParameters(self.ptr.as_ptr(), value.ptr().cast_mut(), trainable_only)?];
Ok(())
}

pub fn optimizer(&self) -> Optimizer<'_> {
Optimizer::new(self.ptr)
}
Expand All @@ -284,3 +314,47 @@ impl Drop for Trainer {
trainsys![unsafe ReleaseTrainingSession(self.ptr.as_ptr())];
}
}

fn mapped_inputs<'v, 'a>(input_names: &[String], values: &'a [(Cow<'_, str>, SessionInputValue<'v>)]) -> Vec<Option<&'a SessionInputValue<'v>>> {
let mut out = Vec::with_capacity(input_names.len());
'o: for want_name in input_names {
for (name, value) in values {
if want_name == name {
out.push(Some(value));
continue 'o;
}
}
out.push(None);
}
out
}

fn extract_io_names(
ptr: NonNull<ort_sys::OrtTrainingSession>,
allocator: &Allocator,
get_count: unsafe extern "system" fn(sess: *const ort_sys::OrtTrainingSession, out: *mut usize) -> ort_sys::OrtStatusPtr,
get_name: unsafe extern "system" fn(
sess: *const ort_sys::OrtTrainingSession,
index: usize,
allocator: *mut ort_sys::OrtAllocator,
output: *mut *const c_char
) -> ort_sys::OrtStatusPtr
) -> Result<Vec<String>> {
let mut count = 0;
unsafe { status_to_result(get_count(ptr.as_ptr(), &mut count)) }?;
(0..count)
.map(|i| {
let mut name_bytes: *const c_char = ptr::null();
unsafe { status_to_result(get_name(ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_bytes)) }?;
let name = match char_p_to_string(name_bytes) {
Ok(name) => name,
Err(e) => {
unsafe { allocator.free(name_bytes.cast_mut()) };
return Err(e);
}
};
unsafe { allocator.free(name_bytes.cast_mut()) };
Ok(name)
})
.collect::<Result<Vec<String>>>()
}

0 comments on commit 59cc310

Please sign in to comment.