From 78911cf268452df1804355bc5138743ac973a13c Mon Sep 17 00:00:00 2001 From: David Grieve Date: Tue, 6 Aug 2024 16:51:53 -0400 Subject: [PATCH 1/3] make ChatHistory thread safe --- .../chatcompletion/OpenAIChatCompletion.java | 27 +++++++++++-------- .../OpenAIChatMessageContent.java | 2 +- .../services/chatcompletion/ChatHistory.java | 16 ++++++----- 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java index 6bdb4f1c..f4a57f00 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java @@ -183,7 +183,7 @@ private static class ChatMessages { private final List newMessages; private final List allMessages; - private final List newChatMessageContent; + private final List> newChatMessageContent; public ChatMessages(List allMessages) { this.allMessages = Collections.unmodifiableList(allMessages); @@ -194,7 +194,7 @@ public ChatMessages(List allMessages) { private ChatMessages( List allMessages, List newMessages, - List newChatMessageContent) { + List> newChatMessageContent) { this.allMessages = Collections.unmodifiableList(allMessages); this.newMessages = Collections.unmodifiableList(newMessages); this.newChatMessageContent = Collections.unmodifiableList(newChatMessageContent); @@ -218,8 +218,8 @@ public ChatMessages add(ChatRequestMessage requestMessage) { } @CheckReturnValue - public ChatMessages addChatMessage(List chatMessageContent) { - ArrayList tmpChatMessageContent = new ArrayList<>( + public ChatMessages addChatMessage(List> chatMessageContent) { + ArrayList> tmpChatMessageContent = new ArrayList<>( newChatMessageContent); tmpChatMessageContent.addAll(chatMessageContent); @@ -580,7 +580,7 @@ private OpenAIFunctionToolCall extractOpenAIFunctionToolCall( arguments); } - private Mono> getChatMessageContentsAsync( + private Mono>> getChatMessageContentsAsync( ChatCompletions completions) { FunctionResultMetadata completionMetadata = FunctionResultMetadata.build( completions.getId(), @@ -594,22 +594,27 @@ private Mono> getChatMessageContentsAsync( .filter(Objects::nonNull) .collect(Collectors.toList()); - return Flux.fromIterable(responseMessages) - .flatMap(response -> { + List> chatMessageContent = + responseMessages + .stream() + .map(response -> { try { - return Mono.just(new OpenAIChatMessageContent( + return new OpenAIChatMessageContent<>( AuthorRole.ASSISTANT, response.getContent(), this.getModelId(), null, null, completionMetadata, - formOpenAiToolCalls(response))); + formOpenAiToolCalls(response)); } catch (Exception e) { - return Mono.error(e); + return null; } }) - .collectList(); + .filter(Objects::nonNull) + .collect(Collectors.toList()); + + return Mono.just(chatMessageContent); } private List> toOpenAIChatMessageContent( diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatMessageContent.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatMessageContent.java index f2cbf858..89f45014 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatMessageContent.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatMessageContent.java @@ -36,7 +36,7 @@ public OpenAIChatMessageContent( @Nullable String modelId, @Nullable T innerContent, @Nullable Charset encoding, - @Nullable FunctionResultMetadata metadata, + @Nullable FunctionResultMetadata metadata, @Nullable List toolCall) { super(authorRole, content, modelId, innerContent, encoding, metadata); diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatHistory.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatHistory.java index d2f391ff..170aac40 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatHistory.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatHistory.java @@ -5,11 +5,13 @@ import com.microsoft.semantickernel.services.chatcompletion.message.ChatMessageTextContent; import java.nio.charset.Charset; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Optional; import java.util.Spliterator; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.function.Consumer; import javax.annotation.Nullable; @@ -18,7 +20,7 @@ */ public class ChatHistory implements Iterable> { - private final List> chatMessageContents; + private final Collection> chatMessageContents; /** * The default constructor @@ -33,7 +35,7 @@ public ChatHistory() { * @param instructions The instructions to add to the chat history */ public ChatHistory(@Nullable String instructions) { - this.chatMessageContents = new ArrayList<>(); + this.chatMessageContents = new ConcurrentLinkedQueue<>(); if (instructions != null) { this.chatMessageContents.add( ChatMessageTextContent.systemMessage(instructions)); @@ -45,8 +47,8 @@ public ChatHistory(@Nullable String instructions) { * * @param chatMessageContents The chat message contents to add to the chat history */ - public ChatHistory(List chatMessageContents) { - this.chatMessageContents = new ArrayList(chatMessageContents); + public ChatHistory(List> chatMessageContents) { + this.chatMessageContents = new ConcurrentLinkedQueue<>(chatMessageContents); } /** @@ -55,7 +57,7 @@ public ChatHistory(List chatMessageContents) { * @return List of messages in the chat */ public List> getMessages() { - return Collections.unmodifiableList(chatMessageContents); + return Collections.unmodifiableList(new ArrayList<>(chatMessageContents)); } /** @@ -67,7 +69,7 @@ public Optional> getLastMessage() { if (chatMessageContents.isEmpty()) { return Optional.empty(); } - return Optional.of(chatMessageContents.get(chatMessageContents.size() - 1)); + return Optional.of(((ConcurrentLinkedQueue>)chatMessageContents).peek()); } /** @@ -114,7 +116,7 @@ public Spliterator> spliterator() { * @param metadata The metadata of the message */ public void addMessage(AuthorRole authorRole, String content, Charset encoding, - FunctionResultMetadata metadata) { + FunctionResultMetadata metadata) { chatMessageContents.add( ChatMessageTextContent.builder() .withAuthorRole(authorRole) From 1257fee6250582d36aa48e3d24ac30a6042bf10d Mon Sep 17 00:00:00 2001 From: David Grieve Date: Wed, 7 Aug 2024 11:09:51 -0400 Subject: [PATCH 2/3] Log message if formOpenAiToolCalls throws exception --- .../openai/chatcompletion/OpenAIChatCompletion.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java index 4bd424dd..c02c8ef0 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java @@ -619,7 +619,8 @@ private Mono>> getChatMessageContentsAsync( null, completionMetadata, formOpenAiToolCalls(response)); - } catch (Exception e) { + } catch (SKCheckedException e) { + LOGGER.warn("Failed to form chat message content", e); return null; } }) @@ -936,7 +937,7 @@ private static boolean hasToolCallBeenExecuted(List chatRequ } private static List getChatRequestMessages( - List messages) { + List> messages) { if (messages == null || messages.isEmpty()) { return new ArrayList<>(); } From 8d8d5df3b3de90571defe0ec8661ce3206278a62 Mon Sep 17 00:00:00 2001 From: David Grieve Date: Wed, 7 Aug 2024 12:24:46 -0400 Subject: [PATCH 3/3] Remove Mono from private getChatMessageContentsAsync --- .../chatcompletion/OpenAIChatCompletion.java | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java index 4bd424dd..ca8eb5b3 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java @@ -357,19 +357,16 @@ private Mono internalChatMessageContentsAsync( // If we don't want to attempt to invoke any functions // Or if we are auto-invoking, but we somehow end up with other than 1 choice even though only 1 was requested if (autoInvokeAttempts == 0 || responseMessages.size() != 1) { - return getChatMessageContentsAsync(completions) - .flatMap(m -> { - return Mono.just(messages.addChatMessage(m)); - }); + List> chatMessageContents = getChatMessageContentsAsync(completions); + return Mono.just(messages.addChatMessage(chatMessageContents)); } // Or if there are no tool calls to be done ChatResponseMessage response = responseMessages.get(0); List toolCalls = response.getToolCalls(); if (toolCalls == null || toolCalls.isEmpty()) { - return getChatMessageContentsAsync(completions) - .flatMap(m -> { - return Mono.just(messages.addChatMessage(m)); - }); + List> chatMessageContents = getChatMessageContentsAsync( + completions); + return Mono.just(messages.addChatMessage(chatMessageContents)); } ChatRequestAssistantMessage requestMessage = new ChatRequestAssistantMessage( @@ -592,7 +589,7 @@ private OpenAIFunctionToolCall extractOpenAIFunctionToolCall( arguments); } - private Mono>> getChatMessageContentsAsync( + private List> getChatMessageContentsAsync( ChatCompletions completions) { FunctionResultMetadata completionMetadata = FunctionResultMetadata.build( completions.getId(), @@ -619,14 +616,15 @@ private Mono>> getChatMessageContentsAsync( null, completionMetadata, formOpenAiToolCalls(response)); - } catch (Exception e) { + } catch (SKCheckedException e) { + LOGGER.warn("Failed to form chat message content", e); return null; } }) .filter(Objects::nonNull) .collect(Collectors.toList()); - return Mono.just(chatMessageContent); + return chatMessageContent; } private List> toOpenAIChatMessageContent( @@ -936,7 +934,7 @@ private static boolean hasToolCallBeenExecuted(List chatRequ } private static List getChatRequestMessages( - List messages) { + List> messages) { if (messages == null || messages.isEmpty()) { return new ArrayList<>(); }