From 56151270e2133a1a2caf5a25d2f3e0e3ca6b0925 Mon Sep 17 00:00:00 2001 From: mck Date: Wed, 24 Apr 2024 14:31:25 +0200 Subject: [PATCH] Implement LLM call caching with a vector store Use a vector store on top of the existing agent_conversations table --- .../com/datastax/ai/agent/AiApplication.java | 27 ++- .../ai/agent/history/AiAgentSession.java | 18 +- .../ai/agent/history/ChatHistoryImpl.java | 1 + .../agent/llmCache/AiAgentSessionVector.java | 159 ++++++++++++++++++ .../ai/agent/vector/AiAgentVector.java | 2 - .../prompt-templates/system-prompt-qa.txt | 1 + 6 files changed, 198 insertions(+), 10 deletions(-) create mode 100644 src/main/java/com/datastax/ai/agent/llmCache/AiAgentSessionVector.java diff --git a/src/main/java/com/datastax/ai/agent/AiApplication.java b/src/main/java/com/datastax/ai/agent/AiApplication.java index f939961..7e3037b 100644 --- a/src/main/java/com/datastax/ai/agent/AiApplication.java +++ b/src/main/java/com/datastax/ai/agent/AiApplication.java @@ -16,10 +16,12 @@ */ package com.datastax.ai.agent; +import java.util.HashMap; import java.util.Map; import com.datastax.ai.agent.base.AiAgent; import com.datastax.ai.agent.history.AiAgentSession; +import com.datastax.ai.agent.llmCache.AiAgentSessionVector; import com.datastax.ai.agent.vector.AiAgentVector; import com.datastax.oss.driver.api.core.CqlSession; @@ -35,6 +37,7 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.embedding.EmbeddingClient; import org.springframework.ai.vectorstore.CassandraVectorStore; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; @@ -55,10 +58,11 @@ public class AiApplication implements AppShellConfigurator { @Route("") static class AiChatUI extends VerticalLayout { - public AiChatUI(AiAgent baseAgent, CqlSession cqlSession, CassandraVectorStore store) { + public AiChatUI(AiAgent baseAgent, CqlSession cqlSession, CassandraVectorStore store, EmbeddingClient embeddingClient) { AiAgentSession sessionAgent = AiAgentSession.create(baseAgent, cqlSession); - AiAgentVector agent = AiAgentVector.create(sessionAgent, store); + AiAgentVector agentVector = AiAgentVector.create(sessionAgent, store); + AiAgentSessionVector agent = AiAgentSessionVector.create(agentVector, cqlSession, embeddingClient); var messageList = new VerticalLayout(); var messageInput = new MessageInput(); @@ -70,7 +74,7 @@ public AiChatUI(AiAgent baseAgent, CqlSession cqlSession, CassandraVectorStore s messageList.add(userUI, assistantUI); - Prompt prompt = agent.createPrompt(new UserMessage(question), Map.of()); + Prompt prompt = agent.createPrompt(new UserMessageWithProperties(question), Map.of()); agent.send(prompt) .subscribe((response) -> { @@ -101,4 +105,21 @@ private static boolean isValidResponse(ChatResponse chatResponse) { public static void main(String[] args) { SpringApplication.run(AiApplication.class, args); } + + static class UserMessageWithProperties extends UserMessage { + + // intentionally overrides and hides AbstractMessage.properties which UserMessage does not use + private final Map properties = new HashMap<>(); + + public UserMessageWithProperties(String message) { + super(message); + } + + @Override + public Map getProperties() { + return properties; + } + + + } } diff --git a/src/main/java/com/datastax/ai/agent/history/AiAgentSession.java b/src/main/java/com/datastax/ai/agent/history/AiAgentSession.java index 56df346..8fcaefa 100644 --- a/src/main/java/com/datastax/ai/agent/history/AiAgentSession.java +++ b/src/main/java/com/datastax/ai/agent/history/AiAgentSession.java @@ -26,6 +26,9 @@ import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; @@ -36,8 +39,10 @@ public final class AiAgentSession implements AiAgent { + private static final Logger logger = LoggerFactory.getLogger(AiAgentSession.class); + private static final int CHAT_HISTORY_WINDOW_SIZE = 40; - + private final AiAgent agent; private final ChatHistoryImpl chatHistory; private ChatExchange exchange; @@ -54,16 +59,19 @@ public static AiAgentSession create(AiAgent agent, CqlSession cqlSession) { @Override public Prompt createPrompt(UserMessage message, Map promptProperties) { - Prompt prompt = agent.createPrompt(message, promptProperties(promptProperties)); exchange = new ChatExchange(exchange.sessionId()); exchange.messages().add(message); - chatHistory.add(exchange); - return prompt; + + // UserMessage must have been created with a mutable map + message.getProperties().put("ChatExchange_sessionId", exchange.sessionId()); + message.getProperties().put("ChatExchange_exchange_timestamp", exchange.timestamp()); + + return agent.createPrompt(message, promptProperties(promptProperties)); } @Override public Flux send(Prompt prompt) { - + Preconditions.checkArgument( prompt.getInstructions().stream().anyMatch((i) -> exchange.messages().contains(i)), "user message in prompt doesn't match"); diff --git a/src/main/java/com/datastax/ai/agent/history/ChatHistoryImpl.java b/src/main/java/com/datastax/ai/agent/history/ChatHistoryImpl.java index 3b66ae9..90986b6 100644 --- a/src/main/java/com/datastax/ai/agent/history/ChatHistoryImpl.java +++ b/src/main/java/com/datastax/ai/agent/history/ChatHistoryImpl.java @@ -78,6 +78,7 @@ public static ChatHistoryImpl create(CqlSession session) { @Override public void add(ChatExchange exchange) { + Preconditions.checkArgument(2 == exchange.messages().size()); List primaryKeyValues = config.chatExchangeToPrimaryKeyTranslator.apply(exchange); BoundStatementBuilder builder = config.addStmt.boundStatementBuilder(); diff --git a/src/main/java/com/datastax/ai/agent/llmCache/AiAgentSessionVector.java b/src/main/java/com/datastax/ai/agent/llmCache/AiAgentSessionVector.java new file mode 100644 index 0000000..7b12880 --- /dev/null +++ b/src/main/java/com/datastax/ai/agent/llmCache/AiAgentSessionVector.java @@ -0,0 +1,159 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * See the NOTICE file distributed with this work for additional information + * regarding copyright ownership. + */ +package com.datastax.ai.agent.llmCache; + +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +import com.datastax.ai.agent.base.AiAgent; +import com.datastax.ai.agent.base.AiAgentDelegator; +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.type.DataTypes; +import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; +import java.time.Instant; +import java.util.Map; +import java.util.UUID; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.vectorstore.CassandraVectorStore; +import org.springframework.ai.vectorstore.CassandraVectorStoreConfig; +import org.springframework.ai.vectorstore.CassandraVectorStoreConfig.DocumentIdTranslator; +import org.springframework.ai.vectorstore.CassandraVectorStoreConfig.SchemaColumn; +import org.springframework.ai.vectorstore.SearchRequest; + +import reactor.core.publisher.Flux; + +public class AiAgentSessionVector extends AiAgentDelegator { + + private static final CassandraVectorStoreConfig.PrimaryKeyTranslator PRIMARY_KEY_TRANSLATOR + = (pKeyColumns) -> { + if (pKeyColumns.isEmpty()) { + return UUID.randomUUID().toString() + "§¶0"; + } + Preconditions.checkArgument(2 == pKeyColumns.size()); + + String sessionId = pKeyColumns.get(0) instanceof UUID + ? ((UUID) pKeyColumns.get(0)).toString() + : (String)pKeyColumns.get(0); + + String exchangeTimestamp = pKeyColumns.get(1) instanceof Instant + ? String.valueOf(((Instant) pKeyColumns.get(1)).toEpochMilli()) + : (String) pKeyColumns.get(1); + + return sessionId.toString() + "§¶" + exchangeTimestamp; + }; + + private static final DocumentIdTranslator DOCUMENT_ID_TRANSLATOR + = (id) -> { + String[] parts = id.split("§¶"); + Preconditions.checkArgument(2 == parts.length); + UUID sessionId = UUID.fromString(parts[0]); + Instant exchangeTimestamp = Instant.ofEpochMilli(Long.parseLong(parts[1])); + return List.of(sessionId, exchangeTimestamp); + }; + + private static final Logger logger = LoggerFactory.getLogger(AiAgentSessionVector.class); + + private final AiAgent agent; + private final CassandraVectorStore store; + + public static AiAgentSessionVector create(AiAgent agent, CqlSession cqlSession, EmbeddingClient embeddingClient) { + return new AiAgentSessionVector(agent, cqlSession, embeddingClient); + } + + AiAgentSessionVector(AiAgent agent, CqlSession cqlSession, EmbeddingClient embeddingClient) { + super(agent); + this.agent = agent; + + CassandraVectorStoreConfig config = CassandraVectorStoreConfig.builder() + .withCqlSession(cqlSession) + .withKeyspaceName("datastax_ai_agent") + .withTableName("agent_conversations") + .withPartitionKeys(List.of(new SchemaColumn("session_id", DataTypes.TIMEUUID))) + .withClusteringKeys(List.of(new SchemaColumn("exchange_timestamp", DataTypes.TIMESTAMP))) + .withContentColumnName("prompt_request") + .addMetadataColumn(new SchemaColumn("prompt_response", DataTypes.TEXT)) + .withPrimaryKeyTranslator(PRIMARY_KEY_TRANSLATOR) + .withDocumentIdTranslator(DOCUMENT_ID_TRANSLATOR) + .build(); + + this.store = new CassandraVectorStore(config, embeddingClient); + } + + @Override + public Flux send(Prompt prompt) { + + final UserMessage userMsg = (UserMessage) prompt.getInstructions() + .stream().filter(m -> m instanceof UserMessage).findFirst().get(); + + final SystemMessage systemMsg = (SystemMessage) prompt.getInstructions() + .stream().filter(m -> m instanceof SystemMessage).findFirst().get(); + + // AiAgentSession is expected to have put these into the UserMessage + Preconditions.checkState(userMsg.getProperties().containsKey("ChatExchange_sessionId")); + Preconditions.checkState(userMsg.getProperties().containsKey("ChatExchange_exchange_timestamp")); + + SearchRequest request = SearchRequest + .query(userMsg.getContent()) + .withTopK(10) + .withSimilarityThreshold(0.99); + + List similarPromptResponses = store.similaritySearch(request); + if (!similarPromptResponses.isEmpty()) { + String similarResponse = (String) similarPromptResponses.get(0).getMetadata().get("prompt_response"); + return Flux.just(new ChatResponse(List.of(new Generation(similarResponse)))); + } else { + + final AtomicReference stringBufferRef = new AtomicReference<>(); + + return agent.send(prompt).doOnSubscribe(subscription -> { + stringBufferRef.set(new StringBuilder()); + }).doOnNext(chatResponse -> { + if (null != chatResponse.getResult()) { + if (null != chatResponse.getResult().getOutput().getContent()) { + stringBufferRef.get().append(chatResponse.getResult().getOutput().getContent()); + } + } + }).doOnComplete(() -> { + + Document promptRequestResponse = new Document( + PRIMARY_KEY_TRANSLATOR.apply(List.of( + userMsg.getProperties().get("ChatExchange_sessionId"), + userMsg.getProperties().get("ChatExchange_exchange_timestamp"))), + userMsg.getContent(), + Map.of("prompt_response", stringBufferRef.get().toString())); + + store.add(List.of(promptRequestResponse)); + + stringBufferRef.set(null); + }).doOnError(e -> { + logger.error("Aggregation Error", e); + stringBufferRef.set(null); + }); + } + } + +} diff --git a/src/main/java/com/datastax/ai/agent/vector/AiAgentVector.java b/src/main/java/com/datastax/ai/agent/vector/AiAgentVector.java index fbdbb64..47983d5 100644 --- a/src/main/java/com/datastax/ai/agent/vector/AiAgentVector.java +++ b/src/main/java/com/datastax/ai/agent/vector/AiAgentVector.java @@ -37,8 +37,6 @@ public class AiAgentVector extends AiAgentDelegator { private static final int CHAT_DOCUMENTS_SIZE = 3; - private static final Logger logger = LoggerFactory.getLogger(AiAgentVector.class); - private final CassandraVectorStore store; public static AiAgentVector create(AiAgent agent, CassandraVectorStore store) { diff --git a/src/main/resources/prompt-templates/system-prompt-qa.txt b/src/main/resources/prompt-templates/system-prompt-qa.txt index 71dacf8..bd6fd08 100644 --- a/src/main/resources/prompt-templates/system-prompt-qa.txt +++ b/src/main/resources/prompt-templates/system-prompt-qa.txt @@ -2,6 +2,7 @@ You are an expert support technican assistant for a fleet of IoT devices. Respond in an informative and rationale based manner. Use the conversation history from the CONVERSATION section to provide accurate answers. Use the information from the DOCUMENTS section to provide accurate answers. +Seek more information, apprehend the user's next question. Ask questions where your answers are unknown or have low confidence. Seek more information, apprehend the user's next question. Provide the rationale to your answers in laid out logical steps.