diff --git a/src/game/graph.py b/src/game/graph.py index d4f6666..e9d3cd0 100644 --- a/src/game/graph.py +++ b/src/game/graph.py @@ -17,6 +17,7 @@ - Private state management for player mindsets and game setup """ +import asyncio from functools import partial from uuid import uuid4 @@ -200,7 +201,11 @@ def main(): langgraph_config = RunnableConfig( configurable={"thread_id": initial_state["game_id"]}, ) - result = app.invoke(initial_state, config=langgraph_config) + + async def _run_workflow(): + return await app.ainvoke(initial_state, config=langgraph_config) + + result = asyncio.run(_run_workflow()) print(result) diff --git a/src/game/metrics.py b/src/game/metrics.py index 6471581..aeaa54a 100644 --- a/src/game/metrics.py +++ b/src/game/metrics.py @@ -15,6 +15,7 @@ from __future__ import annotations +import asyncio from collections import Counter, defaultdict from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass @@ -736,7 +737,11 @@ def _run_single_multilingual_game( } run_config = RunnableConfig(configurable={"thread_id": game_id}) - app.invoke(initial_state, config=run_config) + + async def _run_game(): + await app.ainvoke(initial_state, config=run_config) + + asyncio.run(_run_game()) def run_multilingual_metrics_batch( diff --git a/src/game/nodes/player.py b/src/game/nodes/player.py index a9d66c9..3874b23 100644 --- a/src/game/nodes/player.py +++ b/src/game/nodes/player.py @@ -95,7 +95,7 @@ def _create_player_private_state_delta( } -def player_speech(state: GameState, player_id: str) -> Dict[str, Any]: +async def player_speech(state: GameState, player_id: str) -> Dict[str, Any]: """ Player node for generating speech. Calls LLM to infer identity and generate speech. @@ -120,7 +120,7 @@ def player_speech(state: GameState, player_id: str) -> Dict[str, Any]: existing_player_mindset = get_normalized_player_mindset(existing_private_state) llm_client = _get_llm_client() - updated_mindset = llm_update_player_mindset( + updated_mindset = await llm_update_player_mindset( llm_client=llm_client, my_word=my_word, completed_speeches=state["completed_speeches"], @@ -142,7 +142,7 @@ def player_speech(state: GameState, player_id: str) -> Dict[str, Any]: speech_plan = None # Generate speech using LLM - new_speech_text = llm_generate_speech( + new_speech_text = await llm_generate_speech( llm_client=llm_client, my_word=my_word, self_belief=updated_mindset_state.get("self_belief", {}), @@ -186,7 +186,7 @@ def player_speech(state: GameState, player_id: str) -> Dict[str, Any]: } -def player_vote(state: GameState, player_id: str) -> Dict[str, Any]: +async def player_vote(state: GameState, player_id: str) -> Dict[str, Any]: """ Player node for casting a vote. Calls LLM to infer identity and decide vote target. @@ -211,7 +211,7 @@ def player_vote(state: GameState, player_id: str) -> Dict[str, Any]: existing_player_mindset = get_normalized_player_mindset(existing_private_state) llm_client = _get_llm_client() - updated_mindset = llm_update_player_mindset( + updated_mindset = await llm_update_player_mindset( llm_client=llm_client, my_word=my_word, completed_speeches=state["completed_speeches"], @@ -223,7 +223,7 @@ def player_vote(state: GameState, player_id: str) -> Dict[str, Any]: ) updated_mindset_state = normalize_mindset(updated_mindset) # Decide on a vote target using the LLM with bound voting tools - voted_target = llm_decide_vote( + voted_target = await llm_decide_vote( llm_client=llm_client, state=state, me=player_id, diff --git a/src/game/strategy/strategy_core.py b/src/game/strategy/strategy_core.py index dd589d1..2dad07d 100644 --- a/src/game/strategy/strategy_core.py +++ b/src/game/strategy/strategy_core.py @@ -5,6 +5,8 @@ for player mindset updates and speech generation. """ +import asyncio +import inspect from typing import Any, List, Dict, Sequence, cast from venv import logger @@ -40,6 +42,20 @@ from src.game.strategy.utils.text_utils import sanitize_speech_output +async def _invoke_async(target: Any, *args: Any, **kwargs: Any) -> Any: + """Awaitably invoke LangChain runnables, falling back to sync methods.""" + ainvoke = getattr(target, "ainvoke", None) + if callable(ainvoke): + result = ainvoke(*args, **kwargs) + return await result if inspect.isawaitable(result) else result + + invoke = getattr(target, "invoke", None) + if callable(invoke): + return await asyncio.to_thread(invoke, *args, **kwargs) + + raise AttributeError(f"Object {target!r} has neither ainvoke nor invoke.") + + def _to_mindset_model( mindset: PlayerMindset | PlayerMindsetModel | None, ) -> PlayerMindsetModel: @@ -64,7 +80,7 @@ def _mindset_model_to_state(model: PlayerMindsetModel) -> PlayerMindset: return cast(PlayerMindset, model.model_dump()) -def llm_update_player_mindset( +async def llm_update_player_mindset( llm_client: Any, my_word: str, completed_speeches: Sequence[Speech], @@ -127,7 +143,7 @@ def llm_update_player_mindset( SystemMessage(content=system_prompt), HumanMessage(content=user_context), ] - result = agent.invoke({"messages": messages}) + result = await _invoke_async(agent, {"messages": messages}) # Extract structured response from agent result structured = result.get("structured_response") @@ -154,7 +170,7 @@ def llm_update_player_mindset( return existing_state -def llm_generate_speech( +async def llm_generate_speech( llm_client: Any, my_word: str, self_belief: SelfBelief, @@ -197,7 +213,7 @@ def llm_generate_speech( HumanMessage(content=user_context), ] - response = llm_client.invoke(messages) + response = await _invoke_async(llm_client, messages) raw_text = response.content if hasattr(response, "content") else response return sanitize_speech_output(raw_text) @@ -223,7 +239,7 @@ def plan_player_speech( return planner.func() -def llm_decide_vote( +async def llm_decide_vote( llm_client: Any, state: GameState, me: str, @@ -274,13 +290,14 @@ def llm_decide_vote( ) try: - result = agent.invoke( + result = await _invoke_async( + agent, { "messages": [ SystemMessage(content=system_prompt), HumanMessage(content=vote_context), ] - } + }, ) structured = result.get("structured_response") if structured: diff --git a/tests/test_llm_strategy.py b/tests/test_llm_strategy.py index b4bf469..0450139 100644 --- a/tests/test_llm_strategy.py +++ b/tests/test_llm_strategy.py @@ -1,5 +1,6 @@ +import asyncio from typing import Dict -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from src.game.strategy import llm_update_player_mindset from src.game.strategy.builders.prompt_builder import ( @@ -206,39 +207,45 @@ def test_llm_update_player_mindset_success(): """Tests successful belief inference with structured output.""" # Mock the agent's invoke method to return structured response mock_agent = MagicMock() - mock_agent.invoke.return_value = { - "structured_response": make_player_mindset( - self_belief=make_self_belief("civilian", 0.9), - suspicions={"b": make_suspicion("spy", 0.7, "Suspicious speech")}, - ) - } + mock_agent.ainvoke = AsyncMock( + return_value={ + "structured_response": make_player_mindset( + self_belief=make_self_belief("civilian", 0.9), + suspicions={"b": make_suspicion("spy", 0.7, "Suspicious speech")}, + ) + } + ) # Mock create_agent to return our mock agent with patch("src.game.strategy.strategy_core.create_agent", return_value=mock_agent): mock_llm = MagicMock() test_state = mock_state_inference_en.copy() - result = llm_update_player_mindset(llm_client=mock_llm, **test_state) + result = asyncio.run( + llm_update_player_mindset(llm_client=mock_llm, **test_state) + ) assert result["self_belief"]["role"] == "civilian" assert result["suspicions"]["b"]["reason"] == "Suspicious speech" - mock_agent.invoke.assert_called_once() + mock_agent.ainvoke.assert_awaited_once() def test_llm_update_player_mindset_failure(): """Tests fallback behavior when structured output extraction fails for inference.""" # Mock the agent's invoke method to return None (simulating failure) mock_agent = MagicMock() - mock_agent.invoke.return_value = {"structured_response": None} + mock_agent.ainvoke = AsyncMock(return_value={"structured_response": None}) # Mock create_agent to return our mock agent with patch("src.game.strategy.strategy_core.create_agent", return_value=mock_agent): mock_llm = MagicMock() test_state = mock_state_inference_en.copy() - result = llm_update_player_mindset(llm_client=mock_llm, **test_state) + result = asyncio.run( + llm_update_player_mindset(llm_client=mock_llm, **test_state) + ) assert result["self_belief"]["role"] == "civilian" assert ( result["self_belief"]["confidence"] == mock_player_mindset["self_belief"]["confidence"] ) - mock_agent.invoke.assert_called_once() + mock_agent.ainvoke.assert_awaited_once() diff --git a/tests/test_player_nodes.py b/tests/test_player_nodes.py index b022ca4..307a6d2 100644 --- a/tests/test_player_nodes.py +++ b/tests/test_player_nodes.py @@ -1,7 +1,8 @@ +import asyncio from typing import Dict import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch from src.game.nodes.player import player_speech, player_vote from src.game.state import ( GameState, @@ -102,8 +103,8 @@ def base_player_state(player_id): @patch("src.game.nodes.player._get_llm_client") -@patch("src.game.nodes.player.llm_generate_speech") -@patch("src.game.nodes.player.llm_update_player_mindset") +@patch("src.game.nodes.player.llm_generate_speech", new_callable=AsyncMock) +@patch("src.game.nodes.player.llm_update_player_mindset", new_callable=AsyncMock) def test_player_speech( mock_infer, mock_speech, mock_get_llm, player_id, base_player_state: GameState ): @@ -119,7 +120,7 @@ def test_player_speech( mock_speech.return_value = "This is a test speech." # Act: Call the player_speech node - update = player_speech(base_player_state, player_id) + update = asyncio.run(player_speech(base_player_state, player_id)) # Assert: Verify the output is correct assert "completed_speeches" in update @@ -134,13 +135,13 @@ def test_player_speech( # Verify mocks were called correctly mock_get_llm.assert_called_once() - mock_infer.assert_called_once() - mock_speech.assert_called_once() + mock_infer.assert_awaited_once() + mock_speech.assert_awaited_once() @patch("src.game.nodes.player._get_llm_client") -@patch("src.game.nodes.player.llm_update_player_mindset") -@patch("src.game.nodes.player.llm_decide_vote") +@patch("src.game.nodes.player.llm_update_player_mindset", new_callable=AsyncMock) +@patch("src.game.nodes.player.llm_decide_vote", new_callable=AsyncMock) def test_player_vote( mock_decide_vote, mock_infer, mock_get_llm, player_id, base_player_state: GameState ): @@ -162,7 +163,7 @@ def test_player_vote( } # Act: Call the player_vote node - update = player_vote(voting_state, player_id) + update = asyncio.run(player_vote(voting_state, player_id)) # Assert: Verify the output assert "current_votes" in update @@ -176,8 +177,8 @@ def test_player_vote( # Verify mocks mock_get_llm.assert_called_once() - mock_infer.assert_called_once() - mock_decide_vote.assert_called_once_with( + mock_infer.assert_awaited_once() + mock_decide_vote.assert_awaited_once_with( llm_client=mock_llm_client, state=voting_state, me=player_id, @@ -189,21 +190,21 @@ def test_player_vote( def test_player_speech_not_in_speaking_phase(base_player_state: GameState): """Tests that player_speech returns empty dict if not in speaking phase.""" state = base_player_state | {"game_phase": "voting"} - update = player_speech(state, "a") + update = asyncio.run(player_speech(state, "a")) assert update == {} def test_player_vote_not_in_voting_phase(base_player_state: GameState): """Tests that player_vote returns empty dict if not in voting phase.""" state = base_player_state | {"game_phase": "speaking"} - update = player_vote(state, "a") + update = asyncio.run(player_vote(state, "a")) assert update == {} def test_player_node_for_eliminated_player(base_player_state: GameState): """Tests that nodes do nothing for an eliminated player.""" state = base_player_state | {"eliminated_players": ["a"]} - speech_update = player_speech(state, "a") - vote_update = player_vote(state | {"game_phase": "voting"}, "a") + speech_update = asyncio.run(player_speech(state, "a")) + vote_update = asyncio.run(player_vote(state | {"game_phase": "voting"}, "a")) assert speech_update == {} assert vote_update == {}