Skip to content

Commit

Permalink
Implement Chat History
Browse files Browse the repository at this point in the history
Feature is added in the com.datastax.ai.agent.history package.

ChatHistoryImpl implements it as a Cassandra table. Schema is flexible. This will be upstreamed to spring-ai.

Based of upcoming work in spring-projects/spring-ai#536
  • Loading branch information
michaelsembwever committed Apr 23, 2024
1 parent b5c10d0 commit 08f51ec
Show file tree
Hide file tree
Showing 9 changed files with 606 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This codebase serves as a starter repository for AI Agents.
## Requirements
- OpenAI API key saved as an environment variable `OPENAI_API_KEY`
- Java 21 (or beyond)
- Apache Cassandra running (defaults to localhost, change via application.properties)

## Running the app
Run the project using `./mvnw spring-boot:run` and open [http://localhost:8080](http://localhost:8080) in your browser.
Expand Down
11 changes: 11 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,21 @@
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
<version>${spring.ai.version}</version>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-cassandra</artifactId>
<version>${spring.ai.version}</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-devtools</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.uuid</groupId>
<artifactId>java-uuid-generator</artifactId>
<version>4.0.1</version>
</dependency>
<!-- test scope -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
Expand Down
13 changes: 10 additions & 3 deletions src/main/java/com/datastax/ai/agent/AiApplication.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import java.util.Map;

import com.datastax.ai.agent.base.AiAgent;
import com.datastax.ai.agent.history.AiAgentSession;
import com.datastax.oss.driver.api.core.CqlSession;

import com.vaadin.flow.component.messages.MessageInput;
import com.vaadin.flow.component.orderedlayout.VerticalLayout;
Expand All @@ -34,21 +36,26 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration;
import org.springframework.context.annotation.Import;

import org.vaadin.firitin.components.messagelist.MarkdownMessage;
import org.vaadin.firitin.components.messagelist.MarkdownMessage.Color;


@Push
@SpringBootApplication
@Import({CassandraAutoConfiguration.class})
public class AiApplication implements AppShellConfigurator {

private static final Logger logger = LoggerFactory.getLogger(AiApplication.class);

@Route("")
static class AiChatUI extends VerticalLayout {

public AiChatUI(AiAgent agent) {
public AiChatUI(AiAgent agent, CqlSession cqlSession) {
AiAgentSession session = AiAgentSession.create(agent, cqlSession);

var messageList = new VerticalLayout();
var messageInput = new MessageInput();

Expand All @@ -59,9 +66,9 @@ public AiChatUI(AiAgent agent) {

messageList.add(userUI, assistantUI);

Prompt prompt = agent.createPrompt(new UserMessage(question), Map.of());
Prompt prompt = session.createPrompt(new UserMessage(question), Map.of());

agent.send(prompt)
session.send(prompt)
.subscribe((response) -> {
if (isValidResponse(response)) {

Expand Down
96 changes: 96 additions & 0 deletions src/main/java/com/datastax/ai/agent/history/AiAgentSession.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* See the NOTICE file distributed with this work for additional information
* regarding copyright ownership.
*/
package com.datastax.ai.agent.history;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import com.datastax.ai.agent.base.AiAgent;
import com.datastax.ai.agent.history.ChatHistory.ChatExchange;
import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;


import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;

import reactor.core.publisher.Flux;


public final class AiAgentSession implements AiAgent {

private static final int CHAT_HISTORY_WINDOW_SIZE = 40;

private final AiAgent agent;
private final ChatHistoryImpl chatHistory;
private ChatExchange exchange;

public static AiAgentSession create(AiAgent agent, CqlSession cqlSession) {
return new AiAgentSession(agent, cqlSession);
}

AiAgentSession(AiAgent agent, CqlSession cqlSession) {
this.agent = agent;
this.chatHistory = new ChatHistoryImpl(cqlSession);
this.exchange = new ChatExchange();
}

@Override
public Prompt createPrompt(UserMessage message, Map<String,Object> promptProperties) {
Prompt prompt = agent.createPrompt(message, promptProperties(promptProperties));
exchange = new ChatExchange(exchange.sessionId());
exchange.messages().add(message);
chatHistory.add(exchange);
return prompt;
}

@Override
public Flux<ChatResponse> send(Prompt prompt) {

Preconditions.checkArgument(
prompt.getInstructions().stream().anyMatch((i) -> exchange.messages().contains(i)),
"user message in prompt doesn't match");

Flux<ChatResponse> responseFlux = agent.send(prompt);

return MessageAggregator.aggregate(
responseFlux,
(completedMessage) -> {
exchange.messages().add(completedMessage);
chatHistory.add(exchange);
});
}

@Override
public Map<String,Object> promptProperties(Map<String,Object> promptProperties) {

List<ChatExchange> history = chatHistory.getLastN(exchange.sessionId(), CHAT_HISTORY_WINDOW_SIZE);

String conversation = history.stream()
.flatMap(e -> e.messages().stream())
.map(e -> e.getMessageType().name().toLowerCase() + ": " + e.getContent())
.collect(Collectors.joining(System.lineSeparator()));

return new HashMap<>() {{
putAll(agent.promptProperties(promptProperties));
put("conversation", conversation);
}};
}
}
55 changes: 55 additions & 0 deletions src/main/java/com/datastax/ai/agent/history/ChatHistory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* See the NOTICE file distributed with this work for additional information
* regarding copyright ownership.
*/
package com.datastax.ai.agent.history;

import java.util.ArrayList;
import java.time.Instant;
import java.util.List;

import com.fasterxml.uuid.Generators;
import org.springframework.ai.chat.messages.Message;

/**
* Coming in Spring-AI
*
* see https://github.com/spring-projects/spring-ai/pull/536
*/
public interface ChatHistory {

public record ChatExchange(String sessionId, List<Message> messages, Instant timestamp) {

public ChatExchange() {
this(Generators.timeBasedGenerator().generate().toString(), new ArrayList<>(), Instant.now());
}

public ChatExchange(String sessionId) {
this(sessionId, Instant.now());
}

public ChatExchange(String sessionId, Instant timestamp) {
this(sessionId, new ArrayList<>(), timestamp);
}

}

void add(ChatExchange exchange);

List<ChatExchange> get(String sessionId);

void clear(String sessionId);

}
Loading

0 comments on commit 08f51ec

Please sign in to comment.