Skip to content

Commit 4b0f484

Browse files
authored
feat: add reasoning support (#63)
1 parent 7af3a85 commit 4b0f484

File tree

10 files changed

+659
-11
lines changed

10 files changed

+659
-11
lines changed

src/models/chat.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,76 @@ use super::tool_choice::ToolChoice;
1212
use super::tool_definition::ToolDefinition;
1313
use super::usage::Usage;
1414

15+
#[derive(Debug, Serialize, Deserialize, Clone, ToSchema)]
16+
pub struct ReasoningConfig {
17+
#[serde(skip_serializing_if = "Option::is_none")]
18+
pub effort: Option<String>, // "low" | "medium" | "high"
19+
#[serde(skip_serializing_if = "Option::is_none")]
20+
pub max_tokens: Option<u32>, // Alternative to effort
21+
#[serde(skip_serializing_if = "Option::is_none")]
22+
pub exclude: Option<bool>, // Whether to exclude from response (default: false)
23+
}
24+
25+
impl ReasoningConfig {
26+
pub fn validate(&self) -> Result<(), String> {
27+
if self.effort.is_some() && self.max_tokens.is_some() {
28+
tracing::warn!("Both effort and max_tokens specified - prioritizing max_tokens");
29+
}
30+
31+
// Only validate effort if max_tokens is not present (since max_tokens takes priority)
32+
if let Some(effort) = &self.effort {
33+
if effort.trim().is_empty() {
34+
return Err("Effort cannot be empty string".to_string());
35+
} else if self.max_tokens.is_none()
36+
&& !["low", "medium", "high"].contains(&effort.as_str())
37+
{
38+
return Err("Invalid effort value. Must be 'low', 'medium', or 'high'".to_string());
39+
}
40+
}
41+
42+
Ok(())
43+
}
44+
45+
// For OpenAI/Azure - Direct passthrough (but prioritize max_tokens over effort)
46+
pub fn to_openai_effort(&self) -> Option<String> {
47+
if self.max_tokens.is_some() {
48+
// If max_tokens is specified, don't use effort for OpenAI
49+
None
50+
} else {
51+
// Only return effort if it's not empty
52+
self.effort
53+
.as_ref()
54+
.filter(|e| !e.trim().is_empty())
55+
.cloned()
56+
}
57+
}
58+
59+
// For Vertex AI (Gemini) - Use max_tokens directly
60+
pub fn to_gemini_thinking_budget(&self) -> Option<i32> {
61+
self.max_tokens.map(|tokens| tokens as i32)
62+
}
63+
64+
// For Anthropic/Bedrock - Custom prompt generation (prioritize max_tokens over effort)
65+
pub fn to_thinking_prompt(&self) -> Option<String> {
66+
if self.max_tokens.is_some() {
67+
// If max_tokens is specified, use a generic thinking prompt
68+
Some("Think through this step-by-step with detailed reasoning.".to_string())
69+
} else {
70+
match self.effort.as_deref() {
71+
Some(effort) if !effort.trim().is_empty() => match effort {
72+
"high" => {
73+
Some("Think through this step-by-step with detailed reasoning.".to_string())
74+
}
75+
"medium" => Some("Consider this problem thoughtfully.".to_string()),
76+
"low" => Some("Think about this briefly.".to_string()),
77+
_ => None,
78+
},
79+
_ => None,
80+
}
81+
}
82+
}
83+
}
84+
1585
#[derive(Deserialize, Serialize, Clone, ToSchema)]
1686
pub struct ChatCompletionRequest {
1787
pub model: String,
@@ -50,6 +120,8 @@ pub struct ChatCompletionRequest {
50120
pub top_logprobs: Option<u32>,
51121
#[serde(skip_serializing_if = "Option::is_none")]
52122
pub response_format: Option<ResponseFormat>,
123+
#[serde(skip_serializing_if = "Option::is_none")]
124+
pub reasoning: Option<ReasoningConfig>,
53125
}
54126

55127
// Note: ChatCompletionResponse cannot derive ToSchema due to BoxStream

src/models/streaming.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ pub struct ChoiceDelta {
2121
pub role: Option<String>,
2222
#[serde(skip_serializing_if = "Option::is_none")]
2323
pub tool_calls: Option<Vec<ChatMessageToolCall>>,
24+
#[serde(skip_serializing_if = "Option::is_none")]
25+
pub reasoning: Option<String>,
2426
}
2527

2628
#[derive(Deserialize, Serialize, Clone, Debug, ToSchema)]

src/providers/anthropic/models.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl From<ChatCompletionRequest> for AnthropicChatCompletionRequest {
9090
))
9191
);
9292

93-
let system = request
93+
let mut system = request
9494
.messages
9595
.iter()
9696
.find(|msg| msg.role == "system")
@@ -103,6 +103,16 @@ impl From<ChatCompletionRequest> for AnthropicChatCompletionRequest {
103103
_ => None,
104104
});
105105

106+
// Add reasoning prompt if reasoning is requested
107+
if let Some(reasoning_config) = &request.reasoning {
108+
if let Some(thinking_prompt) = reasoning_config.to_thinking_prompt() {
109+
system = Some(match system {
110+
Some(existing) => format!("{}\n\n{}", existing, thinking_prompt),
111+
None => thinking_prompt,
112+
});
113+
}
114+
}
115+
106116
let messages: Vec<ChatCompletionMessage> = request
107117
.messages
108118
.into_iter()

src/providers/anthropic/provider.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use async_trait::async_trait;
22
use axum::http::StatusCode;
33
use reqwest::Client;
4+
use tracing::info;
45

56
use super::models::{AnthropicChatCompletionRequest, AnthropicChatCompletionResponse};
67
use crate::config::models::{ModelConfig, Provider as ProviderConfig};
@@ -38,6 +39,33 @@ impl Provider for AnthropicProvider {
3839
payload: ChatCompletionRequest,
3940
_model_config: &ModelConfig,
4041
) -> Result<ChatCompletionResponse, StatusCode> {
42+
// Validate reasoning config if present
43+
if let Some(reasoning) = &payload.reasoning {
44+
if let Err(e) = reasoning.validate() {
45+
tracing::error!("Invalid reasoning config: {}", e);
46+
return Err(StatusCode::BAD_REQUEST);
47+
}
48+
49+
if let Some(max_tokens) = reasoning.max_tokens {
50+
info!(
51+
"✅ Anthropic reasoning enabled with max_tokens: {}",
52+
max_tokens
53+
);
54+
} else if let Some(thinking_prompt) = reasoning.to_thinking_prompt() {
55+
info!(
56+
"✅ Anthropic reasoning enabled with effort level: {:?} -> prompt: \"{}\"",
57+
reasoning.effort,
58+
thinking_prompt.chars().take(50).collect::<String>() + "..."
59+
);
60+
} else {
61+
tracing::debug!(
62+
"ℹ️ Anthropic reasoning config present but no valid parameters (effort: {:?}, max_tokens: {:?})",
63+
reasoning.effort,
64+
reasoning.max_tokens
65+
);
66+
}
67+
}
68+
4169
let request = AnthropicChatCompletionRequest::from(payload);
4270
let response = self
4371
.http_client

src/providers/azure/provider.rs

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use async_trait::async_trait;
22
use axum::http::StatusCode;
33
use reqwest_streams::JsonStreamResponse;
4+
use serde::{Deserialize, Serialize};
45

56
use crate::config::constants::stream_buffer_size_bytes;
67
use crate::config::models::{ModelConfig, Provider as ProviderConfig};
@@ -12,6 +13,28 @@ use crate::providers::provider::Provider;
1213
use reqwest::Client;
1314
use tracing::info;
1415

16+
#[derive(Serialize, Deserialize, Clone)]
17+
struct AzureChatCompletionRequest {
18+
#[serde(flatten)]
19+
base: ChatCompletionRequest,
20+
#[serde(skip_serializing_if = "Option::is_none")]
21+
reasoning_effort: Option<String>,
22+
}
23+
24+
impl From<ChatCompletionRequest> for AzureChatCompletionRequest {
25+
fn from(mut base: ChatCompletionRequest) -> Self {
26+
let reasoning_effort = base.reasoning.as_ref().and_then(|r| r.to_openai_effort());
27+
28+
// Remove reasoning field from base request since Azure uses reasoning_effort
29+
base.reasoning = None;
30+
31+
Self {
32+
base,
33+
reasoning_effort,
34+
}
35+
}
36+
}
37+
1538
pub struct AzureProvider {
1639
config: ProviderConfig,
1740
http_client: Client,
@@ -55,6 +78,32 @@ impl Provider for AzureProvider {
5578
payload: ChatCompletionRequest,
5679
model_config: &ModelConfig,
5780
) -> Result<ChatCompletionResponse, StatusCode> {
81+
// Validate reasoning config if present
82+
if let Some(reasoning) = &payload.reasoning {
83+
if let Err(e) = reasoning.validate() {
84+
tracing::error!("Invalid reasoning config: {}", e);
85+
return Err(StatusCode::BAD_REQUEST);
86+
}
87+
88+
if let Some(max_tokens) = reasoning.max_tokens {
89+
info!(
90+
"✅ Azure reasoning with max_tokens: {} (note: Azure uses effort levels, max_tokens ignored)",
91+
max_tokens
92+
);
93+
} else if let Some(effort) = reasoning.to_openai_effort() {
94+
info!(
95+
"✅ Azure reasoning enabled with effort level: \"{}\"",
96+
effort
97+
);
98+
} else {
99+
tracing::debug!(
100+
"ℹ️ Azure reasoning config present but no valid parameters (effort: {:?}, max_tokens: {:?})",
101+
reasoning.effort,
102+
reasoning.max_tokens
103+
);
104+
}
105+
}
106+
58107
let deployment = model_config.params.get("deployment").unwrap();
59108
let api_version = self.api_version();
60109
let url = format!(
@@ -64,11 +113,14 @@ impl Provider for AzureProvider {
64113
api_version
65114
);
66115

116+
// Convert to Azure-specific request format
117+
let azure_request = AzureChatCompletionRequest::from(payload.clone());
118+
67119
let response = self
68120
.http_client
69121
.post(&url)
70122
.header("api-key", &self.config.api_key)
71-
.json(&payload)
123+
.json(&azure_request)
72124
.send()
73125
.await
74126
.map_err(|e| {

0 commit comments

Comments
 (0)