From dc9ad905e4b96be3f63b07c79048c46bf1c95a0e Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Wed, 5 Mar 2025 21:35:18 -0600 Subject: [PATCH] refactor: shorten the names of EP-specific enums --- src/execution_providers/cann.rs | 24 +++++++------- src/execution_providers/coreml.rs | 18 +++++------ src/execution_providers/cuda.rs | 24 +++++++------- src/execution_providers/qnn.rs | 50 ++++++++++++++-------------- src/execution_providers/webgpu.rs | 54 +++++++++++++++---------------- 5 files changed, 86 insertions(+), 84 deletions(-) diff --git a/src/execution_providers/cann.rs b/src/execution_providers/cann.rs index 7403e424..23cef42f 100644 --- a/src/execution_providers/cann.rs +++ b/src/execution_providers/cann.rs @@ -9,7 +9,7 @@ use crate::{ #[derive(Debug, Clone, PartialEq, Eq)] #[non_exhaustive] -pub enum CANNExecutionProviderPrecisionMode { +pub enum CANNPrecisionMode { /// Convert to float32 first according to operator implementation ForceFP32, /// Convert to float16 when float16 and float32 are both supported @@ -24,7 +24,7 @@ pub enum CANNExecutionProviderPrecisionMode { #[derive(Debug, Clone, PartialEq, Eq)] #[non_exhaustive] -pub enum CANNExecutionProviderImplementationMode { +pub enum CANNImplementationMode { HighPrecision, HighPerformance } @@ -74,15 +74,15 @@ impl CANNExecutionProvider { self } - /// Set the precision mode of the operator. See [`CANNExecutionProviderPrecisionMode`]. + /// Set the precision mode of the operator. See [`CANNPrecisionMode`]. #[must_use] - pub fn with_precision_mode(mut self, mode: CANNExecutionProviderPrecisionMode) -> Self { + pub fn with_precision_mode(mut self, mode: CANNPrecisionMode) -> Self { self.options.set("precision_mode", match mode { - CANNExecutionProviderPrecisionMode::ForceFP32 => "force_fp32", - CANNExecutionProviderPrecisionMode::ForceFP16 => "force_fp16", - CANNExecutionProviderPrecisionMode::AllowFP32ToFP16 => "allow_fp32_to_fp16", - CANNExecutionProviderPrecisionMode::MustKeepOrigin => "must_keep_origin_dtype", - CANNExecutionProviderPrecisionMode::AllowMixedPrecision => "allow_mix_precision" + CANNPrecisionMode::ForceFP32 => "force_fp32", + CANNPrecisionMode::ForceFP16 => "force_fp16", + CANNPrecisionMode::AllowFP32ToFP16 => "allow_fp32_to_fp16", + CANNPrecisionMode::MustKeepOrigin => "must_keep_origin_dtype", + CANNPrecisionMode::AllowMixedPrecision => "allow_mix_precision" }); self } @@ -90,10 +90,10 @@ impl CANNExecutionProvider { /// Configure the implementation mode for operators. Some CANN operators can have both high-precision and /// high-performance implementations. #[must_use] - pub fn with_implementation_mode(mut self, mode: CANNExecutionProviderImplementationMode) -> Self { + pub fn with_implementation_mode(mut self, mode: CANNImplementationMode) -> Self { self.options.set("op_select_impl_mode", match mode { - CANNExecutionProviderImplementationMode::HighPrecision => "high_precision", - CANNExecutionProviderImplementationMode::HighPerformance => "high_performance" + CANNImplementationMode::HighPrecision => "high_precision", + CANNImplementationMode::HighPerformance => "high_performance" }); self } diff --git a/src/execution_providers/coreml.rs b/src/execution_providers/coreml.rs index 5bc1dfe7..d87e1d20 100644 --- a/src/execution_providers/coreml.rs +++ b/src/execution_providers/coreml.rs @@ -8,12 +8,12 @@ use crate::{ }; #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum CoreMLExecutionProviderSpecializationStrategy { +pub enum CoreMLSpecializationStrategy { Default, FastPrediction } -impl CoreMLExecutionProviderSpecializationStrategy { +impl CoreMLSpecializationStrategy { #[must_use] pub fn as_str(&self) -> &'static str { match self { @@ -24,14 +24,14 @@ impl CoreMLExecutionProviderSpecializationStrategy { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum CoreMLExecutionProviderComputeUnits { +pub enum CoreMLComputeUnits { All, CPUAndNeuralEngine, CPUAndGPU, CPUOnly } -impl CoreMLExecutionProviderComputeUnits { +impl CoreMLComputeUnits { #[must_use] pub fn as_str(&self) -> &'static str { match self { @@ -44,14 +44,14 @@ impl CoreMLExecutionProviderComputeUnits { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum CoreMLExecutionProviderModelFormat { +pub enum CoreMLModelFormat { /// Requires Core ML 5 or later (iOS 15+ or macOS 12+). MLProgram, /// Default; requires Core ML 3 or later (iOS 13+ or macOS 10.15+). NeuralNetwork } -impl CoreMLExecutionProviderModelFormat { +impl CoreMLModelFormat { #[must_use] pub fn as_str(&self) -> &'static str { match self { @@ -83,19 +83,19 @@ impl CoreMLExecutionProvider { } #[must_use] - pub fn with_model_format(mut self, model_format: CoreMLExecutionProviderModelFormat) -> Self { + pub fn with_model_format(mut self, model_format: CoreMLModelFormat) -> Self { self.options.set("ModelFormat", model_format.as_str()); self } #[must_use] - pub fn with_specialization_strategy(mut self, strategy: CoreMLExecutionProviderSpecializationStrategy) -> Self { + pub fn with_specialization_strategy(mut self, strategy: CoreMLSpecializationStrategy) -> Self { self.options.set("SpecializationStrategy", strategy.as_str()); self } #[must_use] - pub fn with_compute_units(mut self, units: CoreMLExecutionProviderComputeUnits) -> Self { + pub fn with_compute_units(mut self, units: CoreMLComputeUnits) -> Self { self.options.set("MLComputeUnits", units.as_str()); self } diff --git a/src/execution_providers/cuda.rs b/src/execution_providers/cuda.rs index c154502b..a34f946d 100644 --- a/src/execution_providers/cuda.rs +++ b/src/execution_providers/cuda.rs @@ -11,9 +11,9 @@ use crate::{ // https://github.com/microsoft/onnxruntime/blob/ffceed9d44f2f3efb9dd69fa75fea51163c91d91/onnxruntime/contrib_ops/cpu/bert/attention_common.h#L160-L171 #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] #[repr(transparent)] -pub struct CUDAExecutionProviderAttentionBackend(u32); +pub struct CUDAAttentionBackend(u32); -impl CUDAExecutionProviderAttentionBackend { +impl CUDAAttentionBackend { pub const FLASH_ATTENTION: Self = Self(1 << 0); pub const EFFICIENT_ATTENTION: Self = Self(1 << 1); pub const TRT_FUSED_ATTENTION: Self = Self(1 << 2); @@ -24,8 +24,10 @@ impl CUDAExecutionProviderAttentionBackend { pub const TRT_CROSS_ATTENTION: Self = Self(1 << 6); pub const TRT_CAUSAL_ATTENTION: Self = Self(1 << 7); + pub const LEAN_ATTENTION: Self = Self(1 << 8); + pub fn none() -> Self { - CUDAExecutionProviderAttentionBackend(0) + CUDAAttentionBackend(0) } pub fn all() -> Self { @@ -40,7 +42,7 @@ impl CUDAExecutionProviderAttentionBackend { } } -impl BitOr for CUDAExecutionProviderAttentionBackend { +impl BitOr for CUDAAttentionBackend { type Output = Self; fn bitor(self, rhs: Self) -> Self::Output { Self(rhs.0 | self.0) @@ -49,7 +51,7 @@ impl BitOr for CUDAExecutionProviderAttentionBackend { /// The type of search done for cuDNN convolution algorithms. #[derive(Debug, Clone)] -pub enum CUDAExecutionProviderCuDNNConvAlgoSearch { +pub enum CuDNNConvAlgorithmSearch { /// Expensive exhaustive benchmarking using [`cudnnFindConvolutionForwardAlgorithmEx`][exhaustive]. /// This function will attempt all possible algorithms for `cudnnConvolutionForward` to find the fastest algorithm. /// Exhaustive search trades off between memory usage and speed. The first execution of a graph will be slow while @@ -77,7 +79,7 @@ pub enum CUDAExecutionProviderCuDNNConvAlgoSearch { Default } -impl Default for CUDAExecutionProviderCuDNNConvAlgoSearch { +impl Default for CuDNNConvAlgorithmSearch { fn default() -> Self { Self::Exhaustive } @@ -118,11 +120,11 @@ impl CUDAExecutionProvider { /// configuration (input shape, filter shape, etc.) in each `Conv` node. This option controlls the type of search /// done for cuDNN convolution algorithms. See [`CUDAExecutionProviderCuDNNConvAlgoSearch`] for more info. #[must_use] - pub fn with_conv_algorithm_search(mut self, search: CUDAExecutionProviderCuDNNConvAlgoSearch) -> Self { + pub fn with_conv_algorithm_search(mut self, search: CuDNNConvAlgorithmSearch) -> Self { self.options.set("cudnn_conv_algo_search", match search { - CUDAExecutionProviderCuDNNConvAlgoSearch::Exhaustive => "EXHAUSTIVE", - CUDAExecutionProviderCuDNNConvAlgoSearch::Heuristic => "HEURISTIC", - CUDAExecutionProviderCuDNNConvAlgoSearch::Default => "DEFAULT" + CuDNNConvAlgorithmSearch::Exhaustive => "EXHAUSTIVE", + CuDNNConvAlgorithmSearch::Heuristic => "HEURISTIC", + CuDNNConvAlgorithmSearch::Default => "DEFAULT" }); self } @@ -222,7 +224,7 @@ impl CUDAExecutionProvider { } #[must_use] - pub fn with_attention_backend(mut self, flags: CUDAExecutionProviderAttentionBackend) -> Self { + pub fn with_attention_backend(mut self, flags: CUDAAttentionBackend) -> Self { self.options.set("sdpa_kernel", flags.0.to_string()); self } diff --git a/src/execution_providers/qnn.rs b/src/execution_providers/qnn.rs index 7b467808..b584efd8 100644 --- a/src/execution_providers/qnn.rs +++ b/src/execution_providers/qnn.rs @@ -8,7 +8,7 @@ use crate::{ }; #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum QNNExecutionProviderPerformanceMode { +pub enum QNNPerformanceMode { Default, Burst, Balanced, @@ -20,42 +20,42 @@ pub enum QNNExecutionProviderPerformanceMode { SustainedHighPerformance } -impl QNNExecutionProviderPerformanceMode { +impl QNNPerformanceMode { #[must_use] pub fn as_str(&self) -> &'static str { match self { - QNNExecutionProviderPerformanceMode::Default => "default", - QNNExecutionProviderPerformanceMode::Burst => "burst", - QNNExecutionProviderPerformanceMode::Balanced => "balanced", - QNNExecutionProviderPerformanceMode::HighPerformance => "high_performance", - QNNExecutionProviderPerformanceMode::HighPowerSaver => "high_power_saver", - QNNExecutionProviderPerformanceMode::LowPowerSaver => "low_power_saver", - QNNExecutionProviderPerformanceMode::LowBalanced => "low_balanced", - QNNExecutionProviderPerformanceMode::PowerSaver => "power_saver", - QNNExecutionProviderPerformanceMode::SustainedHighPerformance => "sustained_high_performance" + QNNPerformanceMode::Default => "default", + QNNPerformanceMode::Burst => "burst", + QNNPerformanceMode::Balanced => "balanced", + QNNPerformanceMode::HighPerformance => "high_performance", + QNNPerformanceMode::HighPowerSaver => "high_power_saver", + QNNPerformanceMode::LowPowerSaver => "low_power_saver", + QNNPerformanceMode::LowBalanced => "low_balanced", + QNNPerformanceMode::PowerSaver => "power_saver", + QNNPerformanceMode::SustainedHighPerformance => "sustained_high_performance" } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum QNNExecutionProviderProfilingLevel { +pub enum QNNProfilingLevel { Off, Basic, Detailed } -impl QNNExecutionProviderProfilingLevel { +impl QNNProfilingLevel { pub fn as_str(&self) -> &'static str { match self { - QNNExecutionProviderProfilingLevel::Off => "off", - QNNExecutionProviderProfilingLevel::Basic => "basic", - QNNExecutionProviderProfilingLevel::Detailed => "detailed" + QNNProfilingLevel::Off => "off", + QNNProfilingLevel::Basic => "basic", + QNNProfilingLevel::Detailed => "detailed" } } } #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum QNNExecutionProviderContextPriority { +pub enum QNNContextPriority { Low, #[default] Normal, @@ -63,13 +63,13 @@ pub enum QNNExecutionProviderContextPriority { High } -impl QNNExecutionProviderContextPriority { +impl QNNContextPriority { pub fn as_str(&self) -> &'static str { match self { - QNNExecutionProviderContextPriority::Low => "low", - QNNExecutionProviderContextPriority::Normal => "normal", - QNNExecutionProviderContextPriority::NormalHigh => "normal_high", - QNNExecutionProviderContextPriority::High => "normal_high" + QNNContextPriority::Low => "low", + QNNContextPriority::Normal => "normal", + QNNContextPriority::NormalHigh => "normal_high", + QNNContextPriority::High => "normal_high" } } } @@ -89,7 +89,7 @@ impl QNNExecutionProvider { } #[must_use] - pub fn with_profiling(mut self, level: QNNExecutionProviderProfilingLevel) -> Self { + pub fn with_profiling(mut self, level: QNNProfilingLevel) -> Self { self.options.set("profiling_level", level.as_str()); self } @@ -114,7 +114,7 @@ impl QNNExecutionProvider { } #[must_use] - pub fn with_performance_mode(mut self, mode: QNNExecutionProviderPerformanceMode) -> Self { + pub fn with_performance_mode(mut self, mode: QNNPerformanceMode) -> Self { self.options.set("htp_performance_mode", mode.as_str()); self } @@ -126,7 +126,7 @@ impl QNNExecutionProvider { } #[must_use] - pub fn with_context_priority(mut self, priority: QNNExecutionProviderContextPriority) -> Self { + pub fn with_context_priority(mut self, priority: QNNContextPriority) -> Self { self.options.set("qnn_context_priority", priority.as_str()); self } diff --git a/src/execution_providers/webgpu.rs b/src/execution_providers/webgpu.rs index 6e60a32f..e12dfbdf 100644 --- a/src/execution_providers/webgpu.rs +++ b/src/execution_providers/webgpu.rs @@ -11,73 +11,73 @@ use crate::{ }; #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WebGPUExecutionProviderPreferredLayout { +pub enum WebGPUPreferredLayout { NCHW, NHWC } -impl WebGPUExecutionProviderPreferredLayout { +impl WebGPUPreferredLayout { #[must_use] pub fn as_str(&self) -> &'static str { match self { - WebGPUExecutionProviderPreferredLayout::NCHW => "NCHW", - WebGPUExecutionProviderPreferredLayout::NHWC => "NHWC" + WebGPUPreferredLayout::NCHW => "NCHW", + WebGPUPreferredLayout::NHWC => "NHWC" } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WebGPUExecutionProviderDawnBackendType { +pub enum WebGPUDawnBackendType { Vulkan, D3D12 } -impl WebGPUExecutionProviderDawnBackendType { +impl WebGPUDawnBackendType { #[must_use] pub fn as_str(&self) -> &'static str { match self { - WebGPUExecutionProviderDawnBackendType::Vulkan => "Vulkan", - WebGPUExecutionProviderDawnBackendType::D3D12 => "D3D12" + WebGPUDawnBackendType::Vulkan => "Vulkan", + WebGPUDawnBackendType::D3D12 => "D3D12" } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WebGPUExecutionProviderBufferCacheMode { +pub enum WebGPUBufferCacheMode { Disabled, LazyRelease, Simple, Bucket } -impl WebGPUExecutionProviderBufferCacheMode { +impl WebGPUBufferCacheMode { #[must_use] pub fn as_str(&self) -> &'static str { match self { - WebGPUExecutionProviderBufferCacheMode::Disabled => "disabled", - WebGPUExecutionProviderBufferCacheMode::LazyRelease => "lazyRelease", - WebGPUExecutionProviderBufferCacheMode::Simple => "simple", - WebGPUExecutionProviderBufferCacheMode::Bucket => "bucket" + WebGPUBufferCacheMode::Disabled => "disabled", + WebGPUBufferCacheMode::LazyRelease => "lazyRelease", + WebGPUBufferCacheMode::Simple => "simple", + WebGPUBufferCacheMode::Bucket => "bucket" } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WebGPUExecutionProviderValidationMode { +pub enum WebGPUValidationMode { Disabled, WgpuOnly, Basic, Full } -impl WebGPUExecutionProviderValidationMode { +impl WebGPUValidationMode { #[must_use] pub fn as_str(&self) -> &'static str { match self { - WebGPUExecutionProviderValidationMode::Disabled => "disabled", - WebGPUExecutionProviderValidationMode::WgpuOnly => "wgpuOnly", - WebGPUExecutionProviderValidationMode::Basic => "basic", - WebGPUExecutionProviderValidationMode::Full => "full" + WebGPUValidationMode::Disabled => "disabled", + WebGPUValidationMode::WgpuOnly => "wgpuOnly", + WebGPUValidationMode::Basic => "basic", + WebGPUValidationMode::Full => "full" } } } @@ -89,7 +89,7 @@ pub struct WebGPUExecutionProvider { impl WebGPUExecutionProvider { #[must_use] - pub fn with_preferred_layout(mut self, layout: WebGPUExecutionProviderPreferredLayout) -> Self { + pub fn with_preferred_layout(mut self, layout: WebGPUPreferredLayout) -> Self { self.options.set("WebGPU:preferredLayout", layout.as_str()); self } @@ -107,7 +107,7 @@ impl WebGPUExecutionProvider { } #[must_use] - pub fn with_dawn_backend_type(mut self, backend_type: WebGPUExecutionProviderDawnBackendType) -> Self { + pub fn with_dawn_backend_type(mut self, backend_type: WebGPUDawnBackendType) -> Self { self.options.set("WebGPU:dawnBackendType", backend_type.as_str()); self } @@ -119,31 +119,31 @@ impl WebGPUExecutionProvider { } #[must_use] - pub fn with_storage_buffer_cache_mode(mut self, mode: WebGPUExecutionProviderBufferCacheMode) -> Self { + pub fn with_storage_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self { self.options.set("WebGPU:storageBufferCacheMode", mode.as_str()); self } #[must_use] - pub fn with_uniform_buffer_cache_mode(mut self, mode: WebGPUExecutionProviderBufferCacheMode) -> Self { + pub fn with_uniform_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self { self.options.set("WebGPU:uniformBufferCacheMode", mode.as_str()); self } #[must_use] - pub fn with_query_resolve_buffer_cache_mode(mut self, mode: WebGPUExecutionProviderBufferCacheMode) -> Self { + pub fn with_query_resolve_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self { self.options.set("WebGPU:queryResolveBufferCacheMode", mode.as_str()); self } #[must_use] - pub fn with_default_buffer_cache_mode(mut self, mode: WebGPUExecutionProviderBufferCacheMode) -> Self { + pub fn with_default_buffer_cache_mode(mut self, mode: WebGPUBufferCacheMode) -> Self { self.options.set("WebGPU:defaultBufferCacheMode", mode.as_str()); self } #[must_use] - pub fn with_validation_mode(mut self, mode: WebGPUExecutionProviderValidationMode) -> Self { + pub fn with_validation_mode(mut self, mode: WebGPUValidationMode) -> Self { self.options.set("WebGPU:validationMode", mode.as_str()); self }