diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java index 84a8287e..ca8eb5b3 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java @@ -184,7 +184,7 @@ private static class ChatMessages { private final List newMessages; private final List allMessages; - private final List newChatMessageContent; + private final List> newChatMessageContent; public ChatMessages(List allMessages) { this.allMessages = Collections.unmodifiableList(allMessages); @@ -195,7 +195,7 @@ public ChatMessages(List allMessages) { private ChatMessages( List allMessages, List newMessages, - List newChatMessageContent) { + List> newChatMessageContent) { this.allMessages = Collections.unmodifiableList(allMessages); this.newMessages = Collections.unmodifiableList(newMessages); this.newChatMessageContent = Collections.unmodifiableList(newChatMessageContent); @@ -219,8 +219,8 @@ public ChatMessages add(ChatRequestMessage requestMessage) { } @CheckReturnValue - public ChatMessages addChatMessage(List chatMessageContent) { - ArrayList tmpChatMessageContent = new ArrayList<>( + public ChatMessages addChatMessage(List> chatMessageContent) { + ArrayList> tmpChatMessageContent = new ArrayList<>( newChatMessageContent); tmpChatMessageContent.addAll(chatMessageContent); @@ -357,19 +357,16 @@ private Mono internalChatMessageContentsAsync( // If we don't want to attempt to invoke any functions // Or if we are auto-invoking, but we somehow end up with other than 1 choice even though only 1 was requested if (autoInvokeAttempts == 0 || responseMessages.size() != 1) { - return getChatMessageContentsAsync(completions) - .flatMap(m -> { - return Mono.just(messages.addChatMessage(m)); - }); + List> chatMessageContents = getChatMessageContentsAsync(completions); + return Mono.just(messages.addChatMessage(chatMessageContents)); } // Or if there are no tool calls to be done ChatResponseMessage response = responseMessages.get(0); List toolCalls = response.getToolCalls(); if (toolCalls == null || toolCalls.isEmpty()) { - return getChatMessageContentsAsync(completions) - .flatMap(m -> { - return Mono.just(messages.addChatMessage(m)); - }); + List> chatMessageContents = getChatMessageContentsAsync( + completions); + return Mono.just(messages.addChatMessage(chatMessageContents)); } ChatRequestAssistantMessage requestMessage = new ChatRequestAssistantMessage( @@ -592,7 +589,7 @@ private OpenAIFunctionToolCall extractOpenAIFunctionToolCall( arguments); } - private Mono> getChatMessageContentsAsync( + private List> getChatMessageContentsAsync( ChatCompletions completions) { FunctionResultMetadata completionMetadata = FunctionResultMetadata.build( completions.getId(), @@ -606,22 +603,28 @@ private Mono> getChatMessageContentsAsync( .filter(Objects::nonNull) .collect(Collectors.toList()); - return Flux.fromIterable(responseMessages) - .flatMap(response -> { + List> chatMessageContent = + responseMessages + .stream() + .map(response -> { try { - return Mono.just(new OpenAIChatMessageContent( + return new OpenAIChatMessageContent<>( AuthorRole.ASSISTANT, response.getContent(), this.getModelId(), null, null, completionMetadata, - formOpenAiToolCalls(response))); - } catch (Exception e) { - return Mono.error(e); + formOpenAiToolCalls(response)); + } catch (SKCheckedException e) { + LOGGER.warn("Failed to form chat message content", e); + return null; } }) - .collectList(); + .filter(Objects::nonNull) + .collect(Collectors.toList()); + + return chatMessageContent; } private List> toOpenAIChatMessageContent( @@ -931,7 +934,7 @@ private static boolean hasToolCallBeenExecuted(List chatRequ } private static List getChatRequestMessages( - List messages) { + List> messages) { if (messages == null || messages.isEmpty()) { return new ArrayList<>(); } diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatMessageContent.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatMessageContent.java index f2cbf858..89f45014 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatMessageContent.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatMessageContent.java @@ -36,7 +36,7 @@ public OpenAIChatMessageContent( @Nullable String modelId, @Nullable T innerContent, @Nullable Charset encoding, - @Nullable FunctionResultMetadata metadata, + @Nullable FunctionResultMetadata metadata, @Nullable List toolCall) { super(authorRole, content, modelId, innerContent, encoding, metadata); diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatHistory.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatHistory.java index 8d6bdce6..a8303061 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatHistory.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatHistory.java @@ -5,11 +5,13 @@ import com.microsoft.semantickernel.services.chatcompletion.message.ChatMessageTextContent; import java.nio.charset.Charset; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Optional; import java.util.Spliterator; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.function.Consumer; import javax.annotation.Nullable; @@ -18,7 +20,7 @@ */ public class ChatHistory implements Iterable> { - private final List> chatMessageContents; + private final Collection> chatMessageContents; /** * The default constructor @@ -33,7 +35,7 @@ public ChatHistory() { * @param instructions The instructions to add to the chat history */ public ChatHistory(@Nullable String instructions) { - this.chatMessageContents = new ArrayList<>(); + this.chatMessageContents = new ConcurrentLinkedQueue<>(); if (instructions != null) { this.chatMessageContents.add( ChatMessageTextContent.systemMessage(instructions)); @@ -45,8 +47,8 @@ public ChatHistory(@Nullable String instructions) { * * @param chatMessageContents The chat message contents to add to the chat history */ - public ChatHistory(List chatMessageContents) { - this.chatMessageContents = new ArrayList(chatMessageContents); + public ChatHistory(List> chatMessageContents) { + this.chatMessageContents = new ConcurrentLinkedQueue<>(chatMessageContents); } /** @@ -55,7 +57,7 @@ public ChatHistory(List chatMessageContents) { * @return List of messages in the chat */ public List> getMessages() { - return Collections.unmodifiableList(chatMessageContents); + return Collections.unmodifiableList(new ArrayList<>(chatMessageContents)); } /** @@ -67,7 +69,7 @@ public Optional> getLastMessage() { if (chatMessageContents.isEmpty()) { return Optional.empty(); } - return Optional.of(chatMessageContents.get(chatMessageContents.size() - 1)); + return Optional.of(((ConcurrentLinkedQueue>)chatMessageContents).peek()); } /** @@ -114,7 +116,7 @@ public Spliterator> spliterator() { * @param metadata The metadata of the message */ public ChatHistory addMessage(AuthorRole authorRole, String content, Charset encoding, - FunctionResultMetadata metadata) { + FunctionResultMetadata metadata) { chatMessageContents.add( ChatMessageTextContent.builder() .withAuthorRole(authorRole)