diff --git a/.env.example b/.env.example index a3f5c88..c00749a 100644 --- a/.env.example +++ b/.env.example @@ -144,6 +144,12 @@ LANGFUSE_ENABLED=true LANGFUSE_PUBLIC_KEY=your-langfuse-public-key LANGFUSE_SECRET_KEY=your-langfuse-secret-key LANGFUSE_HOST=https://cloud.langfuse.com +# Comma-separated list of trace attributes to set as Langfuse tags +# Available values: ticket_key, ticket_type, project_id, workflow_step, repo, pr_number, ci_status, event_source, event_type, llm_model +LANGFUSE_TRACE_TAGS=ticket_key,ticket_type,project_id,workflow_step,repo,pr_number,ci_status,event_source,event_type,llm_model +# Comma-separated list of trace attributes to set as Langfuse metadata +# Available values: ticket_key, ticket_type, project_id, workflow_step, repo, pr_number, ci_status, event_source, event_type, retry_count, system_prompt_length, llm_model +LANGFUSE_TRACE_METADATA=ticket_key,ticket_type,project_id,workflow_step,repo,pr_number,ci_status,event_source,event_type,retry_count,system_prompt_length,llm_model # OpenTelemetry distributed tracing (separate from Langfuse LLM tracing above) # Enable/disable OTLP trace export diff --git a/src/forge/config.py b/src/forge/config.py index 89829eb..3cb0af8 100644 --- a/src/forge/config.py +++ b/src/forge/config.py @@ -1,11 +1,17 @@ """Configuration management using Pydantic settings.""" -from functools import lru_cache -from typing import Literal +import logging +from functools import cached_property, lru_cache +from typing import TYPE_CHECKING, Literal from pydantic import Field, SecretStr from pydantic_settings import BaseSettings, SettingsConfigDict +if TYPE_CHECKING: + from forge.integrations.langfuse.fields import TracingField + +logger = logging.getLogger(__name__) + class Settings(BaseSettings): """Application settings loaded from environment variables.""" @@ -188,6 +194,14 @@ def detect_model_provider(model_name: str) -> str: langfuse_host: str = Field( default="https://cloud.langfuse.com", description="Langfuse host URL" ) + langfuse_trace_tags: str = Field( + default="", + description="Comma-separated list of TracingField names to include as Langfuse trace tags", + ) + langfuse_trace_metadata: str = Field( + default="", + description="Comma-separated list of TracingField names to include as Langfuse trace metadata", + ) # Claude Agent SDK Configuration agent_enable_tools: bool = Field( @@ -330,6 +344,32 @@ def langfuse_enabled(self) -> bool: and self.langfuse_secret_key.get_secret_value() ) + @cached_property + def trace_tag_fields(self) -> list["TracingField"]: + """Parse and validate configured Langfuse trace tag fields.""" + from forge.integrations.langfuse.fields import parse_trace_fields + + fields = parse_trace_fields(self.langfuse_trace_tags, allow_tags=True) + if fields: + logger.info( + "Langfuse trace tags configured: %s", + ", ".join(f.value for f in fields), + ) + return fields + + @cached_property + def trace_metadata_fields(self) -> list["TracingField"]: + """Parse and validate configured Langfuse trace metadata fields.""" + from forge.integrations.langfuse.fields import parse_trace_fields + + fields = parse_trace_fields(self.langfuse_trace_metadata, allow_tags=False) + if fields: + logger.info( + "Langfuse trace metadata configured: %s", + ", ".join(f.value for f in fields), + ) + return fields + @property def use_vertex_ai(self) -> bool: """Check if using Vertex AI instead of direct Anthropic API.""" diff --git a/src/forge/integrations/agents/agent.py b/src/forge/integrations/agents/agent.py index 4daff43..b595855 100644 --- a/src/forge/integrations/agents/agent.py +++ b/src/forge/integrations/agents/agent.py @@ -32,6 +32,7 @@ from forge.config import Settings, get_settings from forge.integrations.langfuse import get_langfuse_config, get_langfuse_context +from forge.integrations.langfuse.fields import resolve_trace_fields from forge.prompts import load_prompt, set_default_version from forge.skills.resolver import resolve_skill_paths @@ -57,6 +58,29 @@ logger = logging.getLogger(__name__) +_TRACE_FIELD_KEYS = frozenset( + { + "ticket_key", + "ticket_type", + "current_node", + "current_repo", + "current_pr_number", + "ci_status", + "event_type", + "event_source", + "retry_count", + "repo", + "pr_number", + } +) + + +def _forward_trace_fields(context: dict[str, Any] | None) -> dict[str, Any]: + """Extract trace-relevant fields from an incoming context dict.""" + if not context: + return {} + return {k: v for k, v in context.items() if k in _TRACE_FIELD_KEYS} + def get_weather(city: str) -> str: """Placeholder tool for agent testing.""" @@ -531,6 +555,8 @@ async def _run_agent( session_id: str | None = None, trace_name: str | None = None, ticket_key: str | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, ) -> str: """Run the agent with the given prompt. @@ -543,6 +569,8 @@ async def _run_agent( session_id: Optional session ID for Langfuse (e.g., ticket key). trace_name: Optional trace name for Langfuse. ticket_key: Optional ticket key for per-project skill resolution. + tags: Optional list of trace tags for Langfuse. + metadata: Optional metadata dict for Langfuse. Returns: Agent response text. @@ -564,7 +592,8 @@ async def _run_agent( langfuse_config = get_langfuse_config( trace_name=trace_name or "deep_agent_invocation", session_id=session_id, - metadata={"system_prompt_length": str(len(system_prompt))}, + tags=tags, + metadata=metadata, ) if langfuse_config: # Extract context params and remove from config @@ -579,6 +608,7 @@ async def _run_agent( session_id=langfuse_ctx_params.get("session_id"), user_id=langfuse_ctx_params.get("user_id"), tags=langfuse_ctx_params.get("tags"), + metadata=langfuse_ctx_params.get("metadata"), ): last_error: Exception | None = None for attempt in range(self.MAX_RETRIES): @@ -688,6 +718,13 @@ async def run_task( logger.info(f"Running task '{task}' using Deep Agents") record_agent_invocation(task_type=task) _start = time.monotonic() + trace_state: dict[str, Any] = { + **(context or {}), + "system_prompt_length": len(system_prompt), + "llm_model": self.settings.claude_model, + } + trace_tags, trace_metadata = resolve_trace_fields(trace_state) + result = await self._run_agent( prompt=prompt, system_prompt=system_prompt, @@ -695,6 +732,8 @@ async def run_task( session_id=ticket_key, trace_name=f"task:{task}", ticket_key=ticket_key, + tags=trace_tags or None, + metadata=trace_metadata or None, ) observe_agent_duration(task_type=task, duration=time.monotonic() - _start) @@ -849,13 +888,12 @@ async def generate_prd( ) logger.info("Generating PRD using Deep Agents with skill") + task_context = _forward_trace_fields(context) + task_context["project_key"] = context.get("project_key", "") if context else "" result = await self.run_task( task="generate-prd", prompt=prompt, - context={ - "ticket_key": context.get("ticket_key", "") if context else "", - "project_key": context.get("project_key", "") if context else "", - }, + context=task_context, ) result = self._strip_preamble(result) @@ -885,13 +923,12 @@ async def generate_spec( ) logger.info("Generating Spec using Deep Agents with skill") + task_context = _forward_trace_fields(context) + task_context["project_key"] = context.get("project_key", "") if context else "" result = await self.run_task( task="generate-spec", prompt=prompt, - context={ - "ticket_key": context.get("ticket_key", "") if context else "", - "project_key": context.get("project_key", "") if context else "", - }, + context=task_context, ) result = self._strip_preamble(result) @@ -938,15 +975,14 @@ async def generate_epics( ) logger.info("Generating Epics using Deep Agents with skill") + task_context = _forward_trace_fields(context) + task_context["project_key"] = context.get("project_key", "") if context else "" + task_context["feature_summary"] = context.get("feature_summary", "") if context else "" + task_context["available_repos"] = available_repos result = await self.run_task( task="decompose-epics", prompt=prompt, - context={ - "ticket_key": context.get("ticket_key", "") if context else "", - "project_key": context.get("project_key", "") if context else "", - "feature_summary": context.get("feature_summary", "") if context else "", - "available_repos": available_repos, - }, + context=task_context, ) epics = self._parse_epics_response(result) @@ -959,6 +995,7 @@ async def regenerate_with_feedback( feedback: str, content_type: str, ticket_key: str | None = None, + context: dict[str, Any] | None = None, ) -> str: """Regenerate content incorporating feedback. @@ -966,6 +1003,8 @@ async def regenerate_with_feedback( original_content: The original generated content. feedback: User feedback/revision request. content_type: Type of content (prd, spec, epic). + ticket_key: Optional ticket key for session tracking. + context: Optional context with trace fields from workflow state. Returns: Regenerated content. @@ -986,10 +1025,13 @@ async def regenerate_with_feedback( ) logger.info(f"Regenerating {content_type} with feedback using Deep Agents") + task_context = _forward_trace_fields(context) + task_context["is_revision"] = True + task_context["ticket_key"] = ticket_key or "" result = await self.run_task( task=skill_name, prompt=prompt, - context={"is_revision": True, "ticket_key": ticket_key or ""}, + context=task_context, ) result = self._strip_preamble(result) @@ -1074,13 +1116,12 @@ async def answer_question( ) logger.info(f"Answering question about {artifact_type}") + task_context = _forward_trace_fields(context) + task_context["artifact_type"] = artifact_type result = await self.run_task( task="answer-question", prompt=prompt, - context={ - "artifact_type": artifact_type, - "ticket_key": context.get("ticket_key", ""), - }, + context=task_context, # Q&A gets read-only MCP tools for lookups (filtered by agent_mcp_read_only) ) diff --git a/src/forge/integrations/langfuse/__init__.py b/src/forge/integrations/langfuse/__init__.py index 5586bcd..b8f39d3 100644 --- a/src/forge/integrations/langfuse/__init__.py +++ b/src/forge/integrations/langfuse/__init__.py @@ -1,5 +1,9 @@ """Langfuse integration for LLM observability.""" +from forge.integrations.langfuse.fields import ( + TracingField, + resolve_trace_fields, +) from forge.integrations.langfuse.tracing import ( get_langfuse_config, get_langfuse_context, @@ -9,9 +13,11 @@ ) __all__ = [ + "TracingField", "get_langfuse_config", "get_langfuse_context", "get_langfuse_handler", + "resolve_trace_fields", "shutdown_langfuse", "trace_llm_call", ] diff --git a/src/forge/integrations/langfuse/fields.py b/src/forge/integrations/langfuse/fields.py new file mode 100644 index 0000000..fe121a3 --- /dev/null +++ b/src/forge/integrations/langfuse/fields.py @@ -0,0 +1,213 @@ +"""Configurable Langfuse trace tag and metadata fields. + +Admins configure which fields to include as tags/metadata via env vars: + LANGFUSE_TRACE_TAGS=ticket_type,project_id,workflow_step + LANGFUSE_TRACE_METADATA=ticket_key,ticket_type,project_id,retry_count +""" + +import logging +from enum import StrEnum +from typing import Any + +logger = logging.getLogger(__name__) + + +class TracingField(StrEnum): + """Available fields for Langfuse trace tags and metadata.""" + + TICKET_KEY = "ticket_key" + TICKET_TYPE = "ticket_type" + PROJECT_ID = "project_id" + WORKFLOW_STEP = "workflow_step" + REPO = "repo" + PR_NUMBER = "pr_number" + CI_STATUS = "ci_status" + EVENT_SOURCE = "event_source" + EVENT_TYPE = "event_type" + RETRY_COUNT = "retry_count" + SYSTEM_PROMPT_LENGTH = "system_prompt_length" + LLM_MODEL = "llm_model" + + @property + def tag_eligible(self) -> bool: + return self not in _METADATA_ONLY_FIELDS + + +_METADATA_ONLY_FIELDS = frozenset({TracingField.RETRY_COUNT, TracingField.SYSTEM_PROMPT_LENGTH}) + + +def resolve_field(field: TracingField, state: dict[str, Any]) -> str | None: + """Resolve a single tracing field from workflow state. + + Args: + field: The field to resolve. + state: Workflow state dict. + + Returns: + String value or None if the data isn't available. + """ + resolver = _RESOLVERS.get(field) + if resolver is None: + return None + return resolver(state) + + +def _resolve_ticket_key(state: dict[str, Any]) -> str | None: + val = state.get("ticket_key") + return str(val) if val is not None else None + + +def _resolve_ticket_type(state: dict[str, Any]) -> str | None: + val = state.get("ticket_type") + return str(val) if val is not None else None + + +def _resolve_project_id(state: dict[str, Any]) -> str | None: + ticket_key = state.get("ticket_key") + if not ticket_key or "-" not in str(ticket_key): + return None + return str(ticket_key).rsplit("-", 1)[0] + + +def _resolve_workflow_step(state: dict[str, Any]) -> str | None: + val = state.get("current_node") + return str(val) if val is not None else None + + +def _resolve_repo(state: dict[str, Any]) -> str | None: + val = state.get("repo") or state.get("current_repo") + return str(val) if val is not None else None + + +def _resolve_pr_number(state: dict[str, Any]) -> str | None: + val = state.get("pr_number") or state.get("current_pr_number") + return str(val) if val is not None else None + + +def _resolve_ci_status(state: dict[str, Any]) -> str | None: + val = state.get("ci_status") + return str(val) if val is not None else None + + +def _resolve_event_source(state: dict[str, Any]) -> str | None: + val = state.get("event_source") + if val is not None: + return str(val) + ctx = state.get("context") + if isinstance(ctx, dict): + val = ctx.get("source") + if val is not None: + return str(val) + return None + + +def _resolve_event_type(state: dict[str, Any]) -> str | None: + val = state.get("event_type") + return str(val) if val is not None else None + + +def _resolve_retry_count(state: dict[str, Any]) -> str | None: + val = state.get("retry_count") + return str(val) if val is not None else None + + +def _resolve_system_prompt_length(state: dict[str, Any]) -> str | None: + val = state.get("system_prompt_length") + return str(val) if val is not None else None + + +def _resolve_llm_model(state: dict[str, Any]) -> str | None: + val = state.get("llm_model") + return str(val) if val is not None else None + + +_RESOLVERS: dict[TracingField, Any] = { + TracingField.TICKET_KEY: _resolve_ticket_key, + TracingField.TICKET_TYPE: _resolve_ticket_type, + TracingField.PROJECT_ID: _resolve_project_id, + TracingField.WORKFLOW_STEP: _resolve_workflow_step, + TracingField.REPO: _resolve_repo, + TracingField.PR_NUMBER: _resolve_pr_number, + TracingField.CI_STATUS: _resolve_ci_status, + TracingField.EVENT_SOURCE: _resolve_event_source, + TracingField.EVENT_TYPE: _resolve_event_type, + TracingField.RETRY_COUNT: _resolve_retry_count, + TracingField.SYSTEM_PROMPT_LENGTH: _resolve_system_prompt_length, + TracingField.LLM_MODEL: _resolve_llm_model, +} + + +def parse_trace_fields(config_str: str, *, allow_tags: bool) -> list[TracingField]: + """Parse a comma-separated config string into validated TracingField list. + + Args: + config_str: Comma-separated field names (e.g., "ticket_key,ticket_type"). + allow_tags: If True, validates tag eligibility; if False, allows all fields. + + Returns: + List of valid TracingField values. Invalid names are warned and skipped. + """ + if not config_str or not config_str.strip(): + return [] + + available = ", ".join(sorted(f.value for f in TracingField)) + result: list[TracingField] = [] + + for raw in config_str.split(","): + name = raw.strip() + if not name: + continue + + try: + field = TracingField(name) + except ValueError: + logger.warning( + "Invalid Langfuse trace field '%s' - not a recognized field name. Available: %s", + name, + available, + ) + continue + + if allow_tags and not field.tag_eligible: + logger.warning( + "Field '%s' is not eligible for tags (quantitative field) - skipping", + name, + ) + continue + + result.append(field) + + return result + + +def resolve_trace_fields(state: dict[str, Any]) -> tuple[list[str], dict[str, Any]]: + """Resolve configured tracing fields from workflow state. + + Reads the configured tag/metadata fields from settings, resolves each + against the state dict, and returns the results. Fields that resolve + to None are silently omitted. + + Args: + state: Workflow state dict containing trace data. + + Returns: + (tags, metadata) — tags is a list of raw string values, + metadata is a dict of field_name -> string value. + """ + from forge.config import get_settings + + settings = get_settings() + + tags: list[str] = [] + for field in settings.trace_tag_fields: + value = resolve_field(field, state) + if value: + tags.append(value) + + metadata: dict[str, Any] = {} + for field in settings.trace_metadata_fields: + value = resolve_field(field, state) + if value is not None: + metadata[field.value] = value + + return tags, metadata diff --git a/src/forge/integrations/langfuse/tracing.py b/src/forge/integrations/langfuse/tracing.py index d514a4c..9b3086e 100644 --- a/src/forge/integrations/langfuse/tracing.py +++ b/src/forge/integrations/langfuse/tracing.py @@ -210,6 +210,7 @@ def get_langfuse_config( "session_id": session_id, "user_id": user_id, "tags": tags, + "metadata": metadata, } return config diff --git a/src/forge/workflow/base.py b/src/forge/workflow/base.py index 668915c..98155cb 100644 --- a/src/forge/workflow/base.py +++ b/src/forge/workflow/base.py @@ -17,6 +17,9 @@ class BaseState(TypedDict, total=False): thread_id: str ticket_key: str + # Event origin + event_type: str + # Execution control current_node: str is_paused: bool diff --git a/src/forge/workflow/nodes/code_review.py b/src/forge/workflow/nodes/code_review.py index b3f4e9a..fa99074 100644 --- a/src/forge/workflow/nodes/code_review.py +++ b/src/forge/workflow/nodes/code_review.py @@ -145,7 +145,18 @@ async def sync_pr_description( updated_body = await agent.run_task( task="sync-pr-description", prompt=prompt, - context={"owner": owner, "repo": repo, "pr_number": pr_number}, + context={ + "ticket_key": state.get("ticket_key", ""), + "ticket_type": state.get("ticket_type", ""), + "current_node": state.get("current_node", ""), + "repo": repo, + "pr_number": pr_number, + "ci_status": state.get("ci_status", ""), + "event_type": state.get("event_type", ""), + "event_source": state.get("context", {}).get("source", ""), + "retry_count": state.get("retry_count", 0), + "owner": owner, + }, include_tools=False, ) finally: diff --git a/src/forge/workflow/nodes/epic_decomposition.py b/src/forge/workflow/nodes/epic_decomposition.py index 35262b2..ec38dbd 100644 --- a/src/forge/workflow/nodes/epic_decomposition.py +++ b/src/forge/workflow/nodes/epic_decomposition.py @@ -107,6 +107,11 @@ async def decompose_epics(state: WorkflowState) -> WorkflowState: # Build context for Epic generation context: dict[str, Any] = { "ticket_key": ticket_key, + "ticket_type": state.get("ticket_type", ""), + "current_node": state.get("current_node", ""), + "event_type": state.get("event_type", ""), + "event_source": state.get("context", {}).get("source", ""), + "retry_count": state.get("retry_count", 0), "project_key": project_key, "feature_summary": parent_issue.summary, "available_repos": available_repos, @@ -312,6 +317,13 @@ async def update_single_epic(state: WorkflowState) -> WorkflowState: feedback=feedback, content_type="epic", ticket_key=ticket_key, + context={ + "ticket_type": state.get("ticket_type", ""), + "current_node": state.get("current_node", ""), + "event_type": state.get("event_type", ""), + "event_source": state.get("context", {}).get("source", ""), + "retry_count": state.get("retry_count", 0), + }, ) # Update Epic description diff --git a/src/forge/workflow/nodes/pr_creation.py b/src/forge/workflow/nodes/pr_creation.py index 1f931a0..f54b531 100644 --- a/src/forge/workflow/nodes/pr_creation.py +++ b/src/forge/workflow/nodes/pr_creation.py @@ -457,7 +457,14 @@ async def _generate_pr_body_with_agent( prompt=prompt, context={ "ticket_key": ticket_key, + "ticket_type": state.get("ticket_type", ""), + "current_node": state.get("current_node", ""), "repo": current_repo, + "pr_number": state.get("current_pr_number", ""), + "ci_status": state.get("ci_status", ""), + "event_type": state.get("event_type", ""), + "event_source": state.get("context", {}).get("source", ""), + "retry_count": state.get("retry_count", 0), "task_count": len(implemented_tasks), }, include_tools=False, # No tools needed for text generation diff --git a/src/forge/workflow/nodes/prd_generation.py b/src/forge/workflow/nodes/prd_generation.py index 90e0181..3873208 100644 --- a/src/forge/workflow/nodes/prd_generation.py +++ b/src/forge/workflow/nodes/prd_generation.py @@ -53,6 +53,11 @@ async def generate_prd(state: WorkflowState) -> WorkflowState: # Build context from issue metadata context: dict[str, Any] = { "ticket_key": ticket_key, + "ticket_type": state.get("ticket_type", ""), + "current_node": state.get("current_node", ""), + "event_type": state.get("event_type", ""), + "event_source": state.get("context", {}).get("source", ""), + "retry_count": state.get("retry_count", 0), "summary": issue.summary, "project_key": issue.project_key, } @@ -156,6 +161,13 @@ async def regenerate_prd_with_feedback(state: WorkflowState) -> WorkflowState: feedback=feedback, content_type="prd", ticket_key=ticket_key, + context={ + "ticket_type": state.get("ticket_type", ""), + "current_node": state.get("current_node", ""), + "event_type": state.get("event_type", ""), + "event_source": state.get("context", {}).get("source", ""), + "retry_count": state.get("retry_count", 0), + }, ) # Update Jira with regenerated PRD diff --git a/src/forge/workflow/nodes/qa_handler.py b/src/forge/workflow/nodes/qa_handler.py index b0bb389..1b62d5a 100644 --- a/src/forge/workflow/nodes/qa_handler.py +++ b/src/forge/workflow/nodes/qa_handler.py @@ -74,9 +74,14 @@ async def answer_question(state: WorkflowState) -> WorkflowState: question=question, artifact_content=artifact_content, context={ + "ticket_key": ticket_key, + "ticket_type": state.get("ticket_type", ""), + "current_node": state.get("current_node", ""), + "event_type": state.get("event_type", ""), + "event_source": state.get("context", {}).get("source", ""), + "retry_count": state.get("retry_count", 0), "artifact_type": artifact_type, "generation_context": generation_context, - "ticket_key": ticket_key, }, ) diff --git a/src/forge/workflow/nodes/spec_generation.py b/src/forge/workflow/nodes/spec_generation.py index 389e3d4..e2a9e51 100644 --- a/src/forge/workflow/nodes/spec_generation.py +++ b/src/forge/workflow/nodes/spec_generation.py @@ -62,6 +62,11 @@ async def generate_spec(state: WorkflowState) -> WorkflowState: # Build context context: dict[str, Any] = { "ticket_key": ticket_key, + "ticket_type": state.get("ticket_type", ""), + "current_node": state.get("current_node", ""), + "event_type": state.get("event_type", ""), + "event_source": state.get("context", {}).get("source", ""), + "retry_count": state.get("retry_count", 0), } # Generate specification using Claude - primary operation @@ -167,6 +172,13 @@ async def regenerate_spec_with_feedback(state: WorkflowState) -> WorkflowState: feedback=feedback, content_type="spec", ticket_key=ticket_key, + context={ + "ticket_type": state.get("ticket_type", ""), + "current_node": state.get("current_node", ""), + "event_type": state.get("event_type", ""), + "event_source": state.get("context", {}).get("source", ""), + "retry_count": state.get("retry_count", 0), + }, ) # Store updated spec in Jira (comment or custom field based on config) diff --git a/src/forge/workflow/nodes/task_generation.py b/src/forge/workflow/nodes/task_generation.py index 3fe2cc4..2c7f6d0 100644 --- a/src/forge/workflow/nodes/task_generation.py +++ b/src/forge/workflow/nodes/task_generation.py @@ -99,6 +99,12 @@ async def generate_tasks(state: WorkflowState) -> WorkflowState: # Build context context: dict[str, Any] = { + "ticket_key": ticket_key, + "ticket_type": state.get("ticket_type", ""), + "current_node": state.get("current_node", ""), + "event_type": state.get("event_type", ""), + "event_source": state.get("context", {}).get("source", ""), + "retry_count": state.get("retry_count", 0), "epic_key": epic_key, "epic_summary": epic_summary, "feature_key": ticket_key, @@ -512,6 +518,13 @@ async def update_single_task(state: WorkflowState) -> WorkflowState: feedback=feedback, content_type="task", ticket_key=ticket_key, + context={ + "ticket_type": state.get("ticket_type", ""), + "current_node": state.get("current_node", ""), + "event_type": state.get("event_type", ""), + "event_source": state.get("context", {}).get("source", ""), + "retry_count": state.get("retry_count", 0), + }, ) # Update Task in Jira diff --git a/tests/unit/integrations/agents/test_run_task_tracing.py b/tests/unit/integrations/agents/test_run_task_tracing.py new file mode 100644 index 0000000..76ac96f --- /dev/null +++ b/tests/unit/integrations/agents/test_run_task_tracing.py @@ -0,0 +1,147 @@ +"""Tests for run_task() trace field resolution. + +Verifies that run_task() builds the trace_state correctly, calls +resolve_trace_fields(), and passes the resolved tags/metadata to +_run_agent(). +""" + +from typing import Any +from unittest.mock import ANY, AsyncMock, patch + +import pytest + +from forge.integrations.agents.agent import ForgeAgent + + +@pytest.fixture +def agent() -> ForgeAgent: + return ForgeAgent() + + +def _metrics_patches(): + """Common patches for the inline-imported metrics helpers.""" + return ( + patch("forge.api.routes.metrics.record_agent_invocation"), + patch("forge.api.routes.metrics.observe_agent_duration"), + ) + + +class TestRunTaskTraceResolution: + """run_task() resolves trace fields and forwards them to _run_agent().""" + + @pytest.mark.asyncio + async def test_builds_trace_state_from_context_and_system_prompt( + self, agent: ForgeAgent + ) -> None: + context = {"ticket_key": "PROJ-42", "current_node": "generate_prd"} + + with ( + patch.object(agent, "_run_agent", new_callable=AsyncMock) as mock_run, + patch( + "forge.integrations.agents.agent.resolve_trace_fields" + ) as mock_resolve, + patch("forge.integrations.agents.agent.load_prompt", return_value="prompt"), + ): + mock_run.return_value = "result" + mock_resolve.return_value = (["PROJ-42"], {"ticket_key": "PROJ-42"}) + + await agent.run_task(task="generate-prd", prompt="test", context=context) + + # resolve_trace_fields should receive merged state with system_prompt_length and llm_model + resolve_call_state = mock_resolve.call_args[0][0] + assert resolve_call_state["ticket_key"] == "PROJ-42" + assert resolve_call_state["current_node"] == "generate_prd" + assert "system_prompt_length" in resolve_call_state + assert isinstance(resolve_call_state["system_prompt_length"], int) + assert resolve_call_state["llm_model"] == agent.settings.claude_model + + @pytest.mark.asyncio + async def test_passes_resolved_tags_to_run_agent(self, agent: ForgeAgent) -> None: + with ( + patch.object(agent, "_run_agent", new_callable=AsyncMock) as mock_run, + patch( + "forge.integrations.agents.agent.resolve_trace_fields", + return_value=(["Bug", "PROJ"], {"ticket_key": "PROJ-42"}), + ), + patch("forge.integrations.agents.agent.load_prompt", return_value="prompt"), + ): + mock_run.return_value = "result" + await agent.run_task( + task="test-task", + prompt="test", + context={"ticket_key": "PROJ-42"}, + ) + + call_kwargs = mock_run.call_args.kwargs + assert call_kwargs["tags"] == ["Bug", "PROJ"] + assert call_kwargs["metadata"] == {"ticket_key": "PROJ-42"} + + @pytest.mark.asyncio + async def test_empty_tags_passed_as_none(self, agent: ForgeAgent) -> None: + with ( + patch.object(agent, "_run_agent", new_callable=AsyncMock) as mock_run, + patch( + "forge.integrations.agents.agent.resolve_trace_fields", + return_value=([], {}), + ), + patch("forge.integrations.agents.agent.load_prompt", return_value="prompt"), + ): + mock_run.return_value = "result" + await agent.run_task(task="test-task", prompt="test", context={}) + + call_kwargs = mock_run.call_args.kwargs + assert call_kwargs["tags"] is None + assert call_kwargs["metadata"] is None + + @pytest.mark.asyncio + async def test_none_context_produces_trace_state_with_prompt_and_model( + self, agent: ForgeAgent + ) -> None: + with ( + patch.object(agent, "_run_agent", new_callable=AsyncMock) as mock_run, + patch( + "forge.integrations.agents.agent.resolve_trace_fields" + ) as mock_resolve, + patch("forge.integrations.agents.agent.load_prompt", return_value="prompt"), + ): + mock_run.return_value = "result" + mock_resolve.return_value = ([], {}) + await agent.run_task(task="test-task", prompt="test", context=None) + + resolve_call_state = mock_resolve.call_args[0][0] + assert "system_prompt_length" in resolve_call_state + assert "llm_model" in resolve_call_state + # No other context keys should be present + assert len(resolve_call_state) == 2 + + @pytest.mark.asyncio + async def test_trace_name_uses_task_prefix(self, agent: ForgeAgent) -> None: + with ( + patch.object(agent, "_run_agent", new_callable=AsyncMock) as mock_run, + patch( + "forge.integrations.agents.agent.resolve_trace_fields", + return_value=([], {}), + ), + patch("forge.integrations.agents.agent.load_prompt", return_value="prompt"), + ): + mock_run.return_value = "result" + await agent.run_task(task="generate-prd", prompt="test") + + assert mock_run.call_args.kwargs["trace_name"] == "task:generate-prd" + + @pytest.mark.asyncio + async def test_session_id_from_ticket_key(self, agent: ForgeAgent) -> None: + with ( + patch.object(agent, "_run_agent", new_callable=AsyncMock) as mock_run, + patch( + "forge.integrations.agents.agent.resolve_trace_fields", + return_value=([], {}), + ), + patch("forge.integrations.agents.agent.load_prompt", return_value="prompt"), + ): + mock_run.return_value = "result" + await agent.run_task( + task="test", prompt="test", context={"ticket_key": "PROJ-42"} + ) + + assert mock_run.call_args.kwargs["session_id"] == "PROJ-42" diff --git a/tests/unit/integrations/agents/test_trace_forwarding.py b/tests/unit/integrations/agents/test_trace_forwarding.py new file mode 100644 index 0000000..a4063e0 --- /dev/null +++ b/tests/unit/integrations/agents/test_trace_forwarding.py @@ -0,0 +1,316 @@ +"""Tests for trace field forwarding in ForgeAgent. + +Covers _forward_trace_fields() utility and the agent methods that +pass trace context through to run_task(). +""" + +from typing import Any +from unittest.mock import ANY, AsyncMock, MagicMock, patch + +import pytest + +from forge.integrations.agents.agent import ForgeAgent, _forward_trace_fields + + +class TestForwardTraceFields: + """Unit tests for _forward_trace_fields().""" + + def test_extracts_known_trace_keys(self) -> None: + context = { + "ticket_key": "PROJ-42", + "ticket_type": "Bug", + "current_node": "analyze_bug", + "current_repo": "acme/widgets", + "current_pr_number": 99, + "ci_status": "passed", + "event_type": "issue_updated", + "event_source": "jira", + "retry_count": 3, + "repo": "acme/widgets", + "pr_number": 55, + } + result = _forward_trace_fields(context) + assert result == context + + def test_filters_out_non_trace_keys(self) -> None: + context = { + "ticket_key": "PROJ-42", + "project_key": "PROJ", + "summary": "Fix the thing", + "workspace_path": "/tmp/ws", + "feature_summary": "A feature", + "available_repos": ["acme/widgets"], + } + result = _forward_trace_fields(context) + assert result == {"ticket_key": "PROJ-42"} + + def test_returns_empty_for_none(self) -> None: + assert _forward_trace_fields(None) == {} + + def test_returns_empty_for_empty_dict(self) -> None: + assert _forward_trace_fields({}) == {} + + def test_returns_empty_when_no_trace_keys_present(self) -> None: + context = {"summary": "Something", "workspace_path": "/tmp"} + assert _forward_trace_fields(context) == {} + + def test_preserves_original_values_unchanged(self) -> None: + context = { + "ticket_key": "PROJ-42", + "retry_count": 0, + "ci_status": "", + } + result = _forward_trace_fields(context) + assert result["ticket_key"] == "PROJ-42" + assert result["retry_count"] == 0 + assert result["ci_status"] == "" + + def test_does_not_mutate_input(self) -> None: + context = {"ticket_key": "PROJ-42", "summary": "Test"} + original = dict(context) + _forward_trace_fields(context) + assert context == original + + +class TestGeneratePrdTraceForwarding: + """generate_prd() uses _forward_trace_fields() and adds project_key.""" + + @pytest.mark.asyncio + async def test_forwards_trace_fields_to_run_task(self) -> None: + agent = ForgeAgent() + context = { + "ticket_key": "PROJ-42", + "ticket_type": "Feature", + "current_node": "generate_prd", + "event_type": "issue_updated", + "event_source": "jira", + "retry_count": 1, + "project_key": "PROJ", + "summary": "Build auth", + } + + with patch.object(agent, "run_task", new_callable=AsyncMock) as mock_run: + mock_run.return_value = "# PRD\n\nContent" + await agent.generate_prd("Build auth system", context=context) + + call_ctx = mock_run.call_args.kwargs["context"] + assert call_ctx["ticket_key"] == "PROJ-42" + assert call_ctx["ticket_type"] == "Feature" + assert call_ctx["current_node"] == "generate_prd" + assert call_ctx["event_type"] == "issue_updated" + assert call_ctx["event_source"] == "jira" + assert call_ctx["retry_count"] == 1 + assert call_ctx["project_key"] == "PROJ" + # Non-trace keys should not be forwarded + assert "summary" not in call_ctx + + @pytest.mark.asyncio + async def test_handles_none_context(self) -> None: + agent = ForgeAgent() + with patch.object(agent, "run_task", new_callable=AsyncMock) as mock_run: + mock_run.return_value = "# PRD\n\nContent" + await agent.generate_prd("Build something", context=None) + + call_ctx = mock_run.call_args.kwargs["context"] + assert call_ctx == {"project_key": ""} + + @pytest.mark.asyncio + async def test_project_key_defaults_to_empty(self) -> None: + agent = ForgeAgent() + context: dict[str, Any] = {"ticket_key": "PROJ-42"} + with patch.object(agent, "run_task", new_callable=AsyncMock) as mock_run: + mock_run.return_value = "# PRD" + await agent.generate_prd("Requirements", context=context) + + call_ctx = mock_run.call_args.kwargs["context"] + assert call_ctx["project_key"] == "" + + +class TestGenerateSpecTraceForwarding: + """generate_spec() uses _forward_trace_fields() and adds project_key.""" + + @pytest.mark.asyncio + async def test_forwards_trace_fields(self) -> None: + agent = ForgeAgent() + context = { + "ticket_key": "PROJ-42", + "ticket_type": "Feature", + "current_node": "generate_spec", + "event_type": "issue_updated", + "project_key": "PROJ", + } + with patch.object(agent, "run_task", new_callable=AsyncMock) as mock_run: + mock_run.return_value = "# Spec\n\nContent" + await agent.generate_spec("PRD content", context=context) + + call_ctx = mock_run.call_args.kwargs["context"] + assert call_ctx["ticket_key"] == "PROJ-42" + assert call_ctx["ticket_type"] == "Feature" + assert call_ctx["current_node"] == "generate_spec" + assert call_ctx["project_key"] == "PROJ" + + +class TestGenerateEpicsTraceForwarding: + """generate_epics() uses _forward_trace_fields() and adds extra context.""" + + @pytest.mark.asyncio + async def test_forwards_trace_fields_plus_extra(self) -> None: + agent = ForgeAgent() + context = { + "ticket_key": "PROJ-42", + "ticket_type": "Feature", + "current_node": "decompose_epics", + "event_type": "issue_updated", + "event_source": "jira", + "retry_count": 0, + "project_key": "PROJ", + "feature_summary": "Auth system", + "available_repos": ["acme/backend", "acme/frontend"], + } + with patch.object(agent, "run_task", new_callable=AsyncMock) as mock_run: + mock_run.return_value = "---\nEPIC: Test\nREPO: acme/backend\nPLAN:\n1. Do it\n---" + await agent.generate_epics("Spec content", context=context) + + call_ctx = mock_run.call_args.kwargs["context"] + assert call_ctx["ticket_key"] == "PROJ-42" + assert call_ctx["ticket_type"] == "Feature" + assert call_ctx["current_node"] == "decompose_epics" + assert call_ctx["project_key"] == "PROJ" + assert call_ctx["feature_summary"] == "Auth system" + assert call_ctx["available_repos"] == ["acme/backend", "acme/frontend"] + # Non-forwarded keys should not leak + assert "summary" not in call_ctx + + +class TestRegenerateWithFeedbackTraceForwarding: + """regenerate_with_feedback() accepts context and uses _forward_trace_fields().""" + + @pytest.mark.asyncio + async def test_forwards_trace_fields_from_context(self) -> None: + agent = ForgeAgent() + context = { + "ticket_type": "Feature", + "current_node": "regenerate_prd", + "event_type": "issue_updated", + "event_source": "jira", + "retry_count": 2, + } + with patch.object(agent, "run_task", new_callable=AsyncMock) as mock_run: + mock_run.return_value = "# Revised PRD" + await agent.regenerate_with_feedback( + original_content="# Old PRD", + feedback="Add more detail", + content_type="prd", + ticket_key="PROJ-42", + context=context, + ) + + call_ctx = mock_run.call_args.kwargs["context"] + assert call_ctx["ticket_type"] == "Feature" + assert call_ctx["current_node"] == "regenerate_prd" + assert call_ctx["event_type"] == "issue_updated" + assert call_ctx["event_source"] == "jira" + assert call_ctx["retry_count"] == 2 + assert call_ctx["is_revision"] is True + assert call_ctx["ticket_key"] == "PROJ-42" + + @pytest.mark.asyncio + async def test_handles_none_context(self) -> None: + agent = ForgeAgent() + with patch.object(agent, "run_task", new_callable=AsyncMock) as mock_run: + mock_run.return_value = "# Revised" + await agent.regenerate_with_feedback( + original_content="# Old", + feedback="Fix it", + content_type="spec", + ticket_key="PROJ-42", + context=None, + ) + + call_ctx = mock_run.call_args.kwargs["context"] + assert call_ctx == {"is_revision": True, "ticket_key": "PROJ-42"} + + @pytest.mark.asyncio + async def test_ticket_key_defaults_to_empty(self) -> None: + agent = ForgeAgent() + with patch.object(agent, "run_task", new_callable=AsyncMock) as mock_run: + mock_run.return_value = "# Revised" + await agent.regenerate_with_feedback( + original_content="# Old", + feedback="Fix it", + content_type="prd", + ticket_key=None, + context=None, + ) + + call_ctx = mock_run.call_args.kwargs["context"] + assert call_ctx["ticket_key"] == "" + + @pytest.mark.asyncio + async def test_content_type_maps_to_correct_skill(self) -> None: + agent = ForgeAgent() + skill_map = { + "prd": "generate-prd", + "spec": "generate-spec", + "epic": "decompose-epics", + "task": "generate-tasks", + } + for content_type, expected_task in skill_map.items(): + with patch.object(agent, "run_task", new_callable=AsyncMock) as mock_run: + mock_run.return_value = "# Result" + await agent.regenerate_with_feedback( + original_content="# Old", + feedback="Fix", + content_type=content_type, + ) + assert mock_run.call_args.kwargs["task"] == expected_task + + +class TestAnswerQuestionTraceForwarding: + """answer_question() uses _forward_trace_fields() for context.""" + + @pytest.mark.asyncio + async def test_forwards_trace_fields(self) -> None: + agent = ForgeAgent() + context = { + "ticket_key": "PROJ-42", + "ticket_type": "Feature", + "current_node": "answer_question", + "event_type": "issue_updated", + "event_source": "jira", + "retry_count": 0, + "artifact_type": "prd", + "generation_context": {"raw_requirements": "Build API"}, + } + with patch.object(agent, "run_task", new_callable=AsyncMock) as mock_run: + mock_run.return_value = "The answer" + await agent.answer_question( + question="Why REST?", + artifact_content="# PRD\n\nWe use REST", + context=context, + ) + + call_ctx = mock_run.call_args.kwargs["context"] + assert call_ctx["ticket_key"] == "PROJ-42" + assert call_ctx["ticket_type"] == "Feature" + assert call_ctx["current_node"] == "answer_question" + assert call_ctx["event_type"] == "issue_updated" + assert call_ctx["event_source"] == "jira" + assert call_ctx["retry_count"] == 0 + assert call_ctx["artifact_type"] == "prd" + # generation_context is not a trace field, should not be forwarded + assert "generation_context" not in call_ctx + + @pytest.mark.asyncio + async def test_artifact_type_defaults_to_document(self) -> None: + agent = ForgeAgent() + with patch.object(agent, "run_task", new_callable=AsyncMock) as mock_run: + mock_run.return_value = "Answer" + await agent.answer_question( + question="What?", + artifact_content="Content", + context={}, + ) + + call_ctx = mock_run.call_args.kwargs["context"] + assert call_ctx["artifact_type"] == "document" diff --git a/tests/unit/integrations/langfuse/test_fields.py b/tests/unit/integrations/langfuse/test_fields.py new file mode 100644 index 0000000..6623e20 --- /dev/null +++ b/tests/unit/integrations/langfuse/test_fields.py @@ -0,0 +1,395 @@ +"""Tests for Langfuse tracing field resolvers.""" + +import logging +from typing import Any +from unittest.mock import PropertyMock, patch + +import pytest + +from forge.config import Settings +from forge.integrations.langfuse.fields import ( + TracingField, + parse_trace_fields, + resolve_field, + resolve_trace_fields, +) + + +def _make_state(**overrides: Any) -> dict[str, Any]: + """Build a minimal workflow state dict for testing.""" + base: dict[str, Any] = { + "ticket_key": "PROJ-42", + "ticket_type": "Bug", + "current_node": "analyze_bug", + "current_repo": "acme/widgets", + "current_pr_number": 99, + "ci_status": "passed", + "event_type": "issue_updated", + "retry_count": 3, + "context": {"source": "jira"}, + } + base.update(overrides) + return base + + +class TestFieldResolvers: + """Each TracingField resolver extracts the right value from state.""" + + def test_ticket_key(self) -> None: + assert resolve_field(TracingField.TICKET_KEY, _make_state()) == "PROJ-42" + + def test_ticket_key_missing(self) -> None: + state = _make_state() + del state["ticket_key"] + assert resolve_field(TracingField.TICKET_KEY, state) is None + + def test_ticket_type(self) -> None: + assert resolve_field(TracingField.TICKET_TYPE, _make_state()) == "Bug" + + def test_ticket_type_missing(self) -> None: + state = _make_state() + del state["ticket_type"] + assert resolve_field(TracingField.TICKET_TYPE, state) is None + + def test_project_id(self) -> None: + assert resolve_field(TracingField.PROJECT_ID, _make_state()) == "PROJ" + + def test_project_id_no_ticket_key(self) -> None: + state = _make_state() + del state["ticket_key"] + assert resolve_field(TracingField.PROJECT_ID, state) is None + + def test_project_id_no_dash(self) -> None: + assert resolve_field(TracingField.PROJECT_ID, _make_state(ticket_key="NODASH")) is None + + def test_workflow_step(self) -> None: + assert resolve_field(TracingField.WORKFLOW_STEP, _make_state()) == "analyze_bug" + + def test_workflow_step_missing(self) -> None: + state = _make_state() + del state["current_node"] + assert resolve_field(TracingField.WORKFLOW_STEP, state) is None + + def test_repo(self) -> None: + assert resolve_field(TracingField.REPO, _make_state()) == "acme/widgets" + + def test_repo_from_short_key(self) -> None: + state = _make_state() + del state["current_repo"] + state["repo"] = "acme/other" + assert resolve_field(TracingField.REPO, state) == "acme/other" + + def test_repo_prefers_short_key(self) -> None: + state = _make_state(repo="acme/preferred") + assert resolve_field(TracingField.REPO, state) == "acme/preferred" + + def test_repo_missing(self) -> None: + state = _make_state() + del state["current_repo"] + assert resolve_field(TracingField.REPO, state) is None + + def test_pr_number(self) -> None: + assert resolve_field(TracingField.PR_NUMBER, _make_state()) == "99" + + def test_pr_number_from_short_key(self) -> None: + state = _make_state() + del state["current_pr_number"] + state["pr_number"] = 42 + assert resolve_field(TracingField.PR_NUMBER, state) == "42" + + def test_pr_number_prefers_short_key(self) -> None: + state = _make_state(pr_number=42) + assert resolve_field(TracingField.PR_NUMBER, state) == "42" + + def test_pr_number_missing(self) -> None: + state = _make_state() + del state["current_pr_number"] + assert resolve_field(TracingField.PR_NUMBER, state) is None + + def test_ci_status(self) -> None: + assert resolve_field(TracingField.CI_STATUS, _make_state()) == "passed" + + def test_ci_status_missing(self) -> None: + state = _make_state() + del state["ci_status"] + assert resolve_field(TracingField.CI_STATUS, state) is None + + def test_event_source(self) -> None: + assert resolve_field(TracingField.EVENT_SOURCE, _make_state()) == "jira" + + def test_event_source_missing_context(self) -> None: + assert resolve_field(TracingField.EVENT_SOURCE, _make_state(context={})) is None + + def test_event_source_no_context_key(self) -> None: + state = _make_state() + del state["context"] + assert resolve_field(TracingField.EVENT_SOURCE, state) is None + + def test_event_type(self) -> None: + assert resolve_field(TracingField.EVENT_TYPE, _make_state()) == "issue_updated" + + def test_event_type_missing(self) -> None: + state = _make_state() + del state["event_type"] + assert resolve_field(TracingField.EVENT_TYPE, state) is None + + def test_retry_count(self) -> None: + assert resolve_field(TracingField.RETRY_COUNT, _make_state()) == "3" + + def test_retry_count_missing(self) -> None: + state = _make_state() + del state["retry_count"] + assert resolve_field(TracingField.RETRY_COUNT, state) is None + + def test_retry_count_zero(self) -> None: + assert resolve_field(TracingField.RETRY_COUNT, _make_state(retry_count=0)) == "0" + + def test_system_prompt_length(self) -> None: + state = _make_state(system_prompt_length=4523) + assert resolve_field(TracingField.SYSTEM_PROMPT_LENGTH, state) == "4523" + + def test_system_prompt_length_missing(self) -> None: + assert resolve_field(TracingField.SYSTEM_PROMPT_LENGTH, _make_state()) is None + + def test_llm_model(self) -> None: + state = _make_state(llm_model="claude-sonnet-4-6-20250514") + assert resolve_field(TracingField.LLM_MODEL, state) == "claude-sonnet-4-6-20250514" + + def test_llm_model_missing(self) -> None: + assert resolve_field(TracingField.LLM_MODEL, _make_state()) is None + + +class TestTagEligibility: + """Verify which fields are tag-eligible vs metadata-only.""" + + @pytest.mark.parametrize( + "field", + [ + TracingField.TICKET_KEY, + TracingField.TICKET_TYPE, + TracingField.PROJECT_ID, + TracingField.WORKFLOW_STEP, + TracingField.REPO, + TracingField.PR_NUMBER, + TracingField.CI_STATUS, + TracingField.EVENT_SOURCE, + TracingField.EVENT_TYPE, + TracingField.LLM_MODEL, + ], + ) + def test_tag_eligible_fields(self, field: TracingField) -> None: + assert field.tag_eligible is True + + @pytest.mark.parametrize( + "field", + [ + TracingField.RETRY_COUNT, + TracingField.SYSTEM_PROMPT_LENGTH, + ], + ) + def test_metadata_only_fields(self, field: TracingField) -> None: + assert field.tag_eligible is False + + +class TestParseTraceFields: + """Config string parsing and validation.""" + + def test_valid_metadata_fields(self) -> None: + fields = parse_trace_fields("ticket_key,ticket_type,retry_count", allow_tags=False) + assert fields == [ + TracingField.TICKET_KEY, + TracingField.TICKET_TYPE, + TracingField.RETRY_COUNT, + ] + + def test_valid_tag_fields(self) -> None: + fields = parse_trace_fields("ticket_type,project_id", allow_tags=True) + assert fields == [TracingField.TICKET_TYPE, TracingField.PROJECT_ID] + + def test_empty_string_returns_empty_list(self) -> None: + assert parse_trace_fields("", allow_tags=True) == [] + + def test_whitespace_only_returns_empty_list(self) -> None: + assert parse_trace_fields(" , , ", allow_tags=True) == [] + + def test_strips_whitespace(self) -> None: + fields = parse_trace_fields(" ticket_key , ticket_type ", allow_tags=False) + assert fields == [TracingField.TICKET_KEY, TracingField.TICKET_TYPE] + + def test_invalid_name_warns_and_skips(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING): + fields = parse_trace_fields("ticket_key,foobar,ticket_type", allow_tags=False) + assert fields == [TracingField.TICKET_KEY, TracingField.TICKET_TYPE] + assert "foobar" in caplog.text + assert "not a recognized field name" in caplog.text + + def test_tag_ineligible_field_in_tags_warns_and_skips( + self, caplog: pytest.LogCaptureFixture + ) -> None: + with caplog.at_level(logging.WARNING): + fields = parse_trace_fields("ticket_type,retry_count,project_id", allow_tags=True) + assert fields == [TracingField.TICKET_TYPE, TracingField.PROJECT_ID] + assert "retry_count" in caplog.text + assert "not eligible for tags" in caplog.text + + def test_tag_ineligible_field_allowed_in_metadata(self) -> None: + fields = parse_trace_fields("retry_count,system_prompt_length", allow_tags=False) + assert fields == [TracingField.RETRY_COUNT, TracingField.SYSTEM_PROMPT_LENGTH] + + def test_all_invalid_returns_empty(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING): + fields = parse_trace_fields("bad1,bad2", allow_tags=True) + assert fields == [] + + def test_duplicate_fields_preserved(self) -> None: + fields = parse_trace_fields("ticket_key,ticket_key", allow_tags=False) + assert fields == [TracingField.TICKET_KEY, TracingField.TICKET_KEY] + + +def _make_settings(**overrides: Any) -> Settings: + """Build a Settings instance with required fields and overrides.""" + defaults: dict[str, Any] = { + "jira_base_url": "https://test.atlassian.net", + "jira_api_token": "test", + "jira_user_email": "test@example.com", + "github_token": "test", + "anthropic_api_key": "test", + "langfuse_trace_tags": "", + "langfuse_trace_metadata": "", + } + defaults.update(overrides) + return Settings(**defaults) + + +class TestSettingsTraceFields: + """Settings properties parse and validate trace field config.""" + + def test_trace_tag_fields_default_empty(self) -> None: + settings = _make_settings() + assert settings.trace_tag_fields == [] + + def test_trace_metadata_fields_default_empty(self) -> None: + settings = _make_settings() + assert settings.trace_metadata_fields == [] + + def test_trace_tag_fields_parsed(self) -> None: + settings = _make_settings(langfuse_trace_tags="ticket_type,project_id") + assert settings.trace_tag_fields == [TracingField.TICKET_TYPE, TracingField.PROJECT_ID] + + def test_trace_metadata_fields_parsed(self) -> None: + settings = _make_settings(langfuse_trace_metadata="ticket_key,retry_count") + assert settings.trace_metadata_fields == [ + TracingField.TICKET_KEY, + TracingField.RETRY_COUNT, + ] + + def test_tag_ineligible_field_rejected_from_tags( + self, caplog: pytest.LogCaptureFixture + ) -> None: + with caplog.at_level(logging.WARNING): + settings = _make_settings(langfuse_trace_tags="retry_count,ticket_type") + assert settings.trace_tag_fields == [TracingField.TICKET_TYPE] + + def test_info_logged_when_fields_configured(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.INFO): + settings = _make_settings(langfuse_trace_tags="ticket_type") + _ = settings.trace_tag_fields + assert "Langfuse trace tags configured: ticket_type" in caplog.text + + def test_no_info_logged_when_empty(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.INFO): + settings = _make_settings(langfuse_trace_tags="") + _ = settings.trace_tag_fields + assert "Langfuse trace tags configured" not in caplog.text + + def test_no_info_logged_when_all_invalid(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING): + settings = _make_settings(langfuse_trace_tags="bad1,bad2") + _ = settings.trace_tag_fields + assert "Langfuse trace tags configured" not in caplog.text + + +class TestResolveTraceFields: + """Integration: resolve configured fields from workflow state.""" + + def test_resolves_tags_and_metadata(self) -> None: + state = _make_state() + tag_fields = [TracingField.TICKET_TYPE, TracingField.PROJECT_ID, TracingField.WORKFLOW_STEP] + metadata_fields = [TracingField.TICKET_KEY, TracingField.RETRY_COUNT] + + with ( + patch( + "forge.config.get_settings" + ) as mock_get_settings, + ): + mock_settings = mock_get_settings.return_value + type(mock_settings).trace_tag_fields = PropertyMock(return_value=tag_fields) + type(mock_settings).trace_metadata_fields = PropertyMock(return_value=metadata_fields) + + tags, metadata = resolve_trace_fields(state) + + assert tags == ["Bug", "PROJ", "analyze_bug"] + assert metadata == {"ticket_key": "PROJ-42", "retry_count": "3"} + + def test_skips_missing_fields(self) -> None: + state = _make_state() + del state["current_repo"] + tag_fields = [TracingField.TICKET_TYPE, TracingField.REPO] + metadata_fields = [TracingField.PR_NUMBER] + + with patch( + "forge.config.get_settings" + ) as mock_get_settings: + mock_settings = mock_get_settings.return_value + type(mock_settings).trace_tag_fields = PropertyMock(return_value=tag_fields) + type(mock_settings).trace_metadata_fields = PropertyMock(return_value=metadata_fields) + + tags, metadata = resolve_trace_fields(state) + + assert tags == ["Bug"] + assert metadata == {"pr_number": "99"} + + def test_empty_config_returns_empty(self) -> None: + with patch( + "forge.config.get_settings" + ) as mock_get_settings: + mock_settings = mock_get_settings.return_value + type(mock_settings).trace_tag_fields = PropertyMock(return_value=[]) + type(mock_settings).trace_metadata_fields = PropertyMock(return_value=[]) + + tags, metadata = resolve_trace_fields(_make_state()) + + assert tags == [] + assert metadata == {} + + def test_system_prompt_length_in_metadata(self) -> None: + state = _make_state(system_prompt_length=4523) + metadata_fields = [TracingField.SYSTEM_PROMPT_LENGTH] + + with patch( + "forge.config.get_settings" + ) as mock_get_settings: + mock_settings = mock_get_settings.return_value + type(mock_settings).trace_tag_fields = PropertyMock(return_value=[]) + type(mock_settings).trace_metadata_fields = PropertyMock(return_value=metadata_fields) + + tags, metadata = resolve_trace_fields(state) + + assert tags == [] + assert metadata == {"system_prompt_length": "4523"} + + def test_llm_model_in_tags(self) -> None: + state = _make_state(llm_model="claude-sonnet-4-6-20250514") + tag_fields = [TracingField.LLM_MODEL] + + with patch( + "forge.config.get_settings" + ) as mock_get_settings: + mock_settings = mock_get_settings.return_value + type(mock_settings).trace_tag_fields = PropertyMock(return_value=tag_fields) + type(mock_settings).trace_metadata_fields = PropertyMock(return_value=[]) + + tags, metadata = resolve_trace_fields(state) + + assert tags == ["claude-sonnet-4-6-20250514"] + assert metadata == {} diff --git a/tests/unit/integrations/langfuse/test_tracing.py b/tests/unit/integrations/langfuse/test_tracing.py new file mode 100644 index 0000000..7f097d7 --- /dev/null +++ b/tests/unit/integrations/langfuse/test_tracing.py @@ -0,0 +1,103 @@ +"""Tests for Langfuse tracing configuration. + +Covers get_langfuse_config() metadata passthrough and +get_langfuse_context() metadata parameter. +""" + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from forge.integrations.langfuse.tracing import ( + AsyncLangfuseContext, + get_langfuse_config, + get_langfuse_context, +) + + +class TestGetLangfuseConfigMetadata: + """get_langfuse_config() includes metadata in _langfuse_context.""" + + def _call_with_handler(self, **kwargs: Any) -> dict[str, Any]: + """Call get_langfuse_config with a mocked handler so it returns a config.""" + with patch( + "forge.integrations.langfuse.tracing.get_langfuse_handler", + return_value=MagicMock(), + ): + return get_langfuse_config(**kwargs) + + def test_metadata_included_in_langfuse_context(self) -> None: + metadata = {"ticket_key": "PROJ-42", "retry_count": "3"} + config = self._call_with_handler(metadata=metadata) + + assert config["_langfuse_context"]["metadata"] == metadata + + def test_metadata_none_when_not_provided(self) -> None: + config = self._call_with_handler(trace_name="test-trace") + assert config["_langfuse_context"]["metadata"] is None + + def test_metadata_merged_into_top_level_metadata(self) -> None: + metadata = {"system_prompt_length": "4523"} + config = self._call_with_handler( + trace_name="test-trace", + metadata=metadata, + ) + assert config["metadata"]["langfuse_trace_name"] == "test-trace" + assert config["metadata"]["system_prompt_length"] == "4523" + + def test_tags_and_metadata_both_passed_through(self) -> None: + tags = ["PROJ-42", "Bug"] + metadata = {"retry_count": "1"} + config = self._call_with_handler( + tags=tags, + metadata=metadata, + session_id="PROJ-42", + ) + ctx = config["_langfuse_context"] + assert ctx["tags"] == ["PROJ-42", "Bug"] + assert ctx["metadata"] == {"retry_count": "1"} + assert ctx["session_id"] == "PROJ-42" + + def test_returns_empty_dict_when_disabled(self) -> None: + with patch( + "forge.integrations.langfuse.tracing.get_langfuse_handler", + return_value=None, + ): + config = get_langfuse_config( + metadata={"key": "val"}, + tags=["tag"], + ) + assert config == {} + + def test_all_context_params_present(self) -> None: + config = self._call_with_handler( + session_id="sess-1", + user_id="user-1", + tags=["t1"], + metadata={"k": "v"}, + ) + ctx = config["_langfuse_context"] + assert ctx["session_id"] == "sess-1" + assert ctx["user_id"] == "user-1" + assert ctx["tags"] == ["t1"] + assert ctx["metadata"] == {"k": "v"} + + +class TestGetLangfuseContext: + """get_langfuse_context() accepts metadata parameter.""" + + def test_creates_context_with_metadata(self) -> None: + ctx = get_langfuse_context( + session_id="sess-1", + tags=["tag"], + metadata={"key": "val"}, + ) + assert isinstance(ctx, AsyncLangfuseContext) + assert ctx.metadata == {"key": "val"} + assert ctx.tags == ["tag"] + assert ctx.session_id == "sess-1" + + def test_metadata_defaults_to_none(self) -> None: + ctx = get_langfuse_context() + assert ctx.metadata is None diff --git a/tests/unit/workflow/nodes/test_trace_context_enrichment.py b/tests/unit/workflow/nodes/test_trace_context_enrichment.py new file mode 100644 index 0000000..eef9af3 --- /dev/null +++ b/tests/unit/workflow/nodes/test_trace_context_enrichment.py @@ -0,0 +1,435 @@ +"""Tests that workflow nodes pass trace-related context fields to the agent. + +Each workflow node was updated to include trace fields (ticket_type, current_node, +event_type, event_source, retry_count, etc.) in the context dict passed to the +agent. These tests verify the context dicts contain the expected trace keys. +""" + +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.models.workflow import TicketType + +TRACE_CONTEXT_KEYS = { + "ticket_key", + "ticket_type", + "current_node", + "event_type", + "event_source", + "retry_count", +} + + +def _make_feature_state(**overrides: Any) -> dict[str, Any]: + """Build a minimal feature workflow state.""" + from forge.workflow.feature.state import create_initial_feature_state + + state = create_initial_feature_state( + ticket_key="TEST-123", + ticket_type=TicketType.FEATURE, + ) + state["current_node"] = "generate_prd" + state["event_type"] = "issue_updated" + state["context"] = {"source": "jira"} + state["retry_count"] = 0 + state.update(overrides) + return state + + +def _make_bug_state(**overrides: Any) -> dict[str, Any]: + """Build a minimal bug workflow state.""" + state: dict[str, Any] = { + "thread_id": "test-thread", + "ticket_key": "BUG-456", + "ticket_type": "Bug", + "current_node": "analyze_bug", + "event_type": "issue_updated", + "context": {"source": "jira"}, + "ci_status": "", + "retry_count": 0, + "is_paused": False, + "error_message": None, + } + state.update(overrides) + return state + + +class TestPrdGenerationTraceContext: + """generate_prd node includes trace fields in agent context.""" + + @pytest.mark.asyncio + async def test_generate_prd_passes_trace_fields(self) -> None: + from forge.workflow.nodes.prd_generation import generate_prd + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.get_issue = AsyncMock( + return_value=MagicMock( + summary="Test Feature", + description="Build something", + project_key="TEST", + ) + ) + mock_jira.add_comment = AsyncMock() + mock_jira.add_structured_comment = AsyncMock() + mock_jira.update_description = AsyncMock() + mock_jira.set_workflow_label = AsyncMock() + + mock_agent = MagicMock() + mock_agent.close = AsyncMock() + captured_context: dict[str, Any] = {} + + async def capture_generate_prd(raw_req, context=None): + if context: + captured_context.update(context) + return "# PRD\n\nContent" + + mock_agent.generate_prd = capture_generate_prd + + state = _make_feature_state(current_node="generate_prd") + + with ( + patch( + "forge.workflow.nodes.prd_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.prd_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + await generate_prd(state) + + for key in TRACE_CONTEXT_KEYS: + assert key in captured_context, f"Missing trace key '{key}' in PRD context" + + @pytest.mark.asyncio + async def test_regenerate_prd_passes_trace_fields(self) -> None: + from forge.workflow.nodes.prd_generation import regenerate_prd_with_feedback + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.update_description = AsyncMock() + mock_jira.set_workflow_label = AsyncMock() + mock_jira.add_comment = AsyncMock() + + mock_agent = MagicMock() + mock_agent.close = AsyncMock() + captured_context: dict[str, Any] = {} + + async def capture_regen(**kwargs): + if kwargs.get("context"): + captured_context.update(kwargs["context"]) + return "# Revised PRD" + + mock_agent.regenerate_with_feedback = capture_regen + + state = _make_feature_state( + current_node="regenerate_prd", + prd_content="# Old PRD", + feedback_comment="Add more detail", + revision_requested=True, + ) + + with ( + patch( + "forge.workflow.nodes.prd_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.prd_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + await regenerate_prd_with_feedback(state) + + for key in {"ticket_type", "current_node", "event_type", "event_source", "retry_count"}: + assert key in captured_context, f"Missing trace key '{key}' in PRD regen context" + + +class TestSpecGenerationTraceContext: + """generate_spec node includes trace fields in agent context.""" + + @pytest.mark.asyncio + async def test_generate_spec_passes_trace_fields(self) -> None: + from forge.workflow.nodes.spec_generation import generate_spec + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.update_description = AsyncMock() + mock_jira.set_workflow_label = AsyncMock() + mock_jira.add_comment = AsyncMock() + mock_jira.add_structured_comment = AsyncMock() + + mock_agent = MagicMock() + mock_agent.close = AsyncMock() + captured_context: dict[str, Any] = {} + + async def capture_generate_spec(prd, context=None): + if context: + captured_context.update(context) + return "# Spec\n\nContent" + + mock_agent.generate_spec = capture_generate_spec + + state = _make_feature_state( + current_node="generate_spec", + prd_content="# PRD content", + qa_history=[], + ) + + with ( + patch( + "forge.workflow.nodes.spec_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.spec_generation.ForgeAgent", + return_value=mock_agent, + ), + patch("forge.workflow.nodes.spec_generation.post_qa_summary_if_needed"), + ): + await generate_spec(state) + + for key in TRACE_CONTEXT_KEYS: + assert key in captured_context, f"Missing trace key '{key}' in spec context" + + @pytest.mark.asyncio + async def test_regenerate_spec_passes_trace_fields(self) -> None: + from forge.workflow.nodes.spec_generation import regenerate_spec_with_feedback + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.update_description = AsyncMock() + mock_jira.set_workflow_label = AsyncMock() + mock_jira.add_comment = AsyncMock() + mock_jira.add_structured_comment = AsyncMock() + + mock_agent = MagicMock() + mock_agent.close = AsyncMock() + captured_context: dict[str, Any] = {} + + async def capture_regen(**kwargs): + if kwargs.get("context"): + captured_context.update(kwargs["context"]) + return "# Revised Spec" + + mock_agent.regenerate_with_feedback = capture_regen + + state = _make_feature_state( + current_node="regenerate_spec", + spec_content="# Old Spec", + prd_content="# PRD", + feedback_comment="Change approach", + revision_requested=True, + ) + + with ( + patch( + "forge.workflow.nodes.spec_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.spec_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + await regenerate_spec_with_feedback(state) + + for key in {"ticket_type", "current_node", "event_type", "event_source", "retry_count"}: + assert key in captured_context, f"Missing trace key '{key}' in spec regen context" + + +class TestQaHandlerTraceContext: + """answer_question node includes trace fields in agent context.""" + + @pytest.mark.asyncio + async def test_answer_question_passes_trace_fields(self) -> None: + from forge.workflow.nodes.qa_handler import answer_question + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.add_comment = AsyncMock(return_value=MagicMock(id="c-1")) + + mock_agent = MagicMock() + mock_agent.close = AsyncMock() + captured_context: dict[str, Any] = {} + + async def capture_answer(question, artifact_content, context): + captured_context.update(context) + return "The answer" + + mock_agent.answer_question = capture_answer + + state = _make_feature_state( + current_node="prd_approval_gate", + feedback_comment="?Why this approach?", + is_question=True, + prd_content="# PRD", + ) + + with ( + patch( + "forge.workflow.nodes.qa_handler.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.qa_handler.ForgeAgent", + return_value=mock_agent, + ), + ): + await answer_question(state) + + for key in TRACE_CONTEXT_KEYS: + assert key in captured_context, f"Missing trace key '{key}' in QA context" + + +class TestEpicDecompositionTraceContext: + """decompose_epics node includes trace fields in agent context.""" + + @pytest.mark.asyncio + async def test_decompose_epics_passes_trace_fields(self) -> None: + from forge.workflow.nodes.epic_decomposition import decompose_epics + + mock_jira = AsyncMock() + mock_jira.get_issue = AsyncMock( + return_value=MagicMock( + project_key="TEST", + summary="Test Feature", + ) + ) + mock_jira.get_labels = AsyncMock(return_value=[]) + mock_jira.get_project_repos = AsyncMock(return_value=["acme/backend"]) + mock_jira.create_epic = AsyncMock(return_value="TEST-200") + mock_jira.set_workflow_label = AsyncMock() + mock_jira.add_comment = AsyncMock() + + mock_agent = AsyncMock() + captured_context: dict[str, Any] = {} + + async def capture_epics(spec, context=None): + if context: + captured_context.update(context) + return [{"summary": "Epic 1", "plan": "Do it", "repo": "acme/backend"}] + + mock_agent.generate_epics = capture_epics + + state = _make_feature_state( + current_node="decompose_epics", + spec_content="# Spec content", + generation_context={}, + qa_history=[], + ) + + with ( + patch( + "forge.workflow.nodes.epic_decomposition.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.epic_decomposition.ForgeAgent", + return_value=mock_agent, + ), + patch("forge.workflow.nodes.epic_decomposition.post_qa_summary_if_needed"), + ): + await decompose_epics(state) + + for key in TRACE_CONTEXT_KEYS: + assert key in captured_context, f"Missing trace key '{key}' in epic context" + + @pytest.mark.asyncio + async def test_update_single_epic_passes_trace_fields(self) -> None: + from forge.workflow.nodes.epic_decomposition import update_single_epic + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.get_issue = AsyncMock( + return_value=MagicMock(description="Original epic") + ) + mock_jira.update_description = AsyncMock() + mock_jira.add_comment = AsyncMock() + + mock_agent = MagicMock() + mock_agent.close = AsyncMock() + captured_context: dict[str, Any] = {} + + async def capture_regen(**kwargs): + if kwargs.get("context"): + captured_context.update(kwargs["context"]) + return "# Revised Epic" + + mock_agent.regenerate_with_feedback = capture_regen + + state = _make_feature_state( + current_node="update_single_epic", + epic_keys=["TEST-200"], + current_epic_key="TEST-200", + feedback_comment="Change scope", + revision_requested=True, + ) + + with ( + patch( + "forge.workflow.nodes.epic_decomposition.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.epic_decomposition.ForgeAgent", + return_value=mock_agent, + ), + ): + await update_single_epic(state) + + for key in {"ticket_type", "current_node", "event_type", "event_source", "retry_count"}: + assert key in captured_context, f"Missing trace key '{key}' in epic update context" + + +class TestTaskGenerationTraceContext: + """update_single_task node includes trace fields in agent context.""" + + @pytest.mark.asyncio + async def test_update_single_task_passes_trace_fields(self) -> None: + from forge.workflow.nodes.task_generation import update_single_task + + mock_jira = MagicMock() + mock_jira.close = AsyncMock() + mock_jira.get_issue = AsyncMock( + return_value=MagicMock(description="Original task") + ) + mock_jira.update_description = AsyncMock() + mock_jira.add_comment = AsyncMock() + + mock_agent = MagicMock() + mock_agent.close = AsyncMock() + captured_context: dict[str, Any] = {} + + async def capture_regen(**kwargs): + if kwargs.get("context"): + captured_context.update(kwargs["context"]) + return "# Revised Task" + + mock_agent.regenerate_with_feedback = capture_regen + + state = _make_feature_state( + current_node="update_single_task", + current_task_key="TEST-300", + feedback_comment="Make it smaller", + revision_requested=True, + ) + + with ( + patch( + "forge.workflow.nodes.task_generation.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.task_generation.ForgeAgent", + return_value=mock_agent, + ), + ): + await update_single_task(state) + + for key in {"ticket_type", "current_node", "event_type", "event_source", "retry_count"}: + assert key in captured_context, f"Missing trace key '{key}' in task update context"