Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions nemoguardrails/colang/v2_x/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ColangSyntaxError,
)
from nemoguardrails.colang.v2_x.runtime.flows import Event, FlowStatus
from nemoguardrails.colang.v2_x.runtime.serialization import json_to_state
from nemoguardrails.colang.v2_x.runtime.statemachine import (
FlowConfig,
InternalEvent,
Expand Down Expand Up @@ -439,10 +440,13 @@ async def process_events(
)
initialize_state(state)
elif isinstance(state, dict):
# TODO: Implement dict to State conversion
raise NotImplementedError()
# if isinstance(state, dict):
# state = State.from_dict(state)
# Convert dict to State object
if state.get("version") == "2.x" and "state" in state:
# Handle the serialized state format from API calls
state = json_to_state(state["state"])
else:
# TODO: Implement other dict to State conversion formats if needed
raise NotImplementedError("Unsupported state dict format")

assert isinstance(state, State)
assert state.main_flow_state is not None
Expand Down
270 changes: 230 additions & 40 deletions nemoguardrails/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os.path
import re
import time
import uuid
import warnings
from contextlib import asynccontextmanager
from typing import Any, List, Optional
Expand Down Expand Up @@ -207,10 +208,53 @@ class RequestBody(BaseModel):
default=None,
description="A state object that should be used to continue the interaction.",
)
# Standard OpenAI completion parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's bring the OpenAI schema into a separate file- perhaps server/schemes/openai

model: Optional[str] = Field(
default=None,
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
)
max_tokens: Optional[int] = Field(
default=None,
description="The maximum number of tokens to generate.",
)
temperature: Optional[float] = Field(
default=None,
description="Sampling temperature to use.",
)
top_p: Optional[float] = Field(
default=None,
description="Top-p sampling parameter.",
)
stop: Optional[str] = Field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stop needs to be Optional[Union[str, List[str]]]

default=None,
description="Stop sequences.",
)
presence_penalty: Optional[float] = Field(
default=None,
description="Presence penalty parameter.",
)
frequency_penalty: Optional[float] = Field(
default=None,
description="Frequency penalty parameter.",
)
function_call: Optional[dict] = Field(
default=None,
description="Function call parameter.",
)
logit_bias: Optional[dict] = Field(
default=None,
description="Logit bias parameter.",
)
log_probs: Optional[bool] = Field(
default=None,
description="Log probabilities parameter.",
)

@root_validator(pre=True)
def ensure_config_id(cls, data: Any) -> Any:
if isinstance(data, dict):
if data.get("model") is not None and data.get("config_id") is None:
data["config_id"] = data["model"]
if data.get("config_id") is not None and data.get("config_ids") is not None:
raise ValueError(
"Only one of config_id or config_ids should be specified"
Expand All @@ -231,25 +275,113 @@ def ensure_config_ids(cls, v, values):
return v


class Choice(BaseModel):
index: Optional[int] = Field(
default=None, description="The index of the choice in the list of choices."
)
messages: Optional[dict] = Field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: messages needs to be message (no s)

default=None, description="The message of the choice"
)
logprobs: Optional[dict] = Field(
default=None, description="The log probabilities of the choice"
)
finish_reason: Optional[str] = Field(
default=None, description="The reason the model stopped generating tokens."
)


class ResponseBody(BaseModel):
messages: List[dict] = Field(
default=None, description="The new messages in the conversation"
# OpenAI-compatible fields
id: Optional[str] = Field(
default=None, description="A unique identifier for the chat completion."
)
llm_output: Optional[dict] = Field(
default=None,
description="Contains any additional output coming from the LLM.",
object: str = Field(
default="chat.completion",
description="The object type, which is always chat.completion",
)
output_data: Optional[dict] = Field(
created: Optional[int] = Field(
default=None,
description="The output data, i.e. a dict with the values corresponding to the `output_vars`.",
description="The Unix timestamp (in seconds) of when the chat completion was created.",
)
model: Optional[str] = Field(
default=None, description="The model used for the chat completion."
)
log: Optional[GenerationLog] = Field(
default=None, description="Additional logging information."
choices: Optional[List[Choice]] = Field(
default=None, description="A list of chat completion choices."
)
# NeMo-Guardrails specific fields for backward compatibility
state: Optional[dict] = Field(
default=None,
description="A state object that should be used to continue the interaction in the future.",
default=None, description="State object for continuing the conversation."
)
llm_output: Optional[dict] = Field(
default=None, description="Additional LLM output data."
)
output_data: Optional[dict] = Field(
default=None, description="Additional output data."
)
log: Optional[dict] = Field(default=None, description="Generation log data.")


class Model(BaseModel):
id: str = Field(
description="The model identifier, which can be referenced in the API endpoints."
)
object: str = Field(
default="model", description="The object type, which is always 'model'."
)
created: int = Field(
description="The Unix timestamp (in seconds) of when the model was created."
)
owned_by: str = Field(
default="nemo-guardrails", description="The organization that owns the model."
)


class ModelsResponse(BaseModel):
object: str = Field(
default="list", description="The object type, which is always 'list'."
)
data: List[Model] = Field(description="The list of models.")


@app.get(
"/v1/models",
response_model=ModelsResponse,
summary="List available models",
description="Lists the currently available models, mapping guardrails configurations to OpenAI-compatible model format.",
)
async def get_models():
"""Returns the list of available models (guardrails configurations) in OpenAI-compatible format."""

# Use the same logic as get_rails_configs to find available configurations
if app.single_config_mode:
config_ids = [app.single_config_id] if app.single_config_id else []
else:
config_ids = [
f
for f in os.listdir(app.rails_config_path)
if os.path.isdir(os.path.join(app.rails_config_path, f))
and f[0] != "."
and f[0] != "_"
# Filter out all the configs for which there is no `config.yml` file.
and (
os.path.exists(os.path.join(app.rails_config_path, f, "config.yml"))
or os.path.exists(os.path.join(app.rails_config_path, f, "config.yaml"))
)
]

# Convert configurations to OpenAI model format
models = []
for config_id in config_ids:
model = Model(
id=config_id,
object="model",
created=int(time.time()), # Use current time as created timestamp
owned_by="nemo-guardrails",
)
models.append(model)

return ModelsResponse(data=models)


@app.get(
Expand Down Expand Up @@ -372,15 +504,24 @@ async def chat_completion(body: RequestBody, request: Request):
llm_rails = _get_rails(config_ids)
except ValueError as ex:
log.exception(ex)
return {
"messages": [
{
"role": "assistant",
"content": f"Could not load the {config_ids} guardrails configuration. "
f"An internal error has occurred.",
}
]
}
return ResponseBody(
id=f"chatcmpl-{uuid.uuid4()}",
object="chat.completion",
created=int(time.time()),
model=config_ids[0] if config_ids else None,
choices=[
Choice(
index=0,
messages={
"content": f"Could not load the {config_ids} guardrails configuration. "
f"An internal error has occurred.",
"role": "assistant",
},
finish_reason="error",
logprobs=None,
)
],
)

try:
messages = body.messages
Expand All @@ -396,14 +537,23 @@ async def chat_completion(body: RequestBody, request: Request):

# We make sure the `thread_id` meets the minimum complexity requirement.
if len(body.thread_id) < 16:
return {
"messages": [
{
"role": "assistant",
"content": "The `thread_id` must have a minimum length of 16 characters.",
}
]
}
return ResponseBody(
id=f"chatcmpl-{uuid.uuid4()}",
object="chat.completion",
created=int(time.time()),
model=None,
choices=[
Choice(
index=0,
messages={
"content": "The `thread_id` must have a minimum length of 16 characters.",
"role": "assistant",
},
finish_reason="error",
logprobs=None,
)
],
)

# Fetch the existing thread messages. For easier management, we prepend
# the string `thread-` to all thread keys.
Expand All @@ -413,6 +563,20 @@ async def chat_completion(body: RequestBody, request: Request):
# And prepend them.
messages = thread_messages + messages

generation_options = body.options
if body.max_tokens:
generation_options.max_tokens = body.max_tokens
if body.temperature is not None:
generation_options.temperature = body.temperature
if body.top_p is not None:
generation_options.top_p = body.top_p
if body.stop:
generation_options.stop = body.stop
if body.presence_penalty is not None:
generation_options.presence_penalty = body.presence_penalty
if body.frequency_penalty is not None:
generation_options.frequency_penalty = body.frequency_penalty

if (
body.stream
and llm_rails.config.streaming_supported
Expand All @@ -431,8 +595,6 @@ async def chat_completion(body: RequestBody, request: Request):
)
)

# TODO: Add support for thread_ids in streaming mode

return StreamingResponse(streaming_handler)
else:
res = await llm_rails.generate_async(
Expand All @@ -450,22 +612,50 @@ async def chat_completion(body: RequestBody, request: Request):
if body.thread_id:
await datastore.set(datastore_key, json.dumps(messages + [bot_message]))

result = {"messages": [bot_message]}
# Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions
response_kwargs = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": config_ids[0] if config_ids else None,
"choices": [
Choice(
index=0,
messages=bot_message,
finish_reason="stop",
logprobs=None,
)
],
}

# If we have additional GenerationResponse fields, we return as well
# If we have additional GenerationResponse fields, include them for backward compatibility
if isinstance(res, GenerationResponse):
result["llm_output"] = res.llm_output
result["output_data"] = res.output_data
result["log"] = res.log
result["state"] = res.state
response_kwargs["llm_output"] = res.llm_output
response_kwargs["output_data"] = res.output_data
response_kwargs["log"] = res.log
response_kwargs["state"] = res.state

return result
return ResponseBody(**response_kwargs)

except Exception as ex:
log.exception(ex)
return {
"messages": [{"role": "assistant", "content": "Internal server error."}]
}
return ResponseBody(
id=f"chatcmpl-{uuid.uuid4()}",
object="chat.completion",
created=int(time.time()),
model=None,
choices=[
Choice(
index=0,
messages={
"content": "Internal server error",
"role": "assistant",
},
finish_reason="error",
logprobs=None,
)
],
)


# By default, there are no challenges
Expand Down
Loading