diff --git a/conftest.py b/conftest.py index bac5680c39764..4b3005f7a5ae3 100644 --- a/conftest.py +++ b/conftest.py @@ -182,6 +182,7 @@ def pytest_ignore_collect(collection_path, config): "tests/gemini", "tests/groq", "tests/h2o", + "tests/haystack", "tests/johnsnowlabs", "tests/keras", "tests/keras_core", diff --git a/docs/api_reference/source/python_api/mlflow.haystack.rst b/docs/api_reference/source/python_api/mlflow.haystack.rst new file mode 100644 index 0000000000000..4425f7be586ce --- /dev/null +++ b/docs/api_reference/source/python_api/mlflow.haystack.rst @@ -0,0 +1,7 @@ +mlflow.haystack +=============== + +.. automodule:: mlflow.haystack + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/docs/genai/tracing/integrations/listing/haystack.mdx b/docs/docs/genai/tracing/integrations/listing/haystack.mdx new file mode 100644 index 0000000000000..1e80021157eb6 --- /dev/null +++ b/docs/docs/genai/tracing/integrations/listing/haystack.mdx @@ -0,0 +1,48 @@ +--- +sidebarTitle: Haystack +title: Haystack +--- + +import { APILink } from "@site/src/components/APILink"; + +# Tracing Haystack + +![Haystack tracing via autolog](/images/llms/haystack/haystack-tracing.png) + +MLflow Tracing provides automatic tracing capability when using Haystack pipelines and components. +When Haystack auto-tracing is enabled by calling the function, +usage of Haystack pipelines and components will automatically record generated traces during interactive development. + +## Example Usage + +```python +import mlflow +from haystack import Pipeline +from haystack.components.generators import OpenAIGenerator +from haystack.components.builders.prompt_builder import PromptBuilder + +# Turn on auto tracing for Haystack by calling mlflow.haystack.autolog() +mlflow.haystack.autolog() + +# Initialize the pipeline +pipeline = Pipeline() + +# Configure the LLM component +llm = OpenAIGenerator(model="gpt-4o-mini") +prompt_template = """Answer the question. {{question}}""" +prompt_builder = PromptBuilder(template=prompt_template) + +# Build the pipeline +pipeline = Pipeline() +pipeline.add_component("prompt_builder", prompt_builder) +pipeline.add_component("llm", llm) +pipeline.connect("prompt_builder", "llm") + +# Run the pipeline +question = "Who lives in Paris?" +results = pipeline.run({"prompt_builder": {"question": question}}) +``` + +## Disable auto-tracing + +Auto tracing for Haystack can be disabled globally by calling `mlflow.haystack.autolog(disable=True)` or `mlflow.autolog(disable=True)`. diff --git a/docs/src/components/TracingIntegrations/index.tsx b/docs/src/components/TracingIntegrations/index.tsx index ed3fc1b4140e8..eb31dbe837517 100644 --- a/docs/src/components/TracingIntegrations/index.tsx +++ b/docs/src/components/TracingIntegrations/index.tsx @@ -149,6 +149,12 @@ const TRACING_INTEGRATIONS: TracingIntegration[] = [ logoPath: '/images/logos/txtai-logo.png', link: '/genai/tracing/integrations/listing/txtai', }, + { + id: 'haystack', + name: 'Haystack', + logoPath: '/images/logos/haystack-logo.png', + link: '/genai/tracing/integrations/listing/haystack', + }, ]; interface TracingIntegrationsProps { diff --git a/docs/static/images/llms/haystack/haystack-tracing.png b/docs/static/images/llms/haystack/haystack-tracing.png new file mode 100644 index 0000000000000..9e01e805190ed Binary files /dev/null and b/docs/static/images/llms/haystack/haystack-tracing.png differ diff --git a/docs/static/images/logos/haystack-logo.png b/docs/static/images/logos/haystack-logo.png new file mode 100644 index 0000000000000..18be179636dd9 Binary files /dev/null and b/docs/static/images/logos/haystack-logo.png differ diff --git a/examples/haystack/agent_tracing.py b/examples/haystack/agent_tracing.py new file mode 100644 index 0000000000000..8d577054d70f3 --- /dev/null +++ b/examples/haystack/agent_tracing.py @@ -0,0 +1,161 @@ +""" +Example demonstrating MLflow tracing for Haystack Agents. + +This example shows how MLflow captures the full execution flow of an Agent, +including its internal calls to chat generators and tool invokers. +""" + +import asyncio +import datetime + +from haystack.components.agents import Agent +from haystack.components.generators import OpenAIGenerator +from haystack.dataclasses import ChatMessage +from haystack.tools import Tool + +import mlflow + +# Turn on auto tracing for Haystack +mlflow.haystack.autolog() + + +def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + +def multiply_numbers(a: int, b: int) -> int: + """Multiply two numbers together.""" + return a * b + + +# Example 1: Simple Agent without tools (behaves like ChatGenerator) +def simple_agent_example(): + print("=== Simple Agent Example (No Tools) ===") + + # Create a chat generator + llm = OpenAIGenerator(model="gpt-4o-mini") + + # Create an agent without tools + agent = Agent(chat_generator=llm, system_prompt="You are a helpful assistant.") + + # Warm up the agent + agent.warm_up() + + # Run the agent + messages = [ChatMessage.from_user("What is the capital of France?")] + result = agent.run(messages=messages) + + print("User:", messages[0].content) + print("Agent:", result["last_message"].content) + print() + + +# Example 2: Agent with tools +def agent_with_tools_example(): + print("=== Agent with Tools Example ===") + + # Create tools + add_tool = Tool( + name="add_numbers", + description="Add two numbers together", + function=add_numbers, + parameters={ + "type": "object", + "properties": { + "a": {"type": "integer", "description": "First number"}, + "b": {"type": "integer", "description": "Second number"}, + }, + "required": ["a", "b"], + }, + ) + + multiply_tool = Tool( + name="multiply_numbers", + description="Multiply two numbers together", + function=multiply_numbers, + parameters={ + "type": "object", + "properties": { + "a": {"type": "integer", "description": "First number"}, + "b": {"type": "integer", "description": "Second number"}, + }, + "required": ["a", "b"], + }, + ) + + # Create a chat generator + llm = OpenAIGenerator(model="gpt-4o-mini") + + # Create an agent with tools + agent = Agent( + chat_generator=llm, + tools=[add_tool, multiply_tool], + system_prompt="You are a helpful math assistant. Use the tools provided to help with calculations.", + ) + + # Warm up the agent + agent.warm_up() + + # Run the agent with a calculation request + messages = [ChatMessage.from_user("What is 15 + 27, and what is the result multiplied by 3?")] + result = agent.run(messages=messages) + + print("User:", messages[0].content) + print("Agent:", result["last_message"].content) + print("\nNumber of messages exchanged:", len(result["messages"])) + print() + + +# Example 3: Async Agent +async def async_agent_example(): + print("=== Async Agent Example ===") + + # Create a simple tool + def get_time() -> str: + """Get the current time.""" + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + time_tool = Tool( + name="get_time", + description="Get the current time", + function=get_time, + parameters={"type": "object", "properties": {}}, + ) + + # Create components + llm = OpenAIGenerator(model="gpt-4o-mini") + + agent = Agent( + chat_generator=llm, + tools=[time_tool], + system_prompt="You are a helpful assistant that can tell the time.", + ) + + # Warm up the agent + agent.warm_up() + + # Run the agent asynchronously + messages = [ChatMessage.from_user("What time is it now?")] + result = await agent.run_async(messages=messages) + + print("User:", messages[0].content) + print("Agent:", result["last_message"].content) + print() + + +# Run examples +if __name__ == "__main__": + # Run synchronous examples + simple_agent_example() + agent_with_tools_example() + + # Run asynchronous example + asyncio.run(async_agent_example()) + + print("=== Tracing Complete ===") + print("The Agent execution traces have been logged to MLflow.") + print("You can view the hierarchical trace structure showing:") + print("- Agent.run as the parent span") + print("- chat_generator and tool_invoker as child spans") + print("- All tool invocations and their results") diff --git a/examples/haystack/tracing.py b/examples/haystack/tracing.py new file mode 100644 index 0000000000000..7e2a36acac369 --- /dev/null +++ b/examples/haystack/tracing.py @@ -0,0 +1,65 @@ +""" +This is an example for leveraging MLflow's auto tracing capabilities for Haystack. + +For more information about MLflow Tracing, see: https://mlflow.org/docs/latest/llms/tracing/index.html +""" + +from haystack import Pipeline +from haystack.components.builders.prompt_builder import PromptBuilder +from haystack.components.generators import OpenAIGenerator +from haystack.core.pipeline.async_pipeline import AsyncPipeline + +import mlflow + +# Turn on auto tracing for Haystack by calling mlflow.haystack.autolog() +# This will automatically trace: +# - Pipeline executions (both sync and async) +# - Individual component executions +# - Token usage for LLM components +# - Component metadata and parameters +mlflow.haystack.autolog() + + +# Example 1: Synchronous Pipeline +def sync_pipeline_example(): + print("=== Synchronous Pipeline Example ===") + + pipeline = Pipeline() + llm = OpenAIGenerator(model="gpt-4o-mini") + + prompt_template = """Answer the question. {{question}}""" + prompt_builder = PromptBuilder(template=prompt_template) + + pipeline = Pipeline() + pipeline.add_component("prompt_builder", prompt_builder) + pipeline.add_component("llm", llm) + pipeline.connect("prompt_builder", "llm") + + question = "Who lives in Paris?" + results = pipeline.run({"prompt_builder": {"question": question}}) + + print("Question:", question) + print("Answer:", results["llm"]["replies"][0]) + print() + + +# Example 2: Asynchronous Pipeline +async def async_pipeline_example(): + print("=== Asynchronous Pipeline Example ===") + + pipeline = AsyncPipeline() + + llm = OpenAIGenerator(model="gpt-4o-mini") + prompt_template = """Tell me about: {{topic}}""" + prompt_builder = PromptBuilder(template=prompt_template) + + pipeline.add_component("prompt_builder", prompt_builder) + pipeline.add_component("llm", llm) + pipeline.connect("prompt_builder", "llm") + + topic = "artificial intelligence" + results = await pipeline.run_async({"prompt_builder": {"topic": topic}}) + + print("Topic:", topic) + print("Response:", results["llm"]["replies"][0]) + print() diff --git a/mlflow/__init__.py b/mlflow/__init__.py index dc7b13aee6a8e..397243d77ea99 100644 --- a/mlflow/__init__.py +++ b/mlflow/__init__.py @@ -71,6 +71,7 @@ gemini = LazyLoader("mlflow.gemini", globals(), "mlflow.gemini") groq = LazyLoader("mlflow.groq", globals(), "mlflow.groq") h2o = LazyLoader("mlflow.h2o", globals(), "mlflow.h2o") +haystack = LazyLoader("mlflow.haystack", globals(), "mlflow.haystack") johnsnowlabs = LazyLoader("mlflow.johnsnowlabs", globals(), "mlflow.johnsnowlabs") keras = LazyLoader("mlflow.keras", globals(), "mlflow.keras") langchain = LazyLoader("mlflow.langchain", globals(), "mlflow.langchain") @@ -124,6 +125,7 @@ gemini, groq, h2o, + haystack, johnsnowlabs, keras, langchain, diff --git a/mlflow/haystack/__init__.py b/mlflow/haystack/__init__.py new file mode 100644 index 0000000000000..68dbce9fe204e --- /dev/null +++ b/mlflow/haystack/__init__.py @@ -0,0 +1,122 @@ +""" +The ``mlflow.haystack`` module provides an API for tracing Haystack pipelines and components. +""" + +import inspect +import logging + +from mlflow.haystack.autolog import ( + patched_async_class_call, + patched_class_call, +) +from mlflow.utils.annotations import experimental +from mlflow.utils.autologging_utils import autologging_integration, safe_patch + +_logger = logging.getLogger(__name__) + +FLAVOR_NAME = "haystack" + + +@experimental(version="3.0.0") +@autologging_integration(FLAVOR_NAME) +def autolog( + log_traces: bool = True, + disable: bool = False, + silent: bool = False, +): + """ + Enables (or disables) and configures autologging from Haystack to MLflow. + Autologging automatically generates traces for Haystack pipelines and their components. + + Args: + log_traces: If ``True``, traces are logged for Haystack pipelines and components. + If ``False``, no traces are collected during inference. Default to ``True``. + disable: If ``True``, disables the Haystack autologging. Default to ``False``. + silent: If ``True``, suppress all event logs and warnings from MLflow during Haystack + autologging. If ``False``, show all events and warnings. + """ + + # Define class-method mappings following smolagents/crewai pattern + class_method_map = { + "haystack.core.pipeline.pipeline.Pipeline": ["run", "_run_component"], + "haystack.core.pipeline.async_pipeline.AsyncPipeline": [ + "run_async", + "_run_component_async", + ], + } + + # Dynamically discover and patch generator components + try: + # Import common Haystack components + from haystack.components.builders import PromptBuilder + from haystack.components.generators import ( + HuggingFaceAPIGenerator, + HuggingFaceLocalGenerator, + OpenAIGenerator, + ) + from haystack.components.retrievers import ( + InMemoryBM25Retriever, + InMemoryEmbeddingRetriever, + ) + + for generator_class in [ + OpenAIGenerator, + HuggingFaceAPIGenerator, + HuggingFaceLocalGenerator, + ]: + class_method_map[f"{generator_class.__module__}.{generator_class.__name__}"] = ["run"] + + class_method_map[f"{PromptBuilder.__module__}.{PromptBuilder.__name__}"] = ["run"] + for retriever_class in [InMemoryBM25Retriever, InMemoryEmbeddingRetriever]: + class_method_map[f"{retriever_class.__module__}.{retriever_class.__name__}"] = ["run"] + + except (ImportError, AttributeError) as e: + _logger.debug(f"Some Haystack components could not be imported for autolog: {e}") + + try: + for class_path, methods in class_method_map.items(): + *module_parts, class_name = class_path.rsplit(".", 1) + module_path = ".".join(module_parts) + + try: + module = __import__(module_path, fromlist=[class_name]) + cls = getattr(module, class_name) + except (ImportError, AttributeError) as e: + _logger.debug(f"Could not import {class_path}: {e}") + continue + + for method in methods: + try: + if hasattr(cls, method): + original_method = getattr(cls, method) + + if isinstance(original_method, staticmethod): + # Static methods need special handling in Haystack + # We'll patch them directly on the class + original_func = original_method.__func__ + wrapper = patched_class_call + + def make_wrapper(orig): + def wrapped(*args, **kwargs): + # For static methods, there's no self, so we pass None + return wrapper(orig, None, *args, **kwargs) + + return wrapped + + new_method = staticmethod(make_wrapper(original_func)) + setattr(cls, method, new_method) + _logger.debug(f"Patched static method {class_name}.{method}") + else: + wrapper = ( + patched_async_class_call + if inspect.iscoroutinefunction(original_method) + else patched_class_call + ) + safe_patch(FLAVOR_NAME, cls, method, wrapper) + _logger.debug(f"Patched {class_name}.{method}") + + except Exception as e: + _logger.warning(f"Failed to patch {class_name}.{method}: {e}") + + except Exception as e: + _logger.warning(f"Failed to apply autolog patches to Haystack: {e}") diff --git a/mlflow/haystack/autolog.py b/mlflow/haystack/autolog.py new file mode 100644 index 0000000000000..a46834c5adf21 --- /dev/null +++ b/mlflow/haystack/autolog.py @@ -0,0 +1,273 @@ +"""Haystack autolog implementation following smolagents and crewai patterns.""" + +import inspect +import logging +from typing import Any, Optional + +import mlflow +from mlflow.entities import SpanType +from mlflow.entities.span import LiveSpan, SpanAttributeKey +from mlflow.tracing.constant import TokenUsageKey +from mlflow.utils.autologging_utils.config import AutoLoggingConfig + +_logger = logging.getLogger(__name__) + + +def patched_class_call(original, self, *args, **kwargs): + config = AutoLoggingConfig.init(flavor_name=mlflow.haystack.FLAVOR_NAME) + + if not config.log_traces: + return original(self, *args, **kwargs) + + fullname = f"{self.__class__.__name__}.{original.__name__}" + span_type = _get_span_type(self) + + with mlflow.start_span(name=fullname, span_type=span_type) as span: + inputs = _construct_full_inputs(original, self, *args, **kwargs) + span.set_inputs(inputs) + _set_span_attributes(span, self) + + result = original(self, *args, **kwargs) + + outputs = result.__dict__ if hasattr(result, "__dict__") else result + if isinstance(outputs, dict): + outputs = _format_outputs(outputs) + + span.set_outputs(outputs) + + if token_usage := _parse_token_usage(outputs): + span.set_attribute(SpanAttributeKey.CHAT_USAGE, token_usage) + + if model := _extract_model_info(outputs): + span.set_attribute("model", model) + + return result + + +async def patched_async_class_call(original, self, *args, **kwargs): + """Async patch method for Haystack async methods.""" + config = AutoLoggingConfig.init(flavor_name=mlflow.haystack.FLAVOR_NAME) + + if not config.log_traces: + return await original(self, *args, **kwargs) + + fullname = f"{self.__class__.__name__}.{original.__name__}" + span_type = _get_span_type(self) + + with mlflow.start_span(name=fullname, span_type=span_type) as span: + inputs = _construct_full_inputs(original, self, *args, **kwargs) + span.set_inputs(inputs) + _set_span_attributes(span, self) + + result = await original(self, *args, **kwargs) + + outputs = result.__dict__ if hasattr(result, "__dict__") else result + if isinstance(outputs, dict): + outputs = _format_outputs(outputs) + + span.set_outputs(outputs) + + if token_usage := _parse_token_usage(outputs): + span.set_attribute(SpanAttributeKey.CHAT_USAGE, token_usage) + + if model := _extract_model_info(outputs): + span.set_attribute("model", model) + + return result + + +def _get_span_type(instance: Any) -> str: + """Determine the span type based on the instance type.""" + if instance is None: + return SpanType.TOOL + + class_name = instance.__class__.__name__.lower() + + # Define span type mappings + span_type_mapping = { + "pipeline": SpanType.CHAIN, + "asyncpipeline": SpanType.CHAIN, + "agent": SpanType.AGENT, + "generator": SpanType.CHAT_MODEL, + "llm": SpanType.CHAT_MODEL, + "chat": SpanType.CHAT_MODEL, + "retriever": SpanType.RETRIEVER, + "search": SpanType.RETRIEVER, + "embed": SpanType.EMBEDDING, + "toolinvoker": SpanType.TOOL, + } + + # Check exact matches first + if class_name in span_type_mapping: + return span_type_mapping[class_name] + + # Check partial matches + for key, span_type in span_type_mapping.items(): + if key in class_name: + return span_type + + return SpanType.TOOL + + +def _construct_full_inputs(func, *args, **kwargs) -> dict[str, Any]: + """Construct inputs for haystack components following smolagents/crewai pattern.""" + signature = inspect.signature(func) + arguments = signature.bind_partial(*args, **kwargs).arguments + + if "self" in arguments: + arguments.pop("self") + + # Special handling for Pipeline.run - extract meaningful inputs + if "data" in arguments and isinstance(arguments.get("data"), dict): + data = arguments["data"] + # Extract question/query fields from any component + for component_data in data.values(): + if isinstance(component_data, dict): + for key in ["question", "query", "prompt", "text"]: + if key in component_data: + return {key: component_data[key]} + return {} + + return { + k: v.__dict__ if hasattr(v, "__dict__") else v + for k, v in arguments.items() + if v is not None + } + + +def _format_outputs(outputs: Any) -> Any: + """Format outputs for tracing.""" + if not isinstance(outputs, dict): + return outputs + + formatted_outputs = {} + for key, value in outputs.items(): + if isinstance(value, dict) and "replies" in value: + formatted_component = {} + replies = value["replies"] + + if isinstance(replies, list) and len(replies) == 1: + formatted_component["replies"] = replies[0] + else: + formatted_component["replies"] = replies + + if meta := value.get("meta"): + formatted_component["meta"] = meta + + for field_key, field_value in value.items(): + if field_key not in ["replies", "meta"]: + formatted_component[field_key] = field_value + + formatted_outputs[key] = formatted_component + else: + formatted_outputs[key] = value + + return formatted_outputs + + +def _set_span_attributes(span: LiveSpan, instance): + """Set attributes on the span based on the instance type.""" + # Always set message format + span.set_attribute(SpanAttributeKey.MESSAGE_FORMAT, "haystack") + + try: + if hasattr(instance, "graph"): # Pipeline + attributes = _get_pipeline_attributes(instance) + else: # Component + attributes = _get_component_attributes(instance) + + for key, value in attributes.items(): + if value is not None: + span.set_attribute(key, str(value) if isinstance(value, (list, dict)) else value) + except Exception as e: + _logger.debug(f"Failed to set span attributes: {e}") + + +def _get_pipeline_attributes(pipeline) -> dict[str, Any]: + """Extract attributes for a Pipeline instance.""" + attributes = {} + + if hasattr(pipeline, "graph") and hasattr(pipeline.graph, "nodes"): + nodes = pipeline.graph.nodes + if nodes: + attributes["components"] = list(nodes.keys()) + attributes["component_count"] = len(nodes) + + return attributes + + +def _get_component_attributes(instance) -> dict[str, Any]: + """Extract attributes for a component instance.""" + attributes = {"type": instance.__class__.__name__} + + if hasattr(instance, "_init_parameters"): + for key, value in instance._init_parameters.items(): + if key.lower() not in ["api_key", "token"]: + attributes[key] = str(value) if value is not None else None + + if hasattr(instance, "__haystack_input__"): + inputs = _extract_socket_names(instance.__haystack_input__) + if inputs: + attributes["input_types"] = inputs + + if hasattr(instance, "__haystack_output__"): + outputs = _extract_socket_names(instance.__haystack_output__) + if outputs: + attributes["output_types"] = outputs + + return attributes + + +def _extract_socket_names(sockets) -> list[str]: + """Extract socket names from a Haystack Sockets object.""" + try: + if hasattr(sockets, "_sockets"): + return list(sockets._sockets.keys()) + elif hasattr(sockets, "__dict__"): + return [k for k in sockets.__dict__.keys() if not k.startswith("_")] + except Exception: + pass + return [] + + +def _parse_token_usage(outputs: dict[str, Any]) -> Optional[dict[str, int]]: + """Parse token usage from outputs.""" + if not isinstance(outputs, dict): + return None + + try: + # Check for meta information in outputs + meta = outputs.get("meta", {}) + if isinstance(meta, list) and meta: + meta = meta[0] + + if isinstance(meta, dict) and "usage" in meta: + usage = meta["usage"] + if isinstance(usage, dict): + return { + TokenUsageKey.INPUT_TOKENS: usage.get("prompt_tokens", 0), + TokenUsageKey.OUTPUT_TOKENS: usage.get("completion_tokens", 0), + TokenUsageKey.TOTAL_TOKENS: usage.get("total_tokens", 0), + } + except Exception as e: + _logger.debug(f"Failed to parse token usage: {e}") + + return None + + +def _extract_model_info(outputs: dict[str, Any]) -> Optional[str]: + """Extract model information from outputs.""" + if not isinstance(outputs, dict): + return None + + try: + meta = outputs.get("meta", {}) + if isinstance(meta, list) and meta: + meta = meta[0] + + if isinstance(meta, dict) and "model" in meta: + return meta["model"] + except Exception as e: + _logger.debug(f"Failed to extract model info: {e}") + + return None diff --git a/mlflow/ml-package-versions.yml b/mlflow/ml-package-versions.yml index efe4777cf8634..ff3a3408512e2 100644 --- a/mlflow/ml-package-versions.yml +++ b/mlflow/ml-package-versions.yml @@ -1017,6 +1017,19 @@ groq: run: pytest tests/groq test_tracing_sdk: true +haystack: + package_info: + pip_release: "haystack-ai" + module_name: "haystack" + install_dev: | + pip install git+https://github.com/deepset-ai/haystack + autologging: + minimum: "2.0.0" + maximum: "2.10.0" + requirements: + run: pytest tests/haystack + test_tracing_sdk: true + bedrock: package_info: pip_release: "boto3" diff --git a/mlflow/ml_package_versions.py b/mlflow/ml_package_versions.py index 4a739dca39966..f9348a95d3003 100644 --- a/mlflow/ml_package_versions.py +++ b/mlflow/ml_package_versions.py @@ -415,6 +415,16 @@ "maximum": "0.30.0" } }, + "haystack": { + "package_info": { + "pip_release": "haystack-ai", + "module_name": "haystack" + }, + "autologging": { + "minimum": "2.0.0", + "maximum": "2.10.0" + } + }, "bedrock": { "package_info": { "pip_release": "boto3", @@ -457,6 +467,7 @@ "mistral": "mistralai", "litellm": "litellm", "groq": "groq", + "haystack": "haystack", "bedrock": "boto3", "pyspark.ml": "pyspark" } diff --git a/mlflow/server/js/package.json b/mlflow/server/js/package.json index a88df9650f7a7..3d21600891f50 100644 --- a/mlflow/server/js/package.json +++ b/mlflow/server/js/package.json @@ -24,9 +24,9 @@ "graphql-codegen:clean": "find . -path '**/__generated__/*.ts' | xargs rm" }, "dependencies": { - "@ag-grid-community/client-side-row-model": "^27.2.1", - "@ag-grid-community/core": "^27.2.1", - "@ag-grid-community/react": "^27.2.1", + "@ag-grid-community/client-side-row-model": "^28.0.2", + "@ag-grid-community/core": "^28.0.2", + "@ag-grid-community/react": "^28.0.2", "@apollo/client": "^3.6.9", "@apollo/client-3-12": "npm:@apollo/client@^3.12.7", "@craco/craco": "7.0.0-alpha.0", diff --git a/mlflow/server/js/yarn.lock b/mlflow/server/js/yarn.lock index 3d3d852fe8419..05199245d2cad 100644 --- a/mlflow/server/js/yarn.lock +++ b/mlflow/server/js/yarn.lock @@ -30,32 +30,32 @@ __metadata: languageName: node linkType: hard -"@ag-grid-community/client-side-row-model@npm:^27.2.1": - version: 27.3.0 - resolution: "@ag-grid-community/client-side-row-model@npm:27.3.0" +"@ag-grid-community/client-side-row-model@npm:^28.0.2": + version: 28.2.1 + resolution: "@ag-grid-community/client-side-row-model@npm:28.2.1" dependencies: - "@ag-grid-community/core": "npm:~27.3.0" - checksum: 10c0/996e2e8b28dc66228fe835983cf5d2c1c728ddac92c4e006db4ef25af1cd1f1d486d67a149d3c38f66e5e6baf523017b8cb6fb117acadd1a2893406592fd0fc4 + "@ag-grid-community/core": "npm:~28.2.1" + checksum: 10c0/1135c024beb5d89c9036b2adcd2b36a083ec7ca2203bb3445c79dd44a1b8a95e001b2f5fa5319445bcbf2226d3440ec31415d882498e996c7c26cb5a6adb5e85 languageName: node linkType: hard -"@ag-grid-community/core@npm:^27.2.1, @ag-grid-community/core@npm:~27.3.0": - version: 27.3.0 - resolution: "@ag-grid-community/core@npm:27.3.0" - checksum: 10c0/b3b071cf4e2fa459dfaf6548d0f630f4f1a3d16c907f12f7b769b51d3dded476f3e53b1ec7696abb003dcbf53095da46234aebcc94f1bc82f7b73638b1615420 +"@ag-grid-community/core@npm:^28.0.2, @ag-grid-community/core@npm:~28.2.1": + version: 28.2.1 + resolution: "@ag-grid-community/core@npm:28.2.1" + checksum: 10c0/6fa7540cd5b3391f879959b297d98cd3b3b43b7ad3216d499c5da3ce048c6bbfe5a73b2928dfaa9024559f63d14e5bbb7282c90fc06fef81cf2b88c1adb0d89f languageName: node linkType: hard -"@ag-grid-community/react@npm:^27.2.1": - version: 27.3.0 - resolution: "@ag-grid-community/react@npm:27.3.0" +"@ag-grid-community/react@npm:^28.0.2": + version: 28.2.1 + resolution: "@ag-grid-community/react@npm:28.2.1" dependencies: prop-types: "npm:^15.8.1" peerDependencies: - "@ag-grid-community/core": ~27.3.0 + "@ag-grid-community/core": ~28.2.1 react: ^16.3.0 || ^17.0.0 || ^18.0.0 react-dom: ^16.3.0 || ^17.0.0 || ^18.0.0 - checksum: 10c0/116f135cc961e6d907cc15373ba0c3106e5f110be559e915826d57dd96309a06449147a1ab19b7dfa50c20a23fdc4d3ee250f2ea308b115addeb018f8b3e3119 + checksum: 10c0/91ddb9d313d0e152dd235952ec067f02d4c6bdf8b6d69c386babdf64114377c5ef096f92681498385aeb0236bafd314f2d7dc62fb46245e10560a14f3cd099d2 languageName: node linkType: hard @@ -4313,9 +4313,9 @@ __metadata: version: 0.0.0-use.local resolution: "@mlflow/mlflow@workspace:." dependencies: - "@ag-grid-community/client-side-row-model": "npm:^27.2.1" - "@ag-grid-community/core": "npm:^27.2.1" - "@ag-grid-community/react": "npm:^27.2.1" + "@ag-grid-community/client-side-row-model": "npm:^28.0.2" + "@ag-grid-community/core": "npm:^28.0.2" + "@ag-grid-community/react": "npm:^28.0.2" "@apollo/client": "npm:^3.6.9" "@apollo/client-3-12": "npm:@apollo/client@^3.12.7" "@babel/core": "npm:^7.27.3" diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index a036731c0f61a..66fb818931db5 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -3195,6 +3195,7 @@ def print_auto_logged_info(r): "crewai": "mlflow.crewai", "smolagents": "mlflow.smolagents", "groq": "mlflow.groq", + "haystack": "mlflow.haystack", "boto3": "mlflow.bedrock", "mistralai": "mlflow.mistral", "pydantic_ai": "mlflow.pydantic_ai", diff --git a/tests/haystack/test_haystack_autolog.py b/tests/haystack/test_haystack_autolog.py new file mode 100644 index 0000000000000..0227026e9b449 --- /dev/null +++ b/tests/haystack/test_haystack_autolog.py @@ -0,0 +1,431 @@ +from unittest.mock import MagicMock, patch + +import pytest + +import mlflow +import mlflow.haystack +from mlflow.entities.span import SpanAttributeKey, SpanType + +from tests.tracing.helper import get_traces + + +def create_mock_pipeline(): + """Create a mock Haystack pipeline.""" + pipeline = MagicMock() + pipeline.__class__.__name__ = "Pipeline" + + # Mock the graph attribute + mock_node1 = MagicMock() + mock_node1.__class__.__name__ = "PromptBuilder" + mock_node2 = MagicMock() + mock_node2.__class__.__name__ = "OpenAIGenerator" + + pipeline.graph = MagicMock() + pipeline.graph.nodes = {"prompt_builder": mock_node1, "llm": mock_node2} + + return pipeline + + +def create_mock_async_pipeline(): + """Create a mock Haystack AsyncPipeline.""" + pipeline = MagicMock() + pipeline.__class__.__name__ = "AsyncPipeline" + + # Mock the graph attribute + mock_node1 = MagicMock() + mock_node1.__class__.__name__ = "PromptBuilder" + mock_node2 = MagicMock() + mock_node2.__class__.__name__ = "OpenAIGenerator" + + pipeline.graph = MagicMock() + pipeline.graph.nodes = {"prompt_builder": mock_node1, "llm": mock_node2} + + return pipeline + + +def create_mock_component(component_type="OpenAIGenerator"): + """Create a mock Haystack component.""" + component = MagicMock() + component.__class__.__name__ = component_type + component.model = "gpt-4o-mini" + component._init_parameters = { + "model": "gpt-4o-mini", + "temperature": 0.7, + "api_key": "secret_key", # Should be filtered out + } + + # Mock input/output sockets + input_socket = MagicMock() + input_socket.type = "str" + component.__haystack_input__ = {"prompt": input_socket} + + output_socket = MagicMock() + output_socket.type = "List[str]" + component.__haystack_output__ = {"replies": output_socket} + + return component + + +DUMMY_PIPELINE_INPUT = {"prompt_builder": {"question": "Who lives in Paris?"}} + +DUMMY_PIPELINE_OUTPUT = { + "llm": { + "replies": ["Many people live in Paris, including residents, tourists, and workers."], + "meta": [ + { + "model": "gpt-4o-mini", + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + } + ], + } +} + +DUMMY_COMPONENT_INPUT = {"prompt": "Answer the question: Who lives in Paris?"} + +DUMMY_COMPONENT_OUTPUT = { + "replies": ["Many people live in Paris, including residents, tourists, and workers."], + "meta": [ + { + "model": "gpt-4o-mini", + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + } + ], +} + + +def test_pipeline_autolog(): + """Test autologging for Haystack pipelines.""" + # Clear any existing traces + mlflow.tracking.fluent._active_experiment_id = None + + with patch("haystack.core.pipeline.pipeline.Pipeline", create_mock_pipeline): + mlflow.haystack.autolog() + + pipeline = create_mock_pipeline() + + # Mock the run method + def mock_run(self, data, *args, **kwargs): + return DUMMY_PIPELINE_OUTPUT + + # Set the correct __name__ attribute to match the expected method name + mock_run.__name__ = "run" + + # Apply patching manually since we're mocking + from mlflow.haystack.autolog import patched_class_call + + def patched_run(data, *args, **kwargs): + return patched_class_call(mock_run, pipeline, data, *args, **kwargs) + + pipeline.run = patched_run + + # Run the pipeline + pipeline.run(DUMMY_PIPELINE_INPUT) + + # Check traces + traces = get_traces() + assert len(traces) == 1 + assert traces[0].info.status == "OK" + assert len(traces[0].data.spans) == 1 + + span = traces[0].data.spans[0] + assert span.name == "Pipeline.run" + assert span.span_type == SpanType.CHAIN + assert span.inputs == {"question": "Who lives in Paris?"} + expected_outputs = { + "llm": { + "replies": "Many people live in Paris, including residents, tourists, and workers.", + "meta": [ + { + "model": "gpt-4o-mini", + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + } + ], + } + } + assert span.outputs == expected_outputs + + # Check attributes + assert span.attributes.get(SpanAttributeKey.MESSAGE_FORMAT) == "haystack" + assert span.attributes.get("components") == "['prompt_builder', 'llm']" + assert span.attributes.get("component_count") == 2 + + +def test_pipeline_component_execution(): + """Test component execution tracing within a pipeline.""" + # Clear any existing traces + mlflow.tracking.fluent._active_experiment_id = None + + mlflow.haystack.autolog() + + # Create mock component + component = create_mock_component() + component_dict = {"instance": component} + + # Mock _run_component method + def mock_run_component(component_name, component, inputs, component_visits, parent_span=None): + return DUMMY_COMPONENT_OUTPUT + + mock_run_component.__name__ = "_run_component" + + # Apply patching - for static methods, we pass None as self + from mlflow.haystack.autolog import patched_class_call + + def run_component(*args, **kwargs): + return patched_class_call(mock_run_component, None, *args, **kwargs) + + # Run the component + run_component("llm", component_dict, DUMMY_COMPONENT_INPUT, {"llm": 1}) + + # Check traces + traces = get_traces() + assert len(traces) == 1 + assert traces[0].info.status == "OK" + assert len(traces[0].data.spans) == 1 + + span = traces[0].data.spans[0] + assert span.name == "NoneType._run_component" # Static method has no class + assert span.span_type == SpanType.TOOL # Static methods default to TOOL type + # For static methods, inputs include all args but with different parameter names + assert "component" in span.inputs + assert span.inputs["component"] == "llm" + assert span.outputs == DUMMY_COMPONENT_OUTPUT + + # Check attributes + assert span.attributes.get(SpanAttributeKey.MESSAGE_FORMAT) == "haystack" + # Static methods with None self won't have component-specific attributes + + # Check token usage in standard format + chat_usage = span.attributes.get(SpanAttributeKey.CHAT_USAGE) + assert chat_usage is not None + assert chat_usage["input_tokens"] == 10 + assert chat_usage["output_tokens"] == 20 + assert chat_usage["total_tokens"] == 30 + + +def test_component_meta_patching(): + """Test ComponentMeta patching for dynamic component wrapping.""" + # Clear any existing traces + mlflow.tracking.fluent._active_experiment_id = None + + mlflow.haystack.autolog() + + # Create a simple component + class MockComponent: + def __init__(self): + self._init_parameters = {"model": "test-model"} + + def run(self, **kwargs): + return DUMMY_COMPONENT_OUTPUT + + # Apply run method patching + from mlflow.haystack.autolog import patched_class_call + + original_run = MockComponent.run + + def patched_run(self, **kwargs): + return patched_class_call(original_run, self, **kwargs) + + MockComponent.run = patched_run + + # Create and run component + component = MockComponent() + component.run(**DUMMY_COMPONENT_INPUT) + + # Check traces + traces = get_traces() + assert len(traces) == 1 + assert traces[0].info.status == "OK" + assert len(traces[0].data.spans) == 1 + + span = traces[0].data.spans[0] + assert span.name == "MockComponent.run" + assert span.span_type == SpanType.TOOL + assert span.inputs == {"kwargs": DUMMY_COMPONENT_INPUT} + # Direct component output (no formatting applied outside of pipeline context) + assert span.outputs == DUMMY_COMPONENT_OUTPUT + + +@pytest.mark.asyncio +async def test_async_pipeline_autolog(): + """Test autologging for Haystack AsyncPipeline.""" + # Clear any existing traces + mlflow.tracking.fluent._active_experiment_id = None + + with patch("haystack.core.pipeline.async_pipeline.AsyncPipeline", create_mock_async_pipeline): + mlflow.haystack.autolog() + + pipeline = create_mock_async_pipeline() + + # Mock the async run method + async def mock_run_async(self, data, *args, **kwargs): + return DUMMY_PIPELINE_OUTPUT + + mock_run_async.__name__ = "run_async" + + # Apply patching manually since we're mocking + from mlflow.haystack.autolog import patched_async_class_call + + async def wrapped_run_async(data, *args, **kwargs): + return await patched_async_class_call(mock_run_async, pipeline, data, *args, **kwargs) + + pipeline.run_async = wrapped_run_async + + # Run the pipeline + await pipeline.run_async(DUMMY_PIPELINE_INPUT) + + # Check traces + traces = get_traces() + assert len(traces) == 1 + assert traces[0].info.status == "OK" + assert len(traces[0].data.spans) == 1 + + span = traces[0].data.spans[0] + assert span.name == "AsyncPipeline.run_async" + assert span.span_type == SpanType.CHAIN + assert span.inputs == {"question": "Who lives in Paris?"} + expected_outputs = { + "llm": { + "replies": "Many people live in Paris, including residents, tourists, and workers.", + "meta": [ + { + "model": "gpt-4o-mini", + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + } + ], + } + } + assert span.outputs == expected_outputs + + +@pytest.mark.asyncio +async def test_async_pipeline_generator(): + """Test autologging for AsyncPipeline.run_async_generator().""" + # Clear any existing traces + mlflow.tracking.fluent._active_experiment_id = None + + mlflow.haystack.autolog() + + pipeline = create_mock_async_pipeline() + + # Mock the async generator method + async def mock_run_async_generator(*args, **kwargs): + yield {"component1": {"output": "partial1"}} + yield {"llm": DUMMY_COMPONENT_OUTPUT} + + # Apply patching for async generator + # For async generators, we need special handling since we can't use the standard async patch + from mlflow.entities import SpanType + + async def wrapped_generator(*args, **kwargs): + fullname = f"{pipeline.__class__.__name__}.run_async_generator" + + with mlflow.start_span(name=fullname, span_type=SpanType.CHAIN) as span: + span.set_inputs({"data": args[0] if args else kwargs}) + span.set_attribute(SpanAttributeKey.MESSAGE_FORMAT, "haystack") + span.set_attribute("components", ["prompt_builder", "llm"]) + span.set_attribute("component_count", 2) + + accumulated_outputs = {} + chunk_count = 0 + + async for output in mock_run_async_generator(*args, **kwargs): + accumulated_outputs.update(output) + chunk_count += 1 + yield output + + span.set_outputs(accumulated_outputs) + span.set_attribute("chunks", chunk_count) + + # Run the generator + outputs = [] + async for output in wrapped_generator(DUMMY_PIPELINE_INPUT): + outputs.append(output) + + # Check traces + traces = get_traces() + assert len(traces) == 1 + assert traces[0].info.status == "OK" + assert len(traces[0].data.spans) == 1 + + span = traces[0].data.spans[0] + assert span.name == "AsyncPipeline.run_async_generator" + assert span.span_type == SpanType.CHAIN + + # Check accumulated outputs + assert "component1" in span.outputs + assert "llm" in span.outputs + assert span.attributes.get("chunks") == 2 + + +def test_autolog_disable(): + """Test disabling autolog.""" + # Clear any existing traces + mlflow.tracking.fluent._active_experiment_id = None + + with patch("haystack.core.pipeline.pipeline.Pipeline", create_mock_pipeline): + # Enable autolog first + mlflow.haystack.autolog() + + # Then disable it + mlflow.haystack.autolog(disable=True) + + pipeline = create_mock_pipeline() + + # Mock the run method + def mock_run(*args, **kwargs): + return DUMMY_PIPELINE_OUTPUT + + pipeline.run = MagicMock(side_effect=mock_run) + + # Run the pipeline + pipeline.run(DUMMY_PIPELINE_INPUT) + + # Check that no traces were created + traces = get_traces() + assert len(traces) == 0 + + +def test_pipeline_error_handling(): + """Test error handling in pipeline execution.""" + # Clear any existing traces + mlflow.tracking.fluent._active_experiment_id = None + + with patch("haystack.core.pipeline.pipeline.Pipeline", create_mock_pipeline): + mlflow.haystack.autolog() + + pipeline = create_mock_pipeline() + + # Mock the run method to raise an error + error_msg = "Pipeline execution failed" + + def mock_run(self, data, *args, **kwargs): + raise RuntimeError(error_msg) + + mock_run.__name__ = "run" + + pipeline.run = MagicMock(side_effect=mock_run) + + # Apply patching manually + from mlflow.haystack.autolog import patched_class_call + + def patched_error_run(data, *args, **kwargs): + return patched_class_call(mock_run, pipeline, data, *args, **kwargs) + + pipeline.run = patched_error_run + + # Run the pipeline and expect an error + with pytest.raises(RuntimeError, match=error_msg): + pipeline.run(DUMMY_PIPELINE_INPUT) + + # Check traces + traces = get_traces() + assert len(traces) == 1 + assert traces[0].info.status == "ERROR" + assert len(traces[0].data.spans) == 1 + + span = traces[0].data.spans[0] + assert span.name == "Pipeline.run" + assert span.status.status_code == "ERROR" + assert error_msg in span.status.description or any( + error_msg in str(event) for event in (span.events or []) + )