Skip to content

Commit

Permalink
OPIK-610 code style
Browse files Browse the repository at this point in the history
  • Loading branch information
idoberko2 committed Dec 29, 2024
1 parent fecd645 commit 33481ac
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ public ChatCompletionService(
}

public ChatCompletionResponse create(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
log.info("Creating chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
var llmProviderClient = llmProviderFactory.getService(workspaceId, request.model());
llmProviderClient.validateRequest(request);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ public Anthropic(LlmProviderClientConfig llmProviderClientConfig, String apiKey)

@Override
public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
var response = anthropicClient.createMessage(mapToAnthropicCreateMessageRequest(request));
var response = anthropicClient.createMessage(toAnthropicCreateMessageRequest(request));

return ChatCompletionResponse.builder()
.id(response.id)
.model(response.model)
.choices(response.content.stream().map(content -> mapContentToChoice(response, content))
.choices(response.content.stream().map(content -> toChatCompletionChoice(response, content))
.toList())
.usage(Usage.builder()
.promptTokens(response.usage.inputTokens)
Expand All @@ -71,7 +71,7 @@ public void generateStream(
@NonNull Consumer<ChatCompletionResponse> handleMessage,
@NonNull Runnable handleClose, @NonNull Consumer<Throwable> handleError) {
validateRequest(request);
anthropicClient.createMessage(mapToAnthropicCreateMessageRequest(request),
anthropicClient.createMessage(toAnthropicCreateMessageRequest(request),
new ChunkedResponseHandler(handleMessage, handleClose, handleError, request.model()));
}

Expand Down Expand Up @@ -100,22 +100,23 @@ public int getHttpErrorStatusCode(Throwable throwable) {
return 500;
}

private AnthropicCreateMessageRequest mapToAnthropicCreateMessageRequest(ChatCompletionRequest request) {
private AnthropicCreateMessageRequest toAnthropicCreateMessageRequest(ChatCompletionRequest request) {
var builder = AnthropicCreateMessageRequest.builder();
Optional.ofNullable(request.toolChoice())
.ifPresent(toolChoice -> builder.toolChoice(AnthropicToolChoice.from(request.toolChoice().toString())));
.ifPresent(toolChoice -> builder.toolChoice(AnthropicToolChoice.from(
request.toolChoice().toString())));
return builder
.stream(request.stream())
.model(request.model())
.messages(request.messages().stream().map(this::mapMessage).toList())
.messages(request.messages().stream().map(this::toMessage).toList())
.temperature(request.temperature())
.topP(request.topP())
.stopSequences(request.stop())
.maxTokens(request.maxCompletionTokens())
.build();
}

private AnthropicMessage mapMessage(Message message) {
private AnthropicMessage toMessage(Message message) {
if (message.role() == Role.ASSISTANT) {
return AnthropicMessage.builder()
.role(AnthropicRole.ASSISTANT)
Expand All @@ -124,23 +125,24 @@ private AnthropicMessage mapMessage(Message message) {
} else if (message.role() == Role.USER) {
return AnthropicMessage.builder()
.role(AnthropicRole.USER)
.content(List.of(mapMessageContent(((UserMessage) message).content())))
.content(List.of(toAnthropicMessageContent(((UserMessage) message).content())))
.build();
}

// Anthropic only supports assistant and user roles
throw new BadRequestException("not supported message role: " + message.role());
}

private AnthropicMessageContent mapMessageContent(Object rawContent) {
private AnthropicMessageContent toAnthropicMessageContent(Object rawContent) {
if (rawContent instanceof String content) {
return new AnthropicTextContent(content);
}

throw new BadRequestException("only text content is supported");
}

private ChatCompletionChoice mapContentToChoice(AnthropicCreateMessageResponse response, AnthropicContent content) {
private ChatCompletionChoice toChatCompletionChoice(
AnthropicCreateMessageResponse response, AnthropicContent content) {
return ChatCompletionChoice.builder()
.message(AssistantMessage.builder()
.name(content.name)
Expand Down Expand Up @@ -172,6 +174,7 @@ private AnthropicClient newClient(String apiKey) {
Optional.ofNullable(llmProviderClientConfig.getAnthropicClient())
.map(LlmProviderClientConfig.AnthropicClientConfig::logResponses)
.ifPresent(anthropicClientBuilder::logResponses);
// anthropic client builder only receives one timeout variant
Optional.ofNullable(llmProviderClientConfig.getCallTimeout())
.ifPresent(callTimeout -> anthropicClientBuilder.timeout(callTimeout.toJavaDuration()));
return anthropicClientBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ public LlmProviderService getService(@NonNull String workspaceId, @NonNull Strin

/**
* The agreed requirement is to resolve the LLM provider and its API key based on the model.
* Currently, only OPEN AI is supported, so model param is ignored.
* No further validation is needed on the model, as it's just forwarded in the OPEN AI request and will be rejected
* if not valid.
*/
private LlmProvider getLlmProvider(String model) {
if (isModelBelongToProvider(model, ChatCompletionModel.class, ChatCompletionModel::toString)) {
Expand All @@ -58,7 +55,7 @@ private LlmProvider getLlmProvider(String model) {
}

/**
* Finding API keys isn't paginated at the moment, since only OPEN AI is supported.
* Finding API keys isn't paginated at the moment.
* Even in the future, the number of supported LLM providers per workspace is going to be very low.
*/
private String getEncryptedApiKey(String workspaceId, LlmProvider llmProvider) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ public int getHttpErrorStatusCode(Throwable throwable) {
}

/**
* Initially, only OPEN AI is supported, so no need for a more sophisticated client resolution to start with.
* At the moment, openai4j client and also langchain4j wrappers, don't support dynamic API keys. That can imply
* an important performance penalty for next phases. The following options should be evaluated:
* - Cache clients, but can be unsafe.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@
import io.dropwizard.jackson.Jackson;
import io.dropwizard.jersey.validation.Validators;
import jakarta.validation.Validator;
import org.apache.commons.lang3.EnumUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import java.io.IOException;
import java.util.List;
import java.util.UUID;
import java.util.stream.Stream;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -52,15 +56,15 @@ public void tearDown() {
}

@ParameterizedTest
@EnumSource(value = ChatCompletionModel.class)
void testGetServiceOpenai(ChatCompletionModel model) {
@MethodSource
void testGetService(String model, LlmProvider llmProvider, Class<? extends LlmProviderService> providerClass) {
// setup
String workspaceId = UUID.randomUUID().toString();
String apiKey = UUID.randomUUID().toString();

when(llmProviderApiKeyService.find(workspaceId)).thenReturn(ProviderApiKey.ProviderApiKeyPage.builder()
.content(List.of(ProviderApiKey.builder()
.provider(LlmProvider.OPEN_AI)
.provider(llmProvider)
.apiKey(EncryptionUtils.encrypt(apiKey))
.build()))
.total(1)
Expand All @@ -71,35 +75,18 @@ void testGetServiceOpenai(ChatCompletionModel model) {
// SUT
var llmProviderFactory = new LlmProviderFactory(llmProviderClientConfig, llmProviderApiKeyService);

LlmProviderService actual = llmProviderFactory.getService(workspaceId, model.toString());
LlmProviderService actual = llmProviderFactory.getService(workspaceId, model);

// assertions
assertThat(actual).isInstanceOf(OpenAi.class);
assertThat(actual).isInstanceOf(providerClass);
}

@ParameterizedTest
@EnumSource(value = AnthropicChatModelName.class)
void testGetServiceAnthropic(AnthropicChatModelName model) {
// setup
String workspaceId = UUID.randomUUID().toString();
String apiKey = UUID.randomUUID().toString();

when(llmProviderApiKeyService.find(workspaceId)).thenReturn(ProviderApiKey.ProviderApiKeyPage.builder()
.content(List.of(ProviderApiKey.builder()
.provider(LlmProvider.ANTHROPIC)
.apiKey(EncryptionUtils.encrypt(apiKey))
.build()))
.total(1)
.page(1)
.size(1)
.build());
private static Stream<Arguments> testGetService() {
var openAiModels = EnumUtils.getEnumList(ChatCompletionModel.class).stream()
.map(model -> arguments(model.toString(), LlmProvider.OPEN_AI, OpenAi.class));
var anthropicModels = EnumUtils.getEnumList(AnthropicChatModelName.class).stream()
.map(model -> arguments(model.toString(), LlmProvider.ANTHROPIC, Anthropic.class));

// SUT
var llmProviderFactory = new LlmProviderFactory(llmProviderClientConfig, llmProviderApiKeyService);

LlmProviderService actual = llmProviderFactory.getService(workspaceId, model.toString());

// assertions
assertThat(actual).isInstanceOf(Anthropic.class);
return Stream.concat(openAiModels, anthropicModels);
}
}

0 comments on commit 33481ac

Please sign in to comment.