From d877fb34e829d9f3f3132c0cde7ae9ca8eceb232 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 3 Nov 2024 21:47:01 -0600 Subject: [PATCH] feat: LoRA adapters --- src/adapter.rs | 52 ++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 ++ src/session/run_options.rs | 9 +++++++ 3 files changed, 63 insertions(+) create mode 100644 src/adapter.rs diff --git a/src/adapter.rs b/src/adapter.rs new file mode 100644 index 00000000..90979b9b --- /dev/null +++ b/src/adapter.rs @@ -0,0 +1,52 @@ +use std::{ + path::Path, + ptr::{self, NonNull}, + sync::Arc +}; + +use crate::{Allocator, Result, ortsys, util}; + +#[derive(Debug)] +pub(crate) struct AdapterInner { + pub(crate) ptr: NonNull +} + +impl Drop for AdapterInner { + fn drop(&mut self) { + ortsys![unsafe ReleaseLoraAdapter(self.ptr.as_ptr())]; + } +} + +#[derive(Debug, Clone)] +pub struct Adapter { + pub(crate) inner: Arc +} + +impl Adapter { + pub fn from_file(path: impl AsRef, allocator: Option<&Allocator>) -> Result { + let path = util::path_to_os_char(path); + let allocator_ptr = allocator.map(|c| c.ptr()).unwrap_or_else(ptr::null_mut); + let mut ptr = ptr::null_mut(); + ortsys![unsafe CreateLoraAdapter(path.as_ptr(), allocator_ptr, &mut ptr)?]; + Ok(Adapter { + inner: Arc::new(AdapterInner { + ptr: unsafe { NonNull::new_unchecked(ptr) } + }) + }) + } + + pub fn from_memory(bytes: &[u8], allocator: Option<&Allocator>) -> Result { + let allocator_ptr = allocator.map(|c| c.ptr()).unwrap_or_else(ptr::null_mut); + let mut ptr = ptr::null_mut(); + ortsys![unsafe CreateLoraAdapterFromArray(bytes.as_ptr().cast(), bytes.len(), allocator_ptr, &mut ptr)?]; + Ok(Adapter { + inner: Arc::new(AdapterInner { + ptr: unsafe { NonNull::new_unchecked(ptr) } + }) + }) + } + + pub fn ptr(&self) -> *mut ort_sys::OrtLoraAdapter { + self.inner.ptr.as_ptr() + } +} diff --git a/src/lib.rs b/src/lib.rs index 46d17550..191641cd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ #[cfg(all(test, not(feature = "fetch-models")))] compile_error!("`cargo test --features fetch-models`!!1!"); +pub(crate) mod adapter; pub(crate) mod environment; pub(crate) mod error; pub(crate) mod execution_providers; @@ -48,6 +49,7 @@ pub use self::tensor::ArrayExtensions; #[cfg_attr(docsrs, doc(cfg(feature = "training")))] pub use self::training::*; pub use self::{ + adapter::Adapter, environment::{Environment, EnvironmentBuilder, EnvironmentGlobalThreadPoolOptions, get_environment, init}, error::{Error, ErrorCode, Result}, execution_providers::*, diff --git a/src/session/run_options.rs b/src/session/run_options.rs index 5ee6a9bf..6834b5a6 100644 --- a/src/session/run_options.rs +++ b/src/session/run_options.rs @@ -1,6 +1,7 @@ use std::{collections::HashMap, ffi::CString, marker::PhantomData, ptr::NonNull, sync::Arc}; use crate::{ + adapter::{Adapter, AdapterInner}, error::Result, ortsys, session::Output, @@ -157,6 +158,7 @@ impl SelectedOutputMarker for HasSelectedOutputs {} pub struct RunOptions { pub(crate) run_options_ptr: NonNull, pub(crate) outputs: OutputSelector, + adapters: Vec>, _marker: PhantomData } @@ -175,6 +177,7 @@ impl RunOptions { Ok(RunOptions { run_options_ptr: unsafe { NonNull::new_unchecked(run_options_ptr) }, outputs: OutputSelector::default(), + adapters: Vec::new(), _marker: PhantomData }) } @@ -303,6 +306,12 @@ impl RunOptions { ortsys![unsafe AddRunConfigEntry(self.run_options_ptr.as_ptr(), key.as_ptr(), value.as_ptr())?]; Ok(()) } + + pub fn add_adapter(&mut self, adapter: &Adapter) -> Result<()> { + ortsys![unsafe RunOptionsAddActiveLoraAdapter(self.run_options_ptr.as_ptr(), adapter.ptr())?]; + self.adapters.push(Arc::clone(&adapter.inner)); + Ok(()) + } } impl Drop for RunOptions {