From 8aab523770ce8156cc98acb67f7222d5d61b2a9b Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Fri, 21 Feb 2025 23:26:36 -0600 Subject: [PATCH] feat(training): improve training API coverage --- .vscode/settings.json | 3 +- examples/training/examples/train-clm.rs | 4 +- ort-sys/src/lib.rs | 35 +++++ src/training/mod.rs | 173 +++++++++++++++++++++--- src/training/simple/callbacks.rs | 2 +- src/training/simple/mod.rs | 2 +- src/training/trainer.rs | 116 +++++++++++----- tools/training-api-coverage.ts | 75 ++++++++++ 8 files changed, 356 insertions(+), 54 deletions(-) create mode 100644 tools/training-api-coverage.ts diff --git a/.vscode/settings.json b/.vscode/settings.json index f1733e71..6952f106 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -10,6 +10,7 @@ "rust-analyzer.diagnostics.experimental.enable": true, "rust-analyzer.showUnlinkedFileNotification": false, "deno.enablePaths": [ - "tools/api-coverage.ts" + "tools/api-coverage.ts", + "tools/training-api-coverage.ts" ] } diff --git a/examples/training/examples/train-clm.rs b/examples/training/examples/train-clm.rs index 3f94da03..d11c2795 100644 --- a/examples/training/examples/train-clm.rs +++ b/examples/training/examples/train-clm.rs @@ -45,7 +45,7 @@ fn main() -> ort::Result<()> { ) .unwrap(); - let optimizer = trainer.optimizer(); + let mut optimizer = trainer.optimizer(); optimizer.set_lr(7e-5)?; let mut dataset = File::open("dataset.bin").unwrap(); @@ -93,6 +93,8 @@ fn main() -> ort::Result<()> { if loss.is_nan() { return Ok(()); } + + let mut optimizer = trainer.optimizer(); optimizer.step()?; optimizer.reset_grad()?; } diff --git a/ort-sys/src/lib.rs b/ort-sys/src/lib.rs index 6b18507b..68054fdf 100644 --- a/ort-sys/src/lib.rs +++ b/ort-sys/src/lib.rs @@ -552,6 +552,41 @@ pub struct OrtTrainingApi { inference_model_path: *const ortchar, graph_outputs_len: usize, graph_output_names: *const *const c_char + ) -> OrtStatusPtr, + pub SetSeed: unsafe extern "system" fn(seed: i64) -> OrtStatusPtr, + pub TrainingSessionGetTrainingModelInputCount: unsafe extern "system" fn(session: *const OrtTrainingSession, out: *mut usize) -> OrtStatusPtr, + pub TrainingSessionGetEvalModelInputCount: unsafe extern "system" fn(session: *const OrtTrainingSession, out: *mut usize) -> OrtStatusPtr, + pub TrainingSessionGetTrainingModelInputName: + unsafe extern "system" fn(session: *const OrtTrainingSession, index: usize, allocator: *mut OrtAllocator, output: *mut *const c_char) -> OrtStatusPtr, + pub TrainingSessionGetEvalModelInputName: + unsafe extern "system" fn(session: *const OrtTrainingSession, index: usize, allocator: *mut OrtAllocator, output: *mut *const c_char) -> OrtStatusPtr, + pub AddProperty: unsafe extern "system" fn( + checkpoint_state: *mut OrtCheckpointState, + property_name: *const c_char, + property_type: OrtPropertyType, + property_value: *const () + ) -> OrtStatusPtr, + pub GetProperty: unsafe extern "system" fn( + checkpoint_state: *mut OrtCheckpointState, + property_name: *const c_char, + allocator: *mut OrtAllocator, + property_type: *mut OrtPropertyType, + property_value: *mut *const () + ) -> OrtStatusPtr, + pub LoadCheckpointFromBuffer: + unsafe extern "system" fn(checkpoint_buffer: *const (), num_bytes: usize, checkpoint_state: *mut *mut OrtCheckpointState) -> OrtStatusPtr, + pub GetParameterTypeAndShape: unsafe extern "system" fn( + checkpoint_state: *const OrtCheckpointState, + parameter_name: *const c_char, + parameter_type_and_shape: *mut *mut OrtTensorTypeAndShapeInfo + ) -> OrtStatusPtr, + pub UpdateParameter: + unsafe extern "system" fn(checkpoint_state: *mut OrtCheckpointState, parameter_name: *const c_char, parameter: *mut OrtValue) -> OrtStatusPtr, + pub GetParameter: unsafe extern "system" fn( + checkpoint_state: *const OrtCheckpointState, + parameter_name: *const c_char, + allocator: *mut OrtAllocator, + parameter: *mut *mut OrtValue ) -> OrtStatusPtr } #[doc = " \\brief The helper interface to get the right version of OrtApi\n\n Get a pointer to this structure through ::OrtGetApiBase"] diff --git a/src/training/mod.rs b/src/training/mod.rs index 4599509d..1012d59d 100644 --- a/src/training/mod.rs +++ b/src/training/mod.rs @@ -1,12 +1,23 @@ //! Provides [`Trainer`], a simple interface for on-device training/fine-tuning. -use std::{ - path::Path, - ptr::{self, NonNull}, - sync::OnceLock +use alloc::{ + ffi::CString, + string::{String, ToString} }; +use core::{ + ffi::{CStr, c_char}, + marker::PhantomData, + ptr::{self, NonNull} +}; +use std::{path::Path, sync::OnceLock}; -use crate::{AsPointer, Error, Result, ortsys, session::RunOptions}; +use crate::{ + AsPointer, Error, Result, + memory::Allocator, + ortsys, + session::{NoSelectedOutputs, RunOptions}, + value::DynTensor +}; mod simple; mod trainer; @@ -44,6 +55,12 @@ pub fn training_api() -> Result> { .ok_or_else(|| Error::new("Training is not enbled in this build of ONNX Runtime.")) } +/// Sets the seed used for RNG when training. +pub fn set_seed(seed: i64) -> Result<()> { + trainsys![unsafe SetSeed(seed)?]; + Ok(()) +} + macro_rules! trainsys { ($method:ident) => { ($crate::training::training_api().unwrap().as_ref().$method) @@ -60,7 +77,7 @@ macro_rules! trainsys { _x }}; (unsafe $method:ident($($n:expr),+ $(,)?)?) => { - unsafe { $crate::error::status_to_result(($crate::training::training_api()?.as_ref().$method)($($n),+)) }?; + unsafe { $crate::error::status_to_result(($crate::training::training_api()?.as_ref().$method)($($n),+)) }? }; (unsafe $method:ident($($n:expr),+ $(,)?)?; nonNull($($check:expr),+ $(,)?)$(;)?) => {{ unsafe { $crate::error::status_to_result(($crate::training::training_api()?.as_ref().$method)($($n),+)) }?; @@ -84,11 +101,98 @@ impl Checkpoint { }) } + pub fn load_from_buffer(buffer: &[u8]) -> Result { + let mut ptr: *mut ort_sys::OrtCheckpointState = ptr::null_mut(); + trainsys![unsafe LoadCheckpointFromBuffer(buffer.as_ptr().cast(), buffer.len(), &mut ptr)?; nonNull(ptr)]; + Ok(Checkpoint { + ptr: unsafe { NonNull::new_unchecked(ptr) } + }) + } + pub fn save(&self, path: impl AsRef, include_optimizer_state: bool) -> Result<()> { let path = crate::util::path_to_os_char(path); trainsys![unsafe SaveCheckpoint(self.ptr.as_ptr(), path.as_ptr(), include_optimizer_state)?]; Ok(()) } + + pub fn add_property(&mut self, name: impl AsRef, property: impl Into) -> Result<()> { + let name = CString::new(name.as_ref())?; + match property.into() { + Property::Int(value) => { + trainsys![unsafe AddProperty(self.ptr.as_ptr(), name.as_ptr(), ort_sys::OrtPropertyType::OrtIntProperty, (&value as *const i64).cast())?] + } + Property::Float(value) => { + trainsys![unsafe AddProperty(self.ptr.as_ptr(), name.as_ptr(), ort_sys::OrtPropertyType::OrtFloatProperty, (&value as *const f32).cast())?] + } + Property::String(value) => { + let value = CString::new(value)?; + trainsys![unsafe AddProperty(self.ptr.as_ptr(), name.as_ptr(), ort_sys::OrtPropertyType::OrtStringProperty, value.as_ptr().cast())?] + } + } + Ok(()) + } + + pub fn get_property(&self, name: impl AsRef) -> Option { + let name = CString::new(name.as_ref()).ok()?; + let mut allocator = Allocator::default(); + let mut property_type: ort_sys::OrtPropertyType = ort_sys::OrtPropertyType::OrtIntProperty; + let mut property_value: *const () = ptr::null(); + + let status = trainsys![unsafe GetProperty( + self.ptr.as_ptr(), + name.as_ptr(), + allocator.ptr_mut(), + &mut property_type, + &mut property_value + )]; + unsafe { crate::error::status_to_result(status) }.ok()?; + + Some(match property_type { + ort_sys::OrtPropertyType::OrtIntProperty => Property::Int(unsafe { *property_value.cast::() }), + ort_sys::OrtPropertyType::OrtFloatProperty => Property::Float(unsafe { *property_value.cast::() }), + ort_sys::OrtPropertyType::OrtStringProperty => { + let value = unsafe { CStr::from_ptr(property_value.cast::()) }.to_string_lossy().into(); + unsafe { allocator.free(property_value.cast_mut()) }; + Property::String(value) + } + }) + } + + pub fn get_parameter(&self, name: impl AsRef, allocator: &Allocator) -> Result { + let name = CString::new(name.as_ref())?; + + let mut value_ptr = ptr::null_mut(); + trainsys![unsafe GetParameter(self.ptr.as_ptr(), name.as_ptr(), allocator.ptr().cast_mut(), &mut value_ptr)?; nonNull(value_ptr)]; + Ok(unsafe { DynTensor::from_ptr(NonNull::new_unchecked(value_ptr), None) }) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Property { + Int(i64), + Float(f32), + String(String) +} + +impl From for Property { + fn from(value: i64) -> Self { + Self::Int(value) + } +} +impl From for Property { + fn from(value: f32) -> Self { + Self::Float(value) + } +} +impl From<&str> for Property { + fn from(value: &str) -> Self { + Self::String(value.to_string()) + } +} +impl From for Property { + fn from(value: String) -> Self { + Self::String(value) + } } impl AsPointer for Checkpoint { @@ -106,32 +210,67 @@ impl Drop for Checkpoint { } } +#[derive(Debug, Clone)] +pub enum LearningRateScheduler { + Linear { + warmup_step_count: i64, + total_step_count: i64, + initial_lr: f32 + } +} + #[derive(Debug)] -pub struct Optimizer(NonNull); +pub struct Optimizer<'s> { + session: NonNull, + _p: PhantomData<&'s ()> +} + +impl Optimizer<'_> { + pub(crate) fn new(session: NonNull) -> Self { + Self { session, _p: PhantomData } + } -impl Optimizer { - pub fn reset_grad(&self) -> Result<()> { - trainsys![unsafe LazyResetGrad(self.0.as_ptr())?]; + pub fn reset_grad(&mut self) -> Result<()> { + trainsys![unsafe LazyResetGrad(self.session.as_ptr())?]; Ok(()) } pub fn lr(&self) -> Result { let mut lr = f32::NAN; - trainsys![unsafe GetLearningRate(self.0.as_ptr(), &mut lr)?]; + trainsys![unsafe GetLearningRate(self.session.as_ptr(), &mut lr)?]; Ok(lr) } - pub fn set_lr(&self, lr: f32) -> Result<()> { - trainsys![unsafe SetLearningRate(self.0.as_ptr(), lr)?]; + pub fn set_lr(&mut self, lr: f32) -> Result<()> { + trainsys![unsafe SetLearningRate(self.session.as_ptr(), lr)?]; Ok(()) } - pub fn step(&self) -> Result<()> { - self.step_with_options(RunOptions::new()?) + pub fn register_scheduler(&mut self, scheduler: LearningRateScheduler) -> Result<()> { + match scheduler { + LearningRateScheduler::Linear { + warmup_step_count, + total_step_count, + initial_lr + } => { + trainsys![unsafe RegisterLinearLRScheduler(self.session.as_ptr(), warmup_step_count, total_step_count, initial_lr)?]; + } + } + Ok(()) + } + + pub fn step(&mut self) -> Result<()> { + trainsys![unsafe OptimizerStep(self.session.as_ptr(), ptr::null_mut())?]; + Ok(()) + } + + pub fn step_with_options(&mut self, options: RunOptions) -> Result<()> { + trainsys![unsafe OptimizerStep(self.session.as_ptr(), options.ptr())?]; + Ok(()) } - pub fn step_with_options(&self, options: RunOptions) -> Result<()> { - trainsys![unsafe OptimizerStep(self.0.as_ptr(), options.ptr())?]; + pub fn step_scheduler(&mut self) -> Result<()> { + trainsys![unsafe SchedulerStep(self.session.as_ptr())?]; Ok(()) } } diff --git a/src/training/simple/callbacks.rs b/src/training/simple/callbacks.rs index 07cfc044..f55f53bd 100644 --- a/src/training/simple/callbacks.rs +++ b/src/training/simple/callbacks.rs @@ -65,7 +65,7 @@ impl<'t> TrainerControl<'t> { self.trainer.export(out_path, output_names) } - pub fn optimizer(&self) -> &Optimizer { + pub fn optimizer(&self) -> Optimizer<'_> { self.trainer.optimizer() } diff --git a/src/training/simple/mod.rs b/src/training/simple/mod.rs index 0791bafe..dc6d7bd7 100644 --- a/src/training/simple/mod.rs +++ b/src/training/simple/mod.rs @@ -45,7 +45,7 @@ impl Trainer { &self, mut args: TrainingArguments ) -> Result<()> { - let optimizer = self.optimizer(); + let mut optimizer = self.optimizer(); optimizer.set_lr(args.lr)?; let mut saved_ckpts = VecDeque::new(); diff --git a/src/training/trainer.rs b/src/training/trainer.rs index b8113c40..f776f93d 100644 --- a/src/training/trainer.rs +++ b/src/training/trainer.rs @@ -1,8 +1,6 @@ -use std::{ - ffi::CString, - path::Path, - ptr::{self, NonNull} -}; +use alloc::ffi::CString; +use core::ptr::{self, NonNull}; +use std::path::Path; use ort_sys::c_char; @@ -19,7 +17,7 @@ use crate::{ pub struct Trainer { ptr: NonNull, train_output_names: Vec, - optimizer: Optimizer, + eval_output_names: Vec, ckpt: Checkpoint, _allocator: Allocator } @@ -43,7 +41,63 @@ impl Trainer { trainsys![unsafe CreateTrainingSession(env.ptr(), session_options.ptr(), ckpt.ptr.as_ptr(), training_model_path.as_ptr(), eval_model_path.as_ptr(), optimizer_model_path.as_ptr(), &mut ptr)?; nonNull(ptr)]; let ptr = unsafe { NonNull::new_unchecked(ptr) }; + Self::new_inner(ptr, allocator, ckpt) + } + pub fn new_from_artifacts( + session_options: SessionBuilder, + allocator: Allocator, + base_dir: impl AsRef, + override_ckpt: Option + ) -> Result { + let base_dir = base_dir.as_ref(); + let ckpt = if let Some(ckpt) = override_ckpt { + ckpt + } else { + Checkpoint::load(base_dir.join("checkpoint"))? + }; + Self::new( + session_options, + allocator, + ckpt, + base_dir.join("training_model.onnx"), + base_dir.join("eval_model.onnx"), + base_dir.join("optimizer_model.onnx") + ) + } + + pub fn new_from_memory( + session_options: SessionBuilder, + allocator: Allocator, + ckpt: Checkpoint, + training_model: &[u8], + eval_model: &[u8], + optimizer_model: &[u8] + ) -> Result { + let env = crate::environment::get_environment()?; + + let mut ptr: *mut ort_sys::OrtTrainingSession = ptr::null_mut(); + trainsys![ + unsafe CreateTrainingSessionFromBuffer( + env.ptr(), + session_options.ptr(), + ckpt.ptr.as_ptr(), + training_model.as_ptr().cast(), + training_model.len(), + eval_model.as_ptr().cast(), + eval_model.len(), + optimizer_model.as_ptr().cast(), + optimizer_model.len(), + &mut ptr + )?; + nonNull(ptr) + ]; + + let ptr = unsafe { NonNull::new_unchecked(ptr) }; + Self::new_inner(ptr, allocator, ckpt) + } + + fn new_inner(ptr: NonNull, allocator: Allocator, ckpt: Checkpoint) -> Result { let mut train_output_len = 0; trainsys![unsafe TrainingSessionGetTrainingModelOutputCount(ptr.as_ptr(), &mut train_output_len)?]; let train_output_names = (0..train_output_len) @@ -62,37 +116,33 @@ impl Trainer { }) .collect::>>()?; + let mut eval_output_len = 0; + trainsys![unsafe TrainingSessionGetEvalModelOutputCount(ptr.as_ptr(), &mut eval_output_len)?]; + let eval_output_names = (0..eval_output_len) + .map(|i| { + let mut name_bytes: *mut c_char = std::ptr::null_mut(); + trainsys![unsafe TrainingSessionGetEvalModelOutputName(ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_bytes)?]; + let name = match char_p_to_string(name_bytes) { + Ok(name) => name, + Err(e) => { + unsafe { allocator.free(name_bytes) }; + return Err(e); + } + }; + unsafe { allocator.free(name_bytes) }; + Ok(name) + }) + .collect::>>()?; + Ok(Self { ptr, _allocator: allocator, train_output_names, - optimizer: Optimizer(ptr), + eval_output_names, ckpt }) } - pub fn new_from_artifacts( - session_options: SessionBuilder, - allocator: Allocator, - base_dir: impl AsRef, - override_ckpt: Option - ) -> Result { - let base_dir = base_dir.as_ref(); - let ckpt = if let Some(ckpt) = override_ckpt { - ckpt - } else { - Checkpoint::load(base_dir.join("checkpoint"))? - }; - Self::new( - session_options, - allocator, - ckpt, - base_dir.join("training_model.onnx"), - base_dir.join("eval_model.onnx"), - base_dir.join("optimizer_model.onnx") - ) - } - pub fn step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>( &'s self, inputs: impl Into>, @@ -163,7 +213,7 @@ impl Trainer { input_values: impl Iterator>, run_options: Option<&'r RunOptions> ) -> Result> { - let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; + let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.eval_output_names.len()]; let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr()).collect(); @@ -180,7 +230,7 @@ impl Trainer { }) .collect(); - Ok(SessionOutputs::new(self.train_output_names.iter().map(String::as_str).collect(), outputs)) + Ok(SessionOutputs::new(self.eval_output_names.iter().map(String::as_str).collect(), outputs)) } pub fn export>(&self, out_path: impl AsRef, output_names: impl AsRef<[O]>) -> Result<()> { @@ -211,8 +261,8 @@ impl Trainer { Ok(()) } - pub fn optimizer(&self) -> &Optimizer { - &self.optimizer + pub fn optimizer(&self) -> Optimizer<'_> { + Optimizer::new(self.ptr) } pub fn checkpoint(&self) -> &Checkpoint { diff --git a/tools/training-api-coverage.ts b/tools/training-api-coverage.ts new file mode 100644 index 00000000..b47def3f --- /dev/null +++ b/tools/training-api-coverage.ts @@ -0,0 +1,75 @@ +import { dirname, join } from 'jsr:@std/path@1.0.3'; +import { walk } from 'jsr:@std/fs@1.0.2'; + +const PROJECT_ROOT = dirname(import.meta.dirname!); +const DECODER = new TextDecoder('utf-8'); +const SYMBOL_DEF_REGEX = /pub\s+([A-Za-z_][A-Za-z0-9_]+):/; +const SYMBOL_USAGE_REGEX = /trainsys!\[\s*(?:unsafe\s+)?([A-Za-z_][A-Za-z0-9_]+)/gm; + +const IGNORED_SYMBOLS = new Set([ + 'CreateEnv', // we will always create an env with a custom logger for integration w/ tracing + 'CreateEnvWithGlobalThreadPools', + 'KernelContext_GetScratchBuffer', // implemented in src/operator/kernel.rs but impl appears to be broken so ignoring + 'RegisterCustomOpsLibrary', // we use RegisterCustomOpsLibrary_V2 + 'RegisterCustomOpsUsingFunction', + 'SessionOptionsAppendExecutionProvider_CUDA', // we use V2 + 'SessionOptionsAppendExecutionProvider_TensorRT', // we use V2 + 'GetValueType', // we get value types via GetTypeInfo -> GetOnnxTypeFromTypeInfo, which is equivalent + 'SetLanguageProjection', // someday we shall have `ORT_PROJECTION_RUST`, but alas, today is not that day... + + // we use allocator APIs directly on the Allocator struct + 'AllocatorAlloc', + 'AllocatorFree', + 'AllocatorGetInfo', + + // functions that don't make sense with SessionBuilder API + 'HasSessionConfigEntry', + 'GetSessionConfigEntry', + 'DisableProfiling', + 'GetCUDAProviderOptionsAsString', + 'GetTensorRTProviderOptionsAsString', + 'GetCANNProviderOptionsAsString', + 'GetDnnlProviderOptionsAsString' +]); + +const sysSymbols = new Set(); +const sysFile = await Deno.readFile(join(PROJECT_ROOT, 'ort-sys', 'src', 'lib.rs')); +let isInOrtApi = false; +for (const line of DECODER.decode(sysFile).split('\n')) { + if (line === 'pub struct OrtTrainingApi {') { + isInOrtApi = true; + continue; + } + + if (isInOrtApi) { + if (line === '}') { + isInOrtApi = false; + continue; + } + + const trimmedLine = line.trimStart(); + if (SYMBOL_DEF_REGEX.test(trimmedLine)) { + const [ _, symbol ] = trimmedLine.match(SYMBOL_DEF_REGEX)!; + sysSymbols.add(symbol); + } + } +} + +const usedSymbols = new Set(); +for await (const sourceFile of walk(join(PROJECT_ROOT, 'src'))) { + if (sourceFile.isDirectory) { + continue; + } + + const contents = DECODER.decode(await Deno.readFile(sourceFile.path)); + for (const [ _, symbol ] of contents.matchAll(SYMBOL_USAGE_REGEX)) { + usedSymbols.add(symbol); + } +} + +const nonIgnoredSymbols = sysSymbols.difference(IGNORED_SYMBOLS); +const unusedSymbols = nonIgnoredSymbols.difference(usedSymbols); +for (const symbol of unusedSymbols) { + console.log(`%c\t${symbol}`, 'color: red'); +} +console.log(`%cCoverage: ${usedSymbols.size}/${nonIgnoredSymbols.size} (${((usedSymbols.size / nonIgnoredSymbols.size) * 100).toFixed(2)}%)`, 'color: green; font-weight: bold');