Skip to content

Commit

Permalink
refactor: use get_nested_field utility for file variable extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
srijanpatel committed Jan 20, 2025
1 parent f15fcc1 commit 877212e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
15 changes: 5 additions & 10 deletions backend/app/nodes/llm/single_llm_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jinja2 import Template
from pydantic import BaseModel, Field

from ...utils.pydantic_utils import json_schema_to_model
from ...utils.pydantic_utils import get_nested_field, json_schema_to_model

from ..base import (
BaseNodeInput,
Expand Down Expand Up @@ -111,15 +111,10 @@ async def run(self, input: BaseModel) -> BaseModel:
url_vars = {}
if "file" in self.config.url_variables:
# Split the input variable reference (e.g. "input_node.video_url")
parts = self.config.url_variables["file"].split(".")
if len(parts) == 2:
node_id, var_name = parts
else:
node_id, var_name = parts[0], parts[-1]
# Get the value from the input using the node_id
if node_id in raw_input_dict and var_name in raw_input_dict[node_id]:
# Always use image_url format regardless of file type
url_vars["image"] = raw_input_dict[node_id][var_name]
# Get the nested field value using the helper function
file_value = get_nested_field(self.config.url_variables["file"], input)
# Always use image_url format regardless of file type
url_vars["image"] = file_value

assistant_message_str = await generate_text(
messages=messages,
Expand Down
5 changes: 4 additions & 1 deletion backend/app/utils/pydantic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ def get_nested_field(field_name_with_dots: str, model: BaseModel) -> Any:
field_names = field_name_with_dots.split(".")
value = model
for field_name in field_names:
value = getattr(value, field_name)
if isinstance(value, dict):
return value.get(field_name, None) # type: ignore
else:
value = getattr(value, field_name)
return value


Expand Down

0 comments on commit 877212e

Please sign in to comment.