Skip to content

Add tool support for DeepSeek #30223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 195 additions & 59 deletions crates/language_models/src/provider/deepseek.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap;
use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{
AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle,
Expand All @@ -11,11 +12,14 @@ use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, prelude::*};
Expand All @@ -27,6 +31,13 @@ const PROVIDER_ID: &str = "deepseek";
const PROVIDER_NAME: &str = "DeepSeek";
const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";

#[derive(Default)]
struct RawToolCall {
id: String,
name: String,
arguments: String,
}

#[derive(Default, Clone, Debug, PartialEq)]
pub struct DeepSeekSettings {
pub api_url: String,
Expand Down Expand Up @@ -279,7 +290,7 @@ impl LanguageModel for DeepSeekLanguageModel {
}

fn supports_tools(&self) -> bool {
false
true
}

fn telemetry_id(&self) -> String {
Expand Down Expand Up @@ -338,27 +349,8 @@ impl LanguageModel for DeepSeekLanguageModel {
let stream = self.stream_completion(request, cx);

async move {
let stream = stream.await?;
Ok(stream
.map(|result| {
result
.and_then(|response| {
response
.choices
.first()
.ok_or_else(|| anyhow!("Empty response"))
.map(|choice| {
choice
.delta
.content
.clone()
.unwrap_or_default()
.map(LanguageModelCompletionEvent::Text)
})
})
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
let mapper = DeepSeekEventMapper::new();
Ok(mapper.map_stream(stream.await?).boxed())
}
.boxed()
}
Expand All @@ -371,47 +363,44 @@ pub fn into_deepseek(
) -> deepseek::Request {
let is_reasoner = model == "deepseek-reasoner";

let len = request.messages.len();
let merged_messages =
let messages = if is_reasoner {
let len = request.messages.len();
request
.messages
.into_iter()
.fold(Vec::with_capacity(len), |mut acc, msg| {
let role = msg.role;
let content = msg.string_contents();

if is_reasoner {
if let Some(last_msg) = acc.last_mut() {
match (last_msg, role) {
(deepseek::RequestMessage::User { content: last }, Role::User) => {
last.push(' ');
last.push_str(&content);
return acc;
}

(
deepseek::RequestMessage::Assistant {
content: last_content,
..
},
Role::Assistant,
) => {
*last_content = last_content
.take()
.map(|c| {
let mut s =
String::with_capacity(c.len() + content.len() + 1);
s.push_str(&c);
s.push(' ');
s.push_str(&content);
s
})
.or(Some(content));

return acc;
}
_ => {}
if let Some(last_msg) = acc.last_mut() {
match (last_msg, role) {
(deepseek::RequestMessage::User { content: last }, Role::User) => {
last.push(' ');
last.push_str(&content);
return acc;
}

(
deepseek::RequestMessage::Assistant {
content: last_content,
..
},
Role::Assistant,
) => {
*last_content = last_content
.take()
.map(|c| {
let mut s = String::with_capacity(c.len() + content.len() + 1);
s.push_str(&c);
s.push(' ');
s.push_str(&content);
s
})
.or(Some(content));

return acc;
}
_ => {}
}
}

Expand All @@ -424,11 +413,61 @@ pub fn into_deepseek(
Role::System => deepseek::RequestMessage::System { content },
});
acc
});
})
} else {
let mut messages = Vec::new();
for message in request.messages {
for content in message.content {
match content {
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
.push(match message.role {
Role::User => deepseek::RequestMessage::User { content: text },
Role::Assistant => deepseek::RequestMessage::Assistant {
content: Some(text),
tool_calls: Vec::new(),
},
Role::System => deepseek::RequestMessage::System { content: text },
}),
MessageContent::RedactedThinking(_) => {}
MessageContent::Image(_) => {}
MessageContent::ToolUse(tool_use) => {
let tool_call = deepseek::ToolCall {
id: tool_use.id.to_string(),
content: deepseek::ToolCallContent::Function {
function: deepseek::FunctionContent {
name: tool_use.name.to_string(),
arguments: serde_json::to_string(&tool_use.input)
.unwrap_or_default(),
},
},
};

if let Some(deepseek::RequestMessage::Assistant { tool_calls, .. }) =
messages.last_mut()
{
tool_calls.push(tool_call);
} else {
messages.push(deepseek::RequestMessage::Assistant {
content: None,
tool_calls: vec![tool_call],
});
}
}
MessageContent::ToolResult(tool_result) => {
messages.push(deepseek::RequestMessage::Tool {
content: tool_result.content.to_string(),
tool_call_id: tool_result.tool_use_id.to_string(),
});
}
}
}
}
messages
};

deepseek::Request {
model,
messages: merged_messages,
messages,
stream: true,
max_tokens: max_output_tokens,
temperature: if is_reasoner {
Expand All @@ -451,6 +490,103 @@ pub fn into_deepseek(
}
}

pub struct DeepSeekEventMapper {
tool_calls_by_index: HashMap<usize, RawToolCall>,
}

impl DeepSeekEventMapper {
pub fn new() -> Self {
Self {
tool_calls_by_index: HashMap::default(),
}
}

pub fn map_stream(
mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<deepseek::StreamResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
})
})
}

pub fn map_event(
&mut self,
event: deepseek::StreamResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
"Response contained no choices"
)))];
};

let mut events = Vec::new();
if let Some(content) = choice.delta.content.clone() {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}

if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
for tool_call in tool_calls {
let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();

if let Some(tool_id) = tool_call.id.clone() {
entry.id = tool_id;
}

if let Some(function) = tool_call.function.as_ref() {
if let Some(name) = function.name.clone() {
entry.name = name;
}

if let Some(arguments) = function.arguments.clone() {
entry.arguments.push_str(&arguments);
}
}
}
}

match choice.finish_reason.as_deref() {
Some("stop") => {
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
}
Some("tool_calls") => {
events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
match serde_json::Value::from_str(&tool_call.arguments) {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.clone().into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
input,
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
}),
}
}));

events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
}
Some(stop_reason) => {
log::error!("Unexpected DeepSeek stop_reason: {stop_reason:?}",);
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
}
None => {}
}

events
}
}

struct ConfigurationView {
api_key_editor: Entity<Editor>,
state: Entity<State>,
Expand Down
Loading