diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderStreamHandler.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderStreamHandler.java index a5bbe6d6a2..97744ba1d8 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderStreamHandler.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderStreamHandler.java @@ -1,13 +1,14 @@ package com.comet.opik.domain.llmproviders; import com.comet.opik.utils.JsonUtils; -import dev.ai4j.openai4j.OpenAiHttpException; import io.dropwizard.jersey.errors.ErrorMessage; import lombok.extern.slf4j.Slf4j; import org.glassfish.jersey.server.ChunkedOutput; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.function.Consumer; +import java.util.function.Function; @Slf4j public class LlmProviderStreamHandler { @@ -33,17 +34,18 @@ public void handleClose(ChunkedOutput chunkedOutput) { } } - public void handleError(Throwable throwable, ChunkedOutput chunkedOutput) { - log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, throwable); - var errorMessage = new ErrorMessage(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER); - if (throwable instanceof OpenAiHttpException openAiHttpException) { - errorMessage = new ErrorMessage(openAiHttpException.code(), openAiHttpException.getMessage()); - } - try { - handleMessage(errorMessage, chunkedOutput); - } catch (UncheckedIOException uncheckedIOException) { - log.error("Failed to stream error message to client", uncheckedIOException); - } - handleClose(chunkedOutput); + public Consumer getErrorHandler( + Function mapper, ChunkedOutput chunkedOutput) { + return throwable -> { + log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, throwable); + + var errorMessage = mapper.apply(throwable); + try { + handleMessage(errorMessage, chunkedOutput); + } catch (UncheckedIOException uncheckedIOException) { + log.error("Failed to stream error message to client", uncheckedIOException); + } + handleClose(chunkedOutput); + }; } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java index 9932844996..6b486247b5 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java @@ -6,6 +6,7 @@ import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; import dev.langchain4j.internal.RetryUtils; +import io.dropwizard.jersey.errors.ErrorMessage; import jakarta.inject.Inject; import jakarta.ws.rs.ClientErrorException; import jakarta.ws.rs.InternalServerErrorException; @@ -61,7 +62,7 @@ public ChunkedOutput generateStream(@NonNull ChatCompletionRequest reque .onPartialResponse( chatCompletionResponse -> streamHandler.handleMessage(chatCompletionResponse, chunkedOutput)) .onComplete(() -> streamHandler.handleClose(chunkedOutput)) - .onError(throwable -> streamHandler.handleError(throwable, chunkedOutput)) + .onError(streamHandler.getErrorHandler(this::errorMapper, chunkedOutput)) .execute(); log.info("Created and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model()); return chunkedOutput; @@ -97,4 +98,12 @@ private OpenAiClient newOpenAiClient(String apiKey) { .openAiApiKey(apiKey) .build(); } + + private ErrorMessage errorMapper(Throwable throwable) { + if (throwable instanceof OpenAiHttpException openAiHttpException) { + return new ErrorMessage(openAiHttpException.code(), openAiHttpException.getMessage()); + } + + return new ErrorMessage(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER); + } }