Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/game/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Private state management for player mindsets and game setup
"""

import asyncio
from functools import partial
from uuid import uuid4

Expand Down Expand Up @@ -200,7 +201,11 @@ def main():
langgraph_config = RunnableConfig(
configurable={"thread_id": initial_state["game_id"]},
)
result = app.invoke(initial_state, config=langgraph_config)

async def _run_workflow():
return await app.ainvoke(initial_state, config=langgraph_config)

result = asyncio.run(_run_workflow())
print(result)


Expand Down
7 changes: 6 additions & 1 deletion src/game/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from __future__ import annotations

import asyncio
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
Expand Down Expand Up @@ -736,7 +737,11 @@ def _run_single_multilingual_game(
}

run_config = RunnableConfig(configurable={"thread_id": game_id})
app.invoke(initial_state, config=run_config)

async def _run_game():
await app.ainvoke(initial_state, config=run_config)

asyncio.run(_run_game())


def run_multilingual_metrics_batch(
Expand Down
12 changes: 6 additions & 6 deletions src/game/nodes/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _create_player_private_state_delta(
}


def player_speech(state: GameState, player_id: str) -> Dict[str, Any]:
async def player_speech(state: GameState, player_id: str) -> Dict[str, Any]:
"""
Player node for generating speech.
Calls LLM to infer identity and generate speech.
Expand All @@ -120,7 +120,7 @@ def player_speech(state: GameState, player_id: str) -> Dict[str, Any]:
existing_player_mindset = get_normalized_player_mindset(existing_private_state)

llm_client = _get_llm_client()
updated_mindset = llm_update_player_mindset(
updated_mindset = await llm_update_player_mindset(
llm_client=llm_client,
my_word=my_word,
completed_speeches=state["completed_speeches"],
Expand All @@ -142,7 +142,7 @@ def player_speech(state: GameState, player_id: str) -> Dict[str, Any]:
speech_plan = None

# Generate speech using LLM
new_speech_text = llm_generate_speech(
new_speech_text = await llm_generate_speech(
llm_client=llm_client,
my_word=my_word,
self_belief=updated_mindset_state.get("self_belief", {}),
Expand Down Expand Up @@ -186,7 +186,7 @@ def player_speech(state: GameState, player_id: str) -> Dict[str, Any]:
}


def player_vote(state: GameState, player_id: str) -> Dict[str, Any]:
async def player_vote(state: GameState, player_id: str) -> Dict[str, Any]:
"""
Player node for casting a vote.
Calls LLM to infer identity and decide vote target.
Expand All @@ -211,7 +211,7 @@ def player_vote(state: GameState, player_id: str) -> Dict[str, Any]:
existing_player_mindset = get_normalized_player_mindset(existing_private_state)

llm_client = _get_llm_client()
updated_mindset = llm_update_player_mindset(
updated_mindset = await llm_update_player_mindset(
llm_client=llm_client,
my_word=my_word,
completed_speeches=state["completed_speeches"],
Expand All @@ -223,7 +223,7 @@ def player_vote(state: GameState, player_id: str) -> Dict[str, Any]:
)
updated_mindset_state = normalize_mindset(updated_mindset)
# Decide on a vote target using the LLM with bound voting tools
voted_target = llm_decide_vote(
voted_target = await llm_decide_vote(
llm_client=llm_client,
state=state,
me=player_id,
Expand Down
31 changes: 24 additions & 7 deletions src/game/strategy/strategy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
for player mindset updates and speech generation.
"""

import asyncio
import inspect
from typing import Any, List, Dict, Sequence, cast
from venv import logger

Expand Down Expand Up @@ -40,6 +42,20 @@
from src.game.strategy.utils.text_utils import sanitize_speech_output


async def _invoke_async(target: Any, *args: Any, **kwargs: Any) -> Any:
"""Awaitably invoke LangChain runnables, falling back to sync methods."""
ainvoke = getattr(target, "ainvoke", None)
if callable(ainvoke):
result = ainvoke(*args, **kwargs)
return await result if inspect.isawaitable(result) else result

invoke = getattr(target, "invoke", None)
if callable(invoke):
return await asyncio.to_thread(invoke, *args, **kwargs)

raise AttributeError(f"Object {target!r} has neither ainvoke nor invoke.")


def _to_mindset_model(
mindset: PlayerMindset | PlayerMindsetModel | None,
) -> PlayerMindsetModel:
Expand All @@ -64,7 +80,7 @@ def _mindset_model_to_state(model: PlayerMindsetModel) -> PlayerMindset:
return cast(PlayerMindset, model.model_dump())


def llm_update_player_mindset(
async def llm_update_player_mindset(
llm_client: Any,
my_word: str,
completed_speeches: Sequence[Speech],
Expand Down Expand Up @@ -127,7 +143,7 @@ def llm_update_player_mindset(
SystemMessage(content=system_prompt),
HumanMessage(content=user_context),
]
result = agent.invoke({"messages": messages})
result = await _invoke_async(agent, {"messages": messages})

# Extract structured response from agent result
structured = result.get("structured_response")
Expand All @@ -154,7 +170,7 @@ def llm_update_player_mindset(
return existing_state


def llm_generate_speech(
async def llm_generate_speech(
llm_client: Any,
my_word: str,
self_belief: SelfBelief,
Expand Down Expand Up @@ -197,7 +213,7 @@ def llm_generate_speech(
HumanMessage(content=user_context),
]

response = llm_client.invoke(messages)
response = await _invoke_async(llm_client, messages)

raw_text = response.content if hasattr(response, "content") else response
return sanitize_speech_output(raw_text)
Expand All @@ -223,7 +239,7 @@ def plan_player_speech(
return planner.func()


def llm_decide_vote(
async def llm_decide_vote(
llm_client: Any,
state: GameState,
me: str,
Expand Down Expand Up @@ -274,13 +290,14 @@ def llm_decide_vote(
)

try:
result = agent.invoke(
result = await _invoke_async(
agent,
{
"messages": [
SystemMessage(content=system_prompt),
HumanMessage(content=vote_context),
]
}
},
)
structured = result.get("structured_response")
if structured:
Expand Down
31 changes: 19 additions & 12 deletions tests/test_llm_strategy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from typing import Dict
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

from src.game.strategy import llm_update_player_mindset
from src.game.strategy.builders.prompt_builder import (
Expand Down Expand Up @@ -206,39 +207,45 @@ def test_llm_update_player_mindset_success():
"""Tests successful belief inference with structured output."""
# Mock the agent's invoke method to return structured response
mock_agent = MagicMock()
mock_agent.invoke.return_value = {
"structured_response": make_player_mindset(
self_belief=make_self_belief("civilian", 0.9),
suspicions={"b": make_suspicion("spy", 0.7, "Suspicious speech")},
)
}
mock_agent.ainvoke = AsyncMock(
return_value={
"structured_response": make_player_mindset(
self_belief=make_self_belief("civilian", 0.9),
suspicions={"b": make_suspicion("spy", 0.7, "Suspicious speech")},
)
}
)

# Mock create_agent to return our mock agent
with patch("src.game.strategy.strategy_core.create_agent", return_value=mock_agent):
mock_llm = MagicMock()
test_state = mock_state_inference_en.copy()
result = llm_update_player_mindset(llm_client=mock_llm, **test_state)
result = asyncio.run(
llm_update_player_mindset(llm_client=mock_llm, **test_state)
)

assert result["self_belief"]["role"] == "civilian"
assert result["suspicions"]["b"]["reason"] == "Suspicious speech"
mock_agent.invoke.assert_called_once()
mock_agent.ainvoke.assert_awaited_once()


def test_llm_update_player_mindset_failure():
"""Tests fallback behavior when structured output extraction fails for inference."""
# Mock the agent's invoke method to return None (simulating failure)
mock_agent = MagicMock()
mock_agent.invoke.return_value = {"structured_response": None}
mock_agent.ainvoke = AsyncMock(return_value={"structured_response": None})

# Mock create_agent to return our mock agent
with patch("src.game.strategy.strategy_core.create_agent", return_value=mock_agent):
mock_llm = MagicMock()
test_state = mock_state_inference_en.copy()
result = llm_update_player_mindset(llm_client=mock_llm, **test_state)
result = asyncio.run(
llm_update_player_mindset(llm_client=mock_llm, **test_state)
)

assert result["self_belief"]["role"] == "civilian"
assert (
result["self_belief"]["confidence"]
== mock_player_mindset["self_belief"]["confidence"]
)
mock_agent.invoke.assert_called_once()
mock_agent.ainvoke.assert_awaited_once()
31 changes: 16 additions & 15 deletions tests/test_player_nodes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
from typing import Dict

import pytest
from unittest.mock import patch, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
from src.game.nodes.player import player_speech, player_vote
from src.game.state import (
GameState,
Expand Down Expand Up @@ -102,8 +103,8 @@ def base_player_state(player_id):


@patch("src.game.nodes.player._get_llm_client")
@patch("src.game.nodes.player.llm_generate_speech")
@patch("src.game.nodes.player.llm_update_player_mindset")
@patch("src.game.nodes.player.llm_generate_speech", new_callable=AsyncMock)
@patch("src.game.nodes.player.llm_update_player_mindset", new_callable=AsyncMock)
def test_player_speech(
mock_infer, mock_speech, mock_get_llm, player_id, base_player_state: GameState
):
Expand All @@ -119,7 +120,7 @@ def test_player_speech(
mock_speech.return_value = "This is a test speech."

# Act: Call the player_speech node
update = player_speech(base_player_state, player_id)
update = asyncio.run(player_speech(base_player_state, player_id))

# Assert: Verify the output is correct
assert "completed_speeches" in update
Expand All @@ -134,13 +135,13 @@ def test_player_speech(

# Verify mocks were called correctly
mock_get_llm.assert_called_once()
mock_infer.assert_called_once()
mock_speech.assert_called_once()
mock_infer.assert_awaited_once()
mock_speech.assert_awaited_once()


@patch("src.game.nodes.player._get_llm_client")
@patch("src.game.nodes.player.llm_update_player_mindset")
@patch("src.game.nodes.player.llm_decide_vote")
@patch("src.game.nodes.player.llm_update_player_mindset", new_callable=AsyncMock)
@patch("src.game.nodes.player.llm_decide_vote", new_callable=AsyncMock)
def test_player_vote(
mock_decide_vote, mock_infer, mock_get_llm, player_id, base_player_state: GameState
):
Expand All @@ -162,7 +163,7 @@ def test_player_vote(
}

# Act: Call the player_vote node
update = player_vote(voting_state, player_id)
update = asyncio.run(player_vote(voting_state, player_id))

# Assert: Verify the output
assert "current_votes" in update
Expand All @@ -176,8 +177,8 @@ def test_player_vote(

# Verify mocks
mock_get_llm.assert_called_once()
mock_infer.assert_called_once()
mock_decide_vote.assert_called_once_with(
mock_infer.assert_awaited_once()
mock_decide_vote.assert_awaited_once_with(
llm_client=mock_llm_client,
state=voting_state,
me=player_id,
Expand All @@ -189,21 +190,21 @@ def test_player_vote(
def test_player_speech_not_in_speaking_phase(base_player_state: GameState):
"""Tests that player_speech returns empty dict if not in speaking phase."""
state = base_player_state | {"game_phase": "voting"}
update = player_speech(state, "a")
update = asyncio.run(player_speech(state, "a"))
assert update == {}


def test_player_vote_not_in_voting_phase(base_player_state: GameState):
"""Tests that player_vote returns empty dict if not in voting phase."""
state = base_player_state | {"game_phase": "speaking"}
update = player_vote(state, "a")
update = asyncio.run(player_vote(state, "a"))
assert update == {}


def test_player_node_for_eliminated_player(base_player_state: GameState):
"""Tests that nodes do nothing for an eliminated player."""
state = base_player_state | {"eliminated_players": ["a"]}
speech_update = player_speech(state, "a")
vote_update = player_vote(state | {"game_phase": "voting"}, "a")
speech_update = asyncio.run(player_speech(state, "a"))
vote_update = asyncio.run(player_vote(state | {"game_phase": "voting"}, "a"))
assert speech_update == {}
assert vote_update == {}
Loading