diff --git a/CLAUDE.md b/CLAUDE.md index b68518f..433997d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -155,13 +155,16 @@ START → host_setup → host_stage_switch **Quality Scoring**: ```python -from src.game.metrics import metrics_collector +from src.game.dependencies import build_dependencies + +deps = build_dependencies() +collector = deps.metrics # Get quality score -deterministic_score = metrics_collector.compute_quality_score() +deterministic_score = collector.compute_quality_score() # Or use LLM-based evaluation -llm_score = metrics_collector.compute_quality_score(method="llm", llm=client) +llm_score = collector.compute_quality_score(method="llm", llm=client) ``` **Metrics History**: Track prompts and configurations in `docs/metrics-history.md` @@ -194,8 +197,8 @@ llm_score = metrics_collector.compute_quality_score(method="llm", llm=client) ### Working with Player-Specific Hooks (callbacks) for Metrics When implementing player-specific behaviors that need to track metrics per player: -- Use the `metrics_collector.on_player_speech(player_name, is_spy, round_num, speech)` hook within player speech nodes to collect speech diversity metrics -- Use the `metrics_collector.on_vote_cast()` hook in player vote nodes to collect voting pattern data. +- Access the injected collector via the dependencies bundle (e.g., the `metrics` argument supplied to LangGraph nodes). +- Use the `metrics.on_player_speech()` and related hooks within player speech/vote nodes to collect lexical diversity and voting pattern data. - Metrics collection respects the `metrics.enabled` flag in `config.yaml` and will be no-ops when metrics are disabled. ## LangGraph Development Notes @@ -206,4 +209,4 @@ When implementing player-specific behaviors that need to track metrics per playe **Error Handling**: LangGraph nodes should handle exceptions gracefully to prevent workflow crashes -**See**: [ARCHITECTURE.md](ARCHITECTURE.md) for detailed system design and [README.md](README.md) for project overview \ No newline at end of file +**See**: [ARCHITECTURE.md](ARCHITECTURE.md) for detailed system design and [README.md](README.md) for project overview diff --git a/README.md b/README.md index 180dc0d..32d74d8 100644 --- a/README.md +++ b/README.md @@ -174,14 +174,18 @@ Metrics are streamed to memory during play and automatically persisted when a ga - Per-game summaries: `logs/metrics/{game_id}.json` - Rolling aggregate + functional quality score: `logs/metrics/overall.json` -You can also access the live collector from code: +You can also access the live collector from code by building a dependency bundle +for each game instance: ```python -from src.game.metrics import metrics_collector +from src.game.dependencies import build_dependencies -audit = metrics_collector.get_overall_metrics() -score = metrics_collector.compute_quality_score() # deterministic -# metrics_collector.compute_quality_score(method="llm", llm=client) for LLM-based review +deps = build_dependencies() +collector = deps.metrics + +audit = collector.get_overall_metrics() +score = collector.compute_quality_score() # deterministic +# collector.compute_quality_score(method="llm", llm=client) for LLM-based review ``` These outputs are ready to feed into downstream prompt-evaluation or offline analysis pipelines. diff --git a/README_zh.md b/README_zh.md index 60c7318..62f271a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -175,14 +175,17 @@ game: - 单局摘要:`logs/metrics/{game_id}.json` - 全局聚合与函数版总分:`logs/metrics/overall.json` -在代码中可直接访问实时指标: +在代码中可通过依赖容器访问实时指标: ```python -from src.game.metrics import metrics_collector +from src.game.dependencies import build_dependencies -report = metrics_collector.get_overall_metrics() -score = metrics_collector.compute_quality_score() # 函数评分 -# metrics_collector.compute_quality_score(method="llm", llm=client) 可获取 LLM 评价 +deps = build_dependencies() +collector = deps.metrics + +report = collector.get_overall_metrics() +score = collector.compute_quality_score() # 函数评分 +# collector.compute_quality_score(method="llm", llm=client) 可获取 LLM 评价 ``` 这些数据可作为后续提示词评估或离线分析的直接输入。 diff --git a/src/game/config.py b/src/game/config.py index 7ec0f70..e27e200 100644 --- a/src/game/config.py +++ b/src/game/config.py @@ -8,7 +8,7 @@ Configuration precedence: 1. Built-in defaults defined in ``DEFAULT_CONFIG``. -2. Values provided in ``config.yaml`` (or a custom path passed to ``get_config``), +2. Values provided in ``config.yaml`` (or a custom path passed to ``load_config``), merged over the defaults. 3. Pydantic model defaults for any fields still unset after the merge. """ @@ -262,46 +262,33 @@ def validate_config(self) -> bool: return False -# Global configuration instance -_config_instance: GameConfig | None = None logger = get_logger(__name__) -def get_config(config_path: str | Path | None = None) -> GameConfig: - """ - Get the global configuration instance. +def default_config_path() -> Path: + """Return the default config file location inside the repository.""" + return Path(__file__).resolve().parents[2] / "config.yaml" - Args: - config_path: Path to configuration file. If None, uses default location. - Returns: - GameConfig instance +def load_config(config_path: str | Path | None = None) -> GameConfig: """ - global _config_instance - - if _config_instance is None: - if config_path is None: - project_root = Path(__file__).resolve().parents[2] - config_path = project_root / "config.yaml" + Build a new GameConfig instance from the provided path. - _config_instance = GameConfig(config_path) - - return _config_instance + Args: + config_path: Optional override path. When omitted, uses ``config.yaml`` at + the project root. + """ + resolved_path = ( + Path(config_path).expanduser() if config_path else default_config_path() + ) + return GameConfig(resolved_path) def reload_config(config_path: str | Path | None = None) -> GameConfig: """ - Reload the configuration from file. - - Args: - config_path: Path to configuration file. If None, uses default location. - - Returns: - GameConfig instance + Compatibility shim for legacy callers. Returns a freshly loaded config. """ - global _config_instance - _config_instance = None - return get_config(config_path) + return load_config(config_path) def calculate_spy_count(total_players: int) -> int: diff --git a/src/game/dependencies.py b/src/game/dependencies.py new file mode 100644 index 0000000..fe44bff --- /dev/null +++ b/src/game/dependencies.py @@ -0,0 +1,44 @@ +""" +Lightweight dependency container for wiring runtime services into the game. + +Instead of relying on module-level singletons (e.g., global config instances or +metrics collectors), we bundle the required collaborators into a simple data +class and pass them explicitly where needed. This makes it trivial to spin up +multiple, isolated games for tests or concurrent executions. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from .config import GameConfig, load_config +from .metrics import GameMetrics + + +@dataclass(slots=True) +class GameDependencies: + """Container object that holds the runtime services a game needs.""" + + config: GameConfig + metrics: GameMetrics + + +def build_dependencies( + *, + config: GameConfig | None = None, + metrics: GameMetrics | None = None, + config_path: str | Path | None = None, +) -> GameDependencies: + """ + Construct a ``GameDependencies`` instance. + + Args: + config: Optional pre-built ``GameConfig``. + metrics: Optional ``GameMetrics`` instance (useful for sharing collectors). + config_path: Optional config path when ``config`` is not supplied. + """ + cfg = config or load_config(config_path) + collector = metrics or GameMetrics() + return GameDependencies(config=cfg, metrics=collector) diff --git a/src/game/graph.py b/src/game/graph.py index 1c12ad6..ef7844f 100644 --- a/src/game/graph.py +++ b/src/game/graph.py @@ -32,11 +32,26 @@ from src.game.nodes.transition import check_votes_and_transition from src.game.state import GameState, votes_ready, next_alive_player from src.tools import save_graph_image -from src.game.config import get_config +from src.game.dependencies import GameDependencies, build_dependencies logger = get_logger(__name__) +def _resolve_dependencies( + *, + dependencies: GameDependencies | None = None, + config=None, + metrics=None, +) -> GameDependencies: + if dependencies is not None and (config is not None or metrics is not None): + raise ValueError( + "Provide either `dependencies` or individual `config`/`metrics`, not both." + ) + if dependencies is not None: + return dependencies + return build_dependencies(config=config, metrics=metrics) + + def route_from_stage(state: GameState) -> list[str] | str: """Route to appropriate nodes based on current game phase. @@ -84,7 +99,14 @@ def should_continue(state: GameState) -> str: return "end" if state.get("winner") else "continue" -def build_workflow_with_players(players: list[str], *, checkpointer=None): +def build_workflow_with_players( + players: list[str], + *, + dependencies: GameDependencies | None = None, + config=None, + metrics=None, + checkpointer=None, +): """Build the complete LangGraph workflow for a specific set of players. This function constructs the entire state machine with all nodes and edges @@ -102,20 +124,46 @@ def build_workflow_with_players(players: list[str], *, checkpointer=None): - Player nodes: speech and vote nodes for each player - Transition nodes: vote counting and phase transitions """ + deps = _resolve_dependencies( + dependencies=dependencies, + config=config, + metrics=metrics, + ) + cfg = deps.config + collector = deps.metrics + workflow = StateGraph(GameState) # Register nodes - workflow.add_node("host_setup", host_setup) + workflow.add_node( + "host_setup", partial(host_setup, game_config=cfg, metrics=collector) + ) workflow.add_node( "host_stage_switch", host_stage_switch ) # Responsible for writing phase/next_* pointers - workflow.add_node("host_result", host_result) + workflow.add_node("host_result", partial(host_result, metrics=collector)) workflow.add_node("check_votes_and_transition", check_votes_and_transition) for pid in players: - workflow.add_node(f"player_speech_{pid}", partial(player_speech, player_id=pid)) - workflow.add_node(f"player_vote_{pid}", partial(player_vote, player_id=pid)) + workflow.add_node( + f"player_speech_{pid}", + partial( + player_speech, + player_id=pid, + game_config=cfg, + metrics=collector, + ), + ) + workflow.add_node( + f"player_vote_{pid}", + partial( + player_vote, + player_id=pid, + game_config=cfg, + metrics=collector, + ), + ) # Basic skeleton workflow.add_edge(START, "host_setup") @@ -161,27 +209,40 @@ def build_workflow_with_players(players: list[str], *, checkpointer=None): return app -def build_workflow(config=None): +def build_workflow( + *, + dependencies: GameDependencies | None = None, + config=None, + metrics=None, +): """Build workflow for LangGraph Server - accepts RunnableConfig parameter. For LangGraph Server, we build a workflow using the player count from config.yaml. The frontend will get the actual player list from the game state. """ - # Load configuration to get the configured player count - game_config = get_config() + deps = _resolve_dependencies( + dependencies=dependencies, + config=config, + metrics=metrics, + ) + game_config = deps.config # Generate player names based on configuration players = game_config.generate_player_names() logger.info("Building workflow with %d players: %s", len(players), players) - return build_workflow_with_players(players) + return build_workflow_with_players( + players, + dependencies=deps, + ) def main(): """Main execution function using configuration.""" # Load configuration - config = get_config() + deps = build_dependencies() + config = deps.config # Generate player names based on configuration players = config.generate_player_names() @@ -192,7 +253,7 @@ def main(): logger.info("Vocabulary pairs: %d", len(config.vocabulary)) # Build and run the workflow - app = build_workflow_with_players(players) + app = build_workflow_with_players(players, dependencies=deps) save_graph_image(app, filename="artifacts/agent_with_router.png") initial_state = { diff --git a/src/game/metrics.py b/src/game/metrics.py index c3b8f26..14e2eff 100644 --- a/src/game/metrics.py +++ b/src/game/metrics.py @@ -24,7 +24,6 @@ import json import re -import sys from argparse import ArgumentParser from pathlib import Path from threading import Lock @@ -33,14 +32,6 @@ from .state import PlayerMindset -if __name__ == "__main__": - # When executed as a script (e.g. `python -m src.game.metrics`), ensure - # the canonical package-qualified module name shares the same module - # instance so that imports like `from ..metrics import metrics_collector` - # reference the identical singleton. - sys.modules.setdefault("src.game.metrics", sys.modules[__name__]) - - def _safe_mean(values: Iterable[Optional[float]]) -> Optional[float]: """Return the mean of non-null values or None if nothing is available.""" filtered = [v for v in values if v is not None] @@ -693,8 +684,6 @@ def aggregate_from_summaries( return {"metrics": metrics, "quality_score": quality_score} -# Global collector used by the rest of the codebase. -metrics_collector = GameMetrics() logger = get_logger(__name__) MULTILINGUAL_VOCABULARY_BATCH: List[tuple[str, tuple[str, str]]] = [ @@ -712,6 +701,8 @@ def _run_single_multilingual_game( language_tag: str, civilian_word: str, spy_word: str, + collector: "GameMetrics", + config: "GameConfig", ) -> None: """Execute a single multilingual game with the provided vocabulary pair.""" @@ -721,7 +712,11 @@ def _run_single_multilingual_game( from .graph import build_workflow_with_players player_list = list(players) - app = build_workflow_with_players(player_list) + app = build_workflow_with_players( + player_list, + config=config, + metrics=collector, + ) game_id = f"metrics-{idx}-{language_tag}" logger.info( @@ -751,7 +746,11 @@ async def _run_game(): def run_multilingual_metrics_batch( - *, concurrent: bool = False, max_workers: Optional[int] = None + *, + concurrent: bool = False, + max_workers: Optional[int] = None, + config: "GameConfig | None" = None, + metrics: "GameMetrics | None" = None, ) -> Dict[str, Any]: """Run five games over distinct vocabulary pairs and log resulting metrics. @@ -761,14 +760,16 @@ def run_multilingual_metrics_batch( running concurrently. Defaults to the number of vocabulary pairs. """ - from .config import get_config + from .config import load_config + + collector = metrics or GameMetrics() - previous_state = metrics_collector.enabled - metrics_collector.set_enabled(True) - metrics_collector.reset() + previous_state = collector.enabled + collector.set_enabled(True) + collector.reset() - config = get_config() - players = tuple(config.generate_player_names()) + cfg = config or load_config() + players = tuple(cfg.generate_player_names()) batch = list(MULTILINGUAL_VOCABULARY_BATCH) if concurrent: @@ -784,6 +785,8 @@ def run_multilingual_metrics_batch( language_tag, civilian_word, spy_word, + collector, + cfg, ) for idx, (language_tag, (civilian_word, spy_word)) in enumerate( batch, start=1 @@ -794,11 +797,17 @@ def run_multilingual_metrics_batch( else: for idx, (language_tag, (civilian_word, spy_word)) in enumerate(batch, start=1): _run_single_multilingual_game( - players, idx, language_tag, civilian_word, spy_word + players, + idx, + language_tag, + civilian_word, + spy_word, + collector, + cfg, ) - overall_metrics = metrics_collector.get_overall_metrics() - quality_score = metrics_collector.compute_quality_score() + overall_metrics = collector.get_overall_metrics() + quality_score = collector.compute_quality_score() logger.info( "Overall metrics:\n%s", @@ -809,7 +818,7 @@ def run_multilingual_metrics_batch( ) result = {"metrics": overall_metrics, "quality_score": quality_score} - metrics_collector.set_enabled(previous_state) + collector.set_enabled(previous_state) return result @@ -943,7 +952,6 @@ def main(): __all__ = [ "GameMetrics", - "metrics_collector", "MULTILINGUAL_VOCABULARY_BATCH", "run_multilingual_metrics_batch", "load_saved_game_summaries", diff --git a/src/game/nodes/host.py b/src/game/nodes/host.py index 243b3ad..eff95b6 100644 --- a/src/game/nodes/host.py +++ b/src/game/nodes/host.py @@ -1,7 +1,5 @@ -from typing import Dict, Any, cast +from typing import Dict, Any, cast, TYPE_CHECKING -from ..config import get_config -from ..metrics import metrics_collector from ..state import GameState, next_alive_player, generate_phase_id from ..rules import ( assign_roles_and_words, @@ -11,25 +9,32 @@ from ..logger import get_logger from .helpers import get_assigned_word +if TYPE_CHECKING: + from ..config import GameConfig + from ..metrics import GameMetrics + logger = get_logger(__name__) -def host_setup(state: GameState) -> Dict[str, Any]: +def host_setup( + state: GameState, *, game_config: "GameConfig", metrics: "GameMetrics" +) -> Dict[str, Any]: """Initializes the game, assigning roles and words.""" - config = get_config() - desired_state = config.metrics_enabled or metrics_collector.enabled - metrics_collector.set_enabled(desired_state) + desired_state = game_config.metrics_enabled or metrics.enabled + metrics.set_enabled(desired_state) player_list = state.get("players") if not player_list: - player_list = config.generate_player_names() + player_list = game_config.generate_player_names() # Pass the existing host_private_state to assign_roles_and_words # This allows custom words from frontend to be used if provided host_private_state = state.get("host_private_state", {}) assignments = assign_roles_and_words( - player_list, host_private_state=host_private_state + player_list, + word_list=game_config.vocabulary, + host_private_state=host_private_state, ) logger.info("Host initializing game with %d players", len(player_list)) @@ -38,8 +43,8 @@ def host_setup(state: GameState) -> Dict[str, Any]: assigned_word = get_assigned_word(private_state) logger.debug("Assigned word for %s: %s", player_id, assigned_word) - if metrics_collector.enabled: - metrics_collector.on_game_start( + if metrics.enabled: + metrics.on_game_start( game_id=state.get("game_id"), players=player_list, player_roles=assignments["host_private_state"]["player_roles"], @@ -86,7 +91,7 @@ def host_stage_switch(state: GameState) -> Dict[str, Any]: return {} -def host_result(state: GameState) -> Dict[str, Any]: +def host_result(state: GameState, *, metrics: "GameMetrics") -> Dict[str, Any]: """ Calculates the result of a round, eliminates a player, and checks for a winner. This node is the aggregation point after voting. @@ -112,8 +117,8 @@ def host_result(state: GameState) -> Dict[str, Any]: if winner: logger.info("Winner determined: %s", winner) - if metrics_collector.enabled: - metrics_collector.on_game_end( + if metrics.enabled: + metrics.on_game_end( game_id=state.get("game_id"), winner=winner, ) diff --git a/src/game/nodes/player.py b/src/game/nodes/player.py index 00a5c86..3ee2c7e 100644 --- a/src/game/nodes/player.py +++ b/src/game/nodes/player.py @@ -21,11 +21,9 @@ """ from datetime import datetime -from typing import Dict, Any +from typing import Dict, Any, TYPE_CHECKING from src.tools.llm import create_llm -from ..config import get_config -from ..metrics import metrics_collector from ..state import ( GameState, alive_players, @@ -50,6 +48,10 @@ get_normalized_player_mindset, ) +if TYPE_CHECKING: + from ..config import GameConfig + from ..metrics import GameMetrics + logger = get_logger(__name__) @@ -98,7 +100,13 @@ def _create_player_private_state_delta( } -async def player_speech(state: GameState, player_id: str) -> Dict[str, Any]: +async def player_speech( + state: GameState, + player_id: str, + *, + game_config: "GameConfig", + metrics: "GameMetrics", +) -> Dict[str, Any]: """ Player node for generating speech. Calls LLM to infer identity and generate speech. @@ -119,7 +127,6 @@ async def player_speech(state: GameState, player_id: str) -> Dict[str, Any]: logger.debug("Player %s assigned word: %s", player_id, my_word) # Generate playerMindset using LLM - config = get_config() existing_player_mindset = get_normalized_player_mindset(existing_private_state) llm_client = _get_llm_client() @@ -130,7 +137,7 @@ async def player_speech(state: GameState, player_id: str) -> Dict[str, Any]: players=state["players"], alive=alive_players(state), me=player_id, - rules=config.get_game_rules(), + rules=game_config.get_game_rules(), existing_player_mindset=existing_player_mindset, ) @@ -166,15 +173,15 @@ async def player_speech(state: GameState, player_id: str) -> Dict[str, Any]: # Prepare the state updates based on the generated speech and PlayerMindset speech_record: Speech = create_speech_record(state, player_id, new_speech_text) - if metrics_collector.enabled: - metrics_collector.on_player_mindset_update( + if metrics.enabled: + metrics.on_player_mindset_update( game_id=state.get("game_id"), round_number=state["current_round"], phase=state["game_phase"], player_id=player_id, mindset=updated_mindset_state, ) - metrics_collector.on_speech( + metrics.on_speech( game_id=state.get("game_id"), round_number=state["current_round"], player_id=player_id, @@ -191,7 +198,13 @@ async def player_speech(state: GameState, player_id: str) -> Dict[str, Any]: } -async def player_vote(state: GameState, player_id: str) -> Dict[str, Any]: +async def player_vote( + state: GameState, + player_id: str, + *, + game_config: "GameConfig", + metrics: "GameMetrics", +) -> Dict[str, Any]: """ Player node for casting a vote. Calls LLM to infer identity and decide vote target. @@ -212,7 +225,6 @@ async def player_vote(state: GameState, player_id: str) -> Dict[str, Any]: logger.debug("Player %s assigned word: %s", player_id, my_word) # Generate playerMindset using LLM - config = get_config() existing_player_mindset = get_normalized_player_mindset(existing_private_state) llm_client = _get_llm_client() @@ -223,7 +235,7 @@ async def player_vote(state: GameState, player_id: str) -> Dict[str, Any]: players=state["players"], alive=alive_players(state), me=player_id, - rules=config.get_game_rules(), + rules=game_config.get_game_rules(), existing_player_mindset=existing_player_mindset, ) updated_mindset_state = normalize_mindset(updated_mindset) @@ -243,8 +255,8 @@ async def player_vote(state: GameState, player_id: str) -> Dict[str, Any]: # Prepare the state updates based on the decided vote and PlayerMindset ts = int(datetime.now().timestamp() * 1000) - if metrics_collector.enabled: - metrics_collector.on_player_mindset_update( + if metrics.enabled: + metrics.on_player_mindset_update( game_id=state.get("game_id"), round_number=state["current_round"], phase=state["game_phase"], diff --git a/src/game/rules.py b/src/game/rules.py index e7e474b..26c4416 100644 --- a/src/game/rules.py +++ b/src/game/rules.py @@ -22,7 +22,7 @@ from collections import Counter from typing import List, Dict, Any -from .config import get_config, calculate_spy_count +from .config import DEFAULT_CONFIG, calculate_spy_count from .logger import get_logger from .state import ( GameState, @@ -71,8 +71,8 @@ def assign_roles_and_words( elif word_list: civilian_word, spy_word = random.choice(word_list) else: - word_list = get_config().vocabulary - civilian_word, spy_word = random.choice(word_list) + default_vocab = DEFAULT_CONFIG["game"]["vocabulary"] + civilian_word, spy_word = random.choice(default_vocab) # 2. Prepare private states player_private_states: Dict[str, PlayerPrivateState] = {} diff --git a/tests/test_host_nodes.py b/tests/test_host_nodes.py index 3ddac02..6d0a7f7 100644 --- a/tests/test_host_nodes.py +++ b/tests/test_host_nodes.py @@ -1,7 +1,21 @@ import pytest +from src.game.config import load_config +from src.game.metrics import GameMetrics from src.game.nodes.host import host_setup, host_stage_switch, host_result +@pytest.fixture +def game_config(): + return load_config() + + +@pytest.fixture +def metrics(): + collector = GameMetrics() + collector.set_enabled(False) + return collector + + @pytest.fixture def base_state(): """A base game state fixture for tests.""" @@ -18,9 +32,9 @@ def base_state(): } -def test_host_setup(base_state): +def test_host_setup(base_state, game_config, metrics): """Tests that host_setup initializes the game correctly.""" - update = host_setup(base_state) + update = host_setup(base_state, game_config=game_config, metrics=metrics) assert update["current_round"] == 1 assert update["game_phase"] == "speaking" assert "host_private_state" in update @@ -52,7 +66,7 @@ def test_host_stage_switch(base_state): assert "phase_id" in update_done -def test_host_result_elimination_and_advance(base_state): +def test_host_result_elimination_and_advance(base_state, metrics): """Tests a standard round result: one player is eliminated and the game advances.""" # Scenario: 5 players, 1 spy (b), 4 civilians (a,c,d,e) # Eliminate a civilian ('a'). Game should continue. @@ -74,7 +88,7 @@ def test_host_result_elimination_and_advance(base_state): } }, } - update = host_result(voting_state) + update = host_result(voting_state, metrics=metrics) assert update["game_phase"] == "speaking" assert update["current_round"] == 2 @@ -82,7 +96,7 @@ def test_host_result_elimination_and_advance(base_state): assert update["current_votes"] == {} -def test_host_result_spy_win(base_state): +def test_host_result_spy_win(base_state, metrics): """Tests the condition for a spy victory.""" voting_state = base_state | { "game_phase": "voting", @@ -99,13 +113,13 @@ def test_host_result_spy_win(base_state): }, } # After 'c' is eliminated, 1 spy ('b') and 1 civilian ('a') will remain. Spies win. - update = host_result(voting_state) + update = host_result(voting_state, metrics=metrics) assert update["game_phase"] == "result" assert update["eliminated_players"] == ["c"] assert update["winner"] == "spies" -def test_host_result_civilian_win(base_state): +def test_host_result_civilian_win(base_state, metrics): """Tests the condition for a civilian victory.""" voting_state = base_state | { "game_phase": "voting", @@ -121,7 +135,7 @@ def test_host_result_civilian_win(base_state): } }, } - update = host_result(voting_state) + update = host_result(voting_state, metrics=metrics) assert update["game_phase"] == "result" assert update["eliminated_players"] == ["b"] assert update["winner"] == "civilians" diff --git a/tests/test_player_nodes.py b/tests/test_player_nodes.py index 307a6d2..d1c5605 100644 --- a/tests/test_player_nodes.py +++ b/tests/test_player_nodes.py @@ -3,6 +3,8 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch +from src.game.config import load_config +from src.game.metrics import GameMetrics from src.game.nodes.player import player_speech, player_vote from src.game.state import ( GameState, @@ -41,6 +43,18 @@ def make_player_private_state( } +@pytest.fixture +def game_config(): + return load_config() + + +@pytest.fixture +def metrics(): + collector = GameMetrics() + collector.set_enabled(False) + return collector + + @pytest.fixture def player_id(): return "a" @@ -106,7 +120,13 @@ def base_player_state(player_id): @patch("src.game.nodes.player.llm_generate_speech", new_callable=AsyncMock) @patch("src.game.nodes.player.llm_update_player_mindset", new_callable=AsyncMock) def test_player_speech( - mock_infer, mock_speech, mock_get_llm, player_id, base_player_state: GameState + mock_infer, + mock_speech, + mock_get_llm, + player_id, + base_player_state: GameState, + game_config, + metrics, ): """Tests the player_speech node with mocked LLM calls.""" # Arrange: Configure mocks to return predictable values @@ -120,7 +140,14 @@ def test_player_speech( mock_speech.return_value = "This is a test speech." # Act: Call the player_speech node - update = asyncio.run(player_speech(base_player_state, player_id)) + update = asyncio.run( + player_speech( + base_player_state, + player_id, + game_config=game_config, + metrics=metrics, + ) + ) # Assert: Verify the output is correct assert "completed_speeches" in update @@ -143,7 +170,13 @@ def test_player_speech( @patch("src.game.nodes.player.llm_update_player_mindset", new_callable=AsyncMock) @patch("src.game.nodes.player.llm_decide_vote", new_callable=AsyncMock) def test_player_vote( - mock_decide_vote, mock_infer, mock_get_llm, player_id, base_player_state: GameState + mock_decide_vote, + mock_infer, + mock_get_llm, + player_id, + base_player_state: GameState, + game_config, + metrics, ): """Tests the player_vote node with mocked LLM calls.""" # Arrange: Configure mocks @@ -163,7 +196,14 @@ def test_player_vote( } # Act: Call the player_vote node - update = asyncio.run(player_vote(voting_state, player_id)) + update = asyncio.run( + player_vote( + voting_state, + player_id, + game_config=game_config, + metrics=metrics, + ) + ) # Assert: Verify the output assert "current_votes" in update @@ -187,24 +227,45 @@ def test_player_vote( ) -def test_player_speech_not_in_speaking_phase(base_player_state: GameState): +def test_player_speech_not_in_speaking_phase( + base_player_state: GameState, game_config, metrics +): """Tests that player_speech returns empty dict if not in speaking phase.""" state = base_player_state | {"game_phase": "voting"} - update = asyncio.run(player_speech(state, "a")) + update = asyncio.run( + player_speech(state, "a", game_config=game_config, metrics=metrics) + ) assert update == {} -def test_player_vote_not_in_voting_phase(base_player_state: GameState): +def test_player_vote_not_in_voting_phase( + base_player_state: GameState, + game_config, + metrics, +): """Tests that player_vote returns empty dict if not in voting phase.""" state = base_player_state | {"game_phase": "speaking"} - update = asyncio.run(player_vote(state, "a")) + update = asyncio.run( + player_vote(state, "a", game_config=game_config, metrics=metrics) + ) assert update == {} -def test_player_node_for_eliminated_player(base_player_state: GameState): +def test_player_node_for_eliminated_player( + base_player_state: GameState, game_config, metrics +): """Tests that nodes do nothing for an eliminated player.""" state = base_player_state | {"eliminated_players": ["a"]} - speech_update = asyncio.run(player_speech(state, "a")) - vote_update = asyncio.run(player_vote(state | {"game_phase": "voting"}, "a")) + speech_update = asyncio.run( + player_speech(state, "a", game_config=game_config, metrics=metrics) + ) + vote_update = asyncio.run( + player_vote( + state | {"game_phase": "voting"}, + "a", + game_config=game_config, + metrics=metrics, + ) + ) assert speech_update == {} assert vote_update == {}