diff --git a/README.md b/README.md index 865940e..bc05755 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ This codebase serves as a starter repository for AI Agents. ## Requirements - OpenAI API key saved as an environment variable `OPENAI_API_KEY` - Java 21 (or beyond) +- Apache Cassandra running (defaults to localhost, change via application.properties) ## Running the app Run the project using `./mvnw spring-boot:run` and open [http://localhost:8080](http://localhost:8080) in your browser. diff --git a/pom.xml b/pom.xml index a8c94c5..4f20e63 100644 --- a/pom.xml +++ b/pom.xml @@ -52,10 +52,21 @@ spring-ai-openai-spring-boot-starter ${spring.ai.version} + + org.springframework.ai + spring-ai-cassandra + ${spring.ai.version} + org.springframework.boot spring-boot-devtools + + com.fasterxml.uuid + java-uuid-generator + 4.0.1 + + org.springframework.boot spring-boot-starter-test diff --git a/src/main/java/com/datastax/ai/agent/AiApplication.java b/src/main/java/com/datastax/ai/agent/AiApplication.java index f9a0e3c..e635c3a 100644 --- a/src/main/java/com/datastax/ai/agent/AiApplication.java +++ b/src/main/java/com/datastax/ai/agent/AiApplication.java @@ -19,6 +19,8 @@ import java.util.Map; import com.datastax.ai.agent.base.AiAgent; +import com.datastax.ai.agent.history.AiAgentSession; +import com.datastax.oss.driver.api.core.CqlSession; import com.vaadin.flow.component.messages.MessageInput; import com.vaadin.flow.component.orderedlayout.VerticalLayout; @@ -34,6 +36,8 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration; +import org.springframework.context.annotation.Import; import org.vaadin.firitin.components.messagelist.MarkdownMessage; import org.vaadin.firitin.components.messagelist.MarkdownMessage.Color; @@ -41,6 +45,7 @@ @Push @SpringBootApplication +@Import({CassandraAutoConfiguration.class}) public class AiApplication implements AppShellConfigurator { private static final Logger logger = LoggerFactory.getLogger(AiApplication.class); @@ -48,7 +53,9 @@ public class AiApplication implements AppShellConfigurator { @Route("") static class AiChatUI extends VerticalLayout { - public AiChatUI(AiAgent agent) { + public AiChatUI(AiAgent agent, CqlSession cqlSession) { + AiAgentSession session = AiAgentSession.create(agent, cqlSession); + var messageList = new VerticalLayout(); var messageInput = new MessageInput(); @@ -59,9 +66,9 @@ public AiChatUI(AiAgent agent) { messageList.add(userUI, assistantUI); - Prompt prompt = agent.createPrompt(new UserMessage(question), Map.of()); + Prompt prompt = session.createPrompt(new UserMessage(question), Map.of()); - agent.send(prompt) + session.send(prompt) .subscribe((response) -> { if (isValidResponse(response)) { diff --git a/src/main/java/com/datastax/ai/agent/history/AiAgentSession.java b/src/main/java/com/datastax/ai/agent/history/AiAgentSession.java new file mode 100644 index 0000000..56df346 --- /dev/null +++ b/src/main/java/com/datastax/ai/agent/history/AiAgentSession.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * See the NOTICE file distributed with this work for additional information + * regarding copyright ownership. + */ +package com.datastax.ai.agent.history; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import com.datastax.ai.agent.base.AiAgent; +import com.datastax.ai.agent.history.ChatHistory.ChatExchange; +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; + + +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; + +import reactor.core.publisher.Flux; + + +public final class AiAgentSession implements AiAgent { + + private static final int CHAT_HISTORY_WINDOW_SIZE = 40; + + private final AiAgent agent; + private final ChatHistoryImpl chatHistory; + private ChatExchange exchange; + + public static AiAgentSession create(AiAgent agent, CqlSession cqlSession) { + return new AiAgentSession(agent, cqlSession); + } + + AiAgentSession(AiAgent agent, CqlSession cqlSession) { + this.agent = agent; + this.chatHistory = new ChatHistoryImpl(cqlSession); + this.exchange = new ChatExchange(); + } + + @Override + public Prompt createPrompt(UserMessage message, Map promptProperties) { + Prompt prompt = agent.createPrompt(message, promptProperties(promptProperties)); + exchange = new ChatExchange(exchange.sessionId()); + exchange.messages().add(message); + chatHistory.add(exchange); + return prompt; + } + + @Override + public Flux send(Prompt prompt) { + + Preconditions.checkArgument( + prompt.getInstructions().stream().anyMatch((i) -> exchange.messages().contains(i)), + "user message in prompt doesn't match"); + + Flux responseFlux = agent.send(prompt); + + return MessageAggregator.aggregate( + responseFlux, + (completedMessage) -> { + exchange.messages().add(completedMessage); + chatHistory.add(exchange); + }); + } + + @Override + public Map promptProperties(Map promptProperties) { + + List history = chatHistory.getLastN(exchange.sessionId(), CHAT_HISTORY_WINDOW_SIZE); + + String conversation = history.stream() + .flatMap(e -> e.messages().stream()) + .map(e -> e.getMessageType().name().toLowerCase() + ": " + e.getContent()) + .collect(Collectors.joining(System.lineSeparator())); + + return new HashMap<>() {{ + putAll(agent.promptProperties(promptProperties)); + put("conversation", conversation); + }}; + } +} diff --git a/src/main/java/com/datastax/ai/agent/history/ChatHistory.java b/src/main/java/com/datastax/ai/agent/history/ChatHistory.java new file mode 100644 index 0000000..72a1e31 --- /dev/null +++ b/src/main/java/com/datastax/ai/agent/history/ChatHistory.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * See the NOTICE file distributed with this work for additional information + * regarding copyright ownership. + */ +package com.datastax.ai.agent.history; + +import java.util.ArrayList; +import java.time.Instant; +import java.util.List; + +import com.fasterxml.uuid.Generators; +import org.springframework.ai.chat.messages.Message; + +/** + * Coming in Spring-AI + * + * see https://github.com/spring-projects/spring-ai/pull/536 + */ +public interface ChatHistory { + + public record ChatExchange(String sessionId, List messages, Instant timestamp) { + + public ChatExchange() { + this(Generators.timeBasedGenerator().generate().toString(), new ArrayList<>(), Instant.now()); + } + + public ChatExchange(String sessionId) { + this(sessionId, Instant.now()); + } + + public ChatExchange(String sessionId, Instant timestamp) { + this(sessionId, new ArrayList<>(), timestamp); + } + + } + + void add(ChatExchange exchange); + + List get(String sessionId); + + void clear(String sessionId); + +} diff --git a/src/main/java/com/datastax/ai/agent/history/ChatHistoryImpl.java b/src/main/java/com/datastax/ai/agent/history/ChatHistoryImpl.java new file mode 100644 index 0000000..a41ef8e --- /dev/null +++ b/src/main/java/com/datastax/ai/agent/history/ChatHistoryImpl.java @@ -0,0 +1,354 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * See the NOTICE file distributed with this work for additional information + * regarding copyright ownership. + */ +package com.datastax.ai.agent.history; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder; +import com.datastax.oss.driver.api.core.cql.PreparedStatement; +import com.datastax.oss.driver.api.core.cql.Row; +import com.datastax.oss.driver.api.core.cql.SimpleStatement; +import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder; +import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; +import com.datastax.oss.driver.api.core.type.DataType; +import com.datastax.oss.driver.api.core.type.DataTypes; +import com.datastax.oss.driver.api.core.type.ListType; +import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry; +import com.datastax.oss.driver.api.core.type.reflect.GenericType; +import com.datastax.oss.driver.api.querybuilder.QueryBuilder; +import com.datastax.oss.driver.api.querybuilder.SchemaBuilder; +import com.datastax.oss.driver.api.querybuilder.delete.Delete; +import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection; +import com.datastax.oss.driver.api.querybuilder.insert.InsertInto; +import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert; +import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumn; +import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumnEnd; +import com.datastax.oss.driver.api.querybuilder.schema.CreateTable; +import com.datastax.oss.driver.api.querybuilder.schema.CreateTableStart; +import com.datastax.oss.driver.api.querybuilder.select.Select; +import com.datastax.oss.driver.internal.core.type.DefaultListType; +import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; +import java.time.Instant; +import java.util.Map; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; + +import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; + +final class ChatHistoryImpl implements ChatHistory { + + private static final Logger logger = LoggerFactory.getLogger(ChatHistoryImpl.class); + + private final CqlSession session; + private final Config config; + + public static ChatHistoryImpl create(CqlSession session) { + return new ChatHistoryImpl(session); + } + + ChatHistoryImpl(CqlSession session) { + this.session = session; + this.config = new Config(session); + } + + @Override + public void add(ChatExchange exchange) { + List primaryKeyValues = config.chatExchangeToPrimaryKeyTranslator.apply(exchange); + + BoundStatementBuilder builder = config.addStmt.boundStatementBuilder(); + for (int k = 0; k < primaryKeyValues.size(); ++k) { + Config.SchemaColumn keyColumn = config.getPrimaryKeyColumn(k); + builder = builder.set(keyColumn.name(), primaryKeyValues.get(k), keyColumn.javaType()); + } + + builder = builder.setList( + config.schema.messages(), + exchange.messages().stream().map((msg) -> msg.getContent()).toList(), + String.class); + + session.execute(builder.build()); + } + + @Override + public List get(String sessionId) { + return getLastN(sessionId, Integer.MAX_VALUE); + } + + @Override + public void clear(String sessionId) { + ChatExchange dummy = new ChatExchange(sessionId); + List primaryKeyValues = config.chatExchangeToPrimaryKeyTranslator.apply(dummy); + BoundStatementBuilder builder = config.deleteStmt.boundStatementBuilder(); + for (int k = 0; k < primaryKeyValues.size(); ++k) { + Config.SchemaColumn keyColumn = config.getPrimaryKeyColumn(k); + builder = builder.set(keyColumn.name(), primaryKeyValues.get(k), keyColumn.javaType()); + } + session.execute(builder.build()); + } + + List getLastN(String sessionId, int lastN) { + ChatExchange dummy = new ChatExchange(sessionId); + List primaryKeyValues = config.chatExchangeToPrimaryKeyTranslator.apply(dummy); + + BoundStatementBuilder builder = config.getStatement.boundStatementBuilder(); + for (int k = 0; k < primaryKeyValues.size(); ++k) { + Config.SchemaColumn keyColumn = config.getPrimaryKeyColumn(k); + // TODO make compatible with configurable ChatExchangeToPrimaryKeyTranslator + // this assumes there's only one clustering key (for the chatExchange timestamp) + if (!Config.DEFAULT_EXCHANGE_ID_NAME.equals(keyColumn.name())) { + builder = builder.set(keyColumn.name(), primaryKeyValues.get(k), keyColumn.javaType()); + } + } + builder = builder.setInt("lastN", lastN); + List exchanges = new ArrayList<>(); + for (Row r : session.execute(builder.build())) { + List msgs = r.getList(Config.DEFAULT_MESSAGES_COLUMN_NAME, String.class); + exchanges.add( + new ChatExchange( + r.getUuid(Config.DEFAULT_SESSION_ID_NAME).toString(), + List.of( + new UserMessage(msgs.get(0)), + new AssistantMessage(msgs.get(1))), + r.get(Config.DEFAULT_EXCHANGE_ID_NAME, Instant.class))); + } + return exchanges; + } + + public static class Config { + + record Schema( + String keyspace, + String table, + List partitionKeys, + List clusteringKeys, + String messages) { + + } + + record SchemaColumn(String name, DataType type) { + GenericType javaType() { + return CodecRegistry.DEFAULT.codecFor(type).getJavaType(); + } + } + + public interface ChatExchangeToPrimaryKeyTranslator extends Function> {} + + public interface PrimaryKeyToChatExchangeTranslator extends Function, ChatExchange> {} + + public static final String DEFAULT_KEYSPACE_NAME = "datastax_ai_agent"; + public static final String DEFAULT_TABLE_NAME = "agent_conversations"; + public static final String DEFAULT_SESSION_ID_NAME = "session_id"; + public static final String DEFAULT_EXCHANGE_ID_NAME = "exchange_timestamp"; + public static final String DEFAULT_MESSAGES_COLUMN_NAME = "messages"; + + private static final ListType DEFAULT_MESSAGES_COLUMN_TYPE = new DefaultListType(DataTypes.TEXT, true); + + private final CqlSession session; + + private final Schema schema = new Schema( + DEFAULT_KEYSPACE_NAME, + DEFAULT_TABLE_NAME, + List.of(new SchemaColumn(DEFAULT_SESSION_ID_NAME, DataTypes.TIMEUUID)), + List.of(new SchemaColumn(DEFAULT_EXCHANGE_ID_NAME, DataTypes.TIMESTAMP)), + DEFAULT_MESSAGES_COLUMN_NAME); + + private final ChatExchangeToPrimaryKeyTranslator chatExchangeToPrimaryKeyTranslator + = (e) -> List.of(UUID.fromString(e.sessionId()), e.timestamp()); + + private final PrimaryKeyToChatExchangeTranslator primaryKeyToChatExchangeTranslator + = (primaryKeys) + -> new ChatExchange(primaryKeys.get(0).toString(), (Instant) primaryKeys.get(1)); + + private final boolean disallowSchemaChanges = false; + + private final PreparedStatement addStmt, getStatement, deleteStmt; + + Config(CqlSession session) { + this.session = session; + ensureSchemaExists(); + addStmt = prepareAddStmt(); + getStatement = prepareGetStatement(); + deleteStmt = prepareDeleteStmt(); + } + + private SchemaColumn getPrimaryKeyColumn(int index) { + return index < this.schema.partitionKeys().size() + ? this.schema.partitionKeys().get(index) + : this.schema.clusteringKeys().get(index - this.schema.partitionKeys().size()); + } + + private void ensureSchemaExists() { + if (!disallowSchemaChanges) { + ensureKeyspaceExists(); + ensureTableExists(); + ensureTableColumnsExist(); + checkSchemaAgreement(); + } else { + checkSchemaValid(); + } + } + + private void checkSchemaAgreement() throws IllegalStateException { + if (!session.checkSchemaAgreement()) { + logger.warn("Waiting for cluster schema agreement, sleeping 10s…"); + try { + Thread.sleep(Duration.ofSeconds(10).toMillis()); + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + throw new IllegalStateException(ex); + } + if (!session.checkSchemaAgreement()) { + logger.error("no cluster schema agreement still, continuing, let's hope this works…"); + } + } + } + + void checkSchemaValid() { + + Preconditions.checkState(session.getMetadata().getKeyspace(schema.keyspace).isPresent(), + "keyspace %s does not exist", schema.keyspace); + + Preconditions.checkState(session.getMetadata() + .getKeyspace(schema.keyspace) + .get() + .getTable(schema.table) + .isPresent(), "table %s does not exist"); + + TableMetadata tableMetadata = session.getMetadata() + .getKeyspace(schema.keyspace) + .get() + .getTable(schema.table) + .get(); + + Preconditions.checkState(tableMetadata.getColumn(schema.messages).isPresent(), "column %s does not exist", + schema.messages); + + } + + private void ensureKeyspaceExists() { + + SimpleStatement keyspaceStmt = SchemaBuilder.createKeyspace(schema.keyspace) + .ifNotExists() + .withSimpleStrategy(1) + .build(); + + logger.debug("Executing {}", keyspaceStmt.getQuery()); + session.execute(keyspaceStmt); + } + + private void ensureTableExists() { + + CreateTable createTable = null; + + CreateTableStart createTableStart = SchemaBuilder.createTable(schema.keyspace, schema.table) + .ifNotExists(); + + for (SchemaColumn partitionKey : schema.partitionKeys) { + createTable = (null != createTable ? createTable : createTableStart).withPartitionKey(partitionKey.name, + partitionKey.type); + } + for (SchemaColumn clusteringKey : schema.clusteringKeys) { + createTable = createTable.withClusteringColumn(clusteringKey.name, clusteringKey.type); + } + + createTable = createTable.withColumn(schema.messages, DEFAULT_MESSAGES_COLUMN_TYPE); + + session.execute( + createTable.withClusteringOrder(DEFAULT_EXCHANGE_ID_NAME, ClusteringOrder.DESC) + // set this if you want sessions to expire after a period of time + // TODO create option, and append TTL value to select queries (performance) + //.withDefaultTimeToLiveSeconds((int) Duration.ofDays(120).toSeconds()) + + // TODO replace when SchemaBuilder.unifiedCompactionStrategy() becomes available + .withOption("compaction", Map.of("class", "UnifiedCompactionStrategy")) + //.withCompaction(SchemaBuilder.unifiedCompactionStrategy())) + .build()); + } + + private void ensureTableColumnsExist() { + + TableMetadata tableMetadata = session.getMetadata() + .getKeyspace(schema.keyspace()) + .get() + .getTable(schema.table()) + .get(); + + boolean addContent = tableMetadata.getColumn(schema.messages()).isEmpty(); + + if (addContent) { + AlterTableAddColumn alterTable = SchemaBuilder + .alterTable(schema.keyspace(), schema.table()) + .addColumn(schema.messages(), DEFAULT_MESSAGES_COLUMN_TYPE); + + SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build(); + logger.debug("Executing {}", stmt.getQuery()); + session.execute(stmt); + } + } + + private PreparedStatement prepareAddStmt() { + RegularInsert stmt = null; + InsertInto stmtStart = QueryBuilder.insertInto(schema.keyspace(), schema.table()); + for (var c : schema.partitionKeys()) { + stmt = (null != stmt ? stmt : stmtStart) + .value(c.name(), QueryBuilder.bindMarker(c.name())); + } + for (var c : schema.clusteringKeys()) { + stmt = stmt.value(c.name(), QueryBuilder.bindMarker(c.name())); + } + stmt = stmt.value(schema.messages(), QueryBuilder.bindMarker(schema.messages())); + return session.prepare(stmt.build()); + } + + private PreparedStatement prepareGetStatement() { + Select stmt = QueryBuilder.selectFrom(schema.keyspace, schema.table).all(); + // TODO make compatible with configurable ChatExchangeToPrimaryKeyTranslator + // this assumes there's only one clustering key (for the chatExchange timestamp) + for (var c : schema.partitionKeys()) { + stmt = stmt.whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name())); + } + stmt = stmt.limit(QueryBuilder.bindMarker("lastN")); + return session.prepare(stmt.build()); + } + + private PreparedStatement prepareDeleteStmt() { + Delete stmt = null; + DeleteSelection stmtStart = QueryBuilder.deleteFrom(schema.keyspace, schema.table); + for (var c : schema.partitionKeys()) { + stmt = (null != stmt ? stmt : stmtStart) + .whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name())); + } + for (var c : schema.clusteringKeys()) { + stmt = stmt.whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name())); + } + return session.prepare(stmt.build()); + } + + } + +} diff --git a/src/main/java/com/datastax/ai/agent/history/MessageAggregator.java b/src/main/java/com/datastax/ai/agent/history/MessageAggregator.java new file mode 100644 index 0000000..6d41bf8 --- /dev/null +++ b/src/main/java/com/datastax/ai/agent/history/MessageAggregator.java @@ -0,0 +1,74 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.datastax.ai.agent.history; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; + +/** from https://github.com/spring-projects/spring-ai/pull/536 + * + * Helper that for streaming chat responses, aggregate the chat response messages into a + * single AssistantMessage. Job is performed in parallel to the chat response processing. + * + * @author Christian Tzolov + */ +final class MessageAggregator { + + private static final Logger logger = LoggerFactory.getLogger(MessageAggregator.class); + + private MessageAggregator() {} + + static Flux aggregate(Flux fluxChatResponse, + Consumer onAggregationComplete) { + + AtomicReference stringBufferRef = new AtomicReference<>(new StringBuilder()); + AtomicReference> mapRef = new AtomicReference<>(); + + return fluxChatResponse.doOnSubscribe(subscription -> { + // logger.info("Aggregation Subscribe:" + subscription); + stringBufferRef.set(new StringBuilder()); + mapRef.set(new HashMap<>()); + }).doOnNext(chatResponse -> { + // logger.info("Aggregation Next:" + chatResponse); + if (chatResponse.getResult() != null) { + if (chatResponse.getResult().getOutput().getContent() != null) { + stringBufferRef.get().append(chatResponse.getResult().getOutput().getContent()); + } + if (chatResponse.getResult().getOutput().getProperties() != null) { + mapRef.get().putAll(chatResponse.getResult().getOutput().getProperties()); + } + } + }).doOnComplete(() -> { + // logger.debug("Aggregation Complete"); + onAggregationComplete.accept(new AssistantMessage(stringBufferRef.get().toString(), mapRef.get())); + stringBufferRef.set(new StringBuilder()); + mapRef.set(new HashMap<>()); + }).doOnError(e -> { + logger.error("Aggregation Error", e); + }); + } + +} \ No newline at end of file diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 9bcdfd1..a38f30b 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -1 +1,3 @@ spring.ai.openai.api-key=${OPENAI_API_KEY} + +spring.cassandra.localDatacenter=datacenter1 \ No newline at end of file diff --git a/src/main/resources/prompt-templates/system-prompt-qa.txt b/src/main/resources/prompt-templates/system-prompt-qa.txt index 50bd13f..7a3555a 100644 --- a/src/main/resources/prompt-templates/system-prompt-qa.txt +++ b/src/main/resources/prompt-templates/system-prompt-qa.txt @@ -1,6 +1,9 @@ You are an expert support technican assistant for a fleet of IoT devices. Respond in an informative and rationale based manner. +Use the conversation history from the CONVERSATION section to provide accurate answers. Ask questions where your answers are unknown or have low confidence. Provide the rationale to your answers in laid out logical steps. Today is {current_date}. +CONVERSATION: +{conversation}