From a8fa54d6563db9ee3b9745b588a5574cdeee392c Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Thu, 23 Oct 2025 09:55:49 +0200 Subject: [PATCH 01/13] add simple try except to subprocess --- src/core/containers/runtime/providers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/core/containers/runtime/providers.py b/src/core/containers/runtime/providers.py index 637b3be5..a8022ddc 100644 --- a/src/core/containers/runtime/providers.py +++ b/src/core/containers/runtime/providers.py @@ -169,8 +169,12 @@ def start_container( cmd.append(image) # Run container - result = subprocess.run(cmd, capture_output=True, text=True, check=True) - self._container_id = result.stdout.strip() + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + self._container_id = result.stdout.strip() + except subprocess.CalledProcessError as e: + error_msg = f"Failed to start Docker container.\nCommand: {' '.join(cmd)}\nExit code: {e.returncode}\nStderr: {e.stderr}\nStdout: {e.stdout}" + raise RuntimeError(error_msg) from e # Wait a moment for container to start time.sleep(1) From bcecb3b0d09b582b489e334c891cea7708ca9ab6 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Sat, 25 Oct 2025 12:35:02 +0200 Subject: [PATCH 02/13] implement basic textarena wrapper server --- src/envs/textarena_env/server/Dockerfile | 32 +++ src/envs/textarena_env/server/__init__.py | 12 + src/envs/textarena_env/server/app.py | 53 +++++ src/envs/textarena_env/server/environment.py | 218 +++++++++++++++++++ 4 files changed, 315 insertions(+) create mode 100644 src/envs/textarena_env/server/Dockerfile create mode 100644 src/envs/textarena_env/server/__init__.py create mode 100644 src/envs/textarena_env/server/app.py create mode 100644 src/envs/textarena_env/server/environment.py diff --git a/src/envs/textarena_env/server/Dockerfile b/src/envs/textarena_env/server/Dockerfile new file mode 100644 index 00000000..5df60823 --- /dev/null +++ b/src/envs/textarena_env/server/Dockerfile @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Use the shared OpenEnv base image (Python 3.11) +ARG BASE_IMAGE=openenv-base:latest +FROM ${BASE_IMAGE} + +# Install system libraries required by TextArena (cv2 needs libGL, glib) +RUN apt-get update && apt-get install -y --no-install-recommends \ + libgl1 \ + libglib2.0-0 \ + && rm -rf /var/lib/apt/lists/* + +# Install TextArena and Python dependencies +RUN pip install --no-cache-dir \ + textarena==0.6.1 \ + nltk==3.9.2 + +# Copy OpenEnv core and TextArena environment sources +COPY src/core/ /app/src/core/ +COPY src/envs/textarena_env/ /app/src/envs/textarena_env/ + +# Optional: health check to ensure server responsiveness +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Run the TextArena FastAPI server +CMD ["uvicorn", "envs.textarena_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] + diff --git a/src/envs/textarena_env/server/__init__.py b/src/envs/textarena_env/server/__init__.py new file mode 100644 index 00000000..22d17ab5 --- /dev/null +++ b/src/envs/textarena_env/server/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Server components for the generic TextArena environment.""" + +from .environment import TextArenaEnvironment + +__all__ = ["TextArenaEnvironment"] + diff --git a/src/envs/textarena_env/server/app.py b/src/envs/textarena_env/server/app.py new file mode 100644 index 00000000..59dea784 --- /dev/null +++ b/src/envs/textarena_env/server/app.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""FastAPI application entrypoint for the TextArena environment.""" + +from __future__ import annotations + +import os + +from core.env_server.http_server import create_app + +from ..models import TextArenaAction, TextArenaObservation +from .environment import TextArenaEnvironment + + +def _parse_env_kwargs(prefix: str = "TEXTARENA_KW_") -> dict[str, str]: + """Collect arbitrary environment kwargs from the process environment.""" + + env_kwargs: dict[str, str] = {} + for key, value in os.environ.items(): + if key.startswith(prefix): + env_key = key[len(prefix) :].lower() + env_kwargs[env_key] = value + return env_kwargs + + +env_id = os.getenv("TEXTARENA_ENV_ID", "Wordle-v0") +num_players = int(os.getenv("TEXTARENA_NUM_PLAYERS", "1")) +max_turns_env = os.getenv("TEXTARENA_MAX_TURNS") +max_turns = int(max_turns_env) if max_turns_env is not None else None +download_nltk = os.getenv("TEXTARENA_DOWNLOAD_NLTK", "1") in {"1", "true", "True"} + +extra_kwargs = _parse_env_kwargs() + +environment = TextArenaEnvironment( + env_id=env_id, + num_players=num_players, + max_turns=max_turns, + download_nltk=download_nltk, + env_kwargs=extra_kwargs, +) + +app = create_app(environment, TextArenaAction, TextArenaObservation, env_name="textarena_env") + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) + diff --git a/src/envs/textarena_env/server/environment.py b/src/envs/textarena_env/server/environment.py new file mode 100644 index 00000000..a6aa5980 --- /dev/null +++ b/src/envs/textarena_env/server/environment.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Server implementation for the generic TextArena environment.""" + +from __future__ import annotations + +import sys +from typing import Any, Dict, Iterable, List, Optional +from uuid import uuid4 + +import nltk + +from core.env_server.interfaces import Environment + +from ..models import TextArenaAction, TextArenaMessage, TextArenaObservation, TextArenaState + + +_TEXTARENA_MODULE: Any | None = None +_TEXTARENA_IMPORT_ERROR: Exception | None = None + + +def _import_textarena() -> Any: + """Import ``textarena`` lazily and cache the module reference.""" + + global _TEXTARENA_MODULE, _TEXTARENA_IMPORT_ERROR + + if _TEXTARENA_MODULE is not None: + return _TEXTARENA_MODULE + + if _TEXTARENA_IMPORT_ERROR is not None: + raise _TEXTARENA_IMPORT_ERROR + + if sys.version_info < (3, 10): + _TEXTARENA_IMPORT_ERROR = RuntimeError( + "TextArena environments require Python 3.10 or newer; " + f"current interpreter is {sys.version_info.major}.{sys.version_info.minor}" + ) + raise _TEXTARENA_IMPORT_ERROR + + try: + import textarena as ta # type: ignore[import] + except Exception as exc: # pragma: no cover - surfaced to caller + _TEXTARENA_IMPORT_ERROR = exc + raise + + _TEXTARENA_MODULE = ta + return ta + + +class TextArenaEnvironment(Environment): + """Wrap any TextArena game behind the OpenEnv ``Environment`` API.""" + + def __init__( + self, + env_id: str = "Wordle-v0", + *, + num_players: int = 1, + max_turns: Optional[int] = None, + download_nltk: bool = True, + env_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + + ta = _import_textarena() + + if download_nltk: + nltk.download("words", quiet=True) + nltk.download("averaged_perceptron_tagger_eng", quiet=True) + + self.env_id = env_id + self.num_players = num_players + self.max_turns = max_turns + self._env_kwargs = env_kwargs or {} + + self._ta_env = ta.make(env_id=env_id, **self._env_kwargs) + + self._state = TextArenaState( + env_id=env_id, + num_players=num_players, + max_turns=max_turns, + ) + + # ------------------------------------------------------------------ + # Environment interface + # ------------------------------------------------------------------ + def reset(self) -> TextArenaObservation: + self._ta_env.reset(num_players=self.num_players) + + self._state.episode_id = str(uuid4()) + self._state.step_count = 0 + self._state.turn = 0 + self._state.last_reward = 0.0 + self._state.last_info = {} + self._state.raw_state = self._snapshot_state() + + observation = self._build_observation() + observation.reward = 0.0 + observation.done = False + + return observation + + def step(self, action: TextArenaAction) -> TextArenaObservation: # type: ignore[override] + if not isinstance(action, TextArenaAction): + raise TypeError(f"Expected TextArenaAction, received {type(action)!r}") + + done, info = self._ta_env.step(action.message) + + self._state.step_count += 1 + self._state.turn = getattr(self._ta_env.state, "turn", self._state.turn + 1) + self._state.last_info = info or {} + + observation = self._build_observation() + observation.done = done + + reward = self._extract_reward() + observation.reward = reward + self._state.last_reward = reward + self._state.raw_state = self._snapshot_state() + + return observation + + @property + def state(self) -> TextArenaState: + return self._state + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _build_observation(self) -> TextArenaObservation: + player_id, messages = self._ta_env.get_observation() + + ta_messages = self._convert_messages(messages) + prompt_lines = [msg.content for msg in ta_messages if msg.category == "PROMPT"] + if not prompt_lines: + # Fallback to most recent message history for prompt + prompt_lines = [msg.content for msg in ta_messages] + + info: Dict[str, Any] = {} + info.update(getattr(self._ta_env.state, "step_info", {})) + + observation = TextArenaObservation( + prompt="\n".join(prompt_lines).strip(), + messages=ta_messages, + current_player_id=player_id, + legal_players=self._legal_players(), + info=info, + metadata={ + "env_id": self.env_id, + "turn": getattr(self._ta_env.state, "turn", 0), + "raw_messages": [ + { + "sender_id": msg.sender_id, + "content": msg.content, + "category": msg.category, + } + for msg in ta_messages + ], + }, + ) + + return observation + + def _legal_players(self) -> List[int]: + role_mapping = getattr(self._ta_env.state, "role_mapping", {}) or {} + players = [pid for pid in role_mapping.keys() if isinstance(pid, int) and pid >= 0] + return sorted(players) + + def _convert_messages(self, messages: Iterable[Any]) -> List[TextArenaMessage]: + converted: List[TextArenaMessage] = [] + for entry in messages: + if isinstance(entry, tuple) and len(entry) == 3: + sender, content, category = entry + elif isinstance(entry, tuple) and len(entry) == 2: + sender, content = entry + category = "MESSAGE" + else: + sender, content, category = -1, str(entry), "MESSAGE" + + category_name = getattr(category, "name", str(category)) + converted.append( + TextArenaMessage( + sender_id=int(sender) if isinstance(sender, (int, float)) else -1, + content=str(content), + category=category_name, + ) + ) + + return converted + + def _extract_reward(self) -> float: + rewards = getattr(self._ta_env.state, "rewards", None) + if isinstance(rewards, dict): + # Use current player reward if available, otherwise default to player 0. + player_id = getattr(self._ta_env.state, "current_player_id", 0) + if player_id in rewards: + return float(rewards[player_id]) + if 0 in rewards: + return float(rewards[0]) + return 0.0 + + def _snapshot_state(self) -> Dict[str, Any]: + state = self._ta_env.state + snapshot: Dict[str, Any] = { + "turn": getattr(state, "turn", 0), + "game_state": getattr(state, "game_state", {}), + "logs": list(getattr(state, "logs", [])), + "rewards": getattr(state, "rewards", None), + "done": getattr(state, "done", False), + "role_mapping": getattr(state, "role_mapping", {}), + "game_info": getattr(state, "game_info", {}), + "step_info": getattr(state, "step_info", {}), + } + return snapshot + From 952173ef78187fad908c7db652a6234a36673ab3 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Sat, 25 Oct 2025 12:36:10 +0200 Subject: [PATCH 03/13] implement basic text arena client --- src/envs/textarena_env/__init__.py | 24 ++++++++++ src/envs/textarena_env/client.py | 76 ++++++++++++++++++++++++++++++ src/envs/textarena_env/models.py | 55 +++++++++++++++++++++ 3 files changed, 155 insertions(+) create mode 100644 src/envs/textarena_env/__init__.py create mode 100644 src/envs/textarena_env/client.py create mode 100644 src/envs/textarena_env/models.py diff --git a/src/envs/textarena_env/__init__.py b/src/envs/textarena_env/__init__.py new file mode 100644 index 00000000..61075679 --- /dev/null +++ b/src/envs/textarena_env/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""TextArena environment integration for OpenEnv.""" + +from .client import TextArenaEnv +from .models import ( + TextArenaAction, + TextArenaMessage, + TextArenaObservation, + TextArenaState, +) + +__all__ = [ + "TextArenaEnv", + "TextArenaAction", + "TextArenaObservation", + "TextArenaState", + "TextArenaMessage", +] + diff --git a/src/envs/textarena_env/client.py b/src/envs/textarena_env/client.py new file mode 100644 index 00000000..9f464206 --- /dev/null +++ b/src/envs/textarena_env/client.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""HTTP client for the generic TextArena environment.""" + +from __future__ import annotations + +from typing import Any, Dict, TYPE_CHECKING + +from core.client_types import StepResult +from core.http_env_client import HTTPEnvClient + +from .models import ( + TextArenaAction, + TextArenaMessage, + TextArenaObservation, + TextArenaState, +) + +if TYPE_CHECKING: + from core.containers.runtime import ContainerProvider + + +class TextArenaEnv(HTTPEnvClient[TextArenaAction, TextArenaObservation]): + """HTTP client for the TextArena environment server.""" + + def _step_payload(self, action: TextArenaAction) -> Dict[str, Any]: + return {"message": action.message} + + def _parse_result( + self, payload: Dict[str, Any] + ) -> StepResult[TextArenaObservation]: + obs_data = payload.get("observation", {}) + messages_payload = obs_data.get("messages", []) + messages = [ + TextArenaMessage( + sender_id=item.get("sender_id", -1), + content=item.get("content", ""), + category=item.get("category", "MESSAGE"), + ) + for item in messages_payload + if isinstance(item, dict) + ] + + observation = TextArenaObservation( + prompt=obs_data.get("prompt", ""), + messages=messages, + current_player_id=obs_data.get("current_player_id", 0), + legal_players=obs_data.get("legal_players", []), + info=obs_data.get("info", {}), + reward=payload.get("reward"), + done=payload.get("done", False), + metadata=obs_data.get("metadata", {}), + ) + return StepResult( + observation=observation, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict[str, Any]) -> TextArenaState: + return TextArenaState( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + env_id=payload.get("env_id", "unknown"), + num_players=payload.get("num_players", 1), + max_turns=payload.get("max_turns"), + turn=payload.get("turn", 0), + last_reward=payload.get("last_reward", 0.0), + last_info=payload.get("last_info", {}), + raw_state=payload.get("raw_state", {}), + ) + diff --git a/src/envs/textarena_env/models.py b/src/envs/textarena_env/models.py new file mode 100644 index 00000000..4fea2c17 --- /dev/null +++ b/src/envs/textarena_env/models.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Common data models for the TextArena environment wrapper.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from core.env_server.types import Action, Observation, State + + +@dataclass +class TextArenaMessage: + """Single message observed by a player.""" + + sender_id: int + content: str + category: str + + +@dataclass(kw_only=True) +class TextArenaAction(Action): + """Action issued by the agent for TextArena games.""" + + message: str + + +@dataclass(kw_only=True) +class TextArenaObservation(Observation): + """Observation returned from any TextArena game.""" + + prompt: str + messages: List[TextArenaMessage] = field(default_factory=list) + current_player_id: int = 0 + legal_players: List[int] = field(default_factory=list) + info: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(kw_only=True) +class TextArenaState(State): + """Structured state snapshot for the server.""" + + env_id: str + num_players: int + max_turns: Optional[int] = None + turn: int = 0 + last_reward: float = 0.0 + last_info: Dict[str, Any] = field(default_factory=dict) + raw_state: Dict[str, Any] = field(default_factory=dict) + From bcc072d0a6c687ea589a67b69f482f6f91375144 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Sat, 25 Oct 2025 12:36:24 +0200 Subject: [PATCH 04/13] add text arena examples and docs --- examples/textarena_simple.py | 87 +++++++++++++ examples/txtaarena_wordle_inference.py | 174 +++++++++++++++++++++++++ src/envs/textarena_env/README.md | 46 +++++++ 3 files changed, 307 insertions(+) create mode 100644 examples/textarena_simple.py create mode 100644 examples/txtaarena_wordle_inference.py create mode 100644 src/envs/textarena_env/README.md diff --git a/examples/textarena_simple.py b/examples/textarena_simple.py new file mode 100644 index 00000000..a65ef1ff --- /dev/null +++ b/examples/textarena_simple.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Quickstart example for the generic TextArena environment.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +# Add project src/ to import path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from envs.textarena_env import TextArenaEnv, TextArenaAction + + +def main() -> None: + + print("=" * 60) + print("πŸ’¬ TextArena Hello World - GuessTheNumber-v0") + print("=" * 60) + + env = TextArenaEnv.from_docker_image( + "textarena-env:latest", + env_vars={ + "TEXTARENA_ENV_ID": "GuessTheNumber-v0", + "TEXTARENA_NUM_PLAYERS": "1", + }, + ports={8000: 8000}, + ) + + try: + print("\nπŸ“ Resetting environment...") + result = env.reset() + print(f" Prompt:\n{result.observation.prompt}\n") + + # Simple heuristic: if prompt mentions a range, start with midpoint + guess = "[10]" + + for step in range(5): + print(f"🎯 Step {step + 1}: sending guess {guess}") + result = env.step(TextArenaAction(message=guess)) + + for message in result.observation.messages: + print(f" [{message.category}] {message.content}") + + if result.done: + break + + # Basic update: look for 'higher' or 'lower' hints + feedback = " ".join(msg.content for msg in result.observation.messages) + if "higher" in feedback: + guess = "[15]" + elif "lower" in feedback: + guess = "[5]" + else: + guess = "[10]" + + print("\nβœ… Episode finished!") + print(f" Reward: {result.reward}") + print(f" Done: {result.done}") + + state = env.state() + print("\nπŸ“Š Server State Snapshot:") + print(f" Episode ID: {state.episode_id}") + print(f" Step count: {state.step_count}") + print(f" Env ID: {state.env_id}") + + except Exception as exc: # pragma: no cover - demonstration script + print(f"\n❌ Error: {exc}") + print("\nMake sure you have built the Docker image first:") + print(" docker build -f src/envs/textarena_env/server/Dockerfile -t textarena-env:latest .") + print("\nAlternatively run the server manually:") + print(" python -m envs.textarena_env.server.app") + + finally: + env.close() + print("\nπŸ‘‹ Done!") + + +if __name__ == "__main__": + main() + diff --git a/examples/txtaarena_wordle_inference.py b/examples/txtaarena_wordle_inference.py new file mode 100644 index 00000000..224d1910 --- /dev/null +++ b/examples/txtaarena_wordle_inference.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +"""Play TextArena Wordle with a hosted LLM via Hugging Face Inference Providers. + +This script mirrors the structure of the Kuhn Poker inference sample but targets +the Wordle environment. We deploy the generic TextArena server (wrapped in +OpenEnv) inside a local Docker container and query a single hosted model using +the OpenAI-compatible API provided by Hugging Face's router. + +Prerequisites +------------- +1. Build the TextArena Docker image:: + + docker build -f src/envs/textarena_env/server/Dockerfile -t textarena-env:latest . + +2. Set your Hugging Face token:: + + export HF_TOKEN=your_token_here + +3. Run this script:: + + python examples/wordle_inference.py + +By default we ask the DeepSeek Terminus model to play ``Wordle-v0``. Adjust the +``MODEL`` constant if you'd like to experiment with another provider-compatible +model. +""" + +from __future__ import annotations + +import os +import re +from typing import Iterable, List + +from openai import OpenAI + +from envs.textarena_env import TextArenaAction, TextArenaEnv +from envs.textarena_env.models import TextArenaMessage + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +API_BASE_URL = "https://router.huggingface.co/v1" +API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") + +MODEL = "openai/gpt-oss-120b:novita" +MAX_TURNS = 8 +VERBOSE = True + +SYSTEM_PROMPT = ( + "You are an expert Wordle solver." + " Always respond with a single guess inside square brackets, e.g. [crane]." + " Use lowercase letters, exactly one five-letter word per reply." + " Reason about prior feedback before choosing the next guess." + " Words must be 5 letters long and real English words." + " Do not not include any other text in your response." + " Do not repeat the same guess twice." +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def format_history(messages: Iterable[TextArenaMessage]) -> str: + """Convert TextArena message history into plain text for the model.""" + + lines: List[str] = [] + for message in messages: + tag = message.category or "MESSAGE" + lines.append(f"[{tag}] {message.content}") + return "\n".join(lines) + + +def extract_guess(text: str) -> str: + """Return the first Wordle-style guess enclosed in square brackets.""" + + match = re.search(r"\[[A-Za-z]{5}\]", text) + if match: + return match.group(0).lower() + # Fallback: remove whitespace and ensure lowercase, then wrap + cleaned = re.sub(r"[^a-zA-Z]", "", text).lower() + if len(cleaned) >= 5: + return f"[{cleaned[:5]}]" + return "[crane]" + + +def make_user_prompt(prompt_text: str, messages: Iterable[TextArenaMessage]) -> str: + """Combine the TextArena prompt and feedback history for the model.""" + + history = format_history(messages) + return ( + f"Current prompt:\n{prompt_text}\n\n" + f"Conversation so far:\n{history}\n\n" + "Reply with your next guess enclosed in square brackets." + ) + + +# --------------------------------------------------------------------------- +# Gameplay +# --------------------------------------------------------------------------- + +def play_wordle(env: TextArenaEnv, client: OpenAI) -> None: + result = env.reset() + observation = result.observation + + if VERBOSE: + print("πŸ“œ Initial Prompt:\n" + observation.prompt) + + for turn in range(1, MAX_TURNS + 1): + if result.done: + break + + user_prompt = make_user_prompt(observation.prompt, observation.messages) + + response = client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ], + max_tokens=2048, + temperature=0.7, + ) + + raw_output = response.choices[0].message.content.strip() + guess = extract_guess(raw_output) + + if VERBOSE: + print(f"\n🎯 Turn {turn}: model replied with -> {raw_output}") + print(f" Parsed guess: {guess}") + + result = env.step(TextArenaAction(message=guess)) + observation = result.observation + + if VERBOSE: + print(" Feedback messages:") + for message in observation.messages: + print(f" [{message.category}] {message.content}") + + print("\nβœ… Game finished") + print(f" Reward: {result.reward}") + print(f" Done: {result.done}") + + +# --------------------------------------------------------------------------- +# Entrypoint +# --------------------------------------------------------------------------- + +def main() -> None: + if not API_KEY: + raise SystemExit("HF_TOKEN (or API_KEY) must be set to query the model.") + + client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) + + env = TextArenaEnv.from_docker_image( + "textarena-env:latest", + env_vars={ + "TEXTARENA_ENV_ID": "Wordle-v0", + "TEXTARENA_NUM_PLAYERS": "1", + }, + ports={8000: 8000}, + ) + + try: + play_wordle(env, client) + finally: + env.close() + + +if __name__ == "__main__": + main() + + diff --git a/src/envs/textarena_env/README.md b/src/envs/textarena_env/README.md new file mode 100644 index 00000000..819a0c8c --- /dev/null +++ b/src/envs/textarena_env/README.md @@ -0,0 +1,46 @@ +# TextArena Environment + +Generic wrapper for any [TextArena](https://www.textarena.ai/docs/overview) game inside OpenEnv. This module exposes the TextArena `Env` interface through the standard HTTP server/client APIs used by other OpenEnv environments, enabling quick experimentation with the full suite of word, reasoning, and multi-agent games. + +## Features +- Works with any registered TextArena game (e.g. `Wordle-v0`, `GuessTheNumber-v0`, `Chess-v0`, ...). +- Transparent access to TextArena message streams, rewards, and state snapshots. +- Docker image for easy deployment with PythonΒ 3.11 and preinstalled dependencies. +- Example client demonstrating end-to-end interaction. + +## Docker + +Build the container from the project root: + +```bash +docker build -f src/envs/textarena_env/server/Dockerfile -t textarena-env:latest . +``` + +Run it with your desired game (default is `Wordle-v0`). Environment configuration is handled via env vars: + +```bash +docker run -p 8000:8000 \ + -e TEXTARENA_ENV_ID=GuessTheNumber-v0 \ + -e TEXTARENA_NUM_PLAYERS=1 \ + textarena-env:latest +``` + +Additional environment arguments can be passed using the `TEXTARENA_KW_` prefix. For example, to enable `hardcore=True`: + +```bash +docker run -p 8000:8000 \ + -e TEXTARENA_ENV_ID=Wordle-v0 \ + -e TEXTARENA_KW_hardcore=true \ + textarena-env:latest +``` + +## Python Example + +The repository ships with a simple client script that connects to a running server (local or Docker) and plays a few turns. Run it from the repo root: + +```bash +python examples/textarena_simple.py +``` + +The script uses `TextArenaEnv.from_docker_image` to automatically build/run the container if needed. Review the source (`examples/textarena_simple.py`) for more details and to customize the gameplay loop. + From 21f7ed36a42e16ac1c1518ae02c9d2748b6dc95f Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Sat, 25 Oct 2025 13:15:12 +0200 Subject: [PATCH 05/13] logical failed prsing word --- examples/txtaarena_wordle_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/txtaarena_wordle_inference.py b/examples/txtaarena_wordle_inference.py index 224d1910..9524a5ae 100644 --- a/examples/txtaarena_wordle_inference.py +++ b/examples/txtaarena_wordle_inference.py @@ -82,7 +82,7 @@ def extract_guess(text: str) -> str: cleaned = re.sub(r"[^a-zA-Z]", "", text).lower() if len(cleaned) >= 5: return f"[{cleaned[:5]}]" - return "[crane]" + return "[dunno]" def make_user_prompt(prompt_text: str, messages: Iterable[TextArenaMessage]) -> str: From 87c17cc9ba03cceff58381e3d1d194da1dd3abd0 Mon Sep 17 00:00:00 2001 From: Ben Burtenshaw Date: Sat, 25 Oct 2025 11:20:36 +0000 Subject: [PATCH 06/13] first draft grpo script --- grpo.py | 435 +++++++++++++++++++++ src/envs/textarena_env/server/run_local.sh | 6 + 2 files changed, 441 insertions(+) create mode 100644 grpo.py create mode 100755 src/envs/textarena_env/server/run_local.sh diff --git a/grpo.py b/grpo.py new file mode 100644 index 00000000..ac4bcdbc --- /dev/null +++ b/grpo.py @@ -0,0 +1,435 @@ +#!/usr/bin/env python3 +""" +GRPO training for Wordle using the TextArena OpenEnv environment. + +Usage: + # First, start the TextArena Wordle server (Docker or local): + TEXTARENA_ENV_ID=Wordle-v0 TEXTARENA_NUM_PLAYERS=1 \ + python -m src.envs.textarena_env.server.app + + # Then run this training script: + python grpo.py +""" + +import sys +from pathlib import Path + +# Add src to path for imports +sys.path.insert(0, str(Path(__file__).parent / "src")) + +import torch +from typing import Iterable +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + get_linear_schedule_with_warmup, +) + + +model_id = "Qwen/Qwen3-0.6B" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained( + model_id, + dtype=torch.bfloat16 if device.type == "cuda" else torch.float32, +).to(device) + +tokenizer.pad_token = tokenizer.eos_token +model.config.pad_token_id = tokenizer.pad_token_id + + +from peft import LoraConfig, get_peft_model + +lora_config = LoraConfig( + r=8, + lora_alpha=16, + lora_dropout=0.05, + task_type="CAUSAL_LM", + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], +) + +model = get_peft_model(model, lora_config) +model.print_trainable_parameters() + +# Note: Gradient checkpointing can cause issues with LoRA + eval/train mode switching +# We use other memory optimizations instead (bfloat16, small batches, etc.) +# Uncomment if you have memory issues and are not using LoRA: +# model.gradient_checkpointing_enable() + + +import numpy as np +from envs.textarena_env import TextArenaAction, TextArenaEnv +from envs.textarena_env.models import TextArenaMessage + +# Connect to the TextArena Wordle environment server (make sure it's running!) +# Start with: TEXTARENA_ENV_ID=Wordle-v0 python -m envs.textarena_env.server.app +env = TextArenaEnv(base_url="http://localhost:8000") + +MAX_TURNS = 8 + +SYSTEM_PROMPT = ( + "You are an expert Wordle solver." + " Always respond with a single guess inside square brackets, e.g. [crane]." + " Use lowercase letters, exactly one five-letter word per reply." + " Reason about prior feedback before choosing the next guess." + " Words must be 5 letters long and real English words." + " Do not include any other text in your response." + " Do not repeat the same guess twice." +) + + +max_train_steps = 500 # More steps to see actual learning +num_generations = 4 # REDUCED: Number of episodes to run per training step (was 8) +max_new_tokens = 8 # Allow generation of one bracketed guess plus reasoning tokens +max_episode_steps = 8 # Wordle has at most 8 turns in our configuration +temperature = 0.7 # Lower temperature for more focused action selection +top_k = 10 # Smaller top_k for more deterministic actions +learning_rate = 1e-5 # Higher learning rate +weight_decay = 0.0 +epsilon = 0.2 +gradient_accumulation_steps = 2 # INCREASED: Accumulate gradients to reduce memory +warmup_ratio = 0.1 +logging_frequency = 10 + + +import re +import gc +import torch.nn.functional as F +from contextlib import nullcontext + + +def format_history(messages: Iterable[TextArenaMessage]) -> str: + """Convert TextArena message history into plain text for the model.""" + + lines = [] + for message in messages: + tag = message.category or "MESSAGE" + content = message.content.strip() + if not content: + continue + lines.append(f"[{tag}] {content}") + return "\n".join(lines) + + +def extract_guess(text: str) -> str: + """Return the first Wordle-style guess enclosed in square brackets.""" + + match = re.search(r"\[[A-Za-z]{5}\]", text) + if match: + return match.group(0).lower() + + # Fallback: remove non-letters and enforce lowercase 5-letter word + cleaned = re.sub(r"[^a-zA-Z]", "", text).lower() + if len(cleaned) >= 5: + return f"[{cleaned[:5]}]" + return "[crane]" + + +def make_user_prompt(prompt_text: str, messages: Iterable[TextArenaMessage]) -> str: + """Combine the TextArena prompt and feedback history for the model.""" + + history = format_history(messages) + prompt_section = prompt_text.strip() if prompt_text.strip() else "Wordle-v0" + history_section = history if history else "[PROMPT] Awaiting first feedback." + + return ( + f"Game prompt:\n{prompt_section}\n\n" + f"Conversation so far:\n{history_section}\n\n" + "Reply with your next guess enclosed in square brackets." + ) + + +def run_wordle_episode(env: TextArenaEnv, model, tokenizer, device, max_steps): + """Run a single Wordle episode and collect prompts/completions for training.""" + + result = env.reset() + observation = result.observation + + episode_reward = 0.0 + all_prompt_ids = [] + all_completion_ids = [] + prompt_lengths = [] + seen_guesses = set() + turn = 0 + + while not result.done and turn < max_steps: + prompt_text = make_user_prompt(observation.prompt, observation.messages) + prompt_with_rules = f"{SYSTEM_PROMPT}\n\n{prompt_text}" + + messages = [{"role": "user", "content": prompt_with_rules}] + inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + ).to(device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + temperature=temperature, + top_k=top_k, + pad_token_id=tokenizer.pad_token_id, + ) + + prompt_length = inputs["input_ids"].shape[1] + completion_ids = outputs[0, prompt_length:] + completion_text = tokenizer.decode(completion_ids, skip_special_tokens=True) + + all_prompt_ids.extend(inputs["input_ids"][0].cpu().tolist()) + all_completion_ids.extend(completion_ids.cpu().tolist()) + prompt_lengths.append(prompt_length) + + guess = extract_guess(completion_text) + if guess in seen_guesses: + # Force a fallback to avoid duplicates + for fallback in ["[crane]", "[slate]", "[adieu]", "[roate]"]: + if fallback not in seen_guesses: + guess = fallback + break + seen_guesses.add(guess) + + result = env.step(TextArenaAction(message=guess)) + reward = result.reward or 0.0 + episode_reward += reward + observation = result.observation + turn += 1 + + del inputs, outputs, completion_ids + + return episode_reward, all_prompt_ids, all_completion_ids, prompt_lengths + + +def per_token_log_probs(logits, labels, use_float32=False): + """ + Compute log probabilities for each token without materialising full log-softmax. + + Args: + logits: Model logits (kept in bfloat16 by default for memory efficiency) + labels: Target token IDs + use_float32: If True, convert to float32 (more accurate but uses 2x memory) + + Note: bfloat16 is sufficient for RL training and saves significant memory. + """ + if use_float32 and logits.dtype != torch.float32: + logits = logits.to(torch.float32) + + vocab_size = logits.size(-1) + # Use reshape instead of view for gradient checkpointing compatibility + flat_logits = logits.reshape(-1, vocab_size) + flat_labels = labels.reshape(-1) + per_token_loss = F.cross_entropy( + flat_logits, + flat_labels, + reduction="none", + ignore_index=tokenizer.pad_token_id, + ) + return (-per_token_loss).reshape_as(labels) + + +# Setup autocast context for mixed precision training +# We use bfloat16 throughout for memory efficiency (4x less than float32) +# bfloat16 has the same exponent range as float32, making it ideal for RL training +if device.type == "cuda": + autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +else: + autocast_ctx = nullcontext() + + +optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay +) +total_update_steps = max_train_steps // gradient_accumulation_steps +warmup_steps = max(1, int(total_update_steps * warmup_ratio)) +scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_update_steps) + +import trackio + +trackio.init(project="grpo-wordle") + +model.train() +global_step = 0 +running_reward = 0.0 +running_loss = 0.0 +logging_frequency = 10 + +print("Starting GRPO training with Wordle environment...") +print(f"Running {num_generations} episodes per training step") +print(f"Model dtype: {next(model.parameters()).dtype}") +print(f"Device: {device}") +print(f"Using bfloat16: {device.type == 'cuda'}") + +for step in range(1, max_train_steps + 1): + print(f"\nStep {step} of {max_train_steps}") + # Run multiple Wordle episodes to collect training data + model.eval() + episode_rewards = [] + all_sequences = [] + all_prompt_lengths = [] # Track prompt lengths for proper masking + + for episode_idx in range(0, num_generations): + ( + episode_reward, + prompt_ids, + completion_ids, + prompt_lengths, + ) = run_wordle_episode( + env, model, tokenizer, device, max_steps=max_episode_steps + ) + episode_rewards.append(episode_reward) + + # Combine prompt and completion into full sequence + full_sequence = prompt_ids + completion_ids + all_sequences.append(full_sequence) + all_prompt_lengths.append(sum(prompt_lengths)) + + # Clear memory after each episode to prevent accumulation + if device.type == "cuda" and episode_idx % 2 == 1: # Every 2 episodes + torch.cuda.empty_cache() + + # Clear memory after all episode generation + if device.type == "cuda": + torch.cuda.empty_cache() + gc.collect() + + model.train() + + # Pad sequences to same length + max_len = max(len(seq) for seq in all_sequences) + padded_sequences = [] + padded_completion_masks = [] + + for seq, prompt_len in zip(all_sequences, all_prompt_lengths): + # Pad sequence + padded = seq + [tokenizer.pad_token_id] * (max_len - len(seq)) + padded_sequences.append(padded) + + # Create completion mask: 1 for completion tokens, 0 for prompt and padding + # CRITICAL FIX: Only train on the completion tokens (actions), not the prompts + comp_mask = [0] * max_len + for i in range(prompt_len, len(seq)): + comp_mask[i] = 1 + padded_completion_masks.append(comp_mask) + + sequences = torch.tensor(padded_sequences, dtype=torch.long, device=device) + attention_mask = (sequences != tokenizer.pad_token_id).long() + completion_mask = torch.tensor( + padded_completion_masks, dtype=torch.long, device=device + ) + + # Convert episode rewards to tensor (use bfloat16 on GPU, float32 on CPU) + reward_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + rewards = torch.tensor(episode_rewards, dtype=reward_dtype, device=device) + running_reward += rewards.mean().item() + + # Compute advantages (normalize rewards) - keep in bfloat16 + mean_reward = rewards.mean() + std_reward = rewards.std() + std_reward = std_reward if std_reward > 0 else 1.0 + advantages = (rewards - mean_reward) / std_reward + + # Prepare labels for loss computation + labels = sequences[:, 1:].clone() + labels[attention_mask[:, 1:] == 0] = tokenizer.pad_token_id + + # Compute old log probs (policy before update) + with torch.no_grad(): + with autocast_ctx if device.type == "cuda" else nullcontext(): + old_outputs = model( + input_ids=sequences, + attention_mask=attention_mask, + use_cache=False, + ) + old_log_probs = per_token_log_probs(old_outputs.logits[:, :-1], labels) + # Delete old_outputs to free memory + del old_outputs + + valid_mask = (completion_mask[:, 1:] == 1) & (labels != tokenizer.pad_token_id) + + # Compute new log probs and loss + # Note: With gradient_accumulation_steps > 1, we only zero grads at the start + if step % gradient_accumulation_steps == 1: + optimizer.zero_grad(set_to_none=True) + + with autocast_ctx if device.type == "cuda" else nullcontext(): + outputs = model( + input_ids=sequences, + attention_mask=attention_mask, + use_cache=False, + ) + log_probs = per_token_log_probs(outputs.logits[:, :-1], labels) + # Delete outputs immediately to free memory + del outputs + + # GRPO loss computation + ratio = (log_probs - old_log_probs).exp() + ratio = torch.where(valid_mask, ratio, torch.ones_like(ratio)) + clipped_ratio = ratio.clamp(1.0 - epsilon, 1.0 + epsilon) + + adv = advantages.unsqueeze(1) + loss_unclipped = ratio * adv + loss_clipped = clipped_ratio * adv + per_token_loss = -torch.min(loss_unclipped, loss_clipped) + per_token_loss = torch.where( + valid_mask, per_token_loss, torch.zeros_like(per_token_loss) + ) + + denom = valid_mask.sum().clamp(min=1) + loss = per_token_loss.sum() / denom + + # Scale loss by gradient accumulation steps + loss = loss / gradient_accumulation_steps + + # Backprop and update (only step optimizer every gradient_accumulation_steps) + loss.backward() + + if step % gradient_accumulation_steps == 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + scheduler.step() + + global_step += 1 + running_loss += loss.item() + + # Clear memory after training step + del sequences, attention_mask, completion_mask, rewards, advantages + del labels, old_log_probs, valid_mask, log_probs + del ratio, clipped_ratio, loss_unclipped, loss_clipped, per_token_loss, loss + + if device.type == "cuda": + torch.cuda.empty_cache() + gc.collect() + + # Logging + if step % logging_frequency == 0: + avg_reward = running_reward / logging_frequency + avg_loss = running_loss / logging_frequency + current_lr = scheduler.get_last_lr()[0] + wins = sum(1 for r in episode_rewards if r > 0) + losses = sum(1 for r in episode_rewards if r < 0) + ties = sum(1 for r in episode_rewards if r == 0) + print( + f"step={step:04d} | loss={avg_loss:.4f} | avg_reward={avg_reward:.4f} | lr={current_lr:.2e}" + ) + print(f" Episode rewards: {[f'{r:+.1f}' for r in episode_rewards]}") + print( + f" Win/Loss/Tie: {wins}/{losses}/{ties} (win rate: {wins/len(episode_rewards)*100:.1f}%)" + ) + running_reward = 0.0 + running_loss = 0.0 + trackio.log( + { + "step": step, + "loss": avg_loss, + "reward": avg_reward, + "win_rate": wins / len(episode_rewards), + } + ) + + +print("\nTraining complete!") +print("Remember to close the OpenSpiel environment server when done.") diff --git a/src/envs/textarena_env/server/run_local.sh b/src/envs/textarena_env/server/run_local.sh new file mode 100755 index 00000000..fbc0bb0f --- /dev/null +++ b/src/envs/textarena_env/server/run_local.sh @@ -0,0 +1,6 @@ +TEXTARENA_ENV_ID="Wordle-v0" TEXTARENA_NUM_PLAYERS=2 + +# Run the server +exec uvicorn envs.textarena_env.server.app:app --host 0.0.0.0 --port 8000 + + From b52e3c8a9707755c3e476a9fa90bd087bc5358e3 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Sun, 26 Oct 2025 20:04:41 +0100 Subject: [PATCH 07/13] rename inference example --- ...taarena_wordle_inference.py => textaarena_wordle_inference.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{txtaarena_wordle_inference.py => textaarena_wordle_inference.py} (100%) diff --git a/examples/txtaarena_wordle_inference.py b/examples/textaarena_wordle_inference.py similarity index 100% rename from examples/txtaarena_wordle_inference.py rename to examples/textaarena_wordle_inference.py From 3b72bcec24ceb5b385997a4b1c2b71c9e918949d Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Mon, 27 Oct 2025 06:51:12 +0100 Subject: [PATCH 08/13] fix typo in file name --- examples/textaarena_wordle_inference.py | 174 ------------------------ examples/textarena_simple.py | 9 +- 2 files changed, 1 insertion(+), 182 deletions(-) delete mode 100644 examples/textaarena_wordle_inference.py diff --git a/examples/textaarena_wordle_inference.py b/examples/textaarena_wordle_inference.py deleted file mode 100644 index 9524a5ae..00000000 --- a/examples/textaarena_wordle_inference.py +++ /dev/null @@ -1,174 +0,0 @@ -#!/usr/bin/env python3 -"""Play TextArena Wordle with a hosted LLM via Hugging Face Inference Providers. - -This script mirrors the structure of the Kuhn Poker inference sample but targets -the Wordle environment. We deploy the generic TextArena server (wrapped in -OpenEnv) inside a local Docker container and query a single hosted model using -the OpenAI-compatible API provided by Hugging Face's router. - -Prerequisites -------------- -1. Build the TextArena Docker image:: - - docker build -f src/envs/textarena_env/server/Dockerfile -t textarena-env:latest . - -2. Set your Hugging Face token:: - - export HF_TOKEN=your_token_here - -3. Run this script:: - - python examples/wordle_inference.py - -By default we ask the DeepSeek Terminus model to play ``Wordle-v0``. Adjust the -``MODEL`` constant if you'd like to experiment with another provider-compatible -model. -""" - -from __future__ import annotations - -import os -import re -from typing import Iterable, List - -from openai import OpenAI - -from envs.textarena_env import TextArenaAction, TextArenaEnv -from envs.textarena_env.models import TextArenaMessage - -# --------------------------------------------------------------------------- -# Configuration -# --------------------------------------------------------------------------- - -API_BASE_URL = "https://router.huggingface.co/v1" -API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") - -MODEL = "openai/gpt-oss-120b:novita" -MAX_TURNS = 8 -VERBOSE = True - -SYSTEM_PROMPT = ( - "You are an expert Wordle solver." - " Always respond with a single guess inside square brackets, e.g. [crane]." - " Use lowercase letters, exactly one five-letter word per reply." - " Reason about prior feedback before choosing the next guess." - " Words must be 5 letters long and real English words." - " Do not not include any other text in your response." - " Do not repeat the same guess twice." -) - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def format_history(messages: Iterable[TextArenaMessage]) -> str: - """Convert TextArena message history into plain text for the model.""" - - lines: List[str] = [] - for message in messages: - tag = message.category or "MESSAGE" - lines.append(f"[{tag}] {message.content}") - return "\n".join(lines) - - -def extract_guess(text: str) -> str: - """Return the first Wordle-style guess enclosed in square brackets.""" - - match = re.search(r"\[[A-Za-z]{5}\]", text) - if match: - return match.group(0).lower() - # Fallback: remove whitespace and ensure lowercase, then wrap - cleaned = re.sub(r"[^a-zA-Z]", "", text).lower() - if len(cleaned) >= 5: - return f"[{cleaned[:5]}]" - return "[dunno]" - - -def make_user_prompt(prompt_text: str, messages: Iterable[TextArenaMessage]) -> str: - """Combine the TextArena prompt and feedback history for the model.""" - - history = format_history(messages) - return ( - f"Current prompt:\n{prompt_text}\n\n" - f"Conversation so far:\n{history}\n\n" - "Reply with your next guess enclosed in square brackets." - ) - - -# --------------------------------------------------------------------------- -# Gameplay -# --------------------------------------------------------------------------- - -def play_wordle(env: TextArenaEnv, client: OpenAI) -> None: - result = env.reset() - observation = result.observation - - if VERBOSE: - print("πŸ“œ Initial Prompt:\n" + observation.prompt) - - for turn in range(1, MAX_TURNS + 1): - if result.done: - break - - user_prompt = make_user_prompt(observation.prompt, observation.messages) - - response = client.chat.completions.create( - model=MODEL, - messages=[ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": user_prompt}, - ], - max_tokens=2048, - temperature=0.7, - ) - - raw_output = response.choices[0].message.content.strip() - guess = extract_guess(raw_output) - - if VERBOSE: - print(f"\n🎯 Turn {turn}: model replied with -> {raw_output}") - print(f" Parsed guess: {guess}") - - result = env.step(TextArenaAction(message=guess)) - observation = result.observation - - if VERBOSE: - print(" Feedback messages:") - for message in observation.messages: - print(f" [{message.category}] {message.content}") - - print("\nβœ… Game finished") - print(f" Reward: {result.reward}") - print(f" Done: {result.done}") - - -# --------------------------------------------------------------------------- -# Entrypoint -# --------------------------------------------------------------------------- - -def main() -> None: - if not API_KEY: - raise SystemExit("HF_TOKEN (or API_KEY) must be set to query the model.") - - client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) - - env = TextArenaEnv.from_docker_image( - "textarena-env:latest", - env_vars={ - "TEXTARENA_ENV_ID": "Wordle-v0", - "TEXTARENA_NUM_PLAYERS": "1", - }, - ports={8000: 8000}, - ) - - try: - play_wordle(env, client) - finally: - env.close() - - -if __name__ == "__main__": - main() - - diff --git a/examples/textarena_simple.py b/examples/textarena_simple.py index a65ef1ff..52df7a5a 100644 --- a/examples/textarena_simple.py +++ b/examples/textarena_simple.py @@ -24,14 +24,7 @@ def main() -> None: print("πŸ’¬ TextArena Hello World - GuessTheNumber-v0") print("=" * 60) - env = TextArenaEnv.from_docker_image( - "textarena-env:latest", - env_vars={ - "TEXTARENA_ENV_ID": "GuessTheNumber-v0", - "TEXTARENA_NUM_PLAYERS": "1", - }, - ports={8000: 8000}, - ) + env = TextArenaEnv("https://huggingface.co/spaces/burtenshaw/textarena") try: print("\nπŸ“ Resetting environment...") From 09ad324221459e728846e0c644a175dacaef90e4 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Mon, 27 Oct 2025 11:26:16 +0100 Subject: [PATCH 09/13] add inference example with hf and gpt oss --- examples/textarena_wordle_inference.py | 174 +++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 examples/textarena_wordle_inference.py diff --git a/examples/textarena_wordle_inference.py b/examples/textarena_wordle_inference.py new file mode 100644 index 00000000..9524a5ae --- /dev/null +++ b/examples/textarena_wordle_inference.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +"""Play TextArena Wordle with a hosted LLM via Hugging Face Inference Providers. + +This script mirrors the structure of the Kuhn Poker inference sample but targets +the Wordle environment. We deploy the generic TextArena server (wrapped in +OpenEnv) inside a local Docker container and query a single hosted model using +the OpenAI-compatible API provided by Hugging Face's router. + +Prerequisites +------------- +1. Build the TextArena Docker image:: + + docker build -f src/envs/textarena_env/server/Dockerfile -t textarena-env:latest . + +2. Set your Hugging Face token:: + + export HF_TOKEN=your_token_here + +3. Run this script:: + + python examples/wordle_inference.py + +By default we ask the DeepSeek Terminus model to play ``Wordle-v0``. Adjust the +``MODEL`` constant if you'd like to experiment with another provider-compatible +model. +""" + +from __future__ import annotations + +import os +import re +from typing import Iterable, List + +from openai import OpenAI + +from envs.textarena_env import TextArenaAction, TextArenaEnv +from envs.textarena_env.models import TextArenaMessage + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +API_BASE_URL = "https://router.huggingface.co/v1" +API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") + +MODEL = "openai/gpt-oss-120b:novita" +MAX_TURNS = 8 +VERBOSE = True + +SYSTEM_PROMPT = ( + "You are an expert Wordle solver." + " Always respond with a single guess inside square brackets, e.g. [crane]." + " Use lowercase letters, exactly one five-letter word per reply." + " Reason about prior feedback before choosing the next guess." + " Words must be 5 letters long and real English words." + " Do not not include any other text in your response." + " Do not repeat the same guess twice." +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def format_history(messages: Iterable[TextArenaMessage]) -> str: + """Convert TextArena message history into plain text for the model.""" + + lines: List[str] = [] + for message in messages: + tag = message.category or "MESSAGE" + lines.append(f"[{tag}] {message.content}") + return "\n".join(lines) + + +def extract_guess(text: str) -> str: + """Return the first Wordle-style guess enclosed in square brackets.""" + + match = re.search(r"\[[A-Za-z]{5}\]", text) + if match: + return match.group(0).lower() + # Fallback: remove whitespace and ensure lowercase, then wrap + cleaned = re.sub(r"[^a-zA-Z]", "", text).lower() + if len(cleaned) >= 5: + return f"[{cleaned[:5]}]" + return "[dunno]" + + +def make_user_prompt(prompt_text: str, messages: Iterable[TextArenaMessage]) -> str: + """Combine the TextArena prompt and feedback history for the model.""" + + history = format_history(messages) + return ( + f"Current prompt:\n{prompt_text}\n\n" + f"Conversation so far:\n{history}\n\n" + "Reply with your next guess enclosed in square brackets." + ) + + +# --------------------------------------------------------------------------- +# Gameplay +# --------------------------------------------------------------------------- + +def play_wordle(env: TextArenaEnv, client: OpenAI) -> None: + result = env.reset() + observation = result.observation + + if VERBOSE: + print("πŸ“œ Initial Prompt:\n" + observation.prompt) + + for turn in range(1, MAX_TURNS + 1): + if result.done: + break + + user_prompt = make_user_prompt(observation.prompt, observation.messages) + + response = client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ], + max_tokens=2048, + temperature=0.7, + ) + + raw_output = response.choices[0].message.content.strip() + guess = extract_guess(raw_output) + + if VERBOSE: + print(f"\n🎯 Turn {turn}: model replied with -> {raw_output}") + print(f" Parsed guess: {guess}") + + result = env.step(TextArenaAction(message=guess)) + observation = result.observation + + if VERBOSE: + print(" Feedback messages:") + for message in observation.messages: + print(f" [{message.category}] {message.content}") + + print("\nβœ… Game finished") + print(f" Reward: {result.reward}") + print(f" Done: {result.done}") + + +# --------------------------------------------------------------------------- +# Entrypoint +# --------------------------------------------------------------------------- + +def main() -> None: + if not API_KEY: + raise SystemExit("HF_TOKEN (or API_KEY) must be set to query the model.") + + client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) + + env = TextArenaEnv.from_docker_image( + "textarena-env:latest", + env_vars={ + "TEXTARENA_ENV_ID": "Wordle-v0", + "TEXTARENA_NUM_PLAYERS": "1", + }, + ports={8000: 8000}, + ) + + try: + play_wordle(env, client) + finally: + env.close() + + +if __name__ == "__main__": + main() + + From 67eec2c2ea1a9e3322074e7a2d8d9ee7c9136f3f Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Mon, 27 Oct 2025 19:12:10 +0100 Subject: [PATCH 10/13] Update examples/textarena_simple.py --- examples/textarena_simple.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/textarena_simple.py b/examples/textarena_simple.py index 52df7a5a..a65ef1ff 100644 --- a/examples/textarena_simple.py +++ b/examples/textarena_simple.py @@ -24,7 +24,14 @@ def main() -> None: print("πŸ’¬ TextArena Hello World - GuessTheNumber-v0") print("=" * 60) - env = TextArenaEnv("https://huggingface.co/spaces/burtenshaw/textarena") + env = TextArenaEnv.from_docker_image( + "textarena-env:latest", + env_vars={ + "TEXTARENA_ENV_ID": "GuessTheNumber-v0", + "TEXTARENA_NUM_PLAYERS": "1", + }, + ports={8000: 8000}, + ) try: print("\nπŸ“ Resetting environment...") From d3b10e849fc5aac71020181327f637603f176559 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Mon, 27 Oct 2025 20:46:31 +0100 Subject: [PATCH 11/13] add env to docker build --- .github/workflows/docker-build.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 8cba8b47..e5195938 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -79,6 +79,8 @@ jobs: dockerfile: src/envs/atari_env/server/Dockerfile - name: git-env dockerfile: src/envs/git_env/server/Dockerfile + - name: textarena-env + dockerfile: src/envs/textarena_env/server/Dockerfile steps: - name: Checkout code From bc8204d07f37083e19f361c588bef21b88bf9a6e Mon Sep 17 00:00:00 2001 From: Ben Burtenshaw Date: Tue, 28 Oct 2025 08:59:11 +0000 Subject: [PATCH 12/13] delete extra grpo example --- grpo.py | 435 -------------------------------------------------------- 1 file changed, 435 deletions(-) delete mode 100644 grpo.py diff --git a/grpo.py b/grpo.py deleted file mode 100644 index ac4bcdbc..00000000 --- a/grpo.py +++ /dev/null @@ -1,435 +0,0 @@ -#!/usr/bin/env python3 -""" -GRPO training for Wordle using the TextArena OpenEnv environment. - -Usage: - # First, start the TextArena Wordle server (Docker or local): - TEXTARENA_ENV_ID=Wordle-v0 TEXTARENA_NUM_PLAYERS=1 \ - python -m src.envs.textarena_env.server.app - - # Then run this training script: - python grpo.py -""" - -import sys -from pathlib import Path - -# Add src to path for imports -sys.path.insert(0, str(Path(__file__).parent / "src")) - -import torch -from typing import Iterable -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - get_linear_schedule_with_warmup, -) - - -model_id = "Qwen/Qwen3-0.6B" -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_pretrained( - model_id, - dtype=torch.bfloat16 if device.type == "cuda" else torch.float32, -).to(device) - -tokenizer.pad_token = tokenizer.eos_token -model.config.pad_token_id = tokenizer.pad_token_id - - -from peft import LoraConfig, get_peft_model - -lora_config = LoraConfig( - r=8, - lora_alpha=16, - lora_dropout=0.05, - task_type="CAUSAL_LM", - target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], -) - -model = get_peft_model(model, lora_config) -model.print_trainable_parameters() - -# Note: Gradient checkpointing can cause issues with LoRA + eval/train mode switching -# We use other memory optimizations instead (bfloat16, small batches, etc.) -# Uncomment if you have memory issues and are not using LoRA: -# model.gradient_checkpointing_enable() - - -import numpy as np -from envs.textarena_env import TextArenaAction, TextArenaEnv -from envs.textarena_env.models import TextArenaMessage - -# Connect to the TextArena Wordle environment server (make sure it's running!) -# Start with: TEXTARENA_ENV_ID=Wordle-v0 python -m envs.textarena_env.server.app -env = TextArenaEnv(base_url="http://localhost:8000") - -MAX_TURNS = 8 - -SYSTEM_PROMPT = ( - "You are an expert Wordle solver." - " Always respond with a single guess inside square brackets, e.g. [crane]." - " Use lowercase letters, exactly one five-letter word per reply." - " Reason about prior feedback before choosing the next guess." - " Words must be 5 letters long and real English words." - " Do not include any other text in your response." - " Do not repeat the same guess twice." -) - - -max_train_steps = 500 # More steps to see actual learning -num_generations = 4 # REDUCED: Number of episodes to run per training step (was 8) -max_new_tokens = 8 # Allow generation of one bracketed guess plus reasoning tokens -max_episode_steps = 8 # Wordle has at most 8 turns in our configuration -temperature = 0.7 # Lower temperature for more focused action selection -top_k = 10 # Smaller top_k for more deterministic actions -learning_rate = 1e-5 # Higher learning rate -weight_decay = 0.0 -epsilon = 0.2 -gradient_accumulation_steps = 2 # INCREASED: Accumulate gradients to reduce memory -warmup_ratio = 0.1 -logging_frequency = 10 - - -import re -import gc -import torch.nn.functional as F -from contextlib import nullcontext - - -def format_history(messages: Iterable[TextArenaMessage]) -> str: - """Convert TextArena message history into plain text for the model.""" - - lines = [] - for message in messages: - tag = message.category or "MESSAGE" - content = message.content.strip() - if not content: - continue - lines.append(f"[{tag}] {content}") - return "\n".join(lines) - - -def extract_guess(text: str) -> str: - """Return the first Wordle-style guess enclosed in square brackets.""" - - match = re.search(r"\[[A-Za-z]{5}\]", text) - if match: - return match.group(0).lower() - - # Fallback: remove non-letters and enforce lowercase 5-letter word - cleaned = re.sub(r"[^a-zA-Z]", "", text).lower() - if len(cleaned) >= 5: - return f"[{cleaned[:5]}]" - return "[crane]" - - -def make_user_prompt(prompt_text: str, messages: Iterable[TextArenaMessage]) -> str: - """Combine the TextArena prompt and feedback history for the model.""" - - history = format_history(messages) - prompt_section = prompt_text.strip() if prompt_text.strip() else "Wordle-v0" - history_section = history if history else "[PROMPT] Awaiting first feedback." - - return ( - f"Game prompt:\n{prompt_section}\n\n" - f"Conversation so far:\n{history_section}\n\n" - "Reply with your next guess enclosed in square brackets." - ) - - -def run_wordle_episode(env: TextArenaEnv, model, tokenizer, device, max_steps): - """Run a single Wordle episode and collect prompts/completions for training.""" - - result = env.reset() - observation = result.observation - - episode_reward = 0.0 - all_prompt_ids = [] - all_completion_ids = [] - prompt_lengths = [] - seen_guesses = set() - turn = 0 - - while not result.done and turn < max_steps: - prompt_text = make_user_prompt(observation.prompt, observation.messages) - prompt_with_rules = f"{SYSTEM_PROMPT}\n\n{prompt_text}" - - messages = [{"role": "user", "content": prompt_with_rules}] - inputs = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - return_dict=True, - return_tensors="pt", - ).to(device) - - with torch.no_grad(): - outputs = model.generate( - **inputs, - max_new_tokens=max_new_tokens, - do_sample=True, - temperature=temperature, - top_k=top_k, - pad_token_id=tokenizer.pad_token_id, - ) - - prompt_length = inputs["input_ids"].shape[1] - completion_ids = outputs[0, prompt_length:] - completion_text = tokenizer.decode(completion_ids, skip_special_tokens=True) - - all_prompt_ids.extend(inputs["input_ids"][0].cpu().tolist()) - all_completion_ids.extend(completion_ids.cpu().tolist()) - prompt_lengths.append(prompt_length) - - guess = extract_guess(completion_text) - if guess in seen_guesses: - # Force a fallback to avoid duplicates - for fallback in ["[crane]", "[slate]", "[adieu]", "[roate]"]: - if fallback not in seen_guesses: - guess = fallback - break - seen_guesses.add(guess) - - result = env.step(TextArenaAction(message=guess)) - reward = result.reward or 0.0 - episode_reward += reward - observation = result.observation - turn += 1 - - del inputs, outputs, completion_ids - - return episode_reward, all_prompt_ids, all_completion_ids, prompt_lengths - - -def per_token_log_probs(logits, labels, use_float32=False): - """ - Compute log probabilities for each token without materialising full log-softmax. - - Args: - logits: Model logits (kept in bfloat16 by default for memory efficiency) - labels: Target token IDs - use_float32: If True, convert to float32 (more accurate but uses 2x memory) - - Note: bfloat16 is sufficient for RL training and saves significant memory. - """ - if use_float32 and logits.dtype != torch.float32: - logits = logits.to(torch.float32) - - vocab_size = logits.size(-1) - # Use reshape instead of view for gradient checkpointing compatibility - flat_logits = logits.reshape(-1, vocab_size) - flat_labels = labels.reshape(-1) - per_token_loss = F.cross_entropy( - flat_logits, - flat_labels, - reduction="none", - ignore_index=tokenizer.pad_token_id, - ) - return (-per_token_loss).reshape_as(labels) - - -# Setup autocast context for mixed precision training -# We use bfloat16 throughout for memory efficiency (4x less than float32) -# bfloat16 has the same exponent range as float32, making it ideal for RL training -if device.type == "cuda": - autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) -else: - autocast_ctx = nullcontext() - - -optimizer = torch.optim.AdamW( - model.parameters(), lr=learning_rate, weight_decay=weight_decay -) -total_update_steps = max_train_steps // gradient_accumulation_steps -warmup_steps = max(1, int(total_update_steps * warmup_ratio)) -scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_update_steps) - -import trackio - -trackio.init(project="grpo-wordle") - -model.train() -global_step = 0 -running_reward = 0.0 -running_loss = 0.0 -logging_frequency = 10 - -print("Starting GRPO training with Wordle environment...") -print(f"Running {num_generations} episodes per training step") -print(f"Model dtype: {next(model.parameters()).dtype}") -print(f"Device: {device}") -print(f"Using bfloat16: {device.type == 'cuda'}") - -for step in range(1, max_train_steps + 1): - print(f"\nStep {step} of {max_train_steps}") - # Run multiple Wordle episodes to collect training data - model.eval() - episode_rewards = [] - all_sequences = [] - all_prompt_lengths = [] # Track prompt lengths for proper masking - - for episode_idx in range(0, num_generations): - ( - episode_reward, - prompt_ids, - completion_ids, - prompt_lengths, - ) = run_wordle_episode( - env, model, tokenizer, device, max_steps=max_episode_steps - ) - episode_rewards.append(episode_reward) - - # Combine prompt and completion into full sequence - full_sequence = prompt_ids + completion_ids - all_sequences.append(full_sequence) - all_prompt_lengths.append(sum(prompt_lengths)) - - # Clear memory after each episode to prevent accumulation - if device.type == "cuda" and episode_idx % 2 == 1: # Every 2 episodes - torch.cuda.empty_cache() - - # Clear memory after all episode generation - if device.type == "cuda": - torch.cuda.empty_cache() - gc.collect() - - model.train() - - # Pad sequences to same length - max_len = max(len(seq) for seq in all_sequences) - padded_sequences = [] - padded_completion_masks = [] - - for seq, prompt_len in zip(all_sequences, all_prompt_lengths): - # Pad sequence - padded = seq + [tokenizer.pad_token_id] * (max_len - len(seq)) - padded_sequences.append(padded) - - # Create completion mask: 1 for completion tokens, 0 for prompt and padding - # CRITICAL FIX: Only train on the completion tokens (actions), not the prompts - comp_mask = [0] * max_len - for i in range(prompt_len, len(seq)): - comp_mask[i] = 1 - padded_completion_masks.append(comp_mask) - - sequences = torch.tensor(padded_sequences, dtype=torch.long, device=device) - attention_mask = (sequences != tokenizer.pad_token_id).long() - completion_mask = torch.tensor( - padded_completion_masks, dtype=torch.long, device=device - ) - - # Convert episode rewards to tensor (use bfloat16 on GPU, float32 on CPU) - reward_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 - rewards = torch.tensor(episode_rewards, dtype=reward_dtype, device=device) - running_reward += rewards.mean().item() - - # Compute advantages (normalize rewards) - keep in bfloat16 - mean_reward = rewards.mean() - std_reward = rewards.std() - std_reward = std_reward if std_reward > 0 else 1.0 - advantages = (rewards - mean_reward) / std_reward - - # Prepare labels for loss computation - labels = sequences[:, 1:].clone() - labels[attention_mask[:, 1:] == 0] = tokenizer.pad_token_id - - # Compute old log probs (policy before update) - with torch.no_grad(): - with autocast_ctx if device.type == "cuda" else nullcontext(): - old_outputs = model( - input_ids=sequences, - attention_mask=attention_mask, - use_cache=False, - ) - old_log_probs = per_token_log_probs(old_outputs.logits[:, :-1], labels) - # Delete old_outputs to free memory - del old_outputs - - valid_mask = (completion_mask[:, 1:] == 1) & (labels != tokenizer.pad_token_id) - - # Compute new log probs and loss - # Note: With gradient_accumulation_steps > 1, we only zero grads at the start - if step % gradient_accumulation_steps == 1: - optimizer.zero_grad(set_to_none=True) - - with autocast_ctx if device.type == "cuda" else nullcontext(): - outputs = model( - input_ids=sequences, - attention_mask=attention_mask, - use_cache=False, - ) - log_probs = per_token_log_probs(outputs.logits[:, :-1], labels) - # Delete outputs immediately to free memory - del outputs - - # GRPO loss computation - ratio = (log_probs - old_log_probs).exp() - ratio = torch.where(valid_mask, ratio, torch.ones_like(ratio)) - clipped_ratio = ratio.clamp(1.0 - epsilon, 1.0 + epsilon) - - adv = advantages.unsqueeze(1) - loss_unclipped = ratio * adv - loss_clipped = clipped_ratio * adv - per_token_loss = -torch.min(loss_unclipped, loss_clipped) - per_token_loss = torch.where( - valid_mask, per_token_loss, torch.zeros_like(per_token_loss) - ) - - denom = valid_mask.sum().clamp(min=1) - loss = per_token_loss.sum() / denom - - # Scale loss by gradient accumulation steps - loss = loss / gradient_accumulation_steps - - # Backprop and update (only step optimizer every gradient_accumulation_steps) - loss.backward() - - if step % gradient_accumulation_steps == 0: - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - optimizer.step() - scheduler.step() - - global_step += 1 - running_loss += loss.item() - - # Clear memory after training step - del sequences, attention_mask, completion_mask, rewards, advantages - del labels, old_log_probs, valid_mask, log_probs - del ratio, clipped_ratio, loss_unclipped, loss_clipped, per_token_loss, loss - - if device.type == "cuda": - torch.cuda.empty_cache() - gc.collect() - - # Logging - if step % logging_frequency == 0: - avg_reward = running_reward / logging_frequency - avg_loss = running_loss / logging_frequency - current_lr = scheduler.get_last_lr()[0] - wins = sum(1 for r in episode_rewards if r > 0) - losses = sum(1 for r in episode_rewards if r < 0) - ties = sum(1 for r in episode_rewards if r == 0) - print( - f"step={step:04d} | loss={avg_loss:.4f} | avg_reward={avg_reward:.4f} | lr={current_lr:.2e}" - ) - print(f" Episode rewards: {[f'{r:+.1f}' for r in episode_rewards]}") - print( - f" Win/Loss/Tie: {wins}/{losses}/{ties} (win rate: {wins/len(episode_rewards)*100:.1f}%)" - ) - running_reward = 0.0 - running_loss = 0.0 - trackio.log( - { - "step": step, - "loss": avg_loss, - "reward": avg_reward, - "win_rate": wins / len(episode_rewards), - } - ) - - -print("\nTraining complete!") -print("Remember to close the OpenSpiel environment server when done.") From 941da1a89770528b0baf1391af303654470d659b Mon Sep 17 00:00:00 2001 From: Ben Burtenshaw Date: Tue, 28 Oct 2025 09:58:39 +0000 Subject: [PATCH 13/13] add wordle specific rewards to the environment --- src/envs/textarena_env/__init__.py | 4 +- src/envs/textarena_env/rewards.py | 133 +++++++++++++++++++ src/envs/textarena_env/server/environment.py | 34 +++++ src/envs/textarena_env/server/run_local.sh | 5 +- 4 files changed, 173 insertions(+), 3 deletions(-) create mode 100644 src/envs/textarena_env/rewards.py diff --git a/src/envs/textarena_env/__init__.py b/src/envs/textarena_env/__init__.py index 61075679..49314f7f 100644 --- a/src/envs/textarena_env/__init__.py +++ b/src/envs/textarena_env/__init__.py @@ -13,6 +13,7 @@ TextArenaObservation, TextArenaState, ) +from .rewards import RewardProvider, build_reward_providers __all__ = [ "TextArenaEnv", @@ -20,5 +21,6 @@ "TextArenaObservation", "TextArenaState", "TextArenaMessage", + "RewardProvider", + "build_reward_providers", ] - diff --git a/src/envs/textarena_env/rewards.py b/src/envs/textarena_env/rewards.py new file mode 100644 index 00000000..964a57a6 --- /dev/null +++ b/src/envs/textarena_env/rewards.py @@ -0,0 +1,133 @@ +"""Reward provider utilities for TextArena environments.""" + +from __future__ import annotations + +import re +from typing import Dict, List, Protocol, Tuple + +from .models import TextArenaAction, TextArenaObservation + + +class RewardProvider(Protocol): + """Interface for computing auxiliary reward signals.""" + + def reset(self) -> None: + """Clear any internal state before a new episode.""" + + def compute( + self, *, action: TextArenaAction, observation: TextArenaObservation + ) -> Dict[str, float]: + """Return a mapping of reward names to float values for the step.""" + + +def build_reward_providers(env_id: str) -> List[RewardProvider]: + """Instantiate reward providers appropriate for the given environment.""" + + providers: List[RewardProvider] = [] + if env_id == "Wordle-v0": + providers.append(_WordleRewardProvider()) + return providers + + +_WORDLE_GUESS_PATTERN = re.compile(r"\[[A-Za-z]{5}\]") + + +def extract_guess(text: str) -> str: + """Normalize a Wordle guess string from arbitrary text.""" + + match = _WORDLE_GUESS_PATTERN.search(text) + if match: + return match.group(0).lower() + + cleaned = re.sub(r"[^a-z]", "", text.lower()) + if len(cleaned) >= 5: + return f"[{cleaned[:5]}]" + return "[dunno]" + + +def extract_wordle_feedback(observation: TextArenaObservation) -> str: + """Pull the latest feedback text from a Wordle observation.""" + + for message in reversed(observation.messages): + content = message.content.strip() + if "Feedback:" in content: + return content.split("Feedback:", 1)[-1].strip() + return "" + + +def extract_feedback_counts(feedback: str) -> Tuple[int, int]: + """Return counts of green (G) and yellow (Y) markers from feedback.""" + + if not feedback: + return (0, 0) + + segments = [ + segment.strip() for segment in feedback.split("\n\n") if segment.strip() + ] + if not segments: + return (0, 0) + + latest_segment = segments[-1] + lines = [line.strip() for line in latest_segment.splitlines() if line.strip()] + latest_line = lines[-1] if lines else latest_segment + + green_count = latest_line.count("G") + yellow_count = latest_line.count("Y") + return (green_count, yellow_count) + + +class _WordleRewardProvider: + """Reward provider that mirrors the GRPO Wordle heuristics.""" + + SIGNAL_MAP = { + "greens": "wordle.greens", + "yellows": "wordle.yellows", + "repetitions": "wordle.repetitions", + "correct": "wordle.correct", + } + + def __init__(self) -> None: + self._guess_history: Dict[str, int] = {} + + def reset(self) -> None: + self._guess_history.clear() + + def compute( + self, *, action: TextArenaAction, observation: TextArenaObservation + ) -> Dict[str, float]: + guess = extract_guess(action.message) + feedback = extract_wordle_feedback(observation) + + normalized_guess = guess if guess and guess != "[dunno]" else "" + previous_occurrences = ( + self._guess_history.get(normalized_guess, 0) if normalized_guess else 0 + ) + + green_score = 0.0 + yellow_score = 0.0 + if feedback: + green_count, yellow_count = extract_feedback_counts(feedback) + green_score = green_count / 5.0 + yellow_score = yellow_count / 5.0 + + repetition_score = 1.0 - previous_occurrences + correct_score = float(observation.reward or 0.0) + + if normalized_guess: + self._guess_history[normalized_guess] = previous_occurrences + 1 + + return { + self.SIGNAL_MAP["greens"]: float(green_score), + self.SIGNAL_MAP["yellows"]: float(yellow_score), + self.SIGNAL_MAP["repetitions"]: float(repetition_score), + self.SIGNAL_MAP["correct"]: float(correct_score), + } + + +__all__ = [ + "RewardProvider", + "build_reward_providers", + "extract_feedback_counts", + "extract_guess", + "extract_wordle_feedback", +] diff --git a/src/envs/textarena_env/server/environment.py b/src/envs/textarena_env/server/environment.py index a6aa5980..808a1465 100644 --- a/src/envs/textarena_env/server/environment.py +++ b/src/envs/textarena_env/server/environment.py @@ -17,6 +17,7 @@ from core.env_server.interfaces import Environment from ..models import TextArenaAction, TextArenaMessage, TextArenaObservation, TextArenaState +from ..rewards import RewardProvider, build_reward_providers _TEXTARENA_MODULE: Any | None = None @@ -84,18 +85,25 @@ def __init__( max_turns=max_turns, ) + self._reward_providers: List[RewardProvider] = build_reward_providers(env_id) + self._last_reward_signals: Dict[str, float] = {} + # ------------------------------------------------------------------ # Environment interface # ------------------------------------------------------------------ def reset(self) -> TextArenaObservation: self._ta_env.reset(num_players=self.num_players) + for provider in self._reward_providers: + provider.reset() + self._state.episode_id = str(uuid4()) self._state.step_count = 0 self._state.turn = 0 self._state.last_reward = 0.0 self._state.last_info = {} self._state.raw_state = self._snapshot_state() + self._last_reward_signals = {} observation = self._build_observation() observation.reward = 0.0 @@ -119,6 +127,14 @@ def step(self, action: TextArenaAction) -> TextArenaObservation: # type: ignore reward = self._extract_reward() observation.reward = reward self._state.last_reward = reward + + reward_signals = self._compute_reward_signals(action=action, observation=observation) + if reward_signals: + observation.info.setdefault("reward_signals", {}).update(reward_signals) + observation.metadata.setdefault("reward_signals", {}).update(reward_signals) + self._last_reward_signals = reward_signals + if reward_signals: + self._state.last_info = {**(self._state.last_info or {}), "reward_signals": reward_signals} self._state.raw_state = self._snapshot_state() return observation @@ -214,5 +230,23 @@ def _snapshot_state(self) -> Dict[str, Any]: "game_info": getattr(state, "game_info", {}), "step_info": getattr(state, "step_info", {}), } + if self._last_reward_signals: + snapshot["reward_signals"] = dict(self._last_reward_signals) return snapshot + def _compute_reward_signals( + self, *, action: TextArenaAction, observation: TextArenaObservation + ) -> Dict[str, float]: + if not self._reward_providers: + return {} + + aggregated: Dict[str, float] = {} + for provider in self._reward_providers: + try: + result = provider.compute(action=action, observation=observation) + except Exception: # pragma: no cover - defensive + continue + for key, value in result.items(): + aggregated[key] = float(value) + return aggregated + diff --git a/src/envs/textarena_env/server/run_local.sh b/src/envs/textarena_env/server/run_local.sh index fbc0bb0f..8efa35f0 100755 --- a/src/envs/textarena_env/server/run_local.sh +++ b/src/envs/textarena_env/server/run_local.sh @@ -1,6 +1,7 @@ -TEXTARENA_ENV_ID="Wordle-v0" TEXTARENA_NUM_PLAYERS=2 +export TEXTARENA_ENV_ID="Wordle-v0" +export TEXTARENA_NUM_PLAYERS=1 # Run the server -exec uvicorn envs.textarena_env.server.app:app --host 0.0.0.0 --port 8000 +exec uvicorn envs.textarena_env.server.app:app --host 0.0.0.0 --port 8001