diff --git a/nemoguardrails/colang/v2_x/runtime/runtime.py b/nemoguardrails/colang/v2_x/runtime/runtime.py index 20044b8a6..19f4b68c0 100644 --- a/nemoguardrails/colang/v2_x/runtime/runtime.py +++ b/nemoguardrails/colang/v2_x/runtime/runtime.py @@ -33,6 +33,7 @@ ColangSyntaxError, ) from nemoguardrails.colang.v2_x.runtime.flows import Event, FlowStatus +from nemoguardrails.colang.v2_x.runtime.serialization import json_to_state from nemoguardrails.colang.v2_x.runtime.statemachine import ( FlowConfig, InternalEvent, @@ -439,10 +440,13 @@ async def process_events( ) initialize_state(state) elif isinstance(state, dict): - # TODO: Implement dict to State conversion - raise NotImplementedError() - # if isinstance(state, dict): - # state = State.from_dict(state) + # Convert dict to State object + if state.get("version") == "2.x" and "state" in state: + # Handle the serialized state format from API calls + state = json_to_state(state["state"]) + else: + # TODO: Implement other dict to State conversion formats if needed + raise NotImplementedError("Unsupported state dict format") assert isinstance(state, State) assert state.main_flow_state is not None diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index d07cb63df..1c2f922cb 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio import contextvars import importlib.util @@ -20,23 +21,27 @@ import os.path import re import time +import uuid import warnings from contextlib import asynccontextmanager from typing import Any, List, Optional from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, Field, root_validator, validator +from pydantic import Field, root_validator, validator from starlette.responses import StreamingResponse from starlette.staticfiles import StaticFiles from nemoguardrails import LLMRails, RailsConfig, utils -from nemoguardrails.rails.llm.options import ( - GenerationLog, - GenerationOptions, - GenerationResponse, -) +from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse from nemoguardrails.server.datastore.datastore import DataStore +from nemoguardrails.server.schemas.openai import ( + Choice, + Model, + ModelsResponse, + OpenAIRequestFields, + ResponseBody, +) from nemoguardrails.streaming import StreamingHandler logging.basicConfig(level=logging.INFO) @@ -168,7 +173,7 @@ async def root_handler(): app.single_config_id = None -class RequestBody(BaseModel): +class RequestBody(OpenAIRequestFields): config_id: Optional[str] = Field( default=os.getenv("DEFAULT_CONFIG_ID", None), description="The id of the configuration to be used. If not set, the default configuration will be used.", @@ -211,6 +216,8 @@ class RequestBody(BaseModel): @root_validator(pre=True) def ensure_config_id(cls, data: Any) -> Any: if isinstance(data, dict): + if data.get("model") is not None and data.get("config_id") is None: + data["config_id"] = data["model"] if data.get("config_id") is not None and data.get("config_ids") is not None: raise ValueError( "Only one of config_id or config_ids should be specified" @@ -231,25 +238,44 @@ def ensure_config_ids(cls, v, values): return v -class ResponseBody(BaseModel): - messages: List[dict] = Field( - default=None, description="The new messages in the conversation" - ) - llm_output: Optional[dict] = Field( - default=None, - description="Contains any additional output coming from the LLM.", - ) - output_data: Optional[dict] = Field( - default=None, - description="The output data, i.e. a dict with the values corresponding to the `output_vars`.", - ) - log: Optional[GenerationLog] = Field( - default=None, description="Additional logging information." - ) - state: Optional[dict] = Field( - default=None, - description="A state object that should be used to continue the interaction in the future.", - ) +@app.get( + "/v1/models", + response_model=ModelsResponse, + summary="List available models", + description="Lists the currently available models, mapping guardrails configurations to OpenAI-compatible model format.", +) +async def get_models(): + """Returns the list of available models (guardrails configurations) in OpenAI-compatible format.""" + + # Use the same logic as get_rails_configs to find available configurations + if app.single_config_mode: + config_ids = [app.single_config_id] if app.single_config_id else [] + else: + config_ids = [ + f + for f in os.listdir(app.rails_config_path) + if os.path.isdir(os.path.join(app.rails_config_path, f)) + and f[0] != "." + and f[0] != "_" + # Filter out all the configs for which there is no `config.yml` file. + and ( + os.path.exists(os.path.join(app.rails_config_path, f, "config.yml")) + or os.path.exists(os.path.join(app.rails_config_path, f, "config.yaml")) + ) + ] + + # Convert configurations to OpenAI model format + models = [] + for config_id in config_ids: + model = Model( + id=config_id, + object="model", + created=int(time.time()), # Use current time as created timestamp + owned_by="nemo-guardrails", + ) + models.append(model) + + return ModelsResponse(data=models) @app.get( @@ -372,15 +398,24 @@ async def chat_completion(body: RequestBody, request: Request): llm_rails = _get_rails(config_ids) except ValueError as ex: log.exception(ex) - return { - "messages": [ - { - "role": "assistant", - "content": f"Could not load the {config_ids} guardrails configuration. " - f"An internal error has occurred.", - } - ] - } + return ResponseBody( + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model=config_ids[0] if config_ids else None, + choices=[ + Choice( + index=0, + message={ + "content": f"Could not load the {config_ids} guardrails configuration. " + f"An internal error has occurred.", + "role": "assistant", + }, + finish_reason="error", + logprobs=None, + ) + ], + ) try: messages = body.messages @@ -396,14 +431,23 @@ async def chat_completion(body: RequestBody, request: Request): # We make sure the `thread_id` meets the minimum complexity requirement. if len(body.thread_id) < 16: - return { - "messages": [ - { - "role": "assistant", - "content": "The `thread_id` must have a minimum length of 16 characters.", - } - ] - } + return ResponseBody( + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model=None, + choices=[ + Choice( + index=0, + message={ + "content": "The `thread_id` must have a minimum length of 16 characters.", + "role": "assistant", + }, + finish_reason="error", + logprobs=None, + ) + ], + ) # Fetch the existing thread messages. For easier management, we prepend # the string `thread-` to all thread keys. @@ -413,6 +457,20 @@ async def chat_completion(body: RequestBody, request: Request): # And prepend them. messages = thread_messages + messages + generation_options = body.options + if body.max_tokens: + generation_options.max_tokens = body.max_tokens + if body.temperature is not None: + generation_options.temperature = body.temperature + if body.top_p is not None: + generation_options.top_p = body.top_p + if body.stop: + generation_options.stop = body.stop + if body.presence_penalty is not None: + generation_options.presence_penalty = body.presence_penalty + if body.frequency_penalty is not None: + generation_options.frequency_penalty = body.frequency_penalty + if ( body.stream and llm_rails.config.streaming_supported @@ -431,8 +489,6 @@ async def chat_completion(body: RequestBody, request: Request): ) ) - # TODO: Add support for thread_ids in streaming mode - return StreamingResponse(streaming_handler) else: res = await llm_rails.generate_async( @@ -450,22 +506,50 @@ async def chat_completion(body: RequestBody, request: Request): if body.thread_id: await datastore.set(datastore_key, json.dumps(messages + [bot_message])) - result = {"messages": [bot_message]} + # Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions + response_kwargs = { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": config_ids[0] if config_ids else None, + "choices": [ + Choice( + index=0, + message=bot_message, + finish_reason="stop", + logprobs=None, + ) + ], + } - # If we have additional GenerationResponse fields, we return as well + # If we have additional GenerationResponse fields, include them for backward compatibility if isinstance(res, GenerationResponse): - result["llm_output"] = res.llm_output - result["output_data"] = res.output_data - result["log"] = res.log - result["state"] = res.state + response_kwargs["llm_output"] = res.llm_output + response_kwargs["output_data"] = res.output_data + response_kwargs["log"] = res.log + response_kwargs["state"] = res.state - return result + return ResponseBody(**response_kwargs) except Exception as ex: log.exception(ex) - return { - "messages": [{"role": "assistant", "content": "Internal server error."}] - } + return ResponseBody( + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model=None, + choices=[ + Choice( + index=0, + message={ + "content": "Internal server error", + "role": "assistant", + }, + finish_reason="error", + logprobs=None, + ) + ], + ) # By default, there are no challenges diff --git a/nemoguardrails/server/schemas/openai.py b/nemoguardrails/server/schemas/openai.py new file mode 100644 index 000000000..78556d803 --- /dev/null +++ b/nemoguardrails/server/schemas/openai.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenAI API schema definitions for the NeMo Guardrails server.""" + +from typing import List, Optional, Union + +from pydantic import BaseModel, Field + + +class OpenAIRequestFields(BaseModel): + """OpenAI API request fields that can be mixed into other request schemas.""" + + # Standard OpenAI completion parameters + model: Optional[str] = Field( + default=None, + description="The model to use for chat completion. Maps to config_id for backward compatibility.", + ) + max_tokens: Optional[int] = Field( + default=None, + description="The maximum number of tokens to generate.", + ) + temperature: Optional[float] = Field( + default=None, + description="Sampling temperature to use.", + ) + top_p: Optional[float] = Field( + default=None, + description="Top-p sampling parameter.", + ) + stop: Optional[Union[str, List[str]]] = Field( + default=None, + description="Stop sequences.", + ) + presence_penalty: Optional[float] = Field( + default=None, + description="Presence penalty parameter.", + ) + frequency_penalty: Optional[float] = Field( + default=None, + description="Frequency penalty parameter.", + ) + function_call: Optional[dict] = Field( + default=None, + description="Function call parameter.", + ) + logit_bias: Optional[dict] = Field( + default=None, + description="Logit bias parameter.", + ) + log_probs: Optional[bool] = Field( + default=None, + description="Log probabilities parameter.", + ) + + +class Choice(BaseModel): + """OpenAI API choice structure in chat completion responses.""" + + index: Optional[int] = Field( + default=None, description="The index of the choice in the list of choices." + ) + message: Optional[dict] = Field( + default=None, description="The message of the choice" + ) + logprobs: Optional[dict] = Field( + default=None, description="The log probabilities of the choice" + ) + finish_reason: Optional[str] = Field( + default=None, description="The reason the model stopped generating tokens." + ) + + +class ResponseBody(BaseModel): + """OpenAI API response body with NeMo-Guardrails extensions.""" + + # OpenAI API fields + id: Optional[str] = Field( + default=None, description="A unique identifier for the chat completion." + ) + object: str = Field( + default="chat.completion", + description="The object type, which is always chat.completion", + ) + created: Optional[int] = Field( + default=None, + description="The Unix timestamp (in seconds) of when the chat completion was created.", + ) + model: Optional[str] = Field( + default=None, description="The model used for the chat completion." + ) + choices: Optional[List[Choice]] = Field( + default=None, description="A list of chat completion choices." + ) + # NeMo-Guardrails specific fields for backward compatibility + state: Optional[dict] = Field( + default=None, description="State object for continuing the conversation." + ) + llm_output: Optional[dict] = Field( + default=None, description="Additional LLM output data." + ) + output_data: Optional[dict] = Field( + default=None, description="Additional output data." + ) + log: Optional[dict] = Field(default=None, description="Generation log data.") + + +class Model(BaseModel): + """OpenAI API model representation.""" + + id: str = Field( + description="The model identifier, which can be referenced in the API endpoints." + ) + object: str = Field( + default="model", description="The object type, which is always 'model'." + ) + created: int = Field( + description="The Unix timestamp (in seconds) of when the model was created." + ) + owned_by: str = Field( + default="nemo-guardrails", description="The organization that owns the model." + ) + + +class ModelsResponse(BaseModel): + """OpenAI API models list response.""" + + object: str = Field( + default="list", description="The object type, which is always 'list'." + ) + data: List[Model] = Field(description="The list of models.") diff --git a/tests/test_api.py b/tests/test_api.py index 819fb9381..2b77949c2 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -43,6 +43,26 @@ def test_get(): assert len(result) > 0 +def test_get_models(): + """Test the OpenAI-compatible /v1/models endpoint.""" + response = client.get("/v1/models") + assert response.status_code == 200 + + result = response.json() + + # Check OpenAI models list format + assert result["object"] == "list" + assert "data" in result + assert len(result["data"]) > 0 + + # Check each model has the required OpenAI format + for model in result["data"]: + assert "id" in model + assert model["object"] == "model" + assert "created" in model + assert model["owned_by"] == "nemo-guardrails" + + @pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") def test_chat_completion(): response = client.post( @@ -59,8 +79,14 @@ def test_chat_completion(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] + # Check OpenAI-compatible response structure + assert res["object"] == "chat.completion" + assert "id" in res + assert "created" in res + assert "model" in res + assert len(res["choices"]) == 1 + assert res["choices"][0]["message"]["content"] + assert res["choices"][0]["message"]["role"] == "assistant" @pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") @@ -80,8 +106,14 @@ def test_chat_completion_with_default_configs(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] + # Check OpenAI-compatible response structure + assert res["object"] == "chat.completion" + assert "id" in res + assert "created" in res + assert "model" in res + assert len(res["choices"]) == 1 + assert res["choices"][0]["message"]["content"] + assert res["choices"][0]["message"]["role"] == "assistant" def test_request_body_validation(): @@ -117,6 +149,31 @@ def test_request_body_validation(): assert request_body.config_ids is None +def test_openai_model_field_mapping(): + """Test OpenAI-compatible model field mapping to config_id.""" + + # Test model field maps to config_id + data = { + "model": "test_model", + "messages": [{"role": "user", "content": "Hello"}], + } + request_body = RequestBody.model_validate(data) + assert request_body.model == "test_model" + assert request_body.config_id == "test_model" + assert request_body.config_ids == ["test_model"] + + # Test model and config_id both provided (config_id takes precedence) + data = { + "model": "test_model", + "config_id": "test_config", + "messages": [{"role": "user", "content": "Hello"}], + } + request_body = RequestBody.model_validate(data) + assert request_body.model == "test_model" + assert request_body.config_id == "test_config" + assert request_body.config_ids == ["test_config"] + + def test_request_body_state(): """Test RequestBody state handling.""" data = { diff --git a/tests/test_server_calls_with_state.py b/tests/test_server_calls_with_state.py index ab9d74d04..499ea5ddc 100644 --- a/tests/test_server_calls_with_state.py +++ b/tests/test_server_calls_with_state.py @@ -37,12 +37,15 @@ def _test_call(config_id): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] == "Hello!" + print(res) + assert len(res["choices"][0]["message"]) == 2 + assert res["choices"][0]["message"]["content"] == "Hello!" assert res.get("state") # When making a second call with the returned state, the conversations should continue # and we should get the "Hello again!" message. + # For Colang 2.x, we only send the new user message, not the conversation history + # since the state maintains the conversation context. response = client.post( "/v1/chat/completions", json={ @@ -57,7 +60,7 @@ def _test_call(config_id): }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"][0]["message"]["content"] == "Hello again!" def test_1(): diff --git a/tests/test_threads.py b/tests/test_threads.py index 4dd4e12dd..27f666820 100644 --- a/tests/test_threads.py +++ b/tests/test_threads.py @@ -53,8 +53,8 @@ def test_1(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] == "Hello!" + assert len(res["choices"][0]["message"]) == 2 + assert res["choices"][0]["message"]["content"] == "Hello!" # When making a second call with the same thread_id, the conversations should continue # and we should get the "Hello again!" message. @@ -72,7 +72,7 @@ def test_1(): }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"][0]["message"]["content"] == "Hello again!" @pytest.mark.parametrize( @@ -140,4 +140,4 @@ def test_with_redis(): }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"]["messages"][0]["content"] == "Hello again!"