Skip to content

Commit

Permalink
refactor: shorten the names of EP-specific enums
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Mar 6, 2025
1 parent e8d873a commit dc9ad90
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 84 deletions.
24 changes: 12 additions & 12 deletions src/execution_providers/cann.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{

#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum CANNExecutionProviderPrecisionMode {
pub enum CANNPrecisionMode {
/// Convert to float32 first according to operator implementation
ForceFP32,
/// Convert to float16 when float16 and float32 are both supported
Expand All @@ -24,7 +24,7 @@ pub enum CANNExecutionProviderPrecisionMode {

#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum CANNExecutionProviderImplementationMode {
pub enum CANNImplementationMode {
HighPrecision,
HighPerformance
}
Expand Down Expand Up @@ -74,26 +74,26 @@ impl CANNExecutionProvider {
self
}

/// Set the precision mode of the operator. See [`CANNExecutionProviderPrecisionMode`].
/// Set the precision mode of the operator. See [`CANNPrecisionMode`].
#[must_use]
pub fn with_precision_mode(mut self, mode: CANNExecutionProviderPrecisionMode) -> Self {
pub fn with_precision_mode(mut self, mode: CANNPrecisionMode) -> Self {
self.options.set("precision_mode", match mode {
CANNExecutionProviderPrecisionMode::ForceFP32 => "force_fp32",
CANNExecutionProviderPrecisionMode::ForceFP16 => "force_fp16",
CANNExecutionProviderPrecisionMode::AllowFP32ToFP16 => "allow_fp32_to_fp16",
CANNExecutionProviderPrecisionMode::MustKeepOrigin => "must_keep_origin_dtype",
CANNExecutionProviderPrecisionMode::AllowMixedPrecision => "allow_mix_precision"
CANNPrecisionMode::ForceFP32 => "force_fp32",
CANNPrecisionMode::ForceFP16 => "force_fp16",
CANNPrecisionMode::AllowFP32ToFP16 => "allow_fp32_to_fp16",
CANNPrecisionMode::MustKeepOrigin => "must_keep_origin_dtype",
CANNPrecisionMode::AllowMixedPrecision => "allow_mix_precision"
});
self
}

/// Configure the implementation mode for operators. Some CANN operators can have both high-precision and
/// high-performance implementations.
#[must_use]
pub fn with_implementation_mode(mut self, mode: CANNExecutionProviderImplementationMode) -> Self {
pub fn with_implementation_mode(mut self, mode: CANNImplementationMode) -> Self {
self.options.set("op_select_impl_mode", match mode {
CANNExecutionProviderImplementationMode::HighPrecision => "high_precision",
CANNExecutionProviderImplementationMode::HighPerformance => "high_performance"
CANNImplementationMode::HighPrecision => "high_precision",
CANNImplementationMode::HighPerformance => "high_performance"
});
self
}
Expand Down
18 changes: 9 additions & 9 deletions src/execution_providers/coreml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ use crate::{
};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CoreMLExecutionProviderSpecializationStrategy {
pub enum CoreMLSpecializationStrategy {
Default,
FastPrediction
}

impl CoreMLExecutionProviderSpecializationStrategy {
impl CoreMLSpecializationStrategy {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Expand All @@ -24,14 +24,14 @@ impl CoreMLExecutionProviderSpecializationStrategy {
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CoreMLExecutionProviderComputeUnits {
pub enum CoreMLComputeUnits {
All,
CPUAndNeuralEngine,
CPUAndGPU,
CPUOnly
}

impl CoreMLExecutionProviderComputeUnits {
impl CoreMLComputeUnits {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Expand All @@ -44,14 +44,14 @@ impl CoreMLExecutionProviderComputeUnits {
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CoreMLExecutionProviderModelFormat {
pub enum CoreMLModelFormat {
/// Requires Core ML 5 or later (iOS 15+ or macOS 12+).
MLProgram,
/// Default; requires Core ML 3 or later (iOS 13+ or macOS 10.15+).
NeuralNetwork
}

impl CoreMLExecutionProviderModelFormat {
impl CoreMLModelFormat {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Expand Down Expand Up @@ -83,19 +83,19 @@ impl CoreMLExecutionProvider {
}

#[must_use]
pub fn with_model_format(mut self, model_format: CoreMLExecutionProviderModelFormat) -> Self {
pub fn with_model_format(mut self, model_format: CoreMLModelFormat) -> Self {
self.options.set("ModelFormat", model_format.as_str());
self
}

#[must_use]
pub fn with_specialization_strategy(mut self, strategy: CoreMLExecutionProviderSpecializationStrategy) -> Self {
pub fn with_specialization_strategy(mut self, strategy: CoreMLSpecializationStrategy) -> Self {
self.options.set("SpecializationStrategy", strategy.as_str());
self
}

#[must_use]
pub fn with_compute_units(mut self, units: CoreMLExecutionProviderComputeUnits) -> Self {
pub fn with_compute_units(mut self, units: CoreMLComputeUnits) -> Self {
self.options.set("MLComputeUnits", units.as_str());
self
}
Expand Down
24 changes: 13 additions & 11 deletions src/execution_providers/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ use crate::{
// https://github.com/microsoft/onnxruntime/blob/ffceed9d44f2f3efb9dd69fa75fea51163c91d91/onnxruntime/contrib_ops/cpu/bert/attention_common.h#L160-L171
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct CUDAExecutionProviderAttentionBackend(u32);
pub struct CUDAAttentionBackend(u32);

impl CUDAExecutionProviderAttentionBackend {
impl CUDAAttentionBackend {
pub const FLASH_ATTENTION: Self = Self(1 << 0);
pub const EFFICIENT_ATTENTION: Self = Self(1 << 1);
pub const TRT_FUSED_ATTENTION: Self = Self(1 << 2);
Expand All @@ -24,8 +24,10 @@ impl CUDAExecutionProviderAttentionBackend {
pub const TRT_CROSS_ATTENTION: Self = Self(1 << 6);
pub const TRT_CAUSAL_ATTENTION: Self = Self(1 << 7);

pub const LEAN_ATTENTION: Self = Self(1 << 8);

pub fn none() -> Self {
CUDAExecutionProviderAttentionBackend(0)
CUDAAttentionBackend(0)
}

pub fn all() -> Self {
Expand All @@ -40,7 +42,7 @@ impl CUDAExecutionProviderAttentionBackend {
}
}

impl BitOr for CUDAExecutionProviderAttentionBackend {
impl BitOr for CUDAAttentionBackend {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
Self(rhs.0 | self.0)
Expand All @@ -49,7 +51,7 @@ impl BitOr for CUDAExecutionProviderAttentionBackend {

/// The type of search done for cuDNN convolution algorithms.
#[derive(Debug, Clone)]
pub enum CUDAExecutionProviderCuDNNConvAlgoSearch {
pub enum CuDNNConvAlgorithmSearch {
/// Expensive exhaustive benchmarking using [`cudnnFindConvolutionForwardAlgorithmEx`][exhaustive].
/// This function will attempt all possible algorithms for `cudnnConvolutionForward` to find the fastest algorithm.
/// Exhaustive search trades off between memory usage and speed. The first execution of a graph will be slow while
Expand Down Expand Up @@ -77,7 +79,7 @@ pub enum CUDAExecutionProviderCuDNNConvAlgoSearch {
Default
}

impl Default for CUDAExecutionProviderCuDNNConvAlgoSearch {
impl Default for CuDNNConvAlgorithmSearch {
fn default() -> Self {
Self::Exhaustive
}
Expand Down Expand Up @@ -118,11 +120,11 @@ impl CUDAExecutionProvider {
/// configuration (input shape, filter shape, etc.) in each `Conv` node. This option controlls the type of search
/// done for cuDNN convolution algorithms. See [`CUDAExecutionProviderCuDNNConvAlgoSearch`] for more info.
#[must_use]
pub fn with_conv_algorithm_search(mut self, search: CUDAExecutionProviderCuDNNConvAlgoSearch) -> Self {
pub fn with_conv_algorithm_search(mut self, search: CuDNNConvAlgorithmSearch) -> Self {
self.options.set("cudnn_conv_algo_search", match search {
CUDAExecutionProviderCuDNNConvAlgoSearch::Exhaustive => "EXHAUSTIVE",
CUDAExecutionProviderCuDNNConvAlgoSearch::Heuristic => "HEURISTIC",
CUDAExecutionProviderCuDNNConvAlgoSearch::Default => "DEFAULT"
CuDNNConvAlgorithmSearch::Exhaustive => "EXHAUSTIVE",
CuDNNConvAlgorithmSearch::Heuristic => "HEURISTIC",
CuDNNConvAlgorithmSearch::Default => "DEFAULT"
});
self
}
Expand Down Expand Up @@ -222,7 +224,7 @@ impl CUDAExecutionProvider {
}

#[must_use]
pub fn with_attention_backend(mut self, flags: CUDAExecutionProviderAttentionBackend) -> Self {
pub fn with_attention_backend(mut self, flags: CUDAAttentionBackend) -> Self {
self.options.set("sdpa_kernel", flags.0.to_string());
self
}
Expand Down
50 changes: 25 additions & 25 deletions src/execution_providers/qnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QNNExecutionProviderPerformanceMode {
pub enum QNNPerformanceMode {
Default,
Burst,
Balanced,
Expand All @@ -20,56 +20,56 @@ pub enum QNNExecutionProviderPerformanceMode {
SustainedHighPerformance
}

impl QNNExecutionProviderPerformanceMode {
impl QNNPerformanceMode {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
QNNExecutionProviderPerformanceMode::Default => "default",
QNNExecutionProviderPerformanceMode::Burst => "burst",
QNNExecutionProviderPerformanceMode::Balanced => "balanced",
QNNExecutionProviderPerformanceMode::HighPerformance => "high_performance",
QNNExecutionProviderPerformanceMode::HighPowerSaver => "high_power_saver",
QNNExecutionProviderPerformanceMode::LowPowerSaver => "low_power_saver",
QNNExecutionProviderPerformanceMode::LowBalanced => "low_balanced",
QNNExecutionProviderPerformanceMode::PowerSaver => "power_saver",
QNNExecutionProviderPerformanceMode::SustainedHighPerformance => "sustained_high_performance"
QNNPerformanceMode::Default => "default",
QNNPerformanceMode::Burst => "burst",
QNNPerformanceMode::Balanced => "balanced",
QNNPerformanceMode::HighPerformance => "high_performance",
QNNPerformanceMode::HighPowerSaver => "high_power_saver",
QNNPerformanceMode::LowPowerSaver => "low_power_saver",
QNNPerformanceMode::LowBalanced => "low_balanced",
QNNPerformanceMode::PowerSaver => "power_saver",
QNNPerformanceMode::SustainedHighPerformance => "sustained_high_performance"
}
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QNNExecutionProviderProfilingLevel {
pub enum QNNProfilingLevel {
Off,
Basic,
Detailed
}

impl QNNExecutionProviderProfilingLevel {
impl QNNProfilingLevel {
pub fn as_str(&self) -> &'static str {
match self {
QNNExecutionProviderProfilingLevel::Off => "off",
QNNExecutionProviderProfilingLevel::Basic => "basic",
QNNExecutionProviderProfilingLevel::Detailed => "detailed"
QNNProfilingLevel::Off => "off",
QNNProfilingLevel::Basic => "basic",
QNNProfilingLevel::Detailed => "detailed"
}
}
}

#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum QNNExecutionProviderContextPriority {
pub enum QNNContextPriority {
Low,
#[default]
Normal,
NormalHigh,
High
}

impl QNNExecutionProviderContextPriority {
impl QNNContextPriority {
pub fn as_str(&self) -> &'static str {
match self {
QNNExecutionProviderContextPriority::Low => "low",
QNNExecutionProviderContextPriority::Normal => "normal",
QNNExecutionProviderContextPriority::NormalHigh => "normal_high",
QNNExecutionProviderContextPriority::High => "normal_high"
QNNContextPriority::Low => "low",
QNNContextPriority::Normal => "normal",
QNNContextPriority::NormalHigh => "normal_high",
QNNContextPriority::High => "normal_high"
}
}
}
Expand All @@ -89,7 +89,7 @@ impl QNNExecutionProvider {
}

#[must_use]
pub fn with_profiling(mut self, level: QNNExecutionProviderProfilingLevel) -> Self {
pub fn with_profiling(mut self, level: QNNProfilingLevel) -> Self {
self.options.set("profiling_level", level.as_str());
self
}
Expand All @@ -114,7 +114,7 @@ impl QNNExecutionProvider {
}

#[must_use]
pub fn with_performance_mode(mut self, mode: QNNExecutionProviderPerformanceMode) -> Self {
pub fn with_performance_mode(mut self, mode: QNNPerformanceMode) -> Self {
self.options.set("htp_performance_mode", mode.as_str());
self
}
Expand All @@ -126,7 +126,7 @@ impl QNNExecutionProvider {
}

#[must_use]
pub fn with_context_priority(mut self, priority: QNNExecutionProviderContextPriority) -> Self {
pub fn with_context_priority(mut self, priority: QNNContextPriority) -> Self {
self.options.set("qnn_context_priority", priority.as_str());
self
}
Expand Down
Loading

0 comments on commit dc9ad90

Please sign in to comment.