Skip to content

enhanced-deepseek cot #3703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@
<optional>true</optional>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-autoconfigure-model-chat-client</artifactId>
<version>${project.parent.version}</version>
</dependency>

<!-- Test dependencies -->
<dependency>
<groupId>org.springframework.ai</groupId>
Expand Down
6 changes: 6 additions & 0 deletions models/spring-ai-deepseek/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@
<artifactId>slf4j-api</artifactId>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-client-chat</artifactId>
<version>${project.parent.version}</version>
</dependency>

<!-- test dependencies -->
<dependency>
<groupId>org.springframework.ai</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.content.Media;

Expand All @@ -38,6 +37,11 @@ public DeepSeekAssistantMessage(String content, String reasoningContent) {
this.reasoningContent = reasoningContent;
}

public DeepSeekAssistantMessage(String content, String reasoningContent, Map<String, Object> properties) {
super(content, properties);
this.reasoningContent = reasoningContent;
}

public DeepSeekAssistantMessage(String content, Map<String, Object> properties) {
super(content, properties);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.springframework.ai.deepseek;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import org.springframework.ai.chat.client.ChatClientResponse;
import reactor.core.publisher.Flux;

public class DeepSeekChatClientMessageAggregator {

public Flux<ChatClientResponse> aggregateChatClientResponse(
Flux<ChatClientResponse> chatClientResponses,
Consumer<ChatClientResponse> aggregationHandler) {

AtomicReference<Map<String, Object>> context = new AtomicReference<>(new HashMap<>());

return new DeepSeekMessageAggregator().aggregate(chatClientResponses.mapNotNull(chatClientResponse -> {
context.get().putAll(chatClientResponse.context());
return chatClientResponse.chatResponse();
}), aggregatedChatResponse -> {
ChatClientResponse aggregatedChatClientResponse = ChatClientResponse.builder()
.chatResponse(aggregatedChatResponse).context(context.get()).build();
aggregationHandler.accept(aggregatedChatClientResponse);
}).map(chatResponse -> ChatClientResponse.builder().chatResponse(chatResponse)
.context(context.get()).build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,14 @@

package org.springframework.ai.deepseek;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
Expand All @@ -40,7 +35,6 @@
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
Expand Down Expand Up @@ -69,6 +63,9 @@
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

/**
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal DeepSeek}
Expand Down Expand Up @@ -312,7 +309,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on

return new MessageAggregator().aggregate(flux, observationContext::setResponse);
return new DeepSeekMessageAggregator().aggregate(flux, observationContext::setResponse);

});
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package org.springframework.ai.deepseek;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyRateLimit;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;

/**
* deepseek消息聚合器
* lucas
*/
public class DeepSeekMessageAggregator extends MessageAggregator {
private static final Logger logger = LoggerFactory.getLogger(DeepSeekMessageAggregator.class);
@Override
public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
Consumer<ChatResponse> onAggregationComplete) {

// Assistant Message
AtomicReference<StringBuilder> messageTextContentRef = new AtomicReference<>(
new StringBuilder());
// Reasoning Message
AtomicReference<StringBuilder> reasoningContentRef = new AtomicReference<>(
new StringBuilder());
AtomicReference<Map<String, Object>> messageMetadataMapRef = new AtomicReference<>();

// ChatGeneration Metadata
AtomicReference<ChatGenerationMetadata> generationMetadataRef = new AtomicReference<>(
ChatGenerationMetadata.NULL);

// Usage
AtomicReference<Integer> metadataUsagePromptTokensRef = new AtomicReference<Integer>(0);
AtomicReference<Integer> metadataUsageGenerationTokensRef = new AtomicReference<Integer>(0);
AtomicReference<Integer> metadataUsageTotalTokensRef = new AtomicReference<Integer>(0);

AtomicReference<PromptMetadata> metadataPromptMetadataRef = new AtomicReference<>(
PromptMetadata.empty());
AtomicReference<RateLimit> metadataRateLimitRef = new AtomicReference<>(new EmptyRateLimit());

AtomicReference<String> metadataIdRef = new AtomicReference<>("");
AtomicReference<String> metadataModelRef = new AtomicReference<>("");

return fluxChatResponse.doOnSubscribe(subscription -> {
messageTextContentRef.set(new StringBuilder());
reasoningContentRef.set(new StringBuilder());
messageMetadataMapRef.set(new HashMap<>());
metadataIdRef.set("");
metadataModelRef.set("");
metadataUsagePromptTokensRef.set(0);
metadataUsageGenerationTokensRef.set(0);
metadataUsageTotalTokensRef.set(0);
metadataPromptMetadataRef.set(PromptMetadata.empty());
metadataRateLimitRef.set(new EmptyRateLimit());

}).doOnNext(chatResponse -> {

if (chatResponse.getResult() != null) {
if (chatResponse.getResult().getMetadata() != null
&& chatResponse.getResult().getMetadata() != ChatGenerationMetadata.NULL) {
generationMetadataRef.set(chatResponse.getResult().getMetadata());
}
if (chatResponse.getResult().getOutput().getText() != null) {
messageTextContentRef.get().append(chatResponse.getResult().getOutput().getText());
}
if (chatResponse.getResult()
.getOutput() instanceof DeepSeekAssistantMessage deepSeekAssistantMessage) {
reasoningContentRef.get().append(deepSeekAssistantMessage.getReasoningContent());
}
messageMetadataMapRef.get().putAll(chatResponse.getResult().getOutput().getMetadata());
}
if (chatResponse.getMetadata() != null) {
if (chatResponse.getMetadata().getUsage() != null) {
Usage usage = chatResponse.getMetadata().getUsage();
metadataUsagePromptTokensRef.set(
usage.getPromptTokens() > 0 ? usage.getPromptTokens()
: metadataUsagePromptTokensRef.get());
metadataUsageGenerationTokensRef.set(
usage.getCompletionTokens() > 0 ? usage.getCompletionTokens()
: metadataUsageGenerationTokensRef.get());
metadataUsageTotalTokensRef
.set(usage.getTotalTokens() > 0 ? usage.getTotalTokens()
: metadataUsageTotalTokensRef.get());
}
if (chatResponse.getMetadata().getPromptMetadata() != null
&& chatResponse.getMetadata().getPromptMetadata().iterator().hasNext()) {
metadataPromptMetadataRef.set(chatResponse.getMetadata().getPromptMetadata());
}
if (chatResponse.getMetadata().getRateLimit() != null
&& !(metadataRateLimitRef.get() instanceof EmptyRateLimit)) {
metadataRateLimitRef.set(chatResponse.getMetadata().getRateLimit());
}
if (StringUtils.hasText(chatResponse.getMetadata().getId())) {
metadataIdRef.set(chatResponse.getMetadata().getId());
}
if (StringUtils.hasText(chatResponse.getMetadata().getModel())) {
metadataModelRef.set(chatResponse.getMetadata().getModel());
}
}
}).doOnComplete(() -> {

var usage = new DefaultUsage(metadataUsagePromptTokensRef.get(),
metadataUsageGenerationTokensRef.get(),
metadataUsageTotalTokensRef.get());

var chatResponseMetadata = ChatResponseMetadata.builder()
.id(metadataIdRef.get())
.model(metadataModelRef.get())
.rateLimit(metadataRateLimitRef.get())
.usage(usage)
.promptMetadata(metadataPromptMetadataRef.get())
.build();
onAggregationComplete.accept(new ChatResponse(List.of(new Generation(
new DeepSeekAssistantMessage(messageTextContentRef.get().toString(),
reasoningContentRef.get().toString(), messageMetadataMapRef.get()),
generationMetadataRef.get())), chatResponseMetadata));

messageTextContentRef.set(new StringBuilder());
reasoningContentRef.set(new StringBuilder());
messageMetadataMapRef.set(new HashMap<>());
metadataIdRef.set("");
metadataModelRef.set("");
metadataUsagePromptTokensRef.set(0);
metadataUsageGenerationTokensRef.set(0);
metadataUsageTotalTokensRef.set(0);
metadataPromptMetadataRef.set(PromptMetadata.empty());
metadataRateLimitRef.set(new EmptyRateLimit());

}).doOnError(e -> logger.error("Aggregation Error", e));
}
}
Loading