Skip to content

Commit ea9c6ab

Browse files
committed
core: support dynamic auth tokens for model providers
1 parent 6afaa7d commit ea9c6ab

17 files changed

+678
-4
lines changed

codex-rs/core/config.schema.json

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,10 +816,62 @@
816816
},
817817
"type": "object"
818818
},
819+
"ModelProviderAuthInfo": {
820+
"additionalProperties": false,
821+
"description": "Configuration for obtaining a provider bearer token from a command.",
822+
"properties": {
823+
"args": {
824+
"default": [],
825+
"description": "Command arguments.",
826+
"items": {
827+
"type": "string"
828+
},
829+
"type": "array"
830+
},
831+
"command": {
832+
"description": "Command to execute. Bare names are resolved via `PATH`; paths are resolved against `cwd`.",
833+
"type": "string"
834+
},
835+
"cwd": {
836+
"allOf": [
837+
{
838+
"$ref": "#/definitions/AbsolutePathBuf"
839+
}
840+
],
841+
"description": "Working directory used when running the token command."
842+
},
843+
"refresh_interval_ms": {
844+
"default": 300000,
845+
"description": "Maximum age for the cached token before rerunning the command.",
846+
"format": "uint64",
847+
"minimum": 1.0,
848+
"type": "integer"
849+
},
850+
"timeout_ms": {
851+
"default": 5000,
852+
"description": "Maximum time to wait for the token command to exit successfully.",
853+
"format": "uint64",
854+
"minimum": 1.0,
855+
"type": "integer"
856+
}
857+
},
858+
"required": [
859+
"command"
860+
],
861+
"type": "object"
862+
},
819863
"ModelProviderInfo": {
820864
"additionalProperties": false,
821865
"description": "Serializable representation of a provider definition.",
822866
"properties": {
867+
"auth": {
868+
"allOf": [
869+
{
870+
"$ref": "#/definitions/ModelProviderAuthInfo"
871+
}
872+
],
873+
"description": "Command-backed bearer-token configuration for this provider."
874+
},
823875
"base_url": {
824876
"description": "Base URL for the provider's OpenAI-compatible API.",
825877
"type": "string"

codex-rs/core/src/auth_env_telemetry.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ mod tests {
6464
env_key: Some("sk-should-not-leak".to_string()),
6565
env_key_instructions: None,
6666
experimental_bearer_token: None,
67+
auth: None,
6768
wire_api: crate::model_provider_info::WireApi::Responses,
6869
query_params: None,
6970
http_headers: None,

codex-rs/core/src/client.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ use crate::error::Result;
104104
use crate::flags::CODEX_RS_SSE_FIXTURE;
105105
use crate::model_provider_info::ModelProviderInfo;
106106
use crate::model_provider_info::WireApi;
107+
use crate::provider_auth::auth_manager_for_provider;
107108
use crate::response_debug_context::extract_response_debug_context;
108109
use crate::response_debug_context::extract_response_debug_context_from_api_error;
109110
use crate::response_debug_context::telemetry_api_error_message;
@@ -261,6 +262,7 @@ impl ModelClient {
261262
include_timing_metrics: bool,
262263
beta_features_header: Option<String>,
263264
) -> Self {
265+
let auth_manager = auth_manager_for_provider(auth_manager, &provider);
264266
let codex_api_key_env_enabled = auth_manager
265267
.as_ref()
266268
.is_some_and(|manager| manager.codex_api_key_env_enabled());
@@ -294,6 +296,10 @@ impl ModelClient {
294296
}
295297
}
296298

299+
pub(crate) fn auth_manager(&self) -> Option<Arc<AuthManager>> {
300+
self.state.auth_manager.clone()
301+
}
302+
297303
fn take_cached_websocket_session(&self) -> WebsocketSession {
298304
let mut cached_websocket_session = self
299305
.state

codex-rs/core/src/config/config_tests.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,26 @@ web_search = false
243243
);
244244
}
245245

246+
#[test]
247+
fn rejects_provider_auth_with_env_key() {
248+
let err = toml::from_str::<ConfigToml>(
249+
r#"
250+
[model_providers.corp]
251+
name = "Corp"
252+
env_key = "CORP_TOKEN"
253+
254+
[model_providers.corp.auth]
255+
command = "print-token"
256+
"#,
257+
)
258+
.unwrap_err();
259+
260+
assert!(
261+
err.to_string()
262+
.contains("model_providers.corp: provider auth cannot be combined with env_key")
263+
);
264+
}
265+
246266
#[test]
247267
fn config_toml_deserializes_model_availability_nux() {
248268
let toml = r#"
@@ -4315,6 +4335,7 @@ model_verbosity = "high"
43154335
wire_api: crate::WireApi::Responses,
43164336
env_key_instructions: None,
43174337
experimental_bearer_token: None,
4338+
auth: None,
43184339
query_params: None,
43194340
http_headers: None,
43204341
env_http_headers: None,

codex-rs/core/src/config/mod.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,14 +1837,26 @@ Built-in providers cannot be overridden. Rename your custom provider (for exampl
18371837
}
18381838
}
18391839

1840+
fn validate_model_providers(
1841+
model_providers: &HashMap<String, ModelProviderInfo>,
1842+
) -> Result<(), String> {
1843+
validate_reserved_model_provider_ids(model_providers)?;
1844+
for (key, provider) in model_providers {
1845+
provider
1846+
.validate()
1847+
.map_err(|message| format!("model_providers.{key}: {message}"))?;
1848+
}
1849+
Ok(())
1850+
}
1851+
18401852
fn deserialize_model_providers<'de, D>(
18411853
deserializer: D,
18421854
) -> Result<HashMap<String, ModelProviderInfo>, D::Error>
18431855
where
18441856
D: serde::Deserializer<'de>,
18451857
{
18461858
let model_providers = HashMap::<String, ModelProviderInfo>::deserialize(deserializer)?;
1847-
validate_reserved_model_provider_ids(&model_providers).map_err(serde::de::Error::custom)?;
1859+
validate_model_providers(&model_providers).map_err(serde::de::Error::custom)?;
18481860
Ok(model_providers)
18491861
}
18501862

@@ -1969,7 +1981,7 @@ impl Config {
19691981
codex_home: PathBuf,
19701982
config_layer_stack: ConfigLayerStack,
19711983
) -> std::io::Result<Self> {
1972-
validate_reserved_model_provider_ids(&cfg.model_providers)
1984+
validate_model_providers(&cfg.model_providers)
19731985
.map_err(|message| std::io::Error::new(std::io::ErrorKind::InvalidInput, message))?;
19741986
// Ensure that every field of ConfigRequirements is applied to the final
19751987
// Config.

codex-rs/core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ pub mod utils;
6565
pub use utils::path_utils;
6666
pub mod personality_migration;
6767
pub mod plugins;
68+
mod provider_auth;
6869
pub(crate) mod mentions {
6970
pub(crate) use crate::plugins::build_connector_slug_counts;
7071
pub(crate) use crate::plugins::build_skill_name_counts;
@@ -104,6 +105,7 @@ mod text_encoding;
104105
mod unified_exec;
105106
pub mod windows_sandbox;
106107
pub use client::X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER;
108+
pub use codex_protocol::config_types::ModelProviderAuthInfo;
107109
pub use model_provider_info::DEFAULT_LMSTUDIO_PORT;
108110
pub use model_provider_info::DEFAULT_OLLAMA_PORT;
109111
pub use model_provider_info::LMSTUDIO_OSS_PROVIDER_ID;

codex-rs/core/src/model_provider_info.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::auth::AuthMode;
99
use crate::error::EnvVarError;
1010
use codex_api::Provider as ApiProvider;
1111
use codex_api::provider::RetryConfig as ApiRetryConfig;
12+
use codex_protocol::config_types::ModelProviderAuthInfo;
1213
use http::HeaderMap;
1314
use http::header::HeaderName;
1415
use http::header::HeaderValue;
@@ -86,6 +87,9 @@ pub struct ModelProviderInfo {
8687
/// this may be necessary when using this programmatically.
8788
pub experimental_bearer_token: Option<String>,
8889

90+
/// Command-backed bearer-token configuration for this provider.
91+
pub auth: Option<ModelProviderAuthInfo>,
92+
8993
/// Which wire protocol this provider expects.
9094
#[serde(default)]
9195
pub wire_api: WireApi,
@@ -130,6 +134,36 @@ pub struct ModelProviderInfo {
130134
}
131135

132136
impl ModelProviderInfo {
137+
pub(crate) fn validate(&self) -> std::result::Result<(), String> {
138+
let Some(auth) = self.auth.as_ref() else {
139+
return Ok(());
140+
};
141+
142+
if auth.command.trim().is_empty() {
143+
return Err("provider auth.command must not be empty".to_string());
144+
}
145+
146+
let mut conflicts = Vec::new();
147+
if self.env_key.is_some() {
148+
conflicts.push("env_key");
149+
}
150+
if self.experimental_bearer_token.is_some() {
151+
conflicts.push("experimental_bearer_token");
152+
}
153+
if self.requires_openai_auth {
154+
conflicts.push("requires_openai_auth");
155+
}
156+
157+
if conflicts.is_empty() {
158+
Ok(())
159+
} else {
160+
Err(format!(
161+
"provider auth cannot be combined with {}",
162+
conflicts.join(", ")
163+
))
164+
}
165+
}
166+
133167
fn build_header_map(&self) -> crate::error::Result<HeaderMap> {
134168
let capacity = self.http_headers.as_ref().map_or(0, HashMap::len)
135169
+ self.env_http_headers.as_ref().map_or(0, HashMap::len);
@@ -246,6 +280,7 @@ impl ModelProviderInfo {
246280
env_key: None,
247281
env_key_instructions: None,
248282
experimental_bearer_token: None,
283+
auth: None,
249284
wire_api: WireApi::Responses,
250285
query_params: None,
251286
http_headers: Some(
@@ -277,6 +312,10 @@ impl ModelProviderInfo {
277312
pub fn is_openai(&self) -> bool {
278313
self.name == OPENAI_PROVIDER_NAME
279314
}
315+
316+
pub(crate) fn has_command_auth(&self) -> bool {
317+
self.auth.is_some()
318+
}
280319
}
281320

282321
pub const DEFAULT_LMSTUDIO_PORT: u16 = 1234;
@@ -338,6 +377,7 @@ pub fn create_oss_provider_with_base_url(base_url: &str, wire_api: WireApi) -> M
338377
env_key: None,
339378
env_key_instructions: None,
340379
experimental_bearer_token: None,
380+
auth: None,
341381
wire_api,
342382
query_params: None,
343383
http_headers: None,

codex-rs/core/src/model_provider_info_tests.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
use super::*;
2+
use codex_utils_absolute_path::AbsolutePathBuf;
3+
use codex_utils_absolute_path::AbsolutePathBufGuard;
24
use pretty_assertions::assert_eq;
5+
use std::num::NonZeroU64;
6+
use tempfile::tempdir;
37

48
#[test]
59
fn test_deserialize_ollama_model_provider_toml() {
@@ -13,6 +17,7 @@ base_url = "http://localhost:11434/v1"
1317
env_key: None,
1418
env_key_instructions: None,
1519
experimental_bearer_token: None,
20+
auth: None,
1621
wire_api: WireApi::Responses,
1722
query_params: None,
1823
http_headers: None,
@@ -43,6 +48,7 @@ query_params = { api-version = "2025-04-01-preview" }
4348
env_key: Some("AZURE_OPENAI_API_KEY".into()),
4449
env_key_instructions: None,
4550
experimental_bearer_token: None,
51+
auth: None,
4652
wire_api: WireApi::Responses,
4753
query_params: Some(maplit::hashmap! {
4854
"api-version".to_string() => "2025-04-01-preview".to_string(),
@@ -76,6 +82,7 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
7682
env_key: Some("API_KEY".into()),
7783
env_key_instructions: None,
7884
experimental_bearer_token: None,
85+
auth: None,
7986
wire_api: WireApi::Responses,
8087
query_params: None,
8188
http_headers: Some(maplit::hashmap! {
@@ -121,3 +128,31 @@ supports_websockets = true
121128
let provider: ModelProviderInfo = toml::from_str(provider_toml).unwrap();
122129
assert_eq!(provider.websocket_connect_timeout_ms, Some(15_000));
123130
}
131+
132+
#[test]
133+
fn test_deserialize_provider_auth_config_defaults() {
134+
let base_dir = tempdir().unwrap();
135+
let provider_toml = r#"
136+
name = "Corp"
137+
138+
[auth]
139+
command = "./scripts/print-token"
140+
args = ["--format=text"]
141+
"#;
142+
143+
let provider: ModelProviderInfo = {
144+
let _guard = AbsolutePathBufGuard::new(base_dir.path());
145+
toml::from_str(provider_toml).unwrap()
146+
};
147+
148+
assert_eq!(
149+
provider.auth,
150+
Some(ModelProviderAuthInfo {
151+
command: "./scripts/print-token".to_string(),
152+
args: vec!["--format=text".to_string()],
153+
timeout_ms: NonZeroU64::new(5_000).unwrap(),
154+
refresh_interval_ms: NonZeroU64::new(300_000).unwrap(),
155+
cwd: AbsolutePathBuf::resolve_path_against_base(".", base_dir.path()).unwrap(),
156+
})
157+
);
158+
}

codex-rs/core/src/models_manager/manager.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::model_provider_info::ModelProviderInfo;
1414
use crate::models_manager::collaboration_mode_presets::CollaborationModesConfig;
1515
use crate::models_manager::collaboration_mode_presets::builtin_collaboration_mode_presets;
1616
use crate::models_manager::model_info;
17+
use crate::provider_auth::required_auth_manager_for_provider;
1718
use crate::response_debug_context::extract_response_debug_context;
1819
use crate::response_debug_context::telemetry_transport_error_message;
1920
use crate::util::FeedbackRequestTags;
@@ -212,6 +213,7 @@ impl ModelsManager {
212213
collaboration_modes_config: CollaborationModesConfig,
213214
provider: ModelProviderInfo,
214215
) -> Self {
216+
let auth_manager = required_auth_manager_for_provider(auth_manager, &provider);
215217
let cache_path = codex_home.join(MODEL_CACHE_FILE);
216218
let cache_manager = ModelsCacheManager::new(cache_path, DEFAULT_MODEL_CACHE_TTL);
217219
let catalog_mode = if model_catalog.is_some() {
@@ -396,7 +398,9 @@ impl ModelsManager {
396398
return Ok(());
397399
}
398400

399-
if self.auth_manager.auth_mode() != Some(AuthMode::Chatgpt) {
401+
if self.auth_manager.auth_mode() != Some(AuthMode::Chatgpt)
402+
&& !self.provider.has_command_auth()
403+
{
400404
if matches!(
401405
refresh_strategy,
402406
RefreshStrategy::Offline | RefreshStrategy::OnlineIfUncached

0 commit comments

Comments
 (0)