diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 8cba8b47..e5195938 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -79,6 +79,8 @@ jobs: dockerfile: src/envs/atari_env/server/Dockerfile - name: git-env dockerfile: src/envs/git_env/server/Dockerfile + - name: textarena-env + dockerfile: src/envs/textarena_env/server/Dockerfile steps: - name: Checkout code diff --git a/examples/textarena_simple.py b/examples/textarena_simple.py new file mode 100644 index 00000000..a65ef1ff --- /dev/null +++ b/examples/textarena_simple.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Quickstart example for the generic TextArena environment.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +# Add project src/ to import path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from envs.textarena_env import TextArenaEnv, TextArenaAction + + +def main() -> None: + + print("=" * 60) + print("๐Ÿ’ฌ TextArena Hello World - GuessTheNumber-v0") + print("=" * 60) + + env = TextArenaEnv.from_docker_image( + "textarena-env:latest", + env_vars={ + "TEXTARENA_ENV_ID": "GuessTheNumber-v0", + "TEXTARENA_NUM_PLAYERS": "1", + }, + ports={8000: 8000}, + ) + + try: + print("\n๐Ÿ“ Resetting environment...") + result = env.reset() + print(f" Prompt:\n{result.observation.prompt}\n") + + # Simple heuristic: if prompt mentions a range, start with midpoint + guess = "[10]" + + for step in range(5): + print(f"๐ŸŽฏ Step {step + 1}: sending guess {guess}") + result = env.step(TextArenaAction(message=guess)) + + for message in result.observation.messages: + print(f" [{message.category}] {message.content}") + + if result.done: + break + + # Basic update: look for 'higher' or 'lower' hints + feedback = " ".join(msg.content for msg in result.observation.messages) + if "higher" in feedback: + guess = "[15]" + elif "lower" in feedback: + guess = "[5]" + else: + guess = "[10]" + + print("\nโœ… Episode finished!") + print(f" Reward: {result.reward}") + print(f" Done: {result.done}") + + state = env.state() + print("\n๐Ÿ“Š Server State Snapshot:") + print(f" Episode ID: {state.episode_id}") + print(f" Step count: {state.step_count}") + print(f" Env ID: {state.env_id}") + + except Exception as exc: # pragma: no cover - demonstration script + print(f"\nโŒ Error: {exc}") + print("\nMake sure you have built the Docker image first:") + print(" docker build -f src/envs/textarena_env/server/Dockerfile -t textarena-env:latest .") + print("\nAlternatively run the server manually:") + print(" python -m envs.textarena_env.server.app") + + finally: + env.close() + print("\n๐Ÿ‘‹ Done!") + + +if __name__ == "__main__": + main() + diff --git a/examples/textarena_wordle_inference.py b/examples/textarena_wordle_inference.py new file mode 100644 index 00000000..9524a5ae --- /dev/null +++ b/examples/textarena_wordle_inference.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +"""Play TextArena Wordle with a hosted LLM via Hugging Face Inference Providers. + +This script mirrors the structure of the Kuhn Poker inference sample but targets +the Wordle environment. We deploy the generic TextArena server (wrapped in +OpenEnv) inside a local Docker container and query a single hosted model using +the OpenAI-compatible API provided by Hugging Face's router. + +Prerequisites +------------- +1. Build the TextArena Docker image:: + + docker build -f src/envs/textarena_env/server/Dockerfile -t textarena-env:latest . + +2. Set your Hugging Face token:: + + export HF_TOKEN=your_token_here + +3. Run this script:: + + python examples/wordle_inference.py + +By default we ask the DeepSeek Terminus model to play ``Wordle-v0``. Adjust the +``MODEL`` constant if you'd like to experiment with another provider-compatible +model. +""" + +from __future__ import annotations + +import os +import re +from typing import Iterable, List + +from openai import OpenAI + +from envs.textarena_env import TextArenaAction, TextArenaEnv +from envs.textarena_env.models import TextArenaMessage + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +API_BASE_URL = "https://router.huggingface.co/v1" +API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") + +MODEL = "openai/gpt-oss-120b:novita" +MAX_TURNS = 8 +VERBOSE = True + +SYSTEM_PROMPT = ( + "You are an expert Wordle solver." + " Always respond with a single guess inside square brackets, e.g. [crane]." + " Use lowercase letters, exactly one five-letter word per reply." + " Reason about prior feedback before choosing the next guess." + " Words must be 5 letters long and real English words." + " Do not not include any other text in your response." + " Do not repeat the same guess twice." +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def format_history(messages: Iterable[TextArenaMessage]) -> str: + """Convert TextArena message history into plain text for the model.""" + + lines: List[str] = [] + for message in messages: + tag = message.category or "MESSAGE" + lines.append(f"[{tag}] {message.content}") + return "\n".join(lines) + + +def extract_guess(text: str) -> str: + """Return the first Wordle-style guess enclosed in square brackets.""" + + match = re.search(r"\[[A-Za-z]{5}\]", text) + if match: + return match.group(0).lower() + # Fallback: remove whitespace and ensure lowercase, then wrap + cleaned = re.sub(r"[^a-zA-Z]", "", text).lower() + if len(cleaned) >= 5: + return f"[{cleaned[:5]}]" + return "[dunno]" + + +def make_user_prompt(prompt_text: str, messages: Iterable[TextArenaMessage]) -> str: + """Combine the TextArena prompt and feedback history for the model.""" + + history = format_history(messages) + return ( + f"Current prompt:\n{prompt_text}\n\n" + f"Conversation so far:\n{history}\n\n" + "Reply with your next guess enclosed in square brackets." + ) + + +# --------------------------------------------------------------------------- +# Gameplay +# --------------------------------------------------------------------------- + +def play_wordle(env: TextArenaEnv, client: OpenAI) -> None: + result = env.reset() + observation = result.observation + + if VERBOSE: + print("๐Ÿ“œ Initial Prompt:\n" + observation.prompt) + + for turn in range(1, MAX_TURNS + 1): + if result.done: + break + + user_prompt = make_user_prompt(observation.prompt, observation.messages) + + response = client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ], + max_tokens=2048, + temperature=0.7, + ) + + raw_output = response.choices[0].message.content.strip() + guess = extract_guess(raw_output) + + if VERBOSE: + print(f"\n๐ŸŽฏ Turn {turn}: model replied with -> {raw_output}") + print(f" Parsed guess: {guess}") + + result = env.step(TextArenaAction(message=guess)) + observation = result.observation + + if VERBOSE: + print(" Feedback messages:") + for message in observation.messages: + print(f" [{message.category}] {message.content}") + + print("\nโœ… Game finished") + print(f" Reward: {result.reward}") + print(f" Done: {result.done}") + + +# --------------------------------------------------------------------------- +# Entrypoint +# --------------------------------------------------------------------------- + +def main() -> None: + if not API_KEY: + raise SystemExit("HF_TOKEN (or API_KEY) must be set to query the model.") + + client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) + + env = TextArenaEnv.from_docker_image( + "textarena-env:latest", + env_vars={ + "TEXTARENA_ENV_ID": "Wordle-v0", + "TEXTARENA_NUM_PLAYERS": "1", + }, + ports={8000: 8000}, + ) + + try: + play_wordle(env, client) + finally: + env.close() + + +if __name__ == "__main__": + main() + + diff --git a/src/core/containers/runtime/providers.py b/src/core/containers/runtime/providers.py index 637b3be5..a8022ddc 100644 --- a/src/core/containers/runtime/providers.py +++ b/src/core/containers/runtime/providers.py @@ -169,8 +169,12 @@ def start_container( cmd.append(image) # Run container - result = subprocess.run(cmd, capture_output=True, text=True, check=True) - self._container_id = result.stdout.strip() + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + self._container_id = result.stdout.strip() + except subprocess.CalledProcessError as e: + error_msg = f"Failed to start Docker container.\nCommand: {' '.join(cmd)}\nExit code: {e.returncode}\nStderr: {e.stderr}\nStdout: {e.stdout}" + raise RuntimeError(error_msg) from e # Wait a moment for container to start time.sleep(1) diff --git a/src/envs/textarena_env/README.md b/src/envs/textarena_env/README.md new file mode 100644 index 00000000..819a0c8c --- /dev/null +++ b/src/envs/textarena_env/README.md @@ -0,0 +1,46 @@ +# TextArena Environment + +Generic wrapper for any [TextArena](https://www.textarena.ai/docs/overview) game inside OpenEnv. This module exposes the TextArena `Env` interface through the standard HTTP server/client APIs used by other OpenEnv environments, enabling quick experimentation with the full suite of word, reasoning, and multi-agent games. + +## Features +- Works with any registered TextArena game (e.g. `Wordle-v0`, `GuessTheNumber-v0`, `Chess-v0`, ...). +- Transparent access to TextArena message streams, rewards, and state snapshots. +- Docker image for easy deployment with Pythonย 3.11 and preinstalled dependencies. +- Example client demonstrating end-to-end interaction. + +## Docker + +Build the container from the project root: + +```bash +docker build -f src/envs/textarena_env/server/Dockerfile -t textarena-env:latest . +``` + +Run it with your desired game (default is `Wordle-v0`). Environment configuration is handled via env vars: + +```bash +docker run -p 8000:8000 \ + -e TEXTARENA_ENV_ID=GuessTheNumber-v0 \ + -e TEXTARENA_NUM_PLAYERS=1 \ + textarena-env:latest +``` + +Additional environment arguments can be passed using the `TEXTARENA_KW_` prefix. For example, to enable `hardcore=True`: + +```bash +docker run -p 8000:8000 \ + -e TEXTARENA_ENV_ID=Wordle-v0 \ + -e TEXTARENA_KW_hardcore=true \ + textarena-env:latest +``` + +## Python Example + +The repository ships with a simple client script that connects to a running server (local or Docker) and plays a few turns. Run it from the repo root: + +```bash +python examples/textarena_simple.py +``` + +The script uses `TextArenaEnv.from_docker_image` to automatically build/run the container if needed. Review the source (`examples/textarena_simple.py`) for more details and to customize the gameplay loop. + diff --git a/src/envs/textarena_env/__init__.py b/src/envs/textarena_env/__init__.py new file mode 100644 index 00000000..49314f7f --- /dev/null +++ b/src/envs/textarena_env/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""TextArena environment integration for OpenEnv.""" + +from .client import TextArenaEnv +from .models import ( + TextArenaAction, + TextArenaMessage, + TextArenaObservation, + TextArenaState, +) +from .rewards import RewardProvider, build_reward_providers + +__all__ = [ + "TextArenaEnv", + "TextArenaAction", + "TextArenaObservation", + "TextArenaState", + "TextArenaMessage", + "RewardProvider", + "build_reward_providers", +] diff --git a/src/envs/textarena_env/client.py b/src/envs/textarena_env/client.py new file mode 100644 index 00000000..9f464206 --- /dev/null +++ b/src/envs/textarena_env/client.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""HTTP client for the generic TextArena environment.""" + +from __future__ import annotations + +from typing import Any, Dict, TYPE_CHECKING + +from core.client_types import StepResult +from core.http_env_client import HTTPEnvClient + +from .models import ( + TextArenaAction, + TextArenaMessage, + TextArenaObservation, + TextArenaState, +) + +if TYPE_CHECKING: + from core.containers.runtime import ContainerProvider + + +class TextArenaEnv(HTTPEnvClient[TextArenaAction, TextArenaObservation]): + """HTTP client for the TextArena environment server.""" + + def _step_payload(self, action: TextArenaAction) -> Dict[str, Any]: + return {"message": action.message} + + def _parse_result( + self, payload: Dict[str, Any] + ) -> StepResult[TextArenaObservation]: + obs_data = payload.get("observation", {}) + messages_payload = obs_data.get("messages", []) + messages = [ + TextArenaMessage( + sender_id=item.get("sender_id", -1), + content=item.get("content", ""), + category=item.get("category", "MESSAGE"), + ) + for item in messages_payload + if isinstance(item, dict) + ] + + observation = TextArenaObservation( + prompt=obs_data.get("prompt", ""), + messages=messages, + current_player_id=obs_data.get("current_player_id", 0), + legal_players=obs_data.get("legal_players", []), + info=obs_data.get("info", {}), + reward=payload.get("reward"), + done=payload.get("done", False), + metadata=obs_data.get("metadata", {}), + ) + return StepResult( + observation=observation, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict[str, Any]) -> TextArenaState: + return TextArenaState( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + env_id=payload.get("env_id", "unknown"), + num_players=payload.get("num_players", 1), + max_turns=payload.get("max_turns"), + turn=payload.get("turn", 0), + last_reward=payload.get("last_reward", 0.0), + last_info=payload.get("last_info", {}), + raw_state=payload.get("raw_state", {}), + ) + diff --git a/src/envs/textarena_env/models.py b/src/envs/textarena_env/models.py new file mode 100644 index 00000000..4fea2c17 --- /dev/null +++ b/src/envs/textarena_env/models.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Common data models for the TextArena environment wrapper.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from core.env_server.types import Action, Observation, State + + +@dataclass +class TextArenaMessage: + """Single message observed by a player.""" + + sender_id: int + content: str + category: str + + +@dataclass(kw_only=True) +class TextArenaAction(Action): + """Action issued by the agent for TextArena games.""" + + message: str + + +@dataclass(kw_only=True) +class TextArenaObservation(Observation): + """Observation returned from any TextArena game.""" + + prompt: str + messages: List[TextArenaMessage] = field(default_factory=list) + current_player_id: int = 0 + legal_players: List[int] = field(default_factory=list) + info: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(kw_only=True) +class TextArenaState(State): + """Structured state snapshot for the server.""" + + env_id: str + num_players: int + max_turns: Optional[int] = None + turn: int = 0 + last_reward: float = 0.0 + last_info: Dict[str, Any] = field(default_factory=dict) + raw_state: Dict[str, Any] = field(default_factory=dict) + diff --git a/src/envs/textarena_env/rewards.py b/src/envs/textarena_env/rewards.py new file mode 100644 index 00000000..964a57a6 --- /dev/null +++ b/src/envs/textarena_env/rewards.py @@ -0,0 +1,133 @@ +"""Reward provider utilities for TextArena environments.""" + +from __future__ import annotations + +import re +from typing import Dict, List, Protocol, Tuple + +from .models import TextArenaAction, TextArenaObservation + + +class RewardProvider(Protocol): + """Interface for computing auxiliary reward signals.""" + + def reset(self) -> None: + """Clear any internal state before a new episode.""" + + def compute( + self, *, action: TextArenaAction, observation: TextArenaObservation + ) -> Dict[str, float]: + """Return a mapping of reward names to float values for the step.""" + + +def build_reward_providers(env_id: str) -> List[RewardProvider]: + """Instantiate reward providers appropriate for the given environment.""" + + providers: List[RewardProvider] = [] + if env_id == "Wordle-v0": + providers.append(_WordleRewardProvider()) + return providers + + +_WORDLE_GUESS_PATTERN = re.compile(r"\[[A-Za-z]{5}\]") + + +def extract_guess(text: str) -> str: + """Normalize a Wordle guess string from arbitrary text.""" + + match = _WORDLE_GUESS_PATTERN.search(text) + if match: + return match.group(0).lower() + + cleaned = re.sub(r"[^a-z]", "", text.lower()) + if len(cleaned) >= 5: + return f"[{cleaned[:5]}]" + return "[dunno]" + + +def extract_wordle_feedback(observation: TextArenaObservation) -> str: + """Pull the latest feedback text from a Wordle observation.""" + + for message in reversed(observation.messages): + content = message.content.strip() + if "Feedback:" in content: + return content.split("Feedback:", 1)[-1].strip() + return "" + + +def extract_feedback_counts(feedback: str) -> Tuple[int, int]: + """Return counts of green (G) and yellow (Y) markers from feedback.""" + + if not feedback: + return (0, 0) + + segments = [ + segment.strip() for segment in feedback.split("\n\n") if segment.strip() + ] + if not segments: + return (0, 0) + + latest_segment = segments[-1] + lines = [line.strip() for line in latest_segment.splitlines() if line.strip()] + latest_line = lines[-1] if lines else latest_segment + + green_count = latest_line.count("G") + yellow_count = latest_line.count("Y") + return (green_count, yellow_count) + + +class _WordleRewardProvider: + """Reward provider that mirrors the GRPO Wordle heuristics.""" + + SIGNAL_MAP = { + "greens": "wordle.greens", + "yellows": "wordle.yellows", + "repetitions": "wordle.repetitions", + "correct": "wordle.correct", + } + + def __init__(self) -> None: + self._guess_history: Dict[str, int] = {} + + def reset(self) -> None: + self._guess_history.clear() + + def compute( + self, *, action: TextArenaAction, observation: TextArenaObservation + ) -> Dict[str, float]: + guess = extract_guess(action.message) + feedback = extract_wordle_feedback(observation) + + normalized_guess = guess if guess and guess != "[dunno]" else "" + previous_occurrences = ( + self._guess_history.get(normalized_guess, 0) if normalized_guess else 0 + ) + + green_score = 0.0 + yellow_score = 0.0 + if feedback: + green_count, yellow_count = extract_feedback_counts(feedback) + green_score = green_count / 5.0 + yellow_score = yellow_count / 5.0 + + repetition_score = 1.0 - previous_occurrences + correct_score = float(observation.reward or 0.0) + + if normalized_guess: + self._guess_history[normalized_guess] = previous_occurrences + 1 + + return { + self.SIGNAL_MAP["greens"]: float(green_score), + self.SIGNAL_MAP["yellows"]: float(yellow_score), + self.SIGNAL_MAP["repetitions"]: float(repetition_score), + self.SIGNAL_MAP["correct"]: float(correct_score), + } + + +__all__ = [ + "RewardProvider", + "build_reward_providers", + "extract_feedback_counts", + "extract_guess", + "extract_wordle_feedback", +] diff --git a/src/envs/textarena_env/server/Dockerfile b/src/envs/textarena_env/server/Dockerfile new file mode 100644 index 00000000..5df60823 --- /dev/null +++ b/src/envs/textarena_env/server/Dockerfile @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Use the shared OpenEnv base image (Python 3.11) +ARG BASE_IMAGE=openenv-base:latest +FROM ${BASE_IMAGE} + +# Install system libraries required by TextArena (cv2 needs libGL, glib) +RUN apt-get update && apt-get install -y --no-install-recommends \ + libgl1 \ + libglib2.0-0 \ + && rm -rf /var/lib/apt/lists/* + +# Install TextArena and Python dependencies +RUN pip install --no-cache-dir \ + textarena==0.6.1 \ + nltk==3.9.2 + +# Copy OpenEnv core and TextArena environment sources +COPY src/core/ /app/src/core/ +COPY src/envs/textarena_env/ /app/src/envs/textarena_env/ + +# Optional: health check to ensure server responsiveness +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Run the TextArena FastAPI server +CMD ["uvicorn", "envs.textarena_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] + diff --git a/src/envs/textarena_env/server/__init__.py b/src/envs/textarena_env/server/__init__.py new file mode 100644 index 00000000..22d17ab5 --- /dev/null +++ b/src/envs/textarena_env/server/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Server components for the generic TextArena environment.""" + +from .environment import TextArenaEnvironment + +__all__ = ["TextArenaEnvironment"] + diff --git a/src/envs/textarena_env/server/app.py b/src/envs/textarena_env/server/app.py new file mode 100644 index 00000000..59dea784 --- /dev/null +++ b/src/envs/textarena_env/server/app.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""FastAPI application entrypoint for the TextArena environment.""" + +from __future__ import annotations + +import os + +from core.env_server.http_server import create_app + +from ..models import TextArenaAction, TextArenaObservation +from .environment import TextArenaEnvironment + + +def _parse_env_kwargs(prefix: str = "TEXTARENA_KW_") -> dict[str, str]: + """Collect arbitrary environment kwargs from the process environment.""" + + env_kwargs: dict[str, str] = {} + for key, value in os.environ.items(): + if key.startswith(prefix): + env_key = key[len(prefix) :].lower() + env_kwargs[env_key] = value + return env_kwargs + + +env_id = os.getenv("TEXTARENA_ENV_ID", "Wordle-v0") +num_players = int(os.getenv("TEXTARENA_NUM_PLAYERS", "1")) +max_turns_env = os.getenv("TEXTARENA_MAX_TURNS") +max_turns = int(max_turns_env) if max_turns_env is not None else None +download_nltk = os.getenv("TEXTARENA_DOWNLOAD_NLTK", "1") in {"1", "true", "True"} + +extra_kwargs = _parse_env_kwargs() + +environment = TextArenaEnvironment( + env_id=env_id, + num_players=num_players, + max_turns=max_turns, + download_nltk=download_nltk, + env_kwargs=extra_kwargs, +) + +app = create_app(environment, TextArenaAction, TextArenaObservation, env_name="textarena_env") + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) + diff --git a/src/envs/textarena_env/server/environment.py b/src/envs/textarena_env/server/environment.py new file mode 100644 index 00000000..808a1465 --- /dev/null +++ b/src/envs/textarena_env/server/environment.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Server implementation for the generic TextArena environment.""" + +from __future__ import annotations + +import sys +from typing import Any, Dict, Iterable, List, Optional +from uuid import uuid4 + +import nltk + +from core.env_server.interfaces import Environment + +from ..models import TextArenaAction, TextArenaMessage, TextArenaObservation, TextArenaState +from ..rewards import RewardProvider, build_reward_providers + + +_TEXTARENA_MODULE: Any | None = None +_TEXTARENA_IMPORT_ERROR: Exception | None = None + + +def _import_textarena() -> Any: + """Import ``textarena`` lazily and cache the module reference.""" + + global _TEXTARENA_MODULE, _TEXTARENA_IMPORT_ERROR + + if _TEXTARENA_MODULE is not None: + return _TEXTARENA_MODULE + + if _TEXTARENA_IMPORT_ERROR is not None: + raise _TEXTARENA_IMPORT_ERROR + + if sys.version_info < (3, 10): + _TEXTARENA_IMPORT_ERROR = RuntimeError( + "TextArena environments require Python 3.10 or newer; " + f"current interpreter is {sys.version_info.major}.{sys.version_info.minor}" + ) + raise _TEXTARENA_IMPORT_ERROR + + try: + import textarena as ta # type: ignore[import] + except Exception as exc: # pragma: no cover - surfaced to caller + _TEXTARENA_IMPORT_ERROR = exc + raise + + _TEXTARENA_MODULE = ta + return ta + + +class TextArenaEnvironment(Environment): + """Wrap any TextArena game behind the OpenEnv ``Environment`` API.""" + + def __init__( + self, + env_id: str = "Wordle-v0", + *, + num_players: int = 1, + max_turns: Optional[int] = None, + download_nltk: bool = True, + env_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + + ta = _import_textarena() + + if download_nltk: + nltk.download("words", quiet=True) + nltk.download("averaged_perceptron_tagger_eng", quiet=True) + + self.env_id = env_id + self.num_players = num_players + self.max_turns = max_turns + self._env_kwargs = env_kwargs or {} + + self._ta_env = ta.make(env_id=env_id, **self._env_kwargs) + + self._state = TextArenaState( + env_id=env_id, + num_players=num_players, + max_turns=max_turns, + ) + + self._reward_providers: List[RewardProvider] = build_reward_providers(env_id) + self._last_reward_signals: Dict[str, float] = {} + + # ------------------------------------------------------------------ + # Environment interface + # ------------------------------------------------------------------ + def reset(self) -> TextArenaObservation: + self._ta_env.reset(num_players=self.num_players) + + for provider in self._reward_providers: + provider.reset() + + self._state.episode_id = str(uuid4()) + self._state.step_count = 0 + self._state.turn = 0 + self._state.last_reward = 0.0 + self._state.last_info = {} + self._state.raw_state = self._snapshot_state() + self._last_reward_signals = {} + + observation = self._build_observation() + observation.reward = 0.0 + observation.done = False + + return observation + + def step(self, action: TextArenaAction) -> TextArenaObservation: # type: ignore[override] + if not isinstance(action, TextArenaAction): + raise TypeError(f"Expected TextArenaAction, received {type(action)!r}") + + done, info = self._ta_env.step(action.message) + + self._state.step_count += 1 + self._state.turn = getattr(self._ta_env.state, "turn", self._state.turn + 1) + self._state.last_info = info or {} + + observation = self._build_observation() + observation.done = done + + reward = self._extract_reward() + observation.reward = reward + self._state.last_reward = reward + + reward_signals = self._compute_reward_signals(action=action, observation=observation) + if reward_signals: + observation.info.setdefault("reward_signals", {}).update(reward_signals) + observation.metadata.setdefault("reward_signals", {}).update(reward_signals) + self._last_reward_signals = reward_signals + if reward_signals: + self._state.last_info = {**(self._state.last_info or {}), "reward_signals": reward_signals} + self._state.raw_state = self._snapshot_state() + + return observation + + @property + def state(self) -> TextArenaState: + return self._state + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _build_observation(self) -> TextArenaObservation: + player_id, messages = self._ta_env.get_observation() + + ta_messages = self._convert_messages(messages) + prompt_lines = [msg.content for msg in ta_messages if msg.category == "PROMPT"] + if not prompt_lines: + # Fallback to most recent message history for prompt + prompt_lines = [msg.content for msg in ta_messages] + + info: Dict[str, Any] = {} + info.update(getattr(self._ta_env.state, "step_info", {})) + + observation = TextArenaObservation( + prompt="\n".join(prompt_lines).strip(), + messages=ta_messages, + current_player_id=player_id, + legal_players=self._legal_players(), + info=info, + metadata={ + "env_id": self.env_id, + "turn": getattr(self._ta_env.state, "turn", 0), + "raw_messages": [ + { + "sender_id": msg.sender_id, + "content": msg.content, + "category": msg.category, + } + for msg in ta_messages + ], + }, + ) + + return observation + + def _legal_players(self) -> List[int]: + role_mapping = getattr(self._ta_env.state, "role_mapping", {}) or {} + players = [pid for pid in role_mapping.keys() if isinstance(pid, int) and pid >= 0] + return sorted(players) + + def _convert_messages(self, messages: Iterable[Any]) -> List[TextArenaMessage]: + converted: List[TextArenaMessage] = [] + for entry in messages: + if isinstance(entry, tuple) and len(entry) == 3: + sender, content, category = entry + elif isinstance(entry, tuple) and len(entry) == 2: + sender, content = entry + category = "MESSAGE" + else: + sender, content, category = -1, str(entry), "MESSAGE" + + category_name = getattr(category, "name", str(category)) + converted.append( + TextArenaMessage( + sender_id=int(sender) if isinstance(sender, (int, float)) else -1, + content=str(content), + category=category_name, + ) + ) + + return converted + + def _extract_reward(self) -> float: + rewards = getattr(self._ta_env.state, "rewards", None) + if isinstance(rewards, dict): + # Use current player reward if available, otherwise default to player 0. + player_id = getattr(self._ta_env.state, "current_player_id", 0) + if player_id in rewards: + return float(rewards[player_id]) + if 0 in rewards: + return float(rewards[0]) + return 0.0 + + def _snapshot_state(self) -> Dict[str, Any]: + state = self._ta_env.state + snapshot: Dict[str, Any] = { + "turn": getattr(state, "turn", 0), + "game_state": getattr(state, "game_state", {}), + "logs": list(getattr(state, "logs", [])), + "rewards": getattr(state, "rewards", None), + "done": getattr(state, "done", False), + "role_mapping": getattr(state, "role_mapping", {}), + "game_info": getattr(state, "game_info", {}), + "step_info": getattr(state, "step_info", {}), + } + if self._last_reward_signals: + snapshot["reward_signals"] = dict(self._last_reward_signals) + return snapshot + + def _compute_reward_signals( + self, *, action: TextArenaAction, observation: TextArenaObservation + ) -> Dict[str, float]: + if not self._reward_providers: + return {} + + aggregated: Dict[str, float] = {} + for provider in self._reward_providers: + try: + result = provider.compute(action=action, observation=observation) + except Exception: # pragma: no cover - defensive + continue + for key, value in result.items(): + aggregated[key] = float(value) + return aggregated + diff --git a/src/envs/textarena_env/server/run_local.sh b/src/envs/textarena_env/server/run_local.sh new file mode 100755 index 00000000..8efa35f0 --- /dev/null +++ b/src/envs/textarena_env/server/run_local.sh @@ -0,0 +1,7 @@ +export TEXTARENA_ENV_ID="Wordle-v0" +export TEXTARENA_NUM_PLAYERS=1 + +# Run the server +exec uvicorn envs.textarena_env.server.app:app --host 0.0.0.0 --port 8001 + +