-
Notifications
You must be signed in to change notification settings - Fork 551
feat: Modify endpoints for OpenAPI compatibility #1340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
christinaexyou
wants to merge
3
commits into
NVIDIA-NeMo:develop
Choose a base branch
from
christinaexyou:openai-server-compat
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| import os.path | ||
| import re | ||
| import time | ||
| import uuid | ||
| import warnings | ||
| from contextlib import asynccontextmanager | ||
| from typing import Any, List, Optional | ||
|
|
@@ -207,10 +208,53 @@ class RequestBody(BaseModel): | |
| default=None, | ||
| description="A state object that should be used to continue the interaction.", | ||
| ) | ||
| # Standard OpenAI completion parameters | ||
| model: Optional[str] = Field( | ||
| default=None, | ||
| description="The model to use for chat completion. Maps to config_id for backward compatibility.", | ||
| ) | ||
| max_tokens: Optional[int] = Field( | ||
| default=None, | ||
| description="The maximum number of tokens to generate.", | ||
| ) | ||
| temperature: Optional[float] = Field( | ||
| default=None, | ||
| description="Sampling temperature to use.", | ||
| ) | ||
| top_p: Optional[float] = Field( | ||
| default=None, | ||
| description="Top-p sampling parameter.", | ||
| ) | ||
| stop: Optional[str] = Field( | ||
|
||
| default=None, | ||
| description="Stop sequences.", | ||
| ) | ||
| presence_penalty: Optional[float] = Field( | ||
| default=None, | ||
| description="Presence penalty parameter.", | ||
| ) | ||
| frequency_penalty: Optional[float] = Field( | ||
| default=None, | ||
| description="Frequency penalty parameter.", | ||
| ) | ||
| function_call: Optional[dict] = Field( | ||
| default=None, | ||
| description="Function call parameter.", | ||
| ) | ||
| logit_bias: Optional[dict] = Field( | ||
| default=None, | ||
| description="Logit bias parameter.", | ||
| ) | ||
| log_probs: Optional[bool] = Field( | ||
| default=None, | ||
| description="Log probabilities parameter.", | ||
| ) | ||
|
|
||
| @root_validator(pre=True) | ||
| def ensure_config_id(cls, data: Any) -> Any: | ||
| if isinstance(data, dict): | ||
| if data.get("model") is not None and data.get("config_id") is None: | ||
| data["config_id"] = data["model"] | ||
| if data.get("config_id") is not None and data.get("config_ids") is not None: | ||
| raise ValueError( | ||
| "Only one of config_id or config_ids should be specified" | ||
|
|
@@ -231,25 +275,113 @@ def ensure_config_ids(cls, v, values): | |
| return v | ||
|
|
||
|
|
||
| class Choice(BaseModel): | ||
| index: Optional[int] = Field( | ||
| default=None, description="The index of the choice in the list of choices." | ||
| ) | ||
| messages: Optional[dict] = Field( | ||
|
||
| default=None, description="The message of the choice" | ||
| ) | ||
| logprobs: Optional[dict] = Field( | ||
| default=None, description="The log probabilities of the choice" | ||
| ) | ||
| finish_reason: Optional[str] = Field( | ||
| default=None, description="The reason the model stopped generating tokens." | ||
| ) | ||
|
|
||
|
|
||
| class ResponseBody(BaseModel): | ||
| messages: List[dict] = Field( | ||
| default=None, description="The new messages in the conversation" | ||
| # OpenAI-compatible fields | ||
| id: Optional[str] = Field( | ||
| default=None, description="A unique identifier for the chat completion." | ||
| ) | ||
| llm_output: Optional[dict] = Field( | ||
| default=None, | ||
| description="Contains any additional output coming from the LLM.", | ||
| object: str = Field( | ||
| default="chat.completion", | ||
| description="The object type, which is always chat.completion", | ||
| ) | ||
| output_data: Optional[dict] = Field( | ||
| created: Optional[int] = Field( | ||
| default=None, | ||
| description="The output data, i.e. a dict with the values corresponding to the `output_vars`.", | ||
| description="The Unix timestamp (in seconds) of when the chat completion was created.", | ||
| ) | ||
| model: Optional[str] = Field( | ||
| default=None, description="The model used for the chat completion." | ||
| ) | ||
| log: Optional[GenerationLog] = Field( | ||
| default=None, description="Additional logging information." | ||
| choices: Optional[List[Choice]] = Field( | ||
| default=None, description="A list of chat completion choices." | ||
| ) | ||
| # NeMo-Guardrails specific fields for backward compatibility | ||
| state: Optional[dict] = Field( | ||
| default=None, | ||
| description="A state object that should be used to continue the interaction in the future.", | ||
| default=None, description="State object for continuing the conversation." | ||
| ) | ||
| llm_output: Optional[dict] = Field( | ||
| default=None, description="Additional LLM output data." | ||
| ) | ||
| output_data: Optional[dict] = Field( | ||
| default=None, description="Additional output data." | ||
| ) | ||
| log: Optional[dict] = Field(default=None, description="Generation log data.") | ||
|
|
||
|
|
||
| class Model(BaseModel): | ||
| id: str = Field( | ||
| description="The model identifier, which can be referenced in the API endpoints." | ||
| ) | ||
| object: str = Field( | ||
| default="model", description="The object type, which is always 'model'." | ||
| ) | ||
| created: int = Field( | ||
| description="The Unix timestamp (in seconds) of when the model was created." | ||
| ) | ||
| owned_by: str = Field( | ||
| default="nemo-guardrails", description="The organization that owns the model." | ||
| ) | ||
|
|
||
|
|
||
| class ModelsResponse(BaseModel): | ||
| object: str = Field( | ||
| default="list", description="The object type, which is always 'list'." | ||
| ) | ||
| data: List[Model] = Field(description="The list of models.") | ||
|
|
||
|
|
||
| @app.get( | ||
| "/v1/models", | ||
| response_model=ModelsResponse, | ||
| summary="List available models", | ||
| description="Lists the currently available models, mapping guardrails configurations to OpenAI-compatible model format.", | ||
| ) | ||
| async def get_models(): | ||
| """Returns the list of available models (guardrails configurations) in OpenAI-compatible format.""" | ||
|
|
||
| # Use the same logic as get_rails_configs to find available configurations | ||
| if app.single_config_mode: | ||
| config_ids = [app.single_config_id] if app.single_config_id else [] | ||
| else: | ||
| config_ids = [ | ||
| f | ||
| for f in os.listdir(app.rails_config_path) | ||
| if os.path.isdir(os.path.join(app.rails_config_path, f)) | ||
| and f[0] != "." | ||
| and f[0] != "_" | ||
| # Filter out all the configs for which there is no `config.yml` file. | ||
| and ( | ||
| os.path.exists(os.path.join(app.rails_config_path, f, "config.yml")) | ||
| or os.path.exists(os.path.join(app.rails_config_path, f, "config.yaml")) | ||
| ) | ||
| ] | ||
|
|
||
| # Convert configurations to OpenAI model format | ||
| models = [] | ||
| for config_id in config_ids: | ||
| model = Model( | ||
| id=config_id, | ||
| object="model", | ||
| created=int(time.time()), # Use current time as created timestamp | ||
| owned_by="nemo-guardrails", | ||
| ) | ||
| models.append(model) | ||
|
|
||
| return ModelsResponse(data=models) | ||
|
|
||
|
|
||
| @app.get( | ||
|
|
@@ -372,15 +504,24 @@ async def chat_completion(body: RequestBody, request: Request): | |
| llm_rails = _get_rails(config_ids) | ||
| except ValueError as ex: | ||
| log.exception(ex) | ||
| return { | ||
| "messages": [ | ||
| { | ||
| "role": "assistant", | ||
| "content": f"Could not load the {config_ids} guardrails configuration. " | ||
| f"An internal error has occurred.", | ||
| } | ||
| ] | ||
| } | ||
| return ResponseBody( | ||
| id=f"chatcmpl-{uuid.uuid4()}", | ||
| object="chat.completion", | ||
| created=int(time.time()), | ||
| model=config_ids[0] if config_ids else None, | ||
| choices=[ | ||
| Choice( | ||
| index=0, | ||
| messages={ | ||
| "content": f"Could not load the {config_ids} guardrails configuration. " | ||
| f"An internal error has occurred.", | ||
| "role": "assistant", | ||
| }, | ||
| finish_reason="error", | ||
| logprobs=None, | ||
| ) | ||
| ], | ||
| ) | ||
|
|
||
| try: | ||
| messages = body.messages | ||
|
|
@@ -396,14 +537,23 @@ async def chat_completion(body: RequestBody, request: Request): | |
|
|
||
| # We make sure the `thread_id` meets the minimum complexity requirement. | ||
| if len(body.thread_id) < 16: | ||
| return { | ||
| "messages": [ | ||
| { | ||
| "role": "assistant", | ||
| "content": "The `thread_id` must have a minimum length of 16 characters.", | ||
| } | ||
| ] | ||
| } | ||
| return ResponseBody( | ||
| id=f"chatcmpl-{uuid.uuid4()}", | ||
| object="chat.completion", | ||
| created=int(time.time()), | ||
| model=None, | ||
| choices=[ | ||
| Choice( | ||
| index=0, | ||
| messages={ | ||
| "content": "The `thread_id` must have a minimum length of 16 characters.", | ||
| "role": "assistant", | ||
| }, | ||
| finish_reason="error", | ||
| logprobs=None, | ||
| ) | ||
| ], | ||
| ) | ||
|
|
||
| # Fetch the existing thread messages. For easier management, we prepend | ||
| # the string `thread-` to all thread keys. | ||
|
|
@@ -413,6 +563,20 @@ async def chat_completion(body: RequestBody, request: Request): | |
| # And prepend them. | ||
| messages = thread_messages + messages | ||
|
|
||
| generation_options = body.options | ||
| if body.max_tokens: | ||
| generation_options.max_tokens = body.max_tokens | ||
| if body.temperature is not None: | ||
| generation_options.temperature = body.temperature | ||
| if body.top_p is not None: | ||
| generation_options.top_p = body.top_p | ||
| if body.stop: | ||
| generation_options.stop = body.stop | ||
| if body.presence_penalty is not None: | ||
| generation_options.presence_penalty = body.presence_penalty | ||
| if body.frequency_penalty is not None: | ||
| generation_options.frequency_penalty = body.frequency_penalty | ||
|
|
||
| if ( | ||
| body.stream | ||
| and llm_rails.config.streaming_supported | ||
|
|
@@ -431,8 +595,6 @@ async def chat_completion(body: RequestBody, request: Request): | |
| ) | ||
| ) | ||
|
|
||
| # TODO: Add support for thread_ids in streaming mode | ||
|
|
||
| return StreamingResponse(streaming_handler) | ||
| else: | ||
| res = await llm_rails.generate_async( | ||
|
|
@@ -450,22 +612,50 @@ async def chat_completion(body: RequestBody, request: Request): | |
| if body.thread_id: | ||
| await datastore.set(datastore_key, json.dumps(messages + [bot_message])) | ||
|
|
||
| result = {"messages": [bot_message]} | ||
| # Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions | ||
| response_kwargs = { | ||
| "id": f"chatcmpl-{uuid.uuid4()}", | ||
| "object": "chat.completion", | ||
| "created": int(time.time()), | ||
| "model": config_ids[0] if config_ids else None, | ||
| "choices": [ | ||
| Choice( | ||
| index=0, | ||
| messages=bot_message, | ||
| finish_reason="stop", | ||
| logprobs=None, | ||
| ) | ||
| ], | ||
| } | ||
|
|
||
| # If we have additional GenerationResponse fields, we return as well | ||
| # If we have additional GenerationResponse fields, include them for backward compatibility | ||
| if isinstance(res, GenerationResponse): | ||
| result["llm_output"] = res.llm_output | ||
| result["output_data"] = res.output_data | ||
| result["log"] = res.log | ||
| result["state"] = res.state | ||
| response_kwargs["llm_output"] = res.llm_output | ||
| response_kwargs["output_data"] = res.output_data | ||
| response_kwargs["log"] = res.log | ||
| response_kwargs["state"] = res.state | ||
|
|
||
| return result | ||
| return ResponseBody(**response_kwargs) | ||
|
|
||
| except Exception as ex: | ||
| log.exception(ex) | ||
| return { | ||
| "messages": [{"role": "assistant", "content": "Internal server error."}] | ||
| } | ||
| return ResponseBody( | ||
| id=f"chatcmpl-{uuid.uuid4()}", | ||
| object="chat.completion", | ||
| created=int(time.time()), | ||
| model=None, | ||
| choices=[ | ||
| Choice( | ||
| index=0, | ||
| messages={ | ||
| "content": "Internal server error", | ||
| "role": "assistant", | ||
| }, | ||
| finish_reason="error", | ||
| logprobs=None, | ||
| ) | ||
| ], | ||
| ) | ||
|
|
||
|
|
||
| # By default, there are no challenges | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's bring the OpenAI schema into a separate file- perhaps
server/schemes/openai