-
Notifications
You must be signed in to change notification settings - Fork 88
[ENVIRONMENT] Connect4 Environment #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
3628bf7
308e310
964be28
b57b48d
00eb136
a7dabb0
ed2dc3a
5652e8a
c0f7833
4a6bf6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,19 +25,24 @@ | |
| sys.path.insert(0, str(Path(__file__).parent.parent / "src")) | ||
|
|
||
| from envs.atari_env import AtariEnv, AtariAction | ||
|
|
||
| # import envs | ||
| # print(envs.__path__) | ||
|
|
||
| def main(): | ||
| """Run a simple Atari episode.""" | ||
| # Connect to the Atari environment server | ||
| print("Connecting to Atari environment...") | ||
| env = AtariEnv.from_docker_image("ghcr.io/meta-pytorch/openenv-atari-env:latest") | ||
|
|
||
| # env = AtariEnv.from_docker_image("ghcr.io/meta-pytorch/openenv-atari-env:latest") | ||
| env = AtariEnv(base_url="http://localhost:8000") | ||
|
|
||
| try: | ||
| # Reset the environment | ||
| print("\nResetting environment...") | ||
| result = env.reset() | ||
| print(f"Screen shape: {result.observation.screen_shape}") | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. newlines? |
||
|
|
||
|
|
||
| print(f"Legal actions: {result.observation.legal_actions}") | ||
| print(f"Lives: {result.observation.lives}") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,130 @@ | ||||||
| import sys, os | ||||||
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
| import numpy as np | ||||||
| import matplotlib.pyplot as plt | ||||||
| from matplotlib.animation import FuncAnimation | ||||||
| from matplotlib.patches import Circle | ||||||
| from envs.connect4_env import Connect4Action, Connect4Env | ||||||
|
|
||||||
|
|
||||||
| def render_connect4_board(board, ax, player_colors={1: "red", 2: "yellow", -1: "yellow"}, show=True): | ||||||
|
||||||
| """ | ||||||
| Render a Connect 4 board using matplotlib. | ||||||
| Args: | ||||||
| board: 2D list, numpy array, or board object (6x7) with values: | ||||||
| 0 -> empty, 1 -> player 1, 2 -> player 2 (or -1 for player 2) | ||||||
| player_colors: dict mapping player numbers to colors. | ||||||
| show: If True, calls plt.show(). If False, returns the figure. | ||||||
| Returns: | ||||||
| The matplotlib figure and axis (if show=False). | ||||||
| """ | ||||||
| # Extract board data if it's an object with board attribute | ||||||
| if hasattr(board, 'board'): | ||||||
| b_map = np.array(board.board) | ||||||
| elif hasattr(board, '__array__'): | ||||||
| b_map = np.array(board) | ||||||
| else: | ||||||
| b_map = np.array(board) | ||||||
|
|
||||||
| # Handle different player value representations | ||||||
| # Some environments use 1 and 2, others use 1 and -1 | ||||||
| rows, cols = b_map.shape | ||||||
|
|
||||||
| ax.set_xlim(0, cols) | ||||||
| ax.set_ylim(0, rows) | ||||||
| ax.set_aspect("equal") | ||||||
| ax.axis("off") | ||||||
|
|
||||||
| # Draw the blue board background | ||||||
| rect = plt.Rectangle((0, 0), cols, rows, color="#0055FF", zorder=0) | ||||||
| ax.add_patch(rect) | ||||||
|
|
||||||
| # Draw circular holes | ||||||
| for r in range(rows): | ||||||
| for c in range(cols): | ||||||
| center = (c + 0.5, rows - 1 - r + 0.5) # Fixed: removed extra -1 | ||||||
| val = b_map[r, c] | ||||||
|
|
||||||
| # Handle different value representations | ||||||
| if val == 1: | ||||||
| color = player_colors[1] | ||||||
| elif val == 2 or val == -1: | ||||||
| color = player_colors.get(2, player_colors.get(-1, "yellow")) | ||||||
| else: | ||||||
| color = "white" | ||||||
|
|
||||||
| circ = Circle(center, 0.4, color=color, ec="black", lw=1.5) | ||||||
| ax.add_patch(circ) | ||||||
|
|
||||||
| plt.tight_layout() | ||||||
| if show: | ||||||
| plt.show() | ||||||
| else: | ||||||
| return ax | ||||||
|
|
||||||
|
|
||||||
| def main(render=True): | ||||||
| print("Connecting to Connect4 environment...") | ||||||
| env = Connect4Env(base_url="http://localhost:8000") | ||||||
|
|
||||||
| try: | ||||||
| print("\nResetting environment...") | ||||||
| result = env.reset() | ||||||
|
|
||||||
| frames = [] | ||||||
| rewards = [] | ||||||
| steps = [] | ||||||
|
|
||||||
| # Collect all frames | ||||||
| board = np.array(result.observation.board).reshape(6, 7) | ||||||
| frames.append(board.copy()) | ||||||
| rewards.append(result.reward or 0) | ||||||
| steps.append(0) | ||||||
|
|
||||||
| for step in range(100): | ||||||
| if result.done: | ||||||
| break | ||||||
|
|
||||||
| action_id = int(np.random.choice(result.observation.legal_actions)) | ||||||
| result = env.step(Connect4Action(action_id)) | ||||||
|
||||||
| result = env.step(Connect4Action(action_id)) | |
| result = env.step(Connect4Action(column=action_id)) |
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable ani is not used.
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,31 @@ | ||||||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||||
| # All rights reserved. | ||||||||
| # | ||||||||
| # This source code is licensed under the BSD-style license found in the | ||||||||
| # LICENSE file in the root directory of this source tree. | ||||||||
|
|
||||||||
| """ | ||||||||
| Connect4 Environment for OpenEnv. | ||||||||
| This module provides OpenEnv integration for Connect4 2600 games via the | ||||||||
| Arcade Learning Environment (ALE). | ||||||||
|
||||||||
| This module provides OpenEnv integration for Connect4 2600 games via the | |
| Arcade Learning Environment (ALE). | |
| This module provides OpenEnv integration for the classic Connect4 board game. |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,99 @@ | ||||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
| # All rights reserved. | ||||||
| # | ||||||
| # This source code is licensed under the BSD-style license found in the | ||||||
| # LICENSE file in the root directory of this source tree. | ||||||
|
|
||||||
| """ | ||||||
| Connect4 Environment HTTP Client. | ||||||
| This module provides the client for connecting to a Connect4 Environment server | ||||||
| over HTTP. | ||||||
| """ | ||||||
|
|
||||||
| from __future__ import annotations | ||||||
|
|
||||||
| from typing import Any, Dict, TYPE_CHECKING | ||||||
|
|
||||||
| from core.client_types import StepResult | ||||||
| from core.http_env_client import HTTPEnvClient | ||||||
|
|
||||||
| from .models import Connect4Action, Connect4Observation, Connect4State | ||||||
|
|
||||||
| if TYPE_CHECKING: | ||||||
| from core.containers.runtime import ContainerProvider | ||||||
|
||||||
| from core.containers.runtime import ContainerProvider |
Outdated
Copilot
AI
Oct 31, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example uses 'action_id=3' but Connect4Action expects 'column' as the parameter name according to the model definition. This should be 'Connect4Action(column=3)'.
| >>> result = client.step(Connect4Action(action_id=3)) | |
| >>> result = client.step(Connect4Action(column=3)) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """ | ||
| Data models for Connect4 Environment. | ||
|
|
||
| This module defines the Action, Observation, and State types for Connect4 games | ||
| via the OpenEnv interface. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
| from dataclasses import dataclass, field | ||
| import numpy as np | ||
| from typing import List | ||
|
|
||
| from core.env_server import Action, Observation, State | ||
|
|
||
|
|
||
| @dataclass | ||
| class Connect4Action(Action): | ||
| """ | ||
| Action for Connect4 environment. | ||
|
|
||
| Attributes: | ||
| column: The column index (0 to 6) where the piece will be placed. | ||
| """ | ||
| column: int | ||
|
|
||
|
|
||
| @dataclass(kw_only=True) | ||
| class Connect4Observation(Observation): | ||
| """ | ||
| Observation for Connect4 environment. | ||
|
|
||
| Attributes: | ||
| board: The current board as a 2D list (6 rows x 7 columns). | ||
| 1 = current player, -1 = opponent, 0 = empty. | ||
| legal_actions: List of column indices that are valid moves. | ||
| done: Whether the game is over. | ||
| reward: Reward for the last action. | ||
| """ | ||
|
|
||
| board: List[List[int]] | ||
| legal_actions: List[int] | ||
| done: bool = False | ||
| reward: float = 0.0 | ||
| metadata: dict = field(default_factory=dict) | ||
|
|
||
|
|
||
|
|
||
| @dataclass(kw_only=True) | ||
| class Connect4State(State): | ||
| """ | ||
| State for Connect4 environment. | ||
|
|
||
| Attributes: | ||
| episode_id: Unique ID for the current game. | ||
| board: Current board state (rows x columns), 0 = empty, 1 = player, -1 = opponent. | ||
| next_player: Whose turn it is (1 or -1). | ||
| step_count: Number of steps taken in the game. | ||
| """ | ||
| episode_id: str | ||
| board: List[List[int]] = field(default_factory=lambda: np.zeros((6,7), dtype=int).tolist()) | ||
| next_player: int = 1 | ||
| step_count: int = 0 |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,18 @@ | ||||||
| ARG BASE_IMAGE=openenv-base:latest | ||||||
| FROM ${BASE_IMAGE} | ||||||
|
|
||||||
| # Install any additional dependencies | ||||||
| RUN pip install --no-cache-dir \ | ||||||
| gymnasium>=0.29.0 \ | ||||||
| ale-py>=0.8.0 \ | ||||||
| numpy>=1.24.0 | ||||||
| # Copy environment code | ||||||
| COPY src/core/ /app/src/core/ | ||||||
| COPY src/envs/connect4_env/ /app/src/envs/connect4_env/ | ||||||
|
|
||||||
| # Health check | ||||||
| HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ | ||||||
| CMD curl -f http://localhost:8000/health || exit 1 | ||||||
|
|
||||||
| # Run server | ||||||
| CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] | ||||||
|
||||||
| CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] | |
| CMD ["uvicorn", "envs.connect4_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove debug code?