From 42aebf8130e706703cc23e8627bd8334dc429da1 Mon Sep 17 00:00:00 2001 From: lucas Date: Sat, 28 Jun 2025 21:57:17 +0800 Subject: [PATCH] enhanced-deepseek cot Signed-off-by: lucas --- .../pom.xml | 6 + models/spring-ai-deepseek/pom.xml | 6 + .../ai/deepseek/DeepSeekAssistantMessage.java | 6 +- .../DeepSeekChatClientMessageAggregator.java | 28 ++++ .../ai/deepseek/DeepSeekChatModel.java | 17 +- .../deepseek/DeepSeekMessageAggregator.java | 143 ++++++++++++++++ .../DeepSeekMessageChatMemoryAdvisor.java | 158 ++++++++++++++++++ .../advisor/DeepSeekSimpleLogAdvisor.java | 127 ++++++++++++++ 8 files changed, 480 insertions(+), 11 deletions(-) create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatClientMessageAggregator.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekMessageAggregator.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/advisor/DeepSeekMessageChatMemoryAdvisor.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/advisor/DeepSeekSimpleLogAdvisor.java diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-deepseek/pom.xml b/auto-configurations/models/spring-ai-autoconfigure-model-deepseek/pom.xml index 2f36a8c976e..77f1cdbb6bd 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-deepseek/pom.xml +++ b/auto-configurations/models/spring-ai-autoconfigure-model-deepseek/pom.xml @@ -72,6 +72,12 @@ true + + org.springframework.ai + spring-ai-autoconfigure-model-chat-client + ${project.parent.version} + + org.springframework.ai diff --git a/models/spring-ai-deepseek/pom.xml b/models/spring-ai-deepseek/pom.xml index 0f4c2a68a48..32f861ade08 100644 --- a/models/spring-ai-deepseek/pom.xml +++ b/models/spring-ai-deepseek/pom.xml @@ -50,6 +50,12 @@ slf4j-api + + org.springframework.ai + spring-ai-client-chat + ${project.parent.version} + + org.springframework.ai diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekAssistantMessage.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekAssistantMessage.java index 6159d9beadb..98e0bf3131b 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekAssistantMessage.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekAssistantMessage.java @@ -19,7 +19,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; - import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.content.Media; @@ -38,6 +37,11 @@ public DeepSeekAssistantMessage(String content, String reasoningContent) { this.reasoningContent = reasoningContent; } + public DeepSeekAssistantMessage(String content, String reasoningContent, Map properties) { + super(content, properties); + this.reasoningContent = reasoningContent; + } + public DeepSeekAssistantMessage(String content, Map properties) { super(content, properties); } diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatClientMessageAggregator.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatClientMessageAggregator.java new file mode 100644 index 00000000000..0b277431ab4 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatClientMessageAggregator.java @@ -0,0 +1,28 @@ +package org.springframework.ai.deepseek; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import org.springframework.ai.chat.client.ChatClientResponse; +import reactor.core.publisher.Flux; + +public class DeepSeekChatClientMessageAggregator { + + public Flux aggregateChatClientResponse( + Flux chatClientResponses, + Consumer aggregationHandler) { + + AtomicReference> context = new AtomicReference<>(new HashMap<>()); + + return new DeepSeekMessageAggregator().aggregate(chatClientResponses.mapNotNull(chatClientResponse -> { + context.get().putAll(chatClientResponse.context()); + return chatClientResponse.chatResponse(); + }), aggregatedChatResponse -> { + ChatClientResponse aggregatedChatClientResponse = ChatClientResponse.builder() + .chatResponse(aggregatedChatResponse).context(context.get()).build(); + aggregationHandler.accept(aggregatedChatClientResponse); + }).map(chatResponse -> ChatClientResponse.builder().chatResponse(chatResponse) + .context(context.get()).build()); + } +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java index 4b7607c6e38..282290dd8d5 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java @@ -16,19 +16,14 @@ package org.springframework.ai.deepseek; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; - import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -40,7 +35,6 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; @@ -69,6 +63,9 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal DeepSeek} @@ -312,7 +309,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on - return new MessageAggregator().aggregate(flux, observationContext::setResponse); + return new DeepSeekMessageAggregator().aggregate(flux, observationContext::setResponse); }); } diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekMessageAggregator.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekMessageAggregator.java new file mode 100644 index 00000000000..693ea5a6c28 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekMessageAggregator.java @@ -0,0 +1,143 @@ +package org.springframework.ai.deepseek; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.EmptyRateLimit; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.MessageAggregator; +import org.springframework.util.StringUtils; +import reactor.core.publisher.Flux; + +/** + * deepseek消息聚合器 + * lucas + */ +public class DeepSeekMessageAggregator extends MessageAggregator { + private static final Logger logger = LoggerFactory.getLogger(DeepSeekMessageAggregator.class); + @Override + public Flux aggregate(Flux fluxChatResponse, + Consumer onAggregationComplete) { + + // Assistant Message + AtomicReference messageTextContentRef = new AtomicReference<>( + new StringBuilder()); + // Reasoning Message + AtomicReference reasoningContentRef = new AtomicReference<>( + new StringBuilder()); + AtomicReference> messageMetadataMapRef = new AtomicReference<>(); + + // ChatGeneration Metadata + AtomicReference generationMetadataRef = new AtomicReference<>( + ChatGenerationMetadata.NULL); + + // Usage + AtomicReference metadataUsagePromptTokensRef = new AtomicReference(0); + AtomicReference metadataUsageGenerationTokensRef = new AtomicReference(0); + AtomicReference metadataUsageTotalTokensRef = new AtomicReference(0); + + AtomicReference metadataPromptMetadataRef = new AtomicReference<>( + PromptMetadata.empty()); + AtomicReference metadataRateLimitRef = new AtomicReference<>(new EmptyRateLimit()); + + AtomicReference metadataIdRef = new AtomicReference<>(""); + AtomicReference metadataModelRef = new AtomicReference<>(""); + + return fluxChatResponse.doOnSubscribe(subscription -> { + messageTextContentRef.set(new StringBuilder()); + reasoningContentRef.set(new StringBuilder()); + messageMetadataMapRef.set(new HashMap<>()); + metadataIdRef.set(""); + metadataModelRef.set(""); + metadataUsagePromptTokensRef.set(0); + metadataUsageGenerationTokensRef.set(0); + metadataUsageTotalTokensRef.set(0); + metadataPromptMetadataRef.set(PromptMetadata.empty()); + metadataRateLimitRef.set(new EmptyRateLimit()); + + }).doOnNext(chatResponse -> { + + if (chatResponse.getResult() != null) { + if (chatResponse.getResult().getMetadata() != null + && chatResponse.getResult().getMetadata() != ChatGenerationMetadata.NULL) { + generationMetadataRef.set(chatResponse.getResult().getMetadata()); + } + if (chatResponse.getResult().getOutput().getText() != null) { + messageTextContentRef.get().append(chatResponse.getResult().getOutput().getText()); + } + if (chatResponse.getResult() + .getOutput() instanceof DeepSeekAssistantMessage deepSeekAssistantMessage) { + reasoningContentRef.get().append(deepSeekAssistantMessage.getReasoningContent()); + } + messageMetadataMapRef.get().putAll(chatResponse.getResult().getOutput().getMetadata()); + } + if (chatResponse.getMetadata() != null) { + if (chatResponse.getMetadata().getUsage() != null) { + Usage usage = chatResponse.getMetadata().getUsage(); + metadataUsagePromptTokensRef.set( + usage.getPromptTokens() > 0 ? usage.getPromptTokens() + : metadataUsagePromptTokensRef.get()); + metadataUsageGenerationTokensRef.set( + usage.getCompletionTokens() > 0 ? usage.getCompletionTokens() + : metadataUsageGenerationTokensRef.get()); + metadataUsageTotalTokensRef + .set(usage.getTotalTokens() > 0 ? usage.getTotalTokens() + : metadataUsageTotalTokensRef.get()); + } + if (chatResponse.getMetadata().getPromptMetadata() != null + && chatResponse.getMetadata().getPromptMetadata().iterator().hasNext()) { + metadataPromptMetadataRef.set(chatResponse.getMetadata().getPromptMetadata()); + } + if (chatResponse.getMetadata().getRateLimit() != null + && !(metadataRateLimitRef.get() instanceof EmptyRateLimit)) { + metadataRateLimitRef.set(chatResponse.getMetadata().getRateLimit()); + } + if (StringUtils.hasText(chatResponse.getMetadata().getId())) { + metadataIdRef.set(chatResponse.getMetadata().getId()); + } + if (StringUtils.hasText(chatResponse.getMetadata().getModel())) { + metadataModelRef.set(chatResponse.getMetadata().getModel()); + } + } + }).doOnComplete(() -> { + + var usage = new DefaultUsage(metadataUsagePromptTokensRef.get(), + metadataUsageGenerationTokensRef.get(), + metadataUsageTotalTokensRef.get()); + + var chatResponseMetadata = ChatResponseMetadata.builder() + .id(metadataIdRef.get()) + .model(metadataModelRef.get()) + .rateLimit(metadataRateLimitRef.get()) + .usage(usage) + .promptMetadata(metadataPromptMetadataRef.get()) + .build(); + onAggregationComplete.accept(new ChatResponse(List.of(new Generation( + new DeepSeekAssistantMessage(messageTextContentRef.get().toString(), + reasoningContentRef.get().toString(), messageMetadataMapRef.get()), + generationMetadataRef.get())), chatResponseMetadata)); + + messageTextContentRef.set(new StringBuilder()); + reasoningContentRef.set(new StringBuilder()); + messageMetadataMapRef.set(new HashMap<>()); + metadataIdRef.set(""); + metadataModelRef.set(""); + metadataUsagePromptTokensRef.set(0); + metadataUsageGenerationTokensRef.set(0); + metadataUsageTotalTokensRef.set(0); + metadataPromptMetadataRef.set(PromptMetadata.empty()); + metadataRateLimitRef.set(new EmptyRateLimit()); + + }).doOnError(e -> logger.error("Aggregation Error", e)); + } +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/advisor/DeepSeekMessageChatMemoryAdvisor.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/advisor/DeepSeekMessageChatMemoryAdvisor.java new file mode 100644 index 00000000000..5859a23236e --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/advisor/DeepSeekMessageChatMemoryAdvisor.java @@ -0,0 +1,158 @@ +package org.springframework.ai.deepseek.advisor; + +import java.util.ArrayList; +import java.util.List; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.Advisor; +import org.springframework.ai.chat.client.advisor.api.AdvisorChain; +import org.springframework.ai.chat.client.advisor.api.BaseAdvisor; +import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.deepseek.DeepSeekChatClientMessageAggregator; +import org.springframework.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; + +public class DeepSeekMessageChatMemoryAdvisor implements BaseChatMemoryAdvisor { + + private final ChatMemory chatMemory; + + private final String defaultConversationId; + + private final int order; + + private final Scheduler scheduler; + + private DeepSeekMessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int order, + Scheduler scheduler) { + Assert.notNull(chatMemory, "chatMemory cannot be null"); + Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty"); + Assert.notNull(scheduler, "scheduler cannot be null"); + this.chatMemory = chatMemory; + this.defaultConversationId = defaultConversationId; + this.order = order; + this.scheduler = scheduler; + } + + @Override + public int getOrder() { + return this.order; + } + + @Override + public Scheduler getScheduler() { + return this.scheduler; + } + + @Override + public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) { + String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId); + + // 1. Retrieve the chat memory for the current conversation. + List memoryMessages = this.chatMemory.get(conversationId); + + // 2. Advise the request messages list. + List processedMessages = new ArrayList<>(memoryMessages); + processedMessages.addAll(chatClientRequest.prompt().getInstructions()); + + // 3. Create a new request with the advised messages. + ChatClientRequest processedChatClientRequest = chatClientRequest.mutate() + .prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build()) + .build(); + + // 4. Add the new user message to the conversation memory. + UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage(); + this.chatMemory.add(conversationId, userMessage); + + return processedChatClientRequest; + } + + @Override + public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) { + List assistantMessages = new ArrayList<>(); + if (chatClientResponse.chatResponse() != null) { + assistantMessages = chatClientResponse.chatResponse() + .getResults() + .stream() + .map(g -> (Message) g.getOutput()) + .toList(); + } + this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId), + assistantMessages); + return chatClientResponse; + } + + @Override + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { + // Get the scheduler from BaseAdvisor + Scheduler scheduler = this.getScheduler(); + + // Process the request with the before method + return Mono.just(chatClientRequest) + .publishOn(scheduler) + .map(request -> this.before(request, streamAdvisorChain)) + .flatMapMany(streamAdvisorChain::nextStream) + .transform(flux -> new DeepSeekChatClientMessageAggregator().aggregateChatClientResponse(flux, + response -> this.after(response, streamAdvisorChain))); + } + + public static Builder builder(ChatMemory chatMemory) { + return new Builder(chatMemory); + } + + public static final class Builder { + + private String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID; + + private int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER; + + private Scheduler scheduler = BaseAdvisor.DEFAULT_SCHEDULER; + + private ChatMemory chatMemory; + + private Builder(ChatMemory chatMemory) { + this.chatMemory = chatMemory; + } + + /** + * Set the conversation id. + * @param conversationId the conversation id + * @return the builder + */ + public Builder conversationId(String conversationId) { + this.conversationId = conversationId; + return this; + } + + /** + * Set the order. + * @param order the order + * @return the builder + */ + public Builder order(int order) { + this.order = order; + return this; + } + + public Builder scheduler(Scheduler scheduler) { + this.scheduler = scheduler; + return this; + } + + /** + * Build the advisor. + * @return the advisor + */ + public DeepSeekMessageChatMemoryAdvisor build() { + return new DeepSeekMessageChatMemoryAdvisor(this.chatMemory, this.conversationId, this.order, this.scheduler); + } + + } + +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/advisor/DeepSeekSimpleLogAdvisor.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/advisor/DeepSeekSimpleLogAdvisor.java new file mode 100644 index 00000000000..d33e8aeff62 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/advisor/DeepSeekSimpleLogAdvisor.java @@ -0,0 +1,127 @@ +package org.springframework.ai.deepseek.advisor; + +import java.util.function.Function; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.CallAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.deepseek.DeepSeekChatClientMessageAggregator; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.lang.Nullable; +import reactor.core.publisher.Flux; + +public class DeepSeekSimpleLogAdvisor implements CallAdvisor, StreamAdvisor { + + public static final Function DEFAULT_REQUEST_TO_STRING = ChatClientRequest::toString; + + public static final Function DEFAULT_RESPONSE_TO_STRING = ModelOptionsUtils::toJsonStringPrettyPrinter; + + private static final Logger logger = LoggerFactory.getLogger(DeepSeekSimpleLogAdvisor.class); + + private final Function requestToString; + + private final Function responseToString; + + private final int order; + + public DeepSeekSimpleLogAdvisor() { + this(DEFAULT_REQUEST_TO_STRING, DEFAULT_RESPONSE_TO_STRING, 0); + } + + public DeepSeekSimpleLogAdvisor(int order) { + this(DEFAULT_REQUEST_TO_STRING, DEFAULT_RESPONSE_TO_STRING, order); + } + + public DeepSeekSimpleLogAdvisor(@Nullable Function requestToString, + @Nullable Function responseToString, int order) { + this.requestToString = requestToString != null ? requestToString : DEFAULT_REQUEST_TO_STRING; + this.responseToString = responseToString != null ? responseToString : DEFAULT_RESPONSE_TO_STRING; + this.order = order; + } + + @Override + public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { + logRequest(chatClientRequest); + + ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest); + + logResponse(chatClientResponse); + + return chatClientResponse; + } + + @Override + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { + logRequest(chatClientRequest); + + Flux chatClientResponses = streamAdvisorChain.nextStream(chatClientRequest); + + return new DeepSeekChatClientMessageAggregator().aggregateChatClientResponse(chatClientResponses, this::logResponse); + } + + private void logRequest(ChatClientRequest request) { + logger.debug("request: {}", this.requestToString.apply(request)); + } + + private void logResponse(ChatClientResponse chatClientResponse) { + logger.debug("response: {}", this.responseToString.apply(chatClientResponse.chatResponse())); + } + + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public int getOrder() { + return this.order; + } + + @Override + public String toString() { + return DeepSeekSimpleLogAdvisor.class.getSimpleName(); + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private Function requestToString; + + private Function responseToString; + + private int order = 0; + + private Builder() { + } + + public Builder requestToString(Function requestToString) { + this.requestToString = requestToString; + return this; + } + + public Builder responseToString(Function responseToString) { + this.responseToString = responseToString; + return this; + } + + public Builder order(int order) { + this.order = order; + return this; + } + + public DeepSeekSimpleLogAdvisor build() { + return new DeepSeekSimpleLogAdvisor(this.requestToString, this.responseToString, this.order); + } + + } + +}