diff --git a/README.md b/README.md
index 865940e..bc05755 100644
--- a/README.md
+++ b/README.md
@@ -8,6 +8,7 @@ This codebase serves as a starter repository for AI Agents.
## Requirements
- OpenAI API key saved as an environment variable `OPENAI_API_KEY`
- Java 21 (or beyond)
+- Apache Cassandra running (defaults to localhost, change via application.properties)
## Running the app
Run the project using `./mvnw spring-boot:run` and open [http://localhost:8080](http://localhost:8080) in your browser.
diff --git a/pom.xml b/pom.xml
index a8c94c5..4f20e63 100644
--- a/pom.xml
+++ b/pom.xml
@@ -52,10 +52,21 @@
spring-ai-openai-spring-boot-starter
${spring.ai.version}
+
+ org.springframework.ai
+ spring-ai-cassandra
+ ${spring.ai.version}
+
org.springframework.boot
spring-boot-devtools
+
+ com.fasterxml.uuid
+ java-uuid-generator
+ 4.0.1
+
+
org.springframework.boot
spring-boot-starter-test
diff --git a/src/main/java/com/datastax/ai/agent/AiApplication.java b/src/main/java/com/datastax/ai/agent/AiApplication.java
index f9a0e3c..e635c3a 100644
--- a/src/main/java/com/datastax/ai/agent/AiApplication.java
+++ b/src/main/java/com/datastax/ai/agent/AiApplication.java
@@ -19,6 +19,8 @@
import java.util.Map;
import com.datastax.ai.agent.base.AiAgent;
+import com.datastax.ai.agent.history.AiAgentSession;
+import com.datastax.oss.driver.api.core.CqlSession;
import com.vaadin.flow.component.messages.MessageInput;
import com.vaadin.flow.component.orderedlayout.VerticalLayout;
@@ -34,6 +36,8 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
+import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration;
+import org.springframework.context.annotation.Import;
import org.vaadin.firitin.components.messagelist.MarkdownMessage;
import org.vaadin.firitin.components.messagelist.MarkdownMessage.Color;
@@ -41,6 +45,7 @@
@Push
@SpringBootApplication
+@Import({CassandraAutoConfiguration.class})
public class AiApplication implements AppShellConfigurator {
private static final Logger logger = LoggerFactory.getLogger(AiApplication.class);
@@ -48,7 +53,9 @@ public class AiApplication implements AppShellConfigurator {
@Route("")
static class AiChatUI extends VerticalLayout {
- public AiChatUI(AiAgent agent) {
+ public AiChatUI(AiAgent agent, CqlSession cqlSession) {
+ AiAgentSession session = AiAgentSession.create(agent, cqlSession);
+
var messageList = new VerticalLayout();
var messageInput = new MessageInput();
@@ -59,9 +66,9 @@ public AiChatUI(AiAgent agent) {
messageList.add(userUI, assistantUI);
- Prompt prompt = agent.createPrompt(new UserMessage(question), Map.of());
+ Prompt prompt = session.createPrompt(new UserMessage(question), Map.of());
- agent.send(prompt)
+ session.send(prompt)
.subscribe((response) -> {
if (isValidResponse(response)) {
diff --git a/src/main/java/com/datastax/ai/agent/history/AiAgentSession.java b/src/main/java/com/datastax/ai/agent/history/AiAgentSession.java
new file mode 100644
index 0000000..56df346
--- /dev/null
+++ b/src/main/java/com/datastax/ai/agent/history/AiAgentSession.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * See the NOTICE file distributed with this work for additional information
+ * regarding copyright ownership.
+ */
+package com.datastax.ai.agent.history;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import com.datastax.ai.agent.base.AiAgent;
+import com.datastax.ai.agent.history.ChatHistory.ChatExchange;
+import com.datastax.oss.driver.api.core.CqlSession;
+import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
+
+
+import org.springframework.ai.chat.ChatResponse;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.prompt.Prompt;
+
+import reactor.core.publisher.Flux;
+
+
+public final class AiAgentSession implements AiAgent {
+
+ private static final int CHAT_HISTORY_WINDOW_SIZE = 40;
+
+ private final AiAgent agent;
+ private final ChatHistoryImpl chatHistory;
+ private ChatExchange exchange;
+
+ public static AiAgentSession create(AiAgent agent, CqlSession cqlSession) {
+ return new AiAgentSession(agent, cqlSession);
+ }
+
+ AiAgentSession(AiAgent agent, CqlSession cqlSession) {
+ this.agent = agent;
+ this.chatHistory = new ChatHistoryImpl(cqlSession);
+ this.exchange = new ChatExchange();
+ }
+
+ @Override
+ public Prompt createPrompt(UserMessage message, Map promptProperties) {
+ Prompt prompt = agent.createPrompt(message, promptProperties(promptProperties));
+ exchange = new ChatExchange(exchange.sessionId());
+ exchange.messages().add(message);
+ chatHistory.add(exchange);
+ return prompt;
+ }
+
+ @Override
+ public Flux send(Prompt prompt) {
+
+ Preconditions.checkArgument(
+ prompt.getInstructions().stream().anyMatch((i) -> exchange.messages().contains(i)),
+ "user message in prompt doesn't match");
+
+ Flux responseFlux = agent.send(prompt);
+
+ return MessageAggregator.aggregate(
+ responseFlux,
+ (completedMessage) -> {
+ exchange.messages().add(completedMessage);
+ chatHistory.add(exchange);
+ });
+ }
+
+ @Override
+ public Map promptProperties(Map promptProperties) {
+
+ List history = chatHistory.getLastN(exchange.sessionId(), CHAT_HISTORY_WINDOW_SIZE);
+
+ String conversation = history.stream()
+ .flatMap(e -> e.messages().stream())
+ .map(e -> e.getMessageType().name().toLowerCase() + ": " + e.getContent())
+ .collect(Collectors.joining(System.lineSeparator()));
+
+ return new HashMap<>() {{
+ putAll(agent.promptProperties(promptProperties));
+ put("conversation", conversation);
+ }};
+ }
+}
diff --git a/src/main/java/com/datastax/ai/agent/history/ChatHistory.java b/src/main/java/com/datastax/ai/agent/history/ChatHistory.java
new file mode 100644
index 0000000..72a1e31
--- /dev/null
+++ b/src/main/java/com/datastax/ai/agent/history/ChatHistory.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * See the NOTICE file distributed with this work for additional information
+ * regarding copyright ownership.
+ */
+package com.datastax.ai.agent.history;
+
+import java.util.ArrayList;
+import java.time.Instant;
+import java.util.List;
+
+import com.fasterxml.uuid.Generators;
+import org.springframework.ai.chat.messages.Message;
+
+/**
+ * Coming in Spring-AI
+ *
+ * see https://github.com/spring-projects/spring-ai/pull/536
+ */
+public interface ChatHistory {
+
+ public record ChatExchange(String sessionId, List messages, Instant timestamp) {
+
+ public ChatExchange() {
+ this(Generators.timeBasedGenerator().generate().toString(), new ArrayList<>(), Instant.now());
+ }
+
+ public ChatExchange(String sessionId) {
+ this(sessionId, Instant.now());
+ }
+
+ public ChatExchange(String sessionId, Instant timestamp) {
+ this(sessionId, new ArrayList<>(), timestamp);
+ }
+
+ }
+
+ void add(ChatExchange exchange);
+
+ List get(String sessionId);
+
+ void clear(String sessionId);
+
+}
diff --git a/src/main/java/com/datastax/ai/agent/history/ChatHistoryImpl.java b/src/main/java/com/datastax/ai/agent/history/ChatHistoryImpl.java
new file mode 100644
index 0000000..a41ef8e
--- /dev/null
+++ b/src/main/java/com/datastax/ai/agent/history/ChatHistoryImpl.java
@@ -0,0 +1,354 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * See the NOTICE file distributed with this work for additional information
+ * regarding copyright ownership.
+ */
+package com.datastax.ai.agent.history;
+
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+
+import com.datastax.oss.driver.api.core.CqlSession;
+import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder;
+import com.datastax.oss.driver.api.core.cql.PreparedStatement;
+import com.datastax.oss.driver.api.core.cql.Row;
+import com.datastax.oss.driver.api.core.cql.SimpleStatement;
+import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder;
+import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata;
+import com.datastax.oss.driver.api.core.type.DataType;
+import com.datastax.oss.driver.api.core.type.DataTypes;
+import com.datastax.oss.driver.api.core.type.ListType;
+import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry;
+import com.datastax.oss.driver.api.core.type.reflect.GenericType;
+import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
+import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
+import com.datastax.oss.driver.api.querybuilder.delete.Delete;
+import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection;
+import com.datastax.oss.driver.api.querybuilder.insert.InsertInto;
+import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert;
+import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumn;
+import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumnEnd;
+import com.datastax.oss.driver.api.querybuilder.schema.CreateTable;
+import com.datastax.oss.driver.api.querybuilder.schema.CreateTableStart;
+import com.datastax.oss.driver.api.querybuilder.select.Select;
+import com.datastax.oss.driver.internal.core.type.DefaultListType;
+import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
+import java.time.Instant;
+import java.util.Map;
+import java.util.function.Function;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.chat.messages.AssistantMessage;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.UserMessage;
+
+import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.context.annotation.Import;
+
+final class ChatHistoryImpl implements ChatHistory {
+
+ private static final Logger logger = LoggerFactory.getLogger(ChatHistoryImpl.class);
+
+ private final CqlSession session;
+ private final Config config;
+
+ public static ChatHistoryImpl create(CqlSession session) {
+ return new ChatHistoryImpl(session);
+ }
+
+ ChatHistoryImpl(CqlSession session) {
+ this.session = session;
+ this.config = new Config(session);
+ }
+
+ @Override
+ public void add(ChatExchange exchange) {
+ List