Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions lib/crewai/src/crewai/flow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import ast
from collections import defaultdict, deque
from enum import Enum
import inspect
import textwrap
from typing import TYPE_CHECKING, Any
Expand All @@ -40,11 +41,123 @@
_printer = Printer()


def _extract_string_literals_from_type_annotation(
node: ast.expr,
function_globals: dict[str, Any] | None = None,
) -> list[str]:
"""Extract string literals from a type annotation AST node.

Handles:
- Literal["a", "b", "c"]
- "a" | "b" | "c" (union of string literals)
- Just "a" (single string constant annotation)
- Enum types with string values (e.g., class MyEnum(str, Enum))

Args:
node: The AST node representing a type annotation.
function_globals: The globals dict from the function, used to resolve Enum types.

Returns:
List of string literals found in the annotation.
"""

strings: list[str] = []

if isinstance(node, ast.Constant) and isinstance(node.value, str):
strings.append(node.value)

elif isinstance(node, ast.Name) and function_globals:
enum_class = function_globals.get(node.id)
if (
enum_class is not None
and isinstance(enum_class, type)
and issubclass(enum_class, Enum)
):
strings.extend(
member.value for member in enum_class if isinstance(member.value, str)
)

elif isinstance(node, ast.Attribute) and function_globals:
try:
if isinstance(node.value, ast.Name):
module = function_globals.get(node.value.id)
if module is not None:
enum_class = getattr(module, node.attr, None)
if (
enum_class is not None
and isinstance(enum_class, type)
and issubclass(enum_class, Enum)
):
strings.extend(
member.value
for member in enum_class
if isinstance(member.value, str)
)
except (AttributeError, TypeError):
pass

elif isinstance(node, ast.Subscript):
is_literal = False
if isinstance(node.value, ast.Name) and node.value.id == "Literal":
is_literal = True
elif isinstance(node.value, ast.Attribute) and node.value.attr == "Literal":
is_literal = True

if is_literal:
if isinstance(node.slice, ast.Tuple):
strings.extend(
elt.value
for elt in node.slice.elts
if isinstance(elt, ast.Constant) and isinstance(elt.value, str)
)
elif isinstance(node.slice, ast.Constant) and isinstance(
node.slice.value, str
):
strings.append(node.slice.value)

elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
strings.extend(
_extract_string_literals_from_type_annotation(node.left, function_globals)
)
strings.extend(
_extract_string_literals_from_type_annotation(node.right, function_globals)
)

return strings


def _unwrap_function(function: Any) -> Any:
"""Unwrap a function to get the original function with correct globals.

Flow methods are wrapped by decorators like @router, @listen, etc.
This function unwraps them to get the original function which has
the correct __globals__ for resolving type annotations like Enums.

Args:
function: The potentially wrapped function.

Returns:
The unwrapped original function.
"""
if hasattr(function, "__func__"):
function = function.__func__

if hasattr(function, "__wrapped__"):
wrapped = function.__wrapped__
if hasattr(wrapped, "unwrap"):
return wrapped.unwrap()
return wrapped

return function


def get_possible_return_constants(function: Any) -> list[str] | None:
"""Extract possible string return values from a function using AST parsing.

This function analyzes the source code of a router method to identify
all possible string values it might return. It handles:
- Return type annotations: -> Literal["a", "b"] or -> "a" | "b" | "c"
- Enum type annotations: -> MyEnum (extracts string values from members)
- Direct string literals: return "value"
- Variable assignments: x = "value"; return x
- Dictionary lookups: d = {"k": "v"}; return d[key]
Expand All @@ -57,6 +170,8 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
Returns:
List of possible string return values, or None if analysis fails.
"""
unwrapped = _unwrap_function(function)

try:
source = inspect.getsource(function)
except OSError:
Expand Down Expand Up @@ -97,6 +212,17 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
return None

return_values: set[str] = set()

function_globals = getattr(unwrapped, "__globals__", None)

for node in ast.walk(code_ast):
if isinstance(node, ast.FunctionDef):
if node.returns:
annotation_values = _extract_string_literals_from_type_annotation(
node.returns, function_globals
)
return_values.update(annotation_values)
break # Only process the first function definition
dict_definitions: dict[str, list[str]] = {}
variable_values: dict[str, list[str]] = {}
state_attribute_values: dict[str, list[str]] = {}
Expand Down
65 changes: 40 additions & 25 deletions lib/crewai/src/crewai/flow/visualization/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable
import inspect
import logging
from typing import TYPE_CHECKING, Any

from crewai.flow.constants import AND_CONDITION, OR_CONDITION
from crewai.flow.flow_wrappers import FlowCondition
from crewai.flow.types import FlowMethodName, FlowRouteName
from crewai.flow.types import FlowMethodName
from crewai.flow.utils import (
is_flow_condition_dict,
is_simple_flow_condition,
Expand All @@ -18,6 +18,9 @@
from crewai.flow.visualization.types import FlowStructure, NodeMetadata, StructureEdge


logger = logging.getLogger(__name__)


if TYPE_CHECKING:
from crewai.flow.flow import Flow

Expand Down Expand Up @@ -346,34 +349,43 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
if trigger_method in nodes
)

all_string_triggers: set[str] = set()
for condition_data in flow._listeners.values():
if is_simple_flow_condition(condition_data):
_, methods = condition_data
for m in methods:
if str(m) not in nodes: # It's a string trigger, not a method name
all_string_triggers.add(str(m))
elif is_flow_condition_dict(condition_data):
for trigger in _extract_direct_or_triggers(condition_data):
if trigger not in nodes:
all_string_triggers.add(trigger)

all_router_outputs: set[str] = set()
for router_method_name in router_methods:
if router_method_name not in flow._router_paths:
flow._router_paths[FlowMethodName(router_method_name)] = []

inferred_paths: Iterable[FlowMethodName | FlowRouteName] = set(
flow._router_paths.get(FlowMethodName(router_method_name), [])
)

for condition_data in flow._listeners.values():
trigger_strings: list[str] = []

if is_simple_flow_condition(condition_data):
_, methods = condition_data
trigger_strings = [str(m) for m in methods]
elif is_flow_condition_dict(condition_data):
trigger_strings = _extract_direct_or_triggers(condition_data)

for trigger_str in trigger_strings:
if trigger_str not in nodes:
# This is likely a router path output
inferred_paths.add(trigger_str) # type: ignore[attr-defined]

if inferred_paths:
flow._router_paths[FlowMethodName(router_method_name)] = list(
inferred_paths # type: ignore[arg-type]
current_paths = flow._router_paths.get(FlowMethodName(router_method_name), [])
if current_paths and router_method_name in nodes:
nodes[router_method_name]["router_paths"] = [str(p) for p in current_paths]
all_router_outputs.update(str(p) for p in current_paths)

if not current_paths:
logger.warning(
f"Could not determine return paths for router '{router_method_name}'. "
f"Add a return type annotation like "
f"'-> Literal[\"path1\", \"path2\"]' or '-> YourEnum' "
f"to enable proper flow visualization."
)
if router_method_name in nodes:
nodes[router_method_name]["router_paths"] = list(inferred_paths)

orphaned_triggers = all_string_triggers - all_router_outputs
if orphaned_triggers:
logger.error(
f"Found listeners waiting for triggers {orphaned_triggers} "
f"but no router outputs these values explicitly. "
f"If your router returns a non-static value, check that your router has proper return type annotations."
)

for router_method_name in router_methods:
if router_method_name not in flow._router_paths:
Expand All @@ -383,6 +395,9 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:

for path in router_paths:
for listener_name, condition_data in flow._listeners.items():
if listener_name == router_method_name:
continue

trigger_strings_from_cond: list[str] = []

if is_simple_flow_condition(condition_data):
Expand Down
Loading
Loading