diff --git a/src/engine.py b/src/engine.py index 48e8aac..1605302 100644 --- a/src/engine.py +++ b/src/engine.py @@ -113,7 +113,7 @@ def _initialize_llm(self): return engine except Exception as e: logging.error("Error initializing vLLM engine: %s", e) - raise e + raise class OpenAIvLLMEngine(vLLMEngine): @@ -166,7 +166,7 @@ async def generate(self, openai_request: JobInput): async for response in self._handle_chat_or_completion_request(openai_request): yield response else: - yield create_error_response("Invalid route").model_dump() + yield {"error": create_error_response("Invalid route").model_dump()} async def _handle_model_request(self): models = await self.chat_engine.show_available_models() @@ -184,8 +184,9 @@ async def _handle_chat_or_completion_request(self, openai_request: JobInput): request = request_class( **openai_request.openai_input ) + print(request) except Exception as e: - yield create_error_response(str(e)).model_dump() + yield {"error": create_error_response(str(e))} return dummy_request = DummyRequest() @@ -219,4 +220,3 @@ async def _handle_chat_or_completion_request(self, openai_request: JobInput): if self.raw_openai_output: batch = "".join(batch) yield batch - \ No newline at end of file diff --git a/src/handler.py b/src/handler.py index 176ec7e..7a6bbdc 100644 --- a/src/handler.py +++ b/src/handler.py @@ -7,11 +7,19 @@ OpenAIvLLMEngine = OpenAIvLLMEngine(vllm_engine) async def handler(job): - job_input = JobInput(job["input"]) - engine = OpenAIvLLMEngine if job_input.openai_route else vllm_engine - results_generator = engine.generate(job_input) - async for batch in results_generator: - yield batch + try: + job_input = JobInput(job["input"]) + engine = OpenAIvLLMEngine if job_input.openai_route else vllm_engine + results_generator = engine.generate(job_input) + async for batch in results_generator: + # If there's any kind of error in the batch, format it + if isinstance(batch, dict) and 'error' in batch: + yield {"error": str(batch)} + else: + yield batch + except Exception as e: + yield {"error": str(e)} + return runpod.serverless.start( { diff --git a/src/utils.py b/src/utils.py index bfc8ce9..78810fd 100644 --- a/src/utils.py +++ b/src/utils.py @@ -78,6 +78,7 @@ def create_error_response(message: str, err_type: str = "BadRequestError", statu return ErrorResponse(message=message, type=err_type, code=status_code.value) + def get_int_bool_env(env_var: str, default: bool) -> bool: return int(os.getenv(env_var, int(default))) == 1