diff --git a/docs/features/structured-output.mdx b/docs/features/structured-output.mdx index 0b970564..8ce988e0 100644 --- a/docs/features/structured-output.mdx +++ b/docs/features/structured-output.mdx @@ -52,6 +52,7 @@ asyncio.run(main()) **Stage 2: Extraction (Post-Completion)** - `StructuredOutputAgent` receives the final answer text +- Valid JSON answers are validated locally with Pydantic first - Uses LLM's `astructured_predict()` to extract data into your model - Validates against schema and returns typed object or `None` @@ -152,6 +153,21 @@ agent = MobileAgent( ) ``` +### Setting the Schema After Initialization + +You can also attach the schema before the run starts: + +```python +agent = MobileAgent( + goal="Find John Smith's contact information", + config=MobileConfig(), +) + +agent.set_output_schema(ContactInfo) +result = await agent.run() +contact = result.structured_output +``` + ### Reasoning Mode Works in both direct and reasoning modes: diff --git a/docs/sdk/droid-agent.mdx b/docs/sdk/droid-agent.mdx index a3c824ae..16d0769d 100644 --- a/docs/sdk/droid-agent.mdx +++ b/docs/sdk/droid-agent.mdx @@ -220,6 +220,19 @@ if result.success and result.structured_output: print(f"Condition: {weather.condition}") ``` +You can also configure the schema after initialization, before the workflow starts: + +```python +agent = MobileAgent( + goal="Open weather app and get current weather", + config=config, +) +agent.set_output_schema(WeatherInfo) + +result = await agent.run() +weather = result.structured_output +``` + #### MobileAgent.run diff --git a/mobilerun/agent/droid/droid_agent.py b/mobilerun/agent/droid/droid_agent.py index 4f04c64c..a7953185 100644 --- a/mobilerun/agent/droid/droid_agent.py +++ b/mobilerun/agent/droid/droid_agent.py @@ -356,6 +356,22 @@ def run(self, *args, **kwargs) -> Awaitable[ResultEvent] | WorkflowHandler: handler = super().run(*args, **kwargs) # type: ignore[assignment] return handler + def set_output_schema(self, output_model: Type[BaseModel]) -> "MobileAgent": + """Configure the Pydantic model used for structured output extraction.""" + if not isinstance(output_model, type) or not issubclass( + output_model, BaseModel + ): + raise TypeError("output_model must be a Pydantic BaseModel subclass") + + self.output_model = output_model + + if self.structured_output_llm is None: + self.structured_output_llm = self.fast_agent_llm + if self.manager_agent is not None: + self.manager_agent.output_model = output_model + + return self + # ======================================================================== # start_handler — creates driver, registry, action_ctx # ======================================================================== diff --git a/mobilerun/agent/oneflows/structured_output_agent.py b/mobilerun/agent/oneflows/structured_output_agent.py index e7d17276..3293ef06 100644 --- a/mobilerun/agent/oneflows/structured_output_agent.py +++ b/mobilerun/agent/oneflows/structured_output_agent.py @@ -1,33 +1,83 @@ """ StructuredOutputAgent - Extract structured data from final answers. -Takes a raw text answer and a Pydantic model, uses structured_predict() -to extract structured data from the text. +Takes a raw text answer and a Pydantic model, then returns a validated model +instance. Answers that already contain JSON are parsed locally first; otherwise +the agent falls back to LLM structured extraction. """ +import json import logging -from typing import Type +import re +from collections.abc import Iterator +from typing import Any, Type, TypeVar from llama_index.core.llms.llm import LLM from llama_index.core.prompts import PromptTemplate from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from mobilerun.agent.utils.inference import astructured_predict_with_retries logger = logging.getLogger("mobilerun") +T = TypeVar("T", bound=BaseModel) + +_FENCED_JSON_RE = re.compile( + r"```(?:json)?\s*(.*?)```", + re.IGNORECASE | re.DOTALL, +) + + +def coerce_structured_output_from_text( + pydantic_model: Type[T], answer_text: str +) -> T | None: + """Return a validated model when *answer_text* already contains JSON.""" + + for candidate in _iter_json_candidates(answer_text): + try: + if isinstance(candidate, str): + return pydantic_model.model_validate_json(candidate) + return pydantic_model.model_validate(candidate) + except (TypeError, ValueError, ValidationError): + continue + return None + + +def _iter_json_candidates(text: str) -> Iterator[str | Any]: + stripped = text.strip() + if not stripped: + return + + yield stripped + + for match in _FENCED_JSON_RE.finditer(text): + candidate = match.group(1).strip() + if candidate: + yield candidate + + decoder = json.JSONDecoder() + for index, char in enumerate(text): + if char not in "{[": + continue + try: + value, _ = decoder.raw_decode(text[index:]) + except json.JSONDecodeError: + continue + yield value + class StructuredOutputAgent(Workflow): """ Agent that extracts structured output from text answers. - Uses LLM.structured_predict() to parse text into Pydantic models. + Uses direct Pydantic validation for JSON answers, then + LLM.structured_predict() for natural-language answers. """ def __init__( self, - llm: LLM, + llm: LLM | None, pydantic_model: Type[BaseModel], answer_text: str, **kwargs, @@ -42,23 +92,41 @@ async def extract_structured_output( self, ctx: Context, ev: StartEvent ) -> StopEvent: """ - Extract structured output using structured_predict(). + Extract structured output using direct validation or structured_predict(). """ - logger.debug("🔍 Extracting structured output from final answer...") + logger.debug("Extracting structured output from final answer...") try: - # Create prompt for extraction + direct_output = coerce_structured_output_from_text( + self.pydantic_model, + self.answer_text, + ) + if direct_output is not None: + logger.debug("Parsed structured output directly from final answer") + return StopEvent( + result={ + "structured_output": direct_output, + "success": True, + "error_message": "", + } + ) + + if self.llm is None: + raise ValueError( + "No structured output LLM is configured and the final answer " + "does not contain valid JSON for the requested model" + ) + prompt = PromptTemplate( "Extract structured information from the following text:\n\n{text}" ) - # Use structured_predict to extract data - logger.info("🔍 StructuredOutput response:", extra={"color": "magenta"}) + logger.info("StructuredOutput response:", extra={"color": "magenta"}) structured_output = await astructured_predict_with_retries( self.llm, self.pydantic_model, prompt, text=self.answer_text ) - logger.debug("✅ Successfully extracted structured output") + logger.debug("Successfully extracted structured output") return StopEvent( result={ @@ -69,7 +137,7 @@ async def extract_structured_output( ) except Exception as e: - logger.error(f"❌ Failed to extract structured output: {e}") + logger.error(f"Failed to extract structured output: {e}") return StopEvent( result={ diff --git a/tests/test_structured_output.py b/tests/test_structured_output.py new file mode 100644 index 00000000..d55c322c --- /dev/null +++ b/tests/test_structured_output.py @@ -0,0 +1,112 @@ +import asyncio +import unittest + +from pydantic import BaseModel, Field + +from mobilerun import MobileAgent +from mobilerun.agent.oneflows.structured_output_agent import ( + StructuredOutputAgent, + coerce_structured_output_from_text, +) +from mobilerun.config_manager import MobileConfig + + +class ContactInfo(BaseModel): + name: str = Field(description="Full name") + phone: str + email: str | None = None + + +class StructuredOutputCoercionTest(unittest.TestCase): + def test_validates_raw_json_answer(self): + result = coerce_structured_output_from_text( + ContactInfo, + '{"name": "Grace Liu", "phone": "+1 555 0100", "email": "grace@example.com"}', + ) + + self.assertIsInstance(result, ContactInfo) + self.assertEqual(result.name, "Grace Liu") + self.assertEqual(result.phone, "+1 555 0100") + + def test_validates_fenced_json_answer(self): + result = coerce_structured_output_from_text( + ContactInfo, + """ +Done. + +```json +{"name": "Ada Lovelace", "phone": "+44 20 7946 0958"} +``` +""", + ) + + self.assertIsInstance(result, ContactInfo) + self.assertEqual(result.name, "Ada Lovelace") + self.assertIsNone(result.email) + + def test_ignores_plain_text_without_json_shape(self): + result = coerce_structured_output_from_text( + ContactInfo, + "I found Grace Liu's phone number, but this is not JSON.", + ) + + self.assertIsNone(result) + + def test_structured_output_agent_accepts_json_without_llm(self): + async def run_agent(): + handler = StructuredOutputAgent( + llm=None, + pydantic_model=ContactInfo, + answer_text='{"name": "Grace Liu", "phone": "+1 555 0100"}', + ).run() + return await handler + + result = asyncio.run(run_agent()) + + self.assertTrue(result["success"]) + self.assertIsInstance(result["structured_output"], ContactInfo) + self.assertEqual(result["structured_output"].name, "Grace Liu") + + def test_structured_output_agent_reports_missing_llm_for_plain_text(self): + async def run_agent(): + handler = StructuredOutputAgent( + llm=None, + pydantic_model=ContactInfo, + answer_text="Grace Liu can be reached at +1 555 0100.", + ).run() + return await handler + + result = asyncio.run(run_agent()) + + self.assertFalse(result["success"]) + self.assertIsNone(result["structured_output"]) + self.assertIn("No structured output LLM", result["error_message"]) + + +class MobileAgentOutputSchemaTest(unittest.TestCase): + def test_no_schema_keeps_unstructured_mode(self): + config = MobileConfig.from_dict({"agent": {"name": "external-agent"}}) + agent = MobileAgent("Find contact info", config=config) + + self.assertIsNone(agent.output_model) + self.assertIsNone(agent.structured_output_llm) + + def test_set_output_schema_configures_model(self): + config = MobileConfig.from_dict({"agent": {"name": "external-agent"}}) + agent = MobileAgent("Find contact info", config=config) + + returned = agent.set_output_schema(ContactInfo) + + self.assertIs(returned, agent) + self.assertIs(agent.output_model, ContactInfo) + + def test_set_output_schema_rejects_non_model(self): + config = MobileConfig.from_dict({"agent": {"name": "external-agent"}}) + agent = MobileAgent("Find contact info", config=config) + + with self.assertRaises(TypeError): + agent.set_output_schema(dict) # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main()