Skip to content

Commit

Permalink
feat(training): improve training API coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Feb 22, 2025
1 parent eff10de commit 8aab523
Show file tree
Hide file tree
Showing 8 changed files with 356 additions and 54 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"rust-analyzer.diagnostics.experimental.enable": true,
"rust-analyzer.showUnlinkedFileNotification": false,
"deno.enablePaths": [
"tools/api-coverage.ts"
"tools/api-coverage.ts",
"tools/training-api-coverage.ts"
]
}
4 changes: 3 additions & 1 deletion examples/training/examples/train-clm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ fn main() -> ort::Result<()> {
)
.unwrap();

let optimizer = trainer.optimizer();
let mut optimizer = trainer.optimizer();
optimizer.set_lr(7e-5)?;

let mut dataset = File::open("dataset.bin").unwrap();
Expand Down Expand Up @@ -93,6 +93,8 @@ fn main() -> ort::Result<()> {
if loss.is_nan() {
return Ok(());
}

let mut optimizer = trainer.optimizer();
optimizer.step()?;
optimizer.reset_grad()?;
}
Expand Down
35 changes: 35 additions & 0 deletions ort-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,41 @@ pub struct OrtTrainingApi {
inference_model_path: *const ortchar,
graph_outputs_len: usize,
graph_output_names: *const *const c_char
) -> OrtStatusPtr,
pub SetSeed: unsafe extern "system" fn(seed: i64) -> OrtStatusPtr,
pub TrainingSessionGetTrainingModelInputCount: unsafe extern "system" fn(session: *const OrtTrainingSession, out: *mut usize) -> OrtStatusPtr,
pub TrainingSessionGetEvalModelInputCount: unsafe extern "system" fn(session: *const OrtTrainingSession, out: *mut usize) -> OrtStatusPtr,
pub TrainingSessionGetTrainingModelInputName:
unsafe extern "system" fn(session: *const OrtTrainingSession, index: usize, allocator: *mut OrtAllocator, output: *mut *const c_char) -> OrtStatusPtr,
pub TrainingSessionGetEvalModelInputName:
unsafe extern "system" fn(session: *const OrtTrainingSession, index: usize, allocator: *mut OrtAllocator, output: *mut *const c_char) -> OrtStatusPtr,
pub AddProperty: unsafe extern "system" fn(
checkpoint_state: *mut OrtCheckpointState,
property_name: *const c_char,
property_type: OrtPropertyType,
property_value: *const ()
) -> OrtStatusPtr,
pub GetProperty: unsafe extern "system" fn(
checkpoint_state: *mut OrtCheckpointState,
property_name: *const c_char,
allocator: *mut OrtAllocator,
property_type: *mut OrtPropertyType,
property_value: *mut *const ()
) -> OrtStatusPtr,
pub LoadCheckpointFromBuffer:
unsafe extern "system" fn(checkpoint_buffer: *const (), num_bytes: usize, checkpoint_state: *mut *mut OrtCheckpointState) -> OrtStatusPtr,
pub GetParameterTypeAndShape: unsafe extern "system" fn(
checkpoint_state: *const OrtCheckpointState,
parameter_name: *const c_char,
parameter_type_and_shape: *mut *mut OrtTensorTypeAndShapeInfo
) -> OrtStatusPtr,
pub UpdateParameter:
unsafe extern "system" fn(checkpoint_state: *mut OrtCheckpointState, parameter_name: *const c_char, parameter: *mut OrtValue) -> OrtStatusPtr,
pub GetParameter: unsafe extern "system" fn(
checkpoint_state: *const OrtCheckpointState,
parameter_name: *const c_char,
allocator: *mut OrtAllocator,
parameter: *mut *mut OrtValue
) -> OrtStatusPtr
}
#[doc = " \\brief The helper interface to get the right version of OrtApi\n\n Get a pointer to this structure through ::OrtGetApiBase"]
Expand Down
173 changes: 156 additions & 17 deletions src/training/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
//! Provides [`Trainer`], a simple interface for on-device training/fine-tuning.
use std::{
path::Path,
ptr::{self, NonNull},
sync::OnceLock
use alloc::{
ffi::CString,
string::{String, ToString}
};
use core::{
ffi::{CStr, c_char},
marker::PhantomData,
ptr::{self, NonNull}
};
use std::{path::Path, sync::OnceLock};

use crate::{AsPointer, Error, Result, ortsys, session::RunOptions};
use crate::{
AsPointer, Error, Result,
memory::Allocator,
ortsys,
session::{NoSelectedOutputs, RunOptions},
value::DynTensor
};

mod simple;
mod trainer;
Expand Down Expand Up @@ -44,6 +55,12 @@ pub fn training_api() -> Result<NonNull<ort_sys::OrtTrainingApi>> {
.ok_or_else(|| Error::new("Training is not enbled in this build of ONNX Runtime."))
}

/// Sets the seed used for RNG when training.
pub fn set_seed(seed: i64) -> Result<()> {
trainsys![unsafe SetSeed(seed)?];
Ok(())
}

macro_rules! trainsys {
($method:ident) => {
($crate::training::training_api().unwrap().as_ref().$method)
Expand All @@ -60,7 +77,7 @@ macro_rules! trainsys {
_x
}};
(unsafe $method:ident($($n:expr),+ $(,)?)?) => {
unsafe { $crate::error::status_to_result(($crate::training::training_api()?.as_ref().$method)($($n),+)) }?;
unsafe { $crate::error::status_to_result(($crate::training::training_api()?.as_ref().$method)($($n),+)) }?
};
(unsafe $method:ident($($n:expr),+ $(,)?)?; nonNull($($check:expr),+ $(,)?)$(;)?) => {{
unsafe { $crate::error::status_to_result(($crate::training::training_api()?.as_ref().$method)($($n),+)) }?;
Expand All @@ -84,11 +101,98 @@ impl Checkpoint {
})
}

pub fn load_from_buffer(buffer: &[u8]) -> Result<Self> {
let mut ptr: *mut ort_sys::OrtCheckpointState = ptr::null_mut();
trainsys![unsafe LoadCheckpointFromBuffer(buffer.as_ptr().cast(), buffer.len(), &mut ptr)?; nonNull(ptr)];
Ok(Checkpoint {
ptr: unsafe { NonNull::new_unchecked(ptr) }
})
}

pub fn save(&self, path: impl AsRef<Path>, include_optimizer_state: bool) -> Result<()> {
let path = crate::util::path_to_os_char(path);
trainsys![unsafe SaveCheckpoint(self.ptr.as_ptr(), path.as_ptr(), include_optimizer_state)?];
Ok(())
}

pub fn add_property(&mut self, name: impl AsRef<str>, property: impl Into<Property>) -> Result<()> {
let name = CString::new(name.as_ref())?;
match property.into() {
Property::Int(value) => {
trainsys![unsafe AddProperty(self.ptr.as_ptr(), name.as_ptr(), ort_sys::OrtPropertyType::OrtIntProperty, (&value as *const i64).cast())?]
}
Property::Float(value) => {
trainsys![unsafe AddProperty(self.ptr.as_ptr(), name.as_ptr(), ort_sys::OrtPropertyType::OrtFloatProperty, (&value as *const f32).cast())?]
}
Property::String(value) => {
let value = CString::new(value)?;
trainsys![unsafe AddProperty(self.ptr.as_ptr(), name.as_ptr(), ort_sys::OrtPropertyType::OrtStringProperty, value.as_ptr().cast())?]
}
}
Ok(())
}

pub fn get_property(&self, name: impl AsRef<str>) -> Option<Property> {
let name = CString::new(name.as_ref()).ok()?;
let mut allocator = Allocator::default();
let mut property_type: ort_sys::OrtPropertyType = ort_sys::OrtPropertyType::OrtIntProperty;
let mut property_value: *const () = ptr::null();

let status = trainsys![unsafe GetProperty(
self.ptr.as_ptr(),
name.as_ptr(),
allocator.ptr_mut(),
&mut property_type,
&mut property_value
)];
unsafe { crate::error::status_to_result(status) }.ok()?;

Some(match property_type {
ort_sys::OrtPropertyType::OrtIntProperty => Property::Int(unsafe { *property_value.cast::<i64>() }),
ort_sys::OrtPropertyType::OrtFloatProperty => Property::Float(unsafe { *property_value.cast::<f32>() }),
ort_sys::OrtPropertyType::OrtStringProperty => {
let value = unsafe { CStr::from_ptr(property_value.cast::<c_char>()) }.to_string_lossy().into();
unsafe { allocator.free(property_value.cast_mut()) };
Property::String(value)
}
})
}

pub fn get_parameter(&self, name: impl AsRef<str>, allocator: &Allocator) -> Result<DynTensor> {
let name = CString::new(name.as_ref())?;

let mut value_ptr = ptr::null_mut();
trainsys![unsafe GetParameter(self.ptr.as_ptr(), name.as_ptr(), allocator.ptr().cast_mut(), &mut value_ptr)?; nonNull(value_ptr)];
Ok(unsafe { DynTensor::from_ptr(NonNull::new_unchecked(value_ptr), None) })
}
}

#[derive(Debug, Clone, PartialEq)]
pub enum Property {
Int(i64),
Float(f32),
String(String)
}

impl From<i64> for Property {
fn from(value: i64) -> Self {
Self::Int(value)
}
}
impl From<f32> for Property {
fn from(value: f32) -> Self {
Self::Float(value)
}
}
impl From<&str> for Property {
fn from(value: &str) -> Self {
Self::String(value.to_string())
}
}
impl From<String> for Property {
fn from(value: String) -> Self {
Self::String(value)
}
}

impl AsPointer for Checkpoint {
Expand All @@ -106,32 +210,67 @@ impl Drop for Checkpoint {
}
}

#[derive(Debug, Clone)]
pub enum LearningRateScheduler {
Linear {
warmup_step_count: i64,
total_step_count: i64,
initial_lr: f32
}
}

#[derive(Debug)]
pub struct Optimizer(NonNull<ort_sys::OrtTrainingSession>);
pub struct Optimizer<'s> {
session: NonNull<ort_sys::OrtTrainingSession>,
_p: PhantomData<&'s ()>
}

impl Optimizer<'_> {
pub(crate) fn new(session: NonNull<ort_sys::OrtTrainingSession>) -> Self {
Self { session, _p: PhantomData }
}

impl Optimizer {
pub fn reset_grad(&self) -> Result<()> {
trainsys![unsafe LazyResetGrad(self.0.as_ptr())?];
pub fn reset_grad(&mut self) -> Result<()> {
trainsys![unsafe LazyResetGrad(self.session.as_ptr())?];
Ok(())
}

pub fn lr(&self) -> Result<f32> {
let mut lr = f32::NAN;
trainsys![unsafe GetLearningRate(self.0.as_ptr(), &mut lr)?];
trainsys![unsafe GetLearningRate(self.session.as_ptr(), &mut lr)?];
Ok(lr)
}

pub fn set_lr(&self, lr: f32) -> Result<()> {
trainsys![unsafe SetLearningRate(self.0.as_ptr(), lr)?];
pub fn set_lr(&mut self, lr: f32) -> Result<()> {
trainsys![unsafe SetLearningRate(self.session.as_ptr(), lr)?];
Ok(())
}

pub fn step(&self) -> Result<()> {
self.step_with_options(RunOptions::new()?)
pub fn register_scheduler(&mut self, scheduler: LearningRateScheduler) -> Result<()> {
match scheduler {
LearningRateScheduler::Linear {
warmup_step_count,
total_step_count,
initial_lr
} => {
trainsys![unsafe RegisterLinearLRScheduler(self.session.as_ptr(), warmup_step_count, total_step_count, initial_lr)?];
}
}
Ok(())
}

pub fn step(&mut self) -> Result<()> {
trainsys![unsafe OptimizerStep(self.session.as_ptr(), ptr::null_mut())?];
Ok(())
}

pub fn step_with_options(&mut self, options: RunOptions<NoSelectedOutputs>) -> Result<()> {
trainsys![unsafe OptimizerStep(self.session.as_ptr(), options.ptr())?];
Ok(())
}

pub fn step_with_options(&self, options: RunOptions) -> Result<()> {
trainsys![unsafe OptimizerStep(self.0.as_ptr(), options.ptr())?];
pub fn step_scheduler(&mut self) -> Result<()> {
trainsys![unsafe SchedulerStep(self.session.as_ptr())?];
Ok(())
}
}
2 changes: 1 addition & 1 deletion src/training/simple/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl<'t> TrainerControl<'t> {
self.trainer.export(out_path, output_names)
}

pub fn optimizer(&self) -> &Optimizer {
pub fn optimizer(&self) -> Optimizer<'_> {
self.trainer.optimizer()
}

Expand Down
2 changes: 1 addition & 1 deletion src/training/simple/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl Trainer {
&self,
mut args: TrainingArguments<I, L, NI, NL>
) -> Result<()> {
let optimizer = self.optimizer();
let mut optimizer = self.optimizer();
optimizer.set_lr(args.lr)?;

let mut saved_ckpts = VecDeque::new();
Expand Down
Loading

0 comments on commit 8aab523

Please sign in to comment.