Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions sygra/config/models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,50 @@ gemini_2_5_pro:
# Bedrock model
bedrock_model:
model_type: bedrock
model: anthropic.claude-sonnet-4-5-20250929-v1:0
model: us.anthropic.claude-sonnet-4-5-20250929-v1:0
# aws_access_key_id, aws_secret_access_key, aws_region_name should be defined at .env file as SYGRA_BEDROCK_MODEL_AWS_ACCESS_KEY_ID, SYGRA_BEDROCK_MODEL_AWS_SECRET_ACCESS_KEY, SYGRA_BEDROCK_MODEL_AWS_REGION_NAME
parameters:
max_tokens: 5000
temperature: 0.5


#gpt 5
openai_gpt41:
model_type: azure_openai
model: gpt-4.1
api_version: 2024-02-15-preview
parameters:
max_tokens: 5000
temperature: 0.1

gpt-5:
model_type: azure_openai
model: gpt-5
# URL and auth_token should be defined at .env file as SYGRA_GPT-4O_URL and SYGRA_GPT-4O_TOKEN
api_version: 2024-08-01-preview
# URL and auth_token should be defined at .env file as SYGRA_GPT-5_URL and SYGRA_GPT-5_TOKEN
api_version: 2024-02-15-preview
parameters:
max_completion_tokens: 5000
temperature: 1

claude_large:
model_type: claude_proxy
client_type: http
backend: proxy
extra_headers:
x-now-transaction-id: "19076269ff422210f429ffffffffff8f"
x-now-request-metadata: "{\"trace_id\":\"1d076269664222105c1a6ddecc4f6aa6\",\"solution_name\":\"NAA Reasoning Engine (Claude Large)\",\"instance_id\":\"66e9ef02ad712210a6f9e69efb0d20db\",\"instance_name\":\"glide_db_dump\",\"instance_version\":\"27\",\"system_id\":\"192.168.29.108:node-a\",\"traceparent\":\"00-1d076269664222105c1a6ddecc4f6aa6-1d07626966422210-01\",\"solution_id\":\"473b82a5ff022210f429ffffffffffa1\",\"solution_capability\":\"NAA Reasoning Claude\",\"node_id\":\"node-a\",\"anthropic-beta\":\"computer-use-2025-01-24\"}"
additional_params:
thinking:
type: disabled
anthropic_beta:
- computer-use-2025-01-24
parameters:
maxTokens: 2500
temperature: 0

gemini_2.5_proxy:
model_type: gemini_proxy
client_type: http
backend: proxy
parameters:
maxOutputTokens: 2500
temperature: 0
4 changes: 2 additions & 2 deletions sygra/core/dataset/dataset_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,10 +441,10 @@ async def _process_and_store_results(self):
metadata = {"output_file": output_file}
processor = utils.get_func_from_str(post_processor)
processor_name = post_processor.split(".")[-1]
processed_output_data = processor().process(output_data, metadata)
processed_output_data = await processor().process(output_data, metadata)
new_output_file = output_file[: output_file.rfind("/") + 1] + output_file[
output_file.rfind("/") + 1 :
].replace("output_", processor_name + "_", 1)
].replace("output", processor_name, 1)
with open(new_output_file, "w") as f:
logger.info(f"Writing metrics output to file {new_output_file}")
json.dump(processed_output_data, f, indent=4)
Expand Down
2 changes: 1 addition & 1 deletion sygra/core/graph/graph_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ class GraphPostProcessor(ABC):
"""

@abstractmethod
def process(self, data: list, metadata: dict) -> list:
async def process(self, data: list, metadata: dict) -> list:
# implement post processing logic with whole data, return the final data list
pass
2 changes: 1 addition & 1 deletion sygra/core/graph/langgraph/langchain_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def calculate_cost(model_name: str, prompt_tokens: int, completion_tokens: int)
return _get_anthropic_claude_token_cost(
prompt_tokens, completion_tokens, model_name
)
except (KeyError, ValueError):
except (KeyError, ValueError, IndexError):
# Claude model but not in Bedrock pricing table
pass

Expand Down
6 changes: 3 additions & 3 deletions sygra/core/graph/nodes/llm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,13 @@ def _generate_prompt_tmpl_from_msg(
for message in chat_frmt_messages:
contents = message["content"]
# if it's a normal text conversation then no need to expand
if isinstance(contents, str):
if isinstance(contents, str) or message.get("role", "") == "tool":
continue
expanded_contents = []
for item in contents:
if item["type"] == "image_url":
if item.get("type") and item.get("type") == "image_url":
expanded_contents.extend(expand_image_item(item, state))
elif item["type"] == "audio_url":
elif item.get("type") and item.get("type") == "audio_url":
expanded_contents.extend(expand_audio_item(item, state))
else:
expanded_contents.append(item)
Expand Down
56 changes: 48 additions & 8 deletions sygra/core/models/client/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
class ModelType(Enum):
"""Enum representing the supported model types for client creation."""

# model type or server type
VLLM = "vllm"
OPENAI = "openai"
AZURE_OPENAI = "azure_openai"
Expand All @@ -27,8 +28,14 @@ class ModelType(Enum):
TRITON = "triton"


class ClientType(Enum):
# client type
HTTP = "http"


# Define which model types do not require a AUTH_TOKEN
NO_AUTH_TOKEN_NEEDED_MODEL_TYPES = [ModelType.OLLAMA.value]
NO_AUTH_TOKEN_NEEDED_CLIENT_TYPES = [ClientType.HTTP.value]


class ClientFactory:
Expand Down Expand Up @@ -65,22 +72,44 @@ def create_client(
# Validate model_type is present
utils.validate_required_keys(["model_type"], model_config, "model")

# read model type or server type
model_type = model_config["model_type"].lower()

# read client type
client_type = (
model_config["client_type"].lower() if model_config.get("client_type") else None
)

# Validate if url is present
if url is None:
logger.error("URL is required for client creation.")
raise ValueError("URL is required for client creation.")

# Validate if auth_token is present
if auth_token is None and model_type not in NO_AUTH_TOKEN_NEEDED_MODEL_TYPES:
logger.error("Auth token/API key is required for client creation.")
raise ValueError("Auth token/API key is required for client creation.")
# auth_token validation is not required TODO add backend condition
# if auth_token is None and (model_type not in NO_AUTH_TOKEN_NEEDED_MODEL_TYPES):
# logger.error("Auth token/API key is required for client creation.")
# raise ValueError("Auth token/API key is required for client creation.")

# Validate that the model_type is supported
supported_types = [type_enum.value for type_enum in ModelType]
if model_type not in supported_types:
supported_types_str = ", ".join(supported_types)
supported_model_types = [type_enum.value for type_enum in ModelType]
if client_type is not None:
supported_client_types = [type_enum.value for type_enum in ClientType]
if client_type not in supported_client_types:
supported_types_str = ", ".join(supported_client_types)
logger.error(
f"Unsupported client type: {client_type}. Supported types: {supported_types_str}"
)
raise ValueError(
f"Unsupported client type: {client_type}. Must be one of: {supported_types_str}"
)

if client_type == ClientType.HTTP.value:
return cls._create_http_client(model_config, url, auth_token)
else:
logger.error(f"Unsupported client type: {client_type}")
# Add more client type which can be common
elif model_type not in supported_model_types:
supported_types_str = ", ".join(supported_model_types)
logger.error(
f"Unsupported model type: {model_type}. Supported types: {supported_types_str}"
)
Expand All @@ -89,8 +118,9 @@ def create_client(
)

# Create client based on model type
if model_type == ModelType.VLLM.value or model_type == ModelType.OPENAI.value:
elif model_type == ModelType.VLLM.value or model_type == ModelType.OPENAI.value:
# Initialize the client with default chat_completions_api
# for VLLM server type, client type is openai ONLY
return cls._create_openai_client(
model_config,
url,
Expand All @@ -99,6 +129,7 @@ def create_client(
not model_config.get("completions_api", False),
)
elif model_type == ModelType.AZURE_OPENAI.value:
# for AZURE_OPENAI server type, client type is openai_azure ONLY
return cls._create_openai_azure_client(
model_config,
url,
Expand All @@ -111,10 +142,13 @@ def create_client(
or model_type == ModelType.TGI.value
or model_type == ModelType.TRITON.value
):
# for AZURE/TGI/TRITON server type, client type is http ONLY
return cls._create_http_client(model_config, url, auth_token)
elif model_type == ModelType.MISTRALAI.value:
# for Mistral server type, client type is mistral ONLY
return cls._create_mistral_client(model_config, url, auth_token, async_client)
elif model_type == ModelType.OLLAMA.value:
# for OLLAMA server type, client type is ollama
return cls._create_ollama_client(
model_config,
url,
Expand Down Expand Up @@ -306,6 +340,12 @@ def _create_http_client(
auth_token = auth_token.replace("Bearer ", "")
headers["Authorization"] = f"Bearer {auth_token}"

if model_config.get("extra_headers"):
extra_headers = model_config.get("extra_headers")
if extra_headers:
for k, v in extra_headers.items():
headers[k] = v

timeout = model_config.get("timeout", constants.DEFAULT_TIMEOUT)
max_retries = model_config.get("max_retries", 3)

Expand Down
Loading