Skip to content

Commit fca4d4a

Browse files
[ENVIRONMENT] Connect4 Environment (#101)
* Configuring the connect4_env using the atari template * fixing laging behaviour of environment * Adding rendering logic to the connect 4 environment * Making changes as requested * fixing atari example * addin unittest for connect4_env --------- Co-authored-by: Davide Testuggine <[email protected]>
1 parent e7e1928 commit fca4d4a

File tree

13 files changed

+865
-18
lines changed

13 files changed

+865
-18
lines changed

.github/workflows/docker-build.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ jobs:
7979
dockerfile: src/envs/atari_env/server/Dockerfile
8080
- name: git-env
8181
dockerfile: src/envs/git_env/server/Dockerfile
82+
- name: my-env # Add your environment here
83+
dockerfile: src/envs/connect4_env/server/Dockerfile
8284
- name: textarena-env
8385
dockerfile: src/envs/textarena_env/server/Dockerfile
8486

examples/OpenEnv_Tutorial.ipynb

Lines changed: 256 additions & 16 deletions
Large diffs are not rendered by default.

examples/atari_simple.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,24 @@
2525
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
2626

2727
from envs.atari_env import AtariEnv, AtariAction
28-
28+
# import envs
29+
# print(envs.__path__)
2930

3031
def main():
3132
"""Run a simple Atari episode."""
3233
# Connect to the Atari environment server
3334
print("Connecting to Atari environment...")
3435
env = AtariEnv.from_docker_image("ghcr.io/meta-pytorch/openenv-atari-env:latest")
35-
36+
37+
3638
try:
3739
# Reset the environment
3840
print("\nResetting environment...")
3941
result = env.reset()
4042
print(f"Screen shape: {result.observation.screen_shape}")
43+
44+
45+
4146
print(f"Legal actions: {result.observation.legal_actions}")
4247
print(f"Lives: {result.observation.lives}")
4348

examples/connect4.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import sys, os
2+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
3+
4+
5+
6+
import numpy as np
7+
import matplotlib.pyplot as plt
8+
from matplotlib.animation import FuncAnimation
9+
from matplotlib.patches import Circle
10+
from envs.connect4_env import Connect4Action, Connect4Env
11+
12+
13+
def render_connect4_board(board, ax, player_colors={1: "red", 2: "yellow", -1: "yellow"}, show=True):
14+
"""
15+
Render a Connect 4 board using matplotlib.
16+
17+
Args:
18+
board: 2D list, numpy array, or board object (6x7) with values:
19+
0 -> empty, 1 -> player 1, 2 -> player 2 (or -1 for player 2)
20+
player_colors: dict mapping player numbers to colors.
21+
show: If True, calls plt.show(). If False, returns the figure.
22+
23+
Returns:
24+
The matplotlib figure and axis (if show=False).
25+
"""
26+
# Extract board data if it's an object with board attribute
27+
if hasattr(board, 'board'):
28+
b_map = np.array(board.board)
29+
elif hasattr(board, '__array__'):
30+
b_map = np.array(board)
31+
else:
32+
b_map = np.array(board)
33+
34+
# Handle different player value representations
35+
# Some environments use 1 and 2, others use 1 and -1
36+
rows, cols = b_map.shape
37+
38+
ax.set_xlim(0, cols)
39+
ax.set_ylim(0, rows)
40+
ax.set_aspect("equal")
41+
ax.axis("off")
42+
43+
# Draw the blue board background
44+
rect = plt.Rectangle((0, 0), cols, rows, color="#0055FF", zorder=0)
45+
ax.add_patch(rect)
46+
47+
# Draw circular holes
48+
for r in range(rows):
49+
for c in range(cols):
50+
center = (c + 0.5, rows - 1 - r + 0.5) # Fixed: removed extra -1
51+
val = b_map[r, c]
52+
53+
# Handle different value representations
54+
if val == 1:
55+
color = player_colors[1]
56+
elif val == 2 or val == -1:
57+
color = player_colors.get(2, player_colors.get(-1, "yellow"))
58+
else:
59+
color = "white"
60+
61+
circ = Circle(center, 0.4, color=color, ec="black", lw=1.5)
62+
ax.add_patch(circ)
63+
64+
plt.tight_layout()
65+
if show:
66+
plt.show()
67+
else:
68+
return ax
69+
70+
71+
def main(render=True):
72+
print("Connecting to Connect4 environment...")
73+
env = Connect4Env(base_url="http://localhost:8000")
74+
75+
try:
76+
print("\nResetting environment...")
77+
result = env.reset()
78+
79+
frames = []
80+
rewards = []
81+
steps = []
82+
83+
# Collect all frames
84+
board = np.array(result.observation.board).reshape(6, 7)
85+
frames.append(board.copy())
86+
rewards.append(result.reward or 0)
87+
steps.append(0)
88+
89+
for step in range(100):
90+
if result.done:
91+
break
92+
93+
action_id = int(np.random.choice(result.observation.legal_actions))
94+
result = env.step(Connect4Action(column=action_id))
95+
96+
board = np.array(result.observation.board).reshape(6, 7)
97+
frames.append(board.copy())
98+
rewards.append(result.reward or 0)
99+
steps.append(step + 1)
100+
101+
if result.done:
102+
print(f"Game finished at step {step + 1} with reward {result.reward}")
103+
break
104+
105+
if render:
106+
# Create a single figure and update it
107+
fig, ax = plt.subplots(figsize=(7, 6))
108+
109+
def animate_frame(i):
110+
ax.clear()
111+
# Use the render function but don't show immediately
112+
render_connect4_board(frames[i], ax=ax, show=False)
113+
ax.set_title(f"Step: {steps[i]}, Reward: {rewards[i]:.2f}\nTotal: {sum(rewards[:i+1]):.2f}",
114+
fontsize=12, pad=20)
115+
return ax.patches
116+
117+
# Create animation
118+
ani = FuncAnimation(fig, animate_frame, frames=len(frames),
119+
interval=700, repeat=False, blit=False)
120+
121+
plt.tight_layout()
122+
plt.show(block=True)
123+
124+
finally:
125+
env.close()
126+
print("Environment closed.")
127+
128+
129+
if __name__ == "__main__":
130+
main(render=True)

src/envs/connect4_env/README.md

Whitespace-only changes.

src/envs/connect4_env/__init__.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Connect4 Environment for OpenEnv.
9+
10+
This module provides OpenEnv integration for the classic Connect4 board game.
11+
12+
Example:
13+
>>> from envs.Connect4_env import Connect4Env, Connect4Action
14+
>>>
15+
>>> # Connect to a running server or start via Docker
16+
>>> env = Connect4Env.from_docker_image("Connect4-env:latest")
17+
>>>
18+
>>> # Reset and interact
19+
>>> result = env.reset()
20+
>>> result = env.step(Connect4Action(column=2))
21+
>>> print(result.reward, result.done)
22+
>>>
23+
>>> # Cleanup
24+
>>> env.close()
25+
"""
26+
27+
from .client import Connect4Env
28+
from .models import Connect4Action, Connect4Observation, Connect4State
29+
30+
__all__ = ["Connect4Env", "Connect4Action", "Connect4Observation", "Connect4State"]

src/envs/connect4_env/client.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Connect4 Environment HTTP Client.
9+
10+
This module provides the client for connecting to a Connect4 Environment server
11+
over HTTP.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
from typing import Any, Dict, TYPE_CHECKING
17+
18+
from core.client_types import StepResult
19+
from core.http_env_client import HTTPEnvClient
20+
21+
from .models import Connect4Action, Connect4Observation, Connect4State
22+
23+
if TYPE_CHECKING:
24+
from core.containers.runtime import ContainerProvider
25+
26+
27+
class Connect4Env(HTTPEnvClient[Connect4Action, Connect4Observation]):
28+
"""
29+
HTTP client for Connect4 Environment.
30+
31+
This client connects to a Connect4Environment HTTP server and provides
32+
methods to interact with it: reset(), step(), and state access.
33+
34+
Example:
35+
>>> client = Connect4Env(base_url="http://localhost:8000")
36+
>>> result = client.reset()
37+
>>> print(result.observation.board)
38+
>>>
39+
>>> # Take an action
40+
>>> result = client.step(Connect4Action(column=3))
41+
>>> print(result.reward, result.done)
42+
"""
43+
44+
def _step_payload(self, action: Connect4Action) -> Dict[str, Any]:
45+
"""
46+
Convert Connect4Action to JSON payload for step request.
47+
48+
Args:
49+
action: Connect4Action instance.
50+
51+
Returns:
52+
Dictionary representation suitable for JSON encoding.
53+
"""
54+
return {
55+
"column": action.column, # column index to drop piece
56+
}
57+
58+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Connect4Observation]:
59+
"""
60+
Parse server response into StepResult[Connect4Observation].
61+
62+
Args:
63+
payload: JSON response from server.
64+
65+
Returns:
66+
StepResult with Connect4Observation.
67+
"""
68+
obs_data = payload.get("observation", {})
69+
70+
observation = Connect4Observation(
71+
board=obs_data.get("board", [[0]*7 for _ in range(6)]),
72+
legal_actions=obs_data.get("legal_actions", []),
73+
done=payload.get("done", False),
74+
reward=payload.get("reward", 0.0),
75+
metadata=obs_data.get("metadata", {}),
76+
)
77+
78+
return StepResult(
79+
observation=observation,
80+
reward=payload.get("reward", 0.0),
81+
done=payload.get("done", False),
82+
)
83+
84+
def _parse_state(self, payload: Dict[str, Any]) -> Connect4State:
85+
"""
86+
Parse server response into Connect4State object.
87+
88+
Args:
89+
payload: JSON response from /state endpoint.
90+
91+
Returns:
92+
Connect4State object with environment state information.
93+
"""
94+
return Connect4State(
95+
episode_id=payload.get("episode_id", ""),
96+
board=payload.get("board", [[0]*7 for _ in range(6)]),
97+
next_player=payload.get("next_player", 1),
98+
step_count=payload.get("step_count", 0),
99+
)

src/envs/connect4_env/models.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Data models for Connect4 Environment.
9+
10+
This module defines the Action, Observation, and State types for Connect4 games
11+
via the OpenEnv interface.
12+
"""
13+
14+
from __future__ import annotations
15+
from dataclasses import dataclass, field
16+
import numpy as np
17+
from typing import List
18+
19+
from core.env_server import Action, Observation, State
20+
21+
22+
@dataclass
23+
class Connect4Action(Action):
24+
"""
25+
Action for Connect4 environment.
26+
27+
Attributes:
28+
column: The column index (0 to 6) where the piece will be placed.
29+
"""
30+
column: int
31+
32+
33+
@dataclass(kw_only=True)
34+
class Connect4Observation(Observation):
35+
"""
36+
Observation for Connect4 environment.
37+
38+
Attributes:
39+
board: The current board as a 2D list (6 rows x 7 columns).
40+
1 = current player, -1 = opponent, 0 = empty.
41+
legal_actions: List of column indices that are valid moves.
42+
done: Whether the game is over.
43+
reward: Reward for the last action.
44+
"""
45+
46+
board: List[List[int]]
47+
legal_actions: List[int]
48+
done: bool = False
49+
reward: float = 0.0
50+
metadata: dict = field(default_factory=dict)
51+
52+
53+
54+
@dataclass(kw_only=True)
55+
class Connect4State(State):
56+
"""
57+
State for Connect4 environment.
58+
59+
Attributes:
60+
episode_id: Unique ID for the current game.
61+
board: Current board state (rows x columns), 0 = empty, 1 = player, -1 = opponent.
62+
next_player: Whose turn it is (1 or -1).
63+
step_count: Number of steps taken in the game.
64+
"""
65+
episode_id: str
66+
board: List[List[int]] = field(default_factory=lambda: np.zeros((6,7), dtype=int).tolist())
67+
next_player: int = 1
68+
step_count: int = 0
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
ARG BASE_IMAGE=openenv-base:latest
2+
FROM ${BASE_IMAGE}
3+
4+
# Install any additional dependencies
5+
RUN pip install --no-cache-dir \
6+
gymnasium>=0.29.0 \
7+
ale-py>=0.8.0 \
8+
numpy>=1.24.0
9+
# Copy environment code
10+
COPY src/core/ /app/src/core/
11+
COPY src/envs/connect4_env/ /app/src/envs/connect4_env/
12+
13+
# Health check
14+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
15+
CMD curl -f http://localhost:8000/health || exit 1
16+
17+
# Run server
18+
CMD ["uvicorn", "envs.connect4_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]

0 commit comments

Comments
 (0)