Skip to content

Commit

Permalink
Merge pull request #162 from dsgrieve/main
Browse files Browse the repository at this point in the history
Make ChatHistory thread safe
  • Loading branch information
johnoliver authored Aug 8, 2024
2 parents b5b873b + 4b729df commit 88e7fe7
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ private static class ChatMessages {

private final List<ChatRequestMessage> newMessages;
private final List<ChatRequestMessage> allMessages;
private final List<OpenAIChatMessageContent> newChatMessageContent;
private final List<OpenAIChatMessageContent<?>> newChatMessageContent;

public ChatMessages(List<ChatRequestMessage> allMessages) {
this.allMessages = Collections.unmodifiableList(allMessages);
Expand All @@ -195,7 +195,7 @@ public ChatMessages(List<ChatRequestMessage> allMessages) {
private ChatMessages(
List<ChatRequestMessage> allMessages,
List<ChatRequestMessage> newMessages,
List<OpenAIChatMessageContent> newChatMessageContent) {
List<OpenAIChatMessageContent<?>> newChatMessageContent) {
this.allMessages = Collections.unmodifiableList(allMessages);
this.newMessages = Collections.unmodifiableList(newMessages);
this.newChatMessageContent = Collections.unmodifiableList(newChatMessageContent);
Expand All @@ -219,8 +219,8 @@ public ChatMessages add(ChatRequestMessage requestMessage) {
}

@CheckReturnValue
public ChatMessages addChatMessage(List<OpenAIChatMessageContent> chatMessageContent) {
ArrayList<OpenAIChatMessageContent> tmpChatMessageContent = new ArrayList<>(
public ChatMessages addChatMessage(List<OpenAIChatMessageContent<?>> chatMessageContent) {
ArrayList<OpenAIChatMessageContent<?>> tmpChatMessageContent = new ArrayList<>(
newChatMessageContent);
tmpChatMessageContent.addAll(chatMessageContent);

Expand Down Expand Up @@ -357,19 +357,16 @@ private Mono<ChatMessages> internalChatMessageContentsAsync(
// If we don't want to attempt to invoke any functions
// Or if we are auto-invoking, but we somehow end up with other than 1 choice even though only 1 was requested
if (autoInvokeAttempts == 0 || responseMessages.size() != 1) {
return getChatMessageContentsAsync(completions)
.flatMap(m -> {
return Mono.just(messages.addChatMessage(m));
});
List<OpenAIChatMessageContent<?>> chatMessageContents = getChatMessageContentsAsync(completions);
return Mono.just(messages.addChatMessage(chatMessageContents));
}
// Or if there are no tool calls to be done
ChatResponseMessage response = responseMessages.get(0);
List<ChatCompletionsToolCall> toolCalls = response.getToolCalls();
if (toolCalls == null || toolCalls.isEmpty()) {
return getChatMessageContentsAsync(completions)
.flatMap(m -> {
return Mono.just(messages.addChatMessage(m));
});
List<OpenAIChatMessageContent<?>> chatMessageContents = getChatMessageContentsAsync(
completions);
return Mono.just(messages.addChatMessage(chatMessageContents));
}

ChatRequestAssistantMessage requestMessage = new ChatRequestAssistantMessage(
Expand Down Expand Up @@ -592,7 +589,7 @@ private OpenAIFunctionToolCall extractOpenAIFunctionToolCall(
arguments);
}

private Mono<List<OpenAIChatMessageContent>> getChatMessageContentsAsync(
private List<OpenAIChatMessageContent<?>> getChatMessageContentsAsync(
ChatCompletions completions) {
FunctionResultMetadata<CompletionsUsage> completionMetadata = FunctionResultMetadata.build(
completions.getId(),
Expand All @@ -606,22 +603,28 @@ private Mono<List<OpenAIChatMessageContent>> getChatMessageContentsAsync(
.filter(Objects::nonNull)
.collect(Collectors.toList());

return Flux.fromIterable(responseMessages)
.flatMap(response -> {
List<OpenAIChatMessageContent<?>> chatMessageContent =
responseMessages
.stream()
.map(response -> {
try {
return Mono.just(new OpenAIChatMessageContent(
return new OpenAIChatMessageContent<>(
AuthorRole.ASSISTANT,
response.getContent(),
this.getModelId(),
null,
null,
completionMetadata,
formOpenAiToolCalls(response)));
} catch (Exception e) {
return Mono.error(e);
formOpenAiToolCalls(response));
} catch (SKCheckedException e) {
LOGGER.warn("Failed to form chat message content", e);
return null;
}
})
.collectList();
.filter(Objects::nonNull)
.collect(Collectors.toList());

return chatMessageContent;
}

private List<ChatMessageContent<?>> toOpenAIChatMessageContent(
Expand Down Expand Up @@ -931,7 +934,7 @@ private static boolean hasToolCallBeenExecuted(List<ChatRequestMessage> chatRequ
}

private static List<ChatRequestMessage> getChatRequestMessages(
List<? extends ChatMessageContent> messages) {
List<? extends ChatMessageContent<?>> messages) {
if (messages == null || messages.isEmpty()) {
return new ArrayList<>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public OpenAIChatMessageContent(
@Nullable String modelId,
@Nullable T innerContent,
@Nullable Charset encoding,
@Nullable FunctionResultMetadata metadata,
@Nullable FunctionResultMetadata<?> metadata,
@Nullable List<OpenAIFunctionToolCall> toolCall) {
super(authorRole, content, modelId, innerContent, encoding, metadata);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import com.microsoft.semantickernel.services.chatcompletion.message.ChatMessageTextContent;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Spliterator;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.function.Consumer;
import javax.annotation.Nullable;

Expand All @@ -18,7 +20,7 @@
*/
public class ChatHistory implements Iterable<ChatMessageContent<?>> {

private final List<ChatMessageContent<?>> chatMessageContents;
private final Collection<ChatMessageContent<?>> chatMessageContents;

/**
* The default constructor
Expand All @@ -33,7 +35,7 @@ public ChatHistory() {
* @param instructions The instructions to add to the chat history
*/
public ChatHistory(@Nullable String instructions) {
this.chatMessageContents = new ArrayList<>();
this.chatMessageContents = new ConcurrentLinkedQueue<>();
if (instructions != null) {
this.chatMessageContents.add(
ChatMessageTextContent.systemMessage(instructions));
Expand All @@ -45,8 +47,8 @@ public ChatHistory(@Nullable String instructions) {
*
* @param chatMessageContents The chat message contents to add to the chat history
*/
public ChatHistory(List<? extends ChatMessageContent> chatMessageContents) {
this.chatMessageContents = new ArrayList(chatMessageContents);
public ChatHistory(List<? extends ChatMessageContent<?>> chatMessageContents) {
this.chatMessageContents = new ConcurrentLinkedQueue<>(chatMessageContents);
}

/**
Expand All @@ -55,7 +57,7 @@ public ChatHistory(List<? extends ChatMessageContent> chatMessageContents) {
* @return List of messages in the chat
*/
public List<ChatMessageContent<?>> getMessages() {
return Collections.unmodifiableList(chatMessageContents);
return Collections.unmodifiableList(new ArrayList<>(chatMessageContents));
}

/**
Expand All @@ -67,7 +69,7 @@ public Optional<ChatMessageContent<?>> getLastMessage() {
if (chatMessageContents.isEmpty()) {
return Optional.empty();
}
return Optional.of(chatMessageContents.get(chatMessageContents.size() - 1));
return Optional.of(((ConcurrentLinkedQueue<ChatMessageContent<?>>)chatMessageContents).peek());
}

/**
Expand Down Expand Up @@ -114,7 +116,7 @@ public Spliterator<ChatMessageContent<?>> spliterator() {
* @param metadata The metadata of the message
*/
public ChatHistory addMessage(AuthorRole authorRole, String content, Charset encoding,
FunctionResultMetadata metadata) {
FunctionResultMetadata<?> metadata) {
chatMessageContents.add(
ChatMessageTextContent.builder()
.withAuthorRole(authorRole)
Expand Down

0 comments on commit 88e7fe7

Please sign in to comment.