diff --git a/integrations/mistral/examples/streaming_chat_with_rag.py b/integrations/mistral/examples/streaming_chat_with_rag.py index 776a11796c..a44e8fef18 100644 --- a/integrations/mistral/examples/streaming_chat_with_rag.py +++ b/integrations/mistral/examples/streaming_chat_with_rag.py @@ -1,4 +1,4 @@ -# To run this example, you will need an to set a `MISTRAL_API_KEY` environment variable. +# To run this example, you will need to set a `MISTRAL_API_KEY` environment variable. # This example streams chat replies to the console. from haystack import Pipeline diff --git a/integrations/mistral/src/haystack_integrations/components/agents/__init__.py b/integrations/mistral/src/haystack_integrations/components/agents/__init__.py new file mode 100644 index 0000000000..14a99b83de --- /dev/null +++ b/integrations/mistral/src/haystack_integrations/components/agents/__init__.py @@ -0,0 +1,3 @@ +from .mistral.agent import MistralAgent + +__all__ = ["MistralAgent"] diff --git a/integrations/mistral/src/haystack_integrations/components/agents/mistral/__init__.py b/integrations/mistral/src/haystack_integrations/components/agents/mistral/__init__.py new file mode 100644 index 0000000000..e0706300bb --- /dev/null +++ b/integrations/mistral/src/haystack_integrations/components/agents/mistral/__init__.py @@ -0,0 +1,3 @@ +from .agent import MistralAgent + +__all__ = ["MistralAgent"] diff --git a/integrations/mistral/src/haystack_integrations/components/agents/mistral/agent.py b/integrations/mistral/src/haystack_integrations/components/agents/mistral/agent.py new file mode 100644 index 0000000000..569440bee5 --- /dev/null +++ b/integrations/mistral/src/haystack_integrations/components/agents/mistral/agent.py @@ -0,0 +1,592 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from typing import Any, Literal, Optional, Union + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses import ( + ChatMessage, + StreamingCallbackT, + StreamingChunk, + ToolCall, +) +from haystack.lazy_imports import LazyImport +from haystack.tools import ( + ToolsType, + _check_duplicate_tool_names, + deserialize_tools_or_toolset_inplace, + flatten_tools_or_toolsets, + serialize_tools_or_toolset, +) +from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable + +with LazyImport(message="Run 'pip install mistralai'") as mistralai_import: + from mistralai import Mistral, models + +logger = logging.getLogger(__name__) + +ToolChoiceType = Union[ + Literal["auto", "none", "any", "required"], + dict[str, Any], +] + + +@component +class MistralAgent: + """ + Generates text using Mistral AI Agents via the official Mistral Python SDK. + + NOTE: + If you get the error message: + + "Cannot set function calling tools in the request and have tools in the agent" + + This is a Mistral API limitation - if your agent in the Mistral console already has tools configured, you cannot + pass additional tools in the API request. They are mutually exclusive. + + For more information on Mistral Agents, see: + [Mistral Agents API](https://docs.mistral.ai/api/endpoint/agents) + + Usage example: + ```python + from haystack_integrations.components.agents.mistral import MistralAgent + from haystack.dataclasses import ChatMessage + + # Initialize with your agent ID from the Mistral console + agent = MistralAgent(agent_id="your-agent-id") + + messages = [ChatMessage.from_user("What can you help me with?")] + response = agent.run(messages) + print(response["replies"][0].text) + ``` + + Streaming example: + ```python + def my_callback(chunk): + print(chunk.content, end="", flush=True) + + agent = MistralAgent( + agent_id="your-agent-id", + streaming_callback=my_callback + ) + response = agent.run([ChatMessage.from_user("Tell me a story")]) + ``` + """ + + def __init__( + self, + agent_id: str, + api_key: Secret = Secret.from_env_var("MISTRAL_API_KEY"), + streaming_callback: Optional[StreamingCallbackT] = None, + tools: Optional[ToolsType] = None, + tool_choice: Optional[ToolChoiceType] = None, + parallel_tool_calls: bool = True, + generation_kwargs: Optional[dict[str, Any]] = None, + *, + timeout_ms: Optional[int] = 30000, + ): + """ + Creates an instance of MistralAgent. + + :param agent_id: + The ID of the Mistral Agent to use. Required. Get this from the + Mistral AI console after creating an agent. + :param api_key: + The Mistral API key. Defaults to environment variable `MISTRAL_API_KEY`. + :param streaming_callback: + A callback function called when a new token is received from the stream. + :param tools: + Additional tools the agent can use beyond its pre-configured tools. + A list of Tool and/or Toolset objects. + :param tool_choice: + Controls which tool is called. Options: + - "auto": Model decides whether to use tools + - "none": No tools, generate text only + - "any" or "required": Must call one or more tools + - {"type": "function", "function": {"name": "..."}}: Force specific tool + :param parallel_tool_calls: + Whether to enable parallel function calling. Defaults to True. + :param generation_kwargs: + Additional parameters for the API call. Supported parameters: + - `max_tokens`: Maximum tokens to generate + - `frequency_penalty`: Penalize word repetition (default: 0) + - `presence_penalty`: Encourage vocabulary diversity (default: 0) + - `n`: Number of completions to return + - `random_seed`: Seed for deterministic results + - `stop`: Stop sequences + - `response_format`: Output format (text/json_object/json_schema) + - `prediction`: Expected completion for optimization + - `prompt_mode`: Set to "reasoning" for reasoning models + :param timeout_ms: + Request timeout in milliseconds. Defaults to 30000 (30 seconds). + """ + self.agent_id = agent_id + self.api_key = api_key + self.streaming_callback = streaming_callback + self.tools = tools + self.tool_choice = tool_choice + self.parallel_tool_calls = parallel_tool_calls + self.generation_kwargs = generation_kwargs or {} + self.timeout_ms = timeout_ms + + _check_duplicate_tool_names(flatten_tools_or_toolsets(self.tools)) + + self._client = None + self._async_client = None + + def warm_up(self): + if self._client: + return + mistralai_import.check() + self._client = Mistral(api_key=self.api_key.resolve_value(), timeout_ms=self.timeout_ms) + + @staticmethod + def _convert_messages(messages: list[ChatMessage]) -> list[dict[str, Any]]: + mistral_messages = [] + + for msg in messages: + # OpenAI format is compatible with Mistral + openai_format = msg.to_openai_dict_format() + + # ensure content is a string (not a list of content blocks) + content = openai_format.get("content", "") + if isinstance(content, list): + # Extract text from content blocks + text_parts = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text_parts.append(block.get("text", "")) + elif isinstance(block, str): + text_parts.append(block) + content = "".join(text_parts) + + mistral_message = { + "role": openai_format.get("role", "user"), + "content": content, + } + + # include tool_call_id for tool messages + if openai_format.get("tool_call_id"): + mistral_message["tool_call_id"] = openai_format["tool_call_id"] + + # include tool_calls for assistant messages + if openai_format.get("tool_calls"): + mistral_message["tool_calls"] = openai_format["tool_calls"] + + mistral_messages.append(mistral_message) + + return mistral_messages + + def _build_tools(self, tools: Optional[ToolsType] = None) -> Optional[list[dict[str, Any]]]: + """Convert Haystack tools to Mistral format.""" + flattened_tools = flatten_tools_or_toolsets(tools or self.tools) + if not flattened_tools: + return None + return [{"type": "function", "function": tool.tool_spec} for tool in flattened_tools] + + @staticmethod + def _parse_response(response: Any) -> list[ChatMessage]: + """ + Parse the Mistral response into Haystack ChatMessages. + + :param response: The response from mistral.agents.complete() + :returns: + List of ChatMessage objects + """ + messages = [] + + for choice in response.choices: + message = choice.message + content = message.content or "" + + # Parse tool calls if present + tool_calls = [] + if message.tool_calls: + for tc in message.tool_calls: + if tc.type == "function": + try: + arguments = json.loads(tc.function.arguments or "{}") + except json.JSONDecodeError: + arguments = {} + tool_calls.append( + ToolCall( + id=tc.id, + tool_name=tc.function.name, + arguments=arguments, + ) + ) + + # Build metadata + meta = { + "model": response.model, + "index": choice.index, + "finish_reason": choice.finish_reason, + "usage": { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + } + if response.usage + else None, + } + + chat_message = ChatMessage.from_assistant( + text=content if content else None, + tool_calls=tool_calls if tool_calls else None, + meta=meta, + ) + messages.append(chat_message) + + return messages + + @staticmethod + def _handle_streaming( + stream_response: Any, + callback: StreamingCallbackT, + ) -> list[ChatMessage]: + """ + Handle streaming response from the Mistral. + + :param stream_response: The streaming response iterator + :param callback: The callback to invoke for each chunk + :returns: + List containing the final assembled ChatMessage + """ + collected_content = "" + collected_tool_calls: dict[int, dict] = {} + meta: dict[str, Any] = {} + + for chunk in stream_response: + # Extract metadata from response (model is on chunk.data, not chunk) + if not meta and chunk.data.model: + meta["model"] = chunk.data.model + + for choice in chunk.data.choices: + delta = choice.delta + + # Handle text content + if delta.content: + collected_content += delta.content + streaming_chunk = StreamingChunk( + content=delta.content, + meta={ + "model": chunk.data.model, + "index": choice.index, + "finish_reason": choice.finish_reason, + }, + ) + callback(streaming_chunk) + + # Handle tool calls + if delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index if hasattr(tc, "index") else 0 + if idx not in collected_tool_calls: + collected_tool_calls[idx] = { + "id": "", + "name": "", + "arguments": "", + } + if tc.id: + collected_tool_calls[idx]["id"] = tc.id + if hasattr(tc, "function") and tc.function: + if tc.function.name: + collected_tool_calls[idx]["name"] = tc.function.name + if tc.function.arguments: + collected_tool_calls[idx]["arguments"] += tc.function.arguments + + # Capture finish reason + if choice.finish_reason: + meta["finish_reason"] = choice.finish_reason + meta["index"] = choice.index + + # Capture usage from final chunk + if chunk.data.usage: + meta["usage"] = { + "prompt_tokens": chunk.data.usage.prompt_tokens, + "completion_tokens": chunk.data.usage.completion_tokens, + "total_tokens": chunk.data.usage.total_tokens, + } + + # Build final tool calls + tool_calls = [] + for idx in sorted(collected_tool_calls.keys()): + tc_data = collected_tool_calls[idx] + try: + arguments = json.loads(tc_data["arguments"]) if tc_data["arguments"] else {} + except json.JSONDecodeError: + arguments = {} + tool_calls.append( + ToolCall( + id=tc_data["id"], + tool_name=tc_data["name"], + arguments=arguments, + ) + ) + + # Create final message + chat_message = ChatMessage.from_assistant( + text=collected_content if collected_content else None, + tool_calls=tool_calls if tool_calls else None, + meta=meta, + ) + + return [chat_message] + + @component.output_types(replies=list[ChatMessage]) + def run( + self, + messages: list[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + tools: Optional[ToolsType] = None, + tool_choice: Optional[ToolChoiceType] = None, + generation_kwargs: Optional[dict[str, Any]] = None, + ) -> dict[str, list[ChatMessage]]: + """ + Invoke the Mistral Agent with the provided messages. + + :param messages: + A list of ChatMessage instances representing the conversation. + :param streaming_callback: + A callback function for streaming. Overrides the init callback. + :param tools: + Additional tools for this request. Overrides init tools. + :param tool_choice: + Tool choice for this request. Overrides init tool_choice. + :param generation_kwargs: + Additional generation parameters. Merged with init params. + + :returns: + A dictionary with key `replies` containing a list of ChatMessage responses. + """ + self.warm_up() + + if not messages: + return {"replies": []} + + effective_callback = streaming_callback or self.streaming_callback + mistral_messages = MistralAgent._convert_messages(messages) + merged_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + request_kwargs: dict[str, Any] = { + "agent_id": self.agent_id, + "messages": mistral_messages, + } + + # add tools if provided + mistral_tools = self._build_tools(tools) + if mistral_tools: + request_kwargs["tools"] = mistral_tools + request_kwargs["parallel_tool_calls"] = self.parallel_tool_calls + + # add tool_choice only when tools are present + effective_tool_choice = tool_choice or self.tool_choice + if effective_tool_choice: + request_kwargs["tool_choice"] = effective_tool_choice + + # add generation kwargs + for key, value in merged_kwargs.items(): + if value is not None: + request_kwargs[key] = value + + try: + if effective_callback: + # streaming + request_kwargs["stream"] = True + stream_response = self._client.agents.stream(**request_kwargs) + replies = MistralAgent._handle_streaming(stream_response, effective_callback) + else: + # non-streaming + response = self._client.agents.complete(**request_kwargs) + replies = self._parse_response(response) + + return {"replies": replies} + + except Exception as e: + if isinstance(e, models.HTTPValidationError): + msg = "Mistral validation error: {detail}" + logger.error(msg, detail=e.data.detail if hasattr(e, "data") else str(e)) + error_msg = f"Mistral validation error: {e}" + raise ValueError(error_msg) from e + + elif isinstance(e, models.MistralError): + msg = "Mistral API error: {status_code} - {message}" + logger.error(msg, status_code=e.status_code, message=e.message) + error_msg = f"Mistral API error ({e.status_code}): {e.message}" + raise ValueError(error_msg) from e + + raise + + @component.output_types(replies=list[ChatMessage]) + async def run_async( + self, + messages: list[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + tools: Optional[ToolsType] = None, + tool_choice: Optional[ToolChoiceType] = None, + generation_kwargs: Optional[dict[str, Any]] = None, + ) -> dict[str, list[ChatMessage]]: + """ + Asynchronously invoke the Mistral Agent. + + Same parameters as `run()`. + """ + self.warm_up() + + if not messages: + return {"replies": []} + + effective_callback = streaming_callback or self.streaming_callback + sdk_messages = self._convert_messages(messages) + merged_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + request_kwargs: dict[str, Any] = { + "agent_id": self.agent_id, + "messages": sdk_messages, + } + + sdk_tools = self._build_tools(tools) + if sdk_tools: + request_kwargs["tools"] = sdk_tools + request_kwargs["parallel_tool_calls"] = self.parallel_tool_calls + effective_tool_choice = tool_choice or self.tool_choice + if effective_tool_choice: + request_kwargs["tool_choice"] = effective_tool_choice + + for key, value in merged_kwargs.items(): + if value is not None: + request_kwargs[key] = value + + try: + if effective_callback: + # Async streaming + request_kwargs["stream"] = True + stream_response = await self._client.agents.stream_async(**request_kwargs) + # Note: For full async streaming, we'd need an async callback + # This is a simplified version + replies = await Mistral._handle_async_streaming(stream_response, effective_callback) + else: + response = await self._client.agents.complete_async(**request_kwargs) + replies = self._parse_response(response) + + return {"replies": replies} + + except Exception as e: + if isinstance(e, (models.HTTPValidationError, models.MistralError)): + error_msg = f"Mistral API error: {e}" + raise ValueError(error_msg) from e + raise + + @staticmethod + async def _handle_async_streaming( + stream_response: Any, + callback: StreamingCallbackT, + ) -> list[ChatMessage]: + """Handle async streaming response.""" + collected_content = "" + collected_tool_calls: dict[int, dict] = {} + meta: dict[str, Any] = {} + + async for chunk in stream_response: + if not meta and chunk.data.model: + meta["model"] = chunk.data.model + + for choice in chunk.data.choices: + delta = choice.delta + + if delta.content: + collected_content += delta.content + streaming_chunk = StreamingChunk( + content=delta.content, + meta={ + "model": chunk.data.model, + "index": choice.index, + "finish_reason": choice.finish_reason, + }, + ) + # For async streaming, callback should be awaitable + if callable(callback): + result = callback(streaming_chunk) + if hasattr(result, "__await__"): + await result + + if delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index if hasattr(tc, "index") else 0 + if idx not in collected_tool_calls: + collected_tool_calls[idx] = {"id": "", "name": "", "arguments": ""} + if tc.id: + collected_tool_calls[idx]["id"] = tc.id + if hasattr(tc, "function") and tc.function: + if tc.function.name: + collected_tool_calls[idx]["name"] = tc.function.name + if tc.function.arguments: + collected_tool_calls[idx]["arguments"] += tc.function.arguments + + if choice.finish_reason: + meta["finish_reason"] = choice.finish_reason + meta["index"] = choice.index + + if chunk.data.usage: + meta["usage"] = { + "prompt_tokens": chunk.data.usage.prompt_tokens, + "completion_tokens": chunk.data.usage.completion_tokens, + "total_tokens": chunk.data.usage.total_tokens, + } + + tool_calls = [] + for idx in sorted(collected_tool_calls.keys()): + tc_data = collected_tool_calls[idx] + try: + arguments = json.loads(tc_data["arguments"]) if tc_data["arguments"] else {} + except json.JSONDecodeError: + arguments = {} + tool_calls.append(ToolCall(id=tc_data["id"], tool_name=tc_data["name"], arguments=arguments)) + + return [ + ChatMessage.from_assistant( + text=collected_content if collected_content else None, + tool_calls=tool_calls if tool_calls else None, + meta=meta, + ) + ] + + def to_dict(self) -> dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: A dictionary representation of the component. + """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + + return default_to_dict( + self, + agent_id=self.agent_id, + api_key=self.api_key.to_dict(), + streaming_callback=callback_name, + tools=serialize_tools_or_toolset(self.tools), + tool_choice=self.tool_choice, + parallel_tool_calls=self.parallel_tool_calls, + generation_kwargs=self.generation_kwargs, + timeout_ms=self.timeout_ms, + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "MistralAgent": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of the component. + :returns: + An instance of MistralAgent + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools") + + init_params = data.get("init_parameters", {}) + if init_params.get("streaming_callback"): + data["init_parameters"]["streaming_callback"] = deserialize_callable(init_params["streaming_callback"]) + + return default_from_dict(cls, data) diff --git a/integrations/mistral/src/haystack_integrations/components/converters/mistral/ocr_document_converter.py b/integrations/mistral/src/haystack_integrations/components/converters/mistral/ocr_document_converter.py index 156461f2f1..18c64d339a 100644 --- a/integrations/mistral/src/haystack_integrations/components/converters/mistral/ocr_document_converter.py +++ b/integrations/mistral/src/haystack_integrations/components/converters/mistral/ocr_document_converter.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import json import re from pathlib import Path diff --git a/integrations/mistral/src/haystack_integrations/components/embedders/mistral/document_embedder.py b/integrations/mistral/src/haystack_integrations/components/embedders/mistral/document_embedder.py index b87f33d7d4..e06064b6dd 100644 --- a/integrations/mistral/src/haystack_integrations/components/embedders/mistral/document_embedder.py +++ b/integrations/mistral/src/haystack_integrations/components/embedders/mistral/document_embedder.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + from typing import Any, Optional from haystack import component, default_to_dict diff --git a/integrations/mistral/src/haystack_integrations/components/embedders/mistral/text_embedder.py b/integrations/mistral/src/haystack_integrations/components/embedders/mistral/text_embedder.py index 74e9b5f365..f703d2643c 100644 --- a/integrations/mistral/src/haystack_integrations/components/embedders/mistral/text_embedder.py +++ b/integrations/mistral/src/haystack_integrations/components/embedders/mistral/text_embedder.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + from typing import Any, Optional from haystack import component, default_to_dict diff --git a/integrations/mistral/src/haystack_integrations/components/generators/mistral/chat/chat_generator.py b/integrations/mistral/src/haystack_integrations/components/generators/mistral/chat/chat_generator.py index 990d625760..42925a81ce 100644 --- a/integrations/mistral/src/haystack_integrations/components/generators/mistral/chat/chat_generator.py +++ b/integrations/mistral/src/haystack_integrations/components/generators/mistral/chat/chat_generator.py @@ -181,7 +181,7 @@ def to_dict(self) -> dict[str, Any]: generation_kwargs["response_format"] = json_schema # if we didn't implement the to_dict method here then the to_dict method of the superclass would be used - # which would serialiaze some fields that we don't want to serialize (e.g. the ones we don't have in + # which would serialize some fields that we don't want to serialize (e.g. the ones we don't have in # the __init__) return default_to_dict( self, diff --git a/integrations/mistral/tests/agent_mocks.py b/integrations/mistral/tests/agent_mocks.py new file mode 100644 index 0000000000..4441fc1698 --- /dev/null +++ b/integrations/mistral/tests/agent_mocks.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +"""Shared mock classes for MistralAgent tests.""" + +from typing import Optional + + +class MockUsage: + def __init__(self, prompt_tokens: int = 10, completion_tokens: int = 20, total_tokens: int = 30): + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + self.total_tokens = total_tokens + + +class MockMessage: + def __init__(self, content: str = "", tool_calls: Optional[list] = None): + self.content = content + self.tool_calls = tool_calls or [] + + +class MockChoice: + def __init__(self, message: MockMessage, index: int = 0, finish_reason: str = "stop"): + self.message = message + self.index = index + self.finish_reason = finish_reason + + +class MockResponse: + def __init__(self, choices: list, model: str = "agent-model", usage: Optional[MockUsage] = None): + self.choices = choices + self.model = model + self.usage = usage or MockUsage() diff --git a/integrations/mistral/tests/conftest.py b/integrations/mistral/tests/conftest.py new file mode 100644 index 0000000000..5b19939f80 --- /dev/null +++ b/integrations/mistral/tests/conftest.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +"""Shared fixtures for MistralAgent tests.""" + +import pytest +from haystack.dataclasses import ChatMessage + +from tests.agent_mocks import MockChoice, MockMessage, MockResponse, MockUsage + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant"), + ChatMessage.from_user("What's the capital of France?"), + ] + + +@pytest.fixture +def mock_sdk_response(): + return MockResponse( + choices=[ + MockChoice( + message=MockMessage(content="The capital of France is Paris."), + finish_reason="stop", + index=0, + ) + ], + model="mistral-agent-model", + usage=MockUsage(prompt_tokens=15, completion_tokens=25, total_tokens=40), + ) diff --git a/integrations/mistral/tests/test_mistral_agent.py b/integrations/mistral/tests/test_mistral_agent.py new file mode 100644 index 0000000000..5d53e71b3a --- /dev/null +++ b/integrations/mistral/tests/test_mistral_agent.py @@ -0,0 +1,443 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional +from unittest.mock import MagicMock, patch + +import pytest +from haystack import Pipeline +from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall +from haystack.tools import Tool, Toolset +from haystack.utils.auth import Secret + +from haystack_integrations.components.agents.mistral.agent import MistralAgent + +# Import shared mocks +from tests.agent_mocks import MockChoice, MockMessage, MockResponse, MockUsage + + +class MockFunction: + def __init__(self, name: str, arguments: str): + self.name = name + self.arguments = arguments + + +class MockToolCall: + def __init__(self, call_id: str, name: str, arguments: str, call_type: str = "function"): + self.id = call_id + self.type = call_type + self.function = MockFunction(name=name, arguments=arguments) + + +class MockDelta: + def __init__(self, content: Optional[str] = None, tool_calls: Optional[list] = None): + self.content = content + self.tool_calls = tool_calls + + +class MockStreamChoice: + def __init__(self, delta: MockDelta, index: int = 0, finish_reason: Optional[str] = None): + self.delta = delta + self.index = index + self.finish_reason = finish_reason + + +class MockStreamData: + def __init__(self, choices: list, model: str = "agent-model", usage: Optional[MockUsage] = None): + self.choices = choices + self.model = model + self.usage = usage + + +class MockStreamChunk: + def __init__(self, data: MockStreamData): + self.data = data + + +# ============================================================================= +# FIXTURES (not shared) +# ============================================================================= + + +@pytest.fixture +def mock_tool_function(): + def weather(city: str) -> str: + return f"Weather in {city}: Sunny, 22°C" + + return weather + + +@pytest.fixture +def tools(mock_tool_function): + return [ + Tool( + name="weather", + description="Get weather for a city", + parameters={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + function=mock_tool_function, + ) + ] + + +@pytest.fixture +def mock_sdk_response_with_tool_call(): + tool_call = MockToolCall( + call_id="call_123", + name="weather", + arguments='{"city": "Paris"}', + ) + return MockResponse( + choices=[ + MockChoice( + message=MockMessage(content="", tool_calls=[tool_call]), + finish_reason="tool_calls", + index=0, + ) + ], + model="mistral-agent-model", + usage=MockUsage(prompt_tokens=20, completion_tokens=15, total_tokens=35), + ) + + +class TestMistralAgent: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key") + + agent = MistralAgent(agent_id="ag-test-123") + + assert agent.agent_id == "ag-test-123" + assert agent.streaming_callback is None + assert agent.tools is None + assert agent.tool_choice is None + assert agent.parallel_tool_calls is True + assert agent.generation_kwargs == {} + assert agent.timeout_ms == 30000 + + def test_init_with_custom_parameters(self, tools): + def my_callback(chunk): + pass + + agent = MistralAgent( + agent_id="ag-custom-456", + api_key=Secret.from_token("custom-key"), + streaming_callback=my_callback, + tools=tools, + tool_choice="auto", + parallel_tool_calls=False, + generation_kwargs={"max_tokens": 500, "random_seed": 42}, + timeout_ms=60000, + ) + + assert agent.agent_id == "ag-custom-456" + assert agent.streaming_callback is my_callback + assert len(agent.tools) == 1 + assert agent.tool_choice == "auto" + assert agent.parallel_tool_calls is False + assert agent.generation_kwargs == {"max_tokens": 500, "random_seed": 42} + assert agent.timeout_ms == 60000 + + def test_init_with_duplicate_tools_raises_error(self, mock_tool_function): + duplicate_tools = [ + Tool( + name="weather", + description="First", + parameters={}, + function=mock_tool_function, + ), + Tool( + name="weather", + description="Second", + parameters={}, + function=mock_tool_function, + ), + ] + + with pytest.raises(ValueError, match="Duplicate tool names"): + MistralAgent( + agent_id="ag-test", + api_key=Secret.from_token("key"), + tools=duplicate_tools, + ) + + def test_init_with_toolset(self, mock_tool_function, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + tool1 = Tool( + name="tool1", + description="First tool", + parameters={}, + function=mock_tool_function, + ) + tool2 = Tool( + name="tool2", + description="Second tool", + parameters={}, + function=mock_tool_function, + ) + toolset = Toolset([tool1, tool2]) + agent = MistralAgent(agent_id="ag-test", tools=[toolset]) + assert agent.tools == [toolset] + + def test_convert_simple_user_message(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + agent = MistralAgent(agent_id="ag-test") + + messages = [ChatMessage.from_user("Hello")] + result = agent._convert_messages(messages) + + assert len(result) == 1 + assert result[0]["role"] == "user" + assert result[0]["content"] == "Hello" + + def test_convert_system_message(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + agent = MistralAgent(agent_id="ag-test") + + messages = [ChatMessage.from_system("You are helpful")] + result = agent._convert_messages(messages) + + assert len(result) == 1 + assert result[0]["role"] == "system" + assert result[0]["content"] == "You are helpful" + + def test_convert_assistant_message(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + agent = MistralAgent(agent_id="ag-test") + + messages = [ChatMessage.from_assistant("I can help you")] + result = agent._convert_messages(messages) + + assert len(result) == 1 + assert result[0]["role"] == "assistant" + assert result[0]["content"] == "I can help you" + + def test_convert_tool_message(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + agent = MistralAgent(agent_id="ag-test") + + tool_call = ToolCall(id="call_123", tool_name="weather", arguments={"city": "Paris"}) + messages = [ChatMessage.from_tool(tool_result="Sunny, 22°C", origin=tool_call)] + result = agent._convert_messages(messages) + + assert len(result) == 1 + assert result[0]["role"] == "tool" + assert result[0]["content"] == "Sunny, 22°C" + assert result[0]["tool_call_id"] == "call_123" + + @patch("haystack_integrations.components.agents.mistral.agent.MistralAgent.warm_up") + def test_run_basic(self, chat_messages, mock_sdk_response, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + + agent = MistralAgent(agent_id="ag-test-123") + agent._client = MagicMock() + agent._client.agents.complete.return_value = mock_sdk_response + + result = agent.run(chat_messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + + reply = result["replies"][0] + assert reply.text == "The capital of France is Paris." + assert reply.meta["model"] == "mistral-agent-model" + assert reply.meta["finish_reason"] == "stop" + assert reply.meta["usage"]["total_tokens"] == 40 + + @patch("haystack_integrations.components.agents.mistral.agent.MistralAgent.warm_up") + def test_run_empty_messages(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + agent = MistralAgent(agent_id="ag-test") + result = agent.run([]) + assert result == {"replies": []} + + @patch("haystack_integrations.components.agents.mistral.agent.MistralAgent.warm_up") + def test_run_with_tool_calls( + self, + _mock_warm_up, + mock_sdk_response_with_tool_call, + tools, + monkeypatch, + ): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + + agent = MistralAgent(agent_id="ag-test", tools=tools) + agent._client = MagicMock() + agent._client.agents.complete.return_value = mock_sdk_response_with_tool_call + + result = agent.run([ChatMessage.from_user("What's the weather in Paris?")]) + + assert len(result["replies"]) == 1 + reply = result["replies"][0] + + assert reply.tool_calls is not None + assert len(reply.tool_calls) == 1 + assert reply.tool_calls[0].tool_name == "weather" + assert reply.tool_calls[0].arguments == {"city": "Paris"} + assert reply.tool_calls[0].id == "call_123" + assert reply.meta["finish_reason"] == "tool_calls" + + @patch("haystack_integrations.components.agents.mistral.agent.MistralAgent.warm_up") + def test_run_with_generation_kwargs(self, chat_messages, mock_sdk_response, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + + agent = MistralAgent(agent_id="ag-test", generation_kwargs={"max_tokens": 100}) + agent._client = MagicMock() + agent._client.agents.complete.return_value = mock_sdk_response + + # Run with additional kwargs + agent.run(chat_messages, generation_kwargs={"random_seed": 42}) + + # Verify the call was made with merged kwargs + call_kwargs = agent._client.agents.complete.call_args.kwargs + assert call_kwargs["max_tokens"] == 100 + assert call_kwargs["random_seed"] == 42 + + @patch("haystack_integrations.components.agents.mistral.agent.MistralAgent.warm_up") + def test_run_with_tool_choice(self, chat_messages, mock_sdk_response, tools, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + + agent = MistralAgent(agent_id="ag-test", tools=tools, tool_choice="any") + agent._client = MagicMock() + agent._client.agents.complete.return_value = mock_sdk_response + + agent.run(chat_messages) + + call_kwargs = agent._client.agents.complete.call_args.kwargs + assert call_kwargs["tool_choice"] == "any" + assert "tools" in call_kwargs + assert call_kwargs["parallel_tool_calls"] is True + + @patch("haystack_integrations.components.agents.mistral.agent.MistralAgent.warm_up") + def test_run_with_streaming(self, chat_messages, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + collected_chunks = [] + + def callback(chunk: StreamingChunk): + collected_chunks.append(chunk) + + # Create mock streaming response + stream_chunks = [ + MockStreamChunk( + data=MockStreamData( + choices=[MockStreamChoice(delta=MockDelta(content="The "))], + model="agent-model", + ) + ), + MockStreamChunk( + data=MockStreamData( + choices=[MockStreamChoice(delta=MockDelta(content="capital "))], + model="agent-model", + ) + ), + MockStreamChunk( + data=MockStreamData( + choices=[MockStreamChoice(delta=MockDelta(content="is Paris."), finish_reason="stop")], + model="agent-model", + usage=MockUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + ), + ] + + agent = MistralAgent(agent_id="ag-test", streaming_callback=callback) + agent._client = MagicMock() + agent._client.agents.stream.return_value = iter(stream_chunks) + + result = agent.run(chat_messages) + + # Verify streaming callback was called + assert len(collected_chunks) == 3 + assert collected_chunks[0].content == "The " + assert collected_chunks[1].content == "capital " + assert collected_chunks[2].content == "is Paris." + + # Verify final message + assert len(result["replies"]) == 1 + assert result["replies"][0].text == "The capital is Paris." + assert result["replies"][0].meta["finish_reason"] == "stop" + + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + + agent = MistralAgent( + agent_id="ag-test", + tool_choice="auto", + generation_kwargs={"max_tokens": 500}, + ) + data = agent.to_dict() + + assert data["type"] == "haystack_integrations.components.agents.mistral.agent.MistralAgent" + assert data["init_parameters"]["agent_id"] == "ag-test" + assert data["init_parameters"]["tool_choice"] == "auto" + assert data["init_parameters"]["generation_kwargs"] == {"max_tokens": 500} + assert data["init_parameters"]["tools"] is None + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + + data = { + "type": "haystack_integrations.components.agents.mistral.agent.MistralAgent", + "init_parameters": { + "agent_id": "ag-restored", + "api_key": { + "env_vars": ["MISTRAL_API_KEY"], + "strict": True, + "type": "env_var", + }, + "streaming_callback": None, + "tools": None, + "tool_choice": "auto", + "parallel_tool_calls": False, + "generation_kwargs": {"max_tokens": 200}, + "timeout_ms": 45000, + }, + } + + agent = MistralAgent.from_dict(data) + + assert agent.agent_id == "ag-restored" + assert agent.tool_choice == "auto" + assert agent.parallel_tool_calls is False + assert agent.generation_kwargs == {"max_tokens": 200} + assert agent.timeout_ms == 45000 + + def test_pipeline_serialization(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + + pipeline = Pipeline() + pipeline.add_component("agent", MistralAgent(agent_id="ag-pipeline")) + + # Serialize to dict + pipeline_dict = pipeline.to_dict() + + assert "agent" in pipeline_dict["components"] + assert ( + pipeline_dict["components"]["agent"]["type"] + == "haystack_integrations.components.agents.mistral.agent.MistralAgent" + ) + + # Deserialize + restored_pipeline = Pipeline.from_dict(pipeline_dict) + restored_agent = restored_pipeline.get_component("agent") + + assert restored_agent.agent_id == "ag-pipeline" + + def test_pipeline_yaml_roundtrip(self, monkeypatch): + """Test YAML serialization and deserialization.""" + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + + pipeline = Pipeline() + pipeline.add_component( + "agent", + MistralAgent(agent_id="ag-yaml", generation_kwargs={"max_tokens": 100}), + ) + + yaml_str = pipeline.dumps() + restored = Pipeline.loads(yaml_str) + + agent = restored.get_component("agent") + assert agent.agent_id == "ag-yaml" + assert agent.generation_kwargs == {"max_tokens": 100} diff --git a/integrations/mistral/tests/test_mistral_agent_async.py b/integrations/mistral/tests/test_mistral_agent_async.py new file mode 100644 index 0000000000..95dce143ea --- /dev/null +++ b/integrations/mistral/tests/test_mistral_agent_async.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from haystack_integrations.components.agents.mistral.agent import MistralAgent + + +@pytest.mark.asyncio +class TestMistralAgentAsync: + @patch("haystack_integrations.components.agents.mistral.agent.MistralAgent.warm_up") + async def test_run_async_basic(self, _mock_warm_up, chat_messages, mock_sdk_response, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + + agent = MistralAgent(agent_id="ag-test-123") + agent._client = MagicMock() + agent._client.agents.complete_async = AsyncMock(return_value=mock_sdk_response) + + result = await agent.run_async(chat_messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + + reply = result["replies"][0] + assert reply.text == "The capital of France is Paris." + assert reply.meta["model"] == "mistral-agent-model" + assert reply.meta["finish_reason"] == "stop" + assert reply.meta["usage"]["total_tokens"] == 40 + + @patch("haystack_integrations.components.agents.mistral.agent.MistralAgent.warm_up") + async def test_run_async_empty_messages(self, _mock_warm_up, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "test-key") + + agent = MistralAgent(agent_id="ag-test") + result = await agent.run_async([]) + + assert result == {"replies": []}