Skip to content

Commit

Permalink
feat: more CoreML EP flags
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Nov 28, 2024
1 parent f90a3b6 commit 7fa9734
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion src/execution_providers/coreml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ extern "C" {
pub struct CoreMLExecutionProvider {
use_cpu_only: bool,
enable_on_subgraph: bool,
only_enable_device_with_ane: bool
only_enable_device_with_ane: bool,
only_static_input_shapes: bool,
mlprogram: bool,
use_cpu_and_gpu: bool
}

impl CoreMLExecutionProvider {
Expand Down Expand Up @@ -41,6 +44,28 @@ impl CoreMLExecutionProvider {
self
}

/// Only allow the CoreML EP to take nodes with inputs that have static shapes. By default the CoreML EP will also
/// allow inputs with dynamic shapes, however performance may be negatively impacted by inputs with dynamic shapes.
#[must_use]
pub fn with_static_input_shapes(mut self) -> Self {
self.only_static_input_shapes = true;
self
}

/// Create an MLProgram format model. Requires Core ML 5 or later (iOS 15+ or macOS 12+). The default is for a
/// NeuralNetwork model to be created as that requires Core ML 3 or later (iOS 13+ or macOS 10.15+).
#[must_use]
pub fn with_mlprogram(mut self) -> Self {
self.mlprogram = true;
self
}

#[must_use]
pub fn with_cpu_and_gpu(mut self) -> Self {
self.use_cpu_and_gpu = true;
self
}

#[must_use]
pub fn build(self) -> ExecutionProviderDispatch {
self.into()
Expand Down Expand Up @@ -79,6 +104,15 @@ impl ExecutionProvider for CoreMLExecutionProvider {
if self.only_enable_device_with_ane {
flags |= 0x004;
}
if self.only_static_input_shapes {
flags |= 0x008;
}
if self.mlprogram {
flags |= 0x010;
}
if self.use_cpu_and_gpu {
flags |= 0x020;
}
return crate::error::status_to_result(unsafe { OrtSessionOptionsAppendExecutionProvider_CoreML(session_builder.ptr_mut(), flags) });
}

Expand Down

0 comments on commit 7fa9734

Please sign in to comment.