Skip to content
67 changes: 47 additions & 20 deletions docs/getting_started/model_configuration.md

Large diffs are not rendered by default.

37 changes: 33 additions & 4 deletions sygra/config/models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,47 @@ gemini_2_5_pro:
# Bedrock model
bedrock_model:
model_type: bedrock
model: anthropic.claude-sonnet-4-5-20250929-v1:0
model: us.anthropic.claude-sonnet-4-5-20250929-v1:0
# aws_access_key_id, aws_secret_access_key, aws_region_name should be defined at .env file as SYGRA_BEDROCK_MODEL_AWS_ACCESS_KEY_ID, SYGRA_BEDROCK_MODEL_AWS_SECRET_ACCESS_KEY, SYGRA_BEDROCK_MODEL_AWS_REGION_NAME
parameters:
max_tokens: 5000
temperature: 0.5


#gpt 5
openai_gpt41:
model_type: azure_openai
model: gpt-4.1
api_version: 2024-02-15-preview
parameters:
max_tokens: 5000
temperature: 0.1

gpt-5:
model_type: azure_openai
model: gpt-5
# URL and auth_token should be defined at .env file as SYGRA_GPT-4O_URL and SYGRA_GPT-4O_TOKEN
api_version: 2024-08-01-preview
# URL and auth_token should be defined at .env file as SYGRA_GPT-5_URL and SYGRA_GPT-5_TOKEN
api_version: 2024-02-15-preview
parameters:
max_completion_tokens: 5000
temperature: 1

claude_large:
model_type: claude_proxy
client_type: http
backend: proxy
additional_params:
thinking:
type: disabled
anthropic_beta:
- computer-use-2025-01-24
parameters:
maxTokens: 2500
temperature: 0

gemini_2.5_proxy:
model_type: gemini_proxy
client_type: http
backend: proxy
parameters:
maxOutputTokens: 2500
temperature: 0
4 changes: 2 additions & 2 deletions sygra/core/dataset/dataset_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,10 +441,10 @@ async def _process_and_store_results(self):
metadata = {"output_file": output_file}
processor = utils.get_func_from_str(post_processor)
processor_name = post_processor.split(".")[-1]
processed_output_data = processor().process(output_data, metadata)
processed_output_data = await processor().process(output_data, metadata)
new_output_file = output_file[: output_file.rfind("/") + 1] + output_file[
output_file.rfind("/") + 1 :
].replace("output_", processor_name + "_", 1)
].replace("output", processor_name, 1)
with open(new_output_file, "w") as f:
logger.info(f"Writing metrics output to file {new_output_file}")
json.dump(processed_output_data, f, indent=4)
Expand Down
2 changes: 1 addition & 1 deletion sygra/core/graph/graph_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ class GraphPostProcessor(ABC):
"""

@abstractmethod
def process(self, data: list, metadata: dict) -> list:
async def process(self, data: list, metadata: dict) -> list:
# implement post processing logic with whole data, return the final data list
pass
2 changes: 1 addition & 1 deletion sygra/core/graph/langgraph/langchain_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def calculate_cost(model_name: str, prompt_tokens: int, completion_tokens: int)
return _get_anthropic_claude_token_cost(
prompt_tokens, completion_tokens, model_name
)
except (KeyError, ValueError):
except (KeyError, ValueError, IndexError):
# Claude model but not in Bedrock pricing table
pass

Expand Down
6 changes: 3 additions & 3 deletions sygra/core/graph/nodes/llm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,13 @@ def _generate_prompt_tmpl_from_msg(
for message in chat_frmt_messages:
contents = message["content"]
# if it's a normal text conversation then no need to expand
if isinstance(contents, str):
if isinstance(contents, str) or message.get("role", "") == "tool":
continue
expanded_contents = []
for item in contents:
if item["type"] == "image_url":
if item.get("type") and item.get("type") == "image_url":
expanded_contents.extend(expand_image_item(item, state))
elif item["type"] == "audio_url":
elif item.get("type") and item.get("type") == "audio_url":
expanded_contents.extend(expand_audio_item(item, state))
else:
expanded_contents.append(item)
Expand Down
55 changes: 50 additions & 5 deletions sygra/core/models/client/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
class ModelType(Enum):
"""Enum representing the supported model types for client creation."""

# model type or server type
VLLM = "vllm"
OPENAI = "openai"
AZURE_OPENAI = "azure_openai"
Expand All @@ -27,8 +28,14 @@ class ModelType(Enum):
TRITON = "triton"


class ClientType(Enum):
# client type
HTTP = "http"


# Define which model types do not require a AUTH_TOKEN
NO_AUTH_TOKEN_NEEDED_MODEL_TYPES = [ModelType.OLLAMA.value]
NO_AUTH_TOKEN_NEEDED_CLIENT_TYPES = [ClientType.HTTP.value]


class ClientFactory:
Expand Down Expand Up @@ -65,22 +72,49 @@ def create_client(
# Validate model_type is present
utils.validate_required_keys(["model_type"], model_config, "model")

# read model type or server type
model_type = model_config["model_type"].lower()

# read client type
client_type = (
model_config["client_type"].lower() if model_config.get("client_type") else None
)

# Validate if url is present
if url is None:
logger.error("URL is required for client creation.")
raise ValueError("URL is required for client creation.")

# Validate if auth_token is present
if auth_token is None and model_type not in NO_AUTH_TOKEN_NEEDED_MODEL_TYPES:
if auth_token is None and (
model_type not in NO_AUTH_TOKEN_NEEDED_MODEL_TYPES
and client_type not in NO_AUTH_TOKEN_NEEDED_CLIENT_TYPES
):
logger.error("Auth token/API key is required for client creation.")
raise ValueError("Auth token/API key is required for client creation.")

# TODO: add backend condition

# Validate that the model_type is supported
supported_types = [type_enum.value for type_enum in ModelType]
if model_type not in supported_types:
supported_types_str = ", ".join(supported_types)
supported_model_types = [type_enum.value for type_enum in ModelType]
if client_type is not None:
supported_client_types = [type_enum.value for type_enum in ClientType]
if client_type not in supported_client_types:
supported_types_str = ", ".join(supported_client_types)
logger.error(
f"Unsupported client type: {client_type}. Supported types: {supported_types_str}"
)
raise ValueError(
f"Unsupported client type: {client_type}. Must be one of: {supported_types_str}"
)

if client_type == ClientType.HTTP.value:
return cls._create_http_client(model_config, url, auth_token)
else:
logger.error(f"Unsupported client type: {client_type}")
# Add more client type which can be common
elif model_type not in supported_model_types:
supported_types_str = ", ".join(supported_model_types)
logger.error(
f"Unsupported model type: {model_type}. Supported types: {supported_types_str}"
)
Expand All @@ -89,8 +123,9 @@ def create_client(
)

# Create client based on model type
if model_type == ModelType.VLLM.value or model_type == ModelType.OPENAI.value:
elif model_type == ModelType.VLLM.value or model_type == ModelType.OPENAI.value:
# Initialize the client with default chat_completions_api
# for VLLM server type, client type is openai ONLY
return cls._create_openai_client(
model_config,
url,
Expand All @@ -99,6 +134,7 @@ def create_client(
not model_config.get("completions_api", False),
)
elif model_type == ModelType.AZURE_OPENAI.value:
# for AZURE_OPENAI server type, client type is openai_azure ONLY
return cls._create_openai_azure_client(
model_config,
url,
Expand All @@ -111,10 +147,13 @@ def create_client(
or model_type == ModelType.TGI.value
or model_type == ModelType.TRITON.value
):
# for AZURE/TGI/TRITON server type, client type is http ONLY
return cls._create_http_client(model_config, url, auth_token)
elif model_type == ModelType.MISTRALAI.value:
# for Mistral server type, client type is mistral ONLY
return cls._create_mistral_client(model_config, url, auth_token, async_client)
elif model_type == ModelType.OLLAMA.value:
# for OLLAMA server type, client type is ollama
return cls._create_ollama_client(
model_config,
url,
Expand Down Expand Up @@ -306,6 +345,12 @@ def _create_http_client(
auth_token = auth_token.replace("Bearer ", "")
headers["Authorization"] = f"Bearer {auth_token}"

if model_config.get("extra_headers"):
extra_headers = model_config.get("extra_headers")
if extra_headers:
for k, v in extra_headers.items():
headers[k] = v

timeout = model_config.get("timeout", constants.DEFAULT_TIMEOUT)
max_retries = model_config.get("max_retries", 3)

Expand Down
Loading