Skip to content

Commit

Permalink
Merge pull request #112 from intelligent-environments-lab/gymnasium-m…
Browse files Browse the repository at this point in the history
…igration

Gymnasium migration and Python version upgrade
  • Loading branch information
kingsleynweye authored Mar 18, 2024
2 parents 5f82a2f + c408a9a commit 29b33f7
Show file tree
Hide file tree
Showing 17 changed files with 567 additions and 587 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pypi_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: [3.7.7]
python-version: [3.12]

steps:
- uses: actions/checkout@v2
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/sphinx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: [3.7.7]
python-version: [3.12]

steps:
- uses: actions/checkout@v2
Expand All @@ -26,7 +26,6 @@ jobs:
python -m pip install --upgrade pip
python -m pip install flake8
pip install -r requirements.txt
pip install -r test_requirements.txt
pip install -r docs/requirements.txt
- name: Build HTML
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ venv/
ENV/
env.bak/
venv.bak/
*env/

# Spyder project settings
.spyderproject
Expand Down
14 changes: 6 additions & 8 deletions citylearn/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import logging
from typing import Any, List, Mapping
from gym import spaces
from gymnasium import spaces
import numpy as np
from citylearn.base import Environment
from citylearn.citylearn import CityLearnEnv

LOGGER = logging.getLogger()
logging.getLogger('matplotlib.font_manager').disabled = True
logging.getLogger('matplotlib.pyplot').disabled = True

class Agent(Environment):
r"""Base agent class.
Expand Down Expand Up @@ -137,22 +135,22 @@ def learn(self, episodes: int = None, deterministic: bool = None, deterministic_

for episode in range(episodes):
deterministic = deterministic or (deterministic_finish and episode >= episodes - 1)
observations = self.env.reset()
observations, _ = self.env.reset()
self.episode_time_steps = self.episode_tracker.episode_time_steps
done = False
terminated = False
time_step = 0
rewards_list = []

while not done:
while not terminated:
actions = self.predict(observations, deterministic=deterministic)

# apply actions to citylearn_env
next_observations, rewards, done, _ = self.env.step(actions)
next_observations, rewards, terminated, truncated, _ = self.env.step(actions)
rewards_list.append(rewards)

# update
if not deterministic:
self.update(observations, actions, rewards, next_observations, done=done)
self.update(observations, actions, rewards, next_observations, terminated=terminated, truncated=truncated)
else:
pass

Expand Down
8 changes: 5 additions & 3 deletions citylearn/agents/marlisa.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def pca_compression(self, pca_compression: float):
def iterations(self, iterations: int):
self.__iterations = 2 if iterations is None else iterations

def update(self, observations: List[List[float]], actions: List[List[float]], reward: List[float], next_observations: List[List[float]], done: bool):
def update(self, observations: List[List[float]], actions: List[List[float]], reward: List[float], next_observations: List[List[float]], terminated: bool, truncated: bool):
r"""Update replay buffer.
Parameters
Expand All @@ -128,8 +128,10 @@ def update(self, observations: List[List[float]], actions: List[List[float]], re
Current time step reward.
next_observations : List[List[float]]
Current time step observations.
done : bool
terminated : bool
Indication that episode has ended.
truncated : bool
If episode truncates due to a time limit or a reason that is not defined as part of the task MDP.
"""

# Run once the regression model has been fitted
Expand Down Expand Up @@ -168,7 +170,7 @@ def update(self, observations: List[List[float]], actions: List[List[float]], re
else:
pass

self.replay_buffer[i].push(o, a, r, n, done)
self.replay_buffer[i].push(o, a, r, n, terminated)

else:
pass
Expand Down
6 changes: 4 additions & 2 deletions citylearn/agents/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __exploit(self, observations: List[List[float]]) -> List[List[float]]:

return actions

def update(self, observations: List[List[float]], actions: List[List[float]], reward: List[float], next_observations: List[List[float]], done: bool):
def update(self, observations: List[List[float]], actions: List[List[float]], reward: List[float], next_observations: List[List[float]], terminated: bool, truncated: bool):
r"""Update Q-Table using Bellman equation.
Parameters
Expand All @@ -123,8 +123,10 @@ def update(self, observations: List[List[float]], actions: List[List[float]], re
Current time step reward.
next_observations : List[List[float]]
Current time step observations.
done : bool
terminated : bool
Indication that episode has ended.
truncated : bool
If episode truncates due to a time limit or a reason that is not defined as part of the task MDP.
"""

# Compute temporal difference target and error to udpate q-function
Expand Down
8 changes: 5 additions & 3 deletions citylearn/agents/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, env: CityLearnEnv, **kwargs: Any):
self.r_norm_std = [None for _ in self.action_space]
self.set_networks()

def update(self, observations: List[List[float]], actions: List[List[float]], reward: List[float], next_observations: List[List[float]], done: bool):
def update(self, observations: List[List[float]], actions: List[List[float]], reward: List[float], next_observations: List[List[float]], terminated: bool, truncated: bool):
r"""Update replay buffer.
Parameters
Expand All @@ -65,8 +65,10 @@ def update(self, observations: List[List[float]], actions: List[List[float]], re
Current time step reward.
next_observations : List[List[float]]
Current time step observations.
done : bool
terminated : bool
Indication that episode has ended.
truncated : bool
If episode truncates due to a time limit or a reason that is not defined as part of the task MDP.
"""

# Run once the regression model has been fitted
Expand All @@ -83,7 +85,7 @@ def update(self, observations: List[List[float]], actions: List[List[float]], re
else:
pass

self.replay_buffer[i].push(o, a, r, n, done)
self.replay_buffer[i].push(o, a, r, n, terminated)

if self.time_step >= self.standardize_start_time_step and self.batch_size <= len(self.replay_buffer[i]):
if not self.normalized[i]:
Expand Down
2 changes: 1 addition & 1 deletion citylearn/building.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, List, Mapping, Tuple, Union
from gym import spaces
from gymnasium import spaces
import numpy as np
import torch
from citylearn.base import Environment, EpisodeTracker
Expand Down
112 changes: 45 additions & 67 deletions citylearn/citylearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import os
from pathlib import Path
from typing import Any, List, Mapping, Tuple, Union
from gym import Env, spaces
from gymnasium import Env, spaces
import numpy as np
import pandas as pd
from citylearn import __version__ as citylearn_version
from citylearn.base import Environment, EpisodeTracker
from citylearn.building import Building, DynamicsBuilding
from citylearn.cost_function import CostFunction
from citylearn.data import DataSet, EnergySimulation, CarbonIntensity, Pricing, TOLERANCE, Weather
from citylearn.rendering import get_background, RenderBuilding, get_plots
from citylearn.reward_function import RewardFunction
from citylearn.utilities import read_json

Expand Down Expand Up @@ -99,6 +98,7 @@ def __init__(self,
):
self.schema = schema
self.__rewards = None
self.buildings = []
self.random_seed = random_seed
root_directory, buildings, episode_time_steps, rolling_episode_split, random_episode_split, \
seconds_per_time_step, reward_function, central_agent, shared_observations, episode_tracker = self._load(
Expand Down Expand Up @@ -225,10 +225,16 @@ def shared_observations(self) -> List[str]:
return self.__shared_observations

@property
def done(self) -> bool:
def terminated(self) -> bool:
"""Check if simulation has reached completion."""

return self.time_step == self.time_steps - 1

@property
def truncated(self) -> bool:
"""Check if episode truncates due to a time limit or a reason that is not defined as part of the task MDP."""

return False

@property
def observation_space(self) -> List[spaces.Box]:
Expand Down Expand Up @@ -740,6 +746,13 @@ def central_agent(self, central_agent: bool):
def shared_observations(self, shared_observations: List[str]):
self.__shared_observations = self.get_default_shared_observations() if shared_observations is None else shared_observations

@Environment.random_seed.setter
def random_seed(self, seed: int):
Environment.random_seed.fset(self, seed)

for b in self.buildings:
b.random_seed = self.random_seed

def get_metadata(self) -> Mapping[str, Any]:
return {
**super().get_metadata(),
Expand Down Expand Up @@ -772,7 +785,7 @@ def get_default_shared_observations() -> List[str]:
]


def step(self, actions: List[List[float]]) -> Tuple[List[List[float]], List[float], bool, dict]:
def step(self, actions: List[List[float]]) -> Tuple[List[List[float]], List[float], bool, bool, dict]:
"""Advance to next time step then apply actions to `buildings` and update variables.
Parameters
Expand All @@ -789,12 +802,15 @@ def step(self, actions: List[List[float]]) -> Tuple[List[List[float]], List[floa
:attr:`observations` current value.
reward: List[float]
:meth:`get_reward` current value.
done: bool
terminated: bool
A boolean value for if the episode has ended, in which case further :meth:`step` calls will return undefined results.
A done signal may be emitted for different reasons: Maybe the task underlying the environment was solved successfully,
a certain timelimit was exceeded, or the physics simulation has entered an invalid observation.
truncated: bool
A boolean value for if episode truncates due to a time limit or a reason that is not defined as part of the task MDP.
Will always return False in this base class.
info: dict
A dictionary that may contain additional information regarding the reason for a ``done`` signal.
A dictionary that may contain additional information regarding the reason for a `terminated` signal.
`info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging).
Override :meth"`get_info` to get custom key-value pairs in `info`.
"""
Expand All @@ -818,7 +834,7 @@ def step(self, actions: List[List[float]]) -> Tuple[List[List[float]], List[floa
self.__rewards.append(reward)

# store episode reward summary
if self.done:
if self.terminated:
rewards = np.array(self.__rewards[1:], dtype='float32')
self.__episode_rewards.append({
'min': rewards.min(axis=0).tolist(),
Expand All @@ -830,7 +846,7 @@ def step(self, actions: List[List[float]]) -> Tuple[List[List[float]], List[floa
else:
pass

return self.observations, reward, self.done, self.get_info()
return self.observations, reward, self.terminated, self.truncated, self.get_info()

def get_info(self) -> Mapping[Any, Any]:
"""Other information to return from the `citylearn.CityLearnEnv.step` function."""
Expand Down Expand Up @@ -1066,63 +1082,6 @@ def evaluate(self, control_condition: EvaluationCondition = None, baseline_condi
cost_functions = pd.concat([district_level, building_level], ignore_index=True, sort=False)

return cost_functions

def render(self):
"""Rendering function for The CityLearn Challenge 2023."""

canvas, canvas_size, draw_obj, color = get_background()
num_buildings = len(self.buildings)
profile_time_steps = 24
norm_min, norm_max = 0.0, 1.0
space_limits = []

for i, b in enumerate(self.buildings):
# current time step net electricity consumption and storage soc and indoor temperature
energy = b.net_electricity_consumption[b.time_step]\
/(b.non_periodic_normalized_observation_space_limits[1]['net_electricity_consumption'])
energy = max(min(energy, norm_max), norm_min)
electrical_storage_soc = b.electrical_storage.soc[b.time_step]
electrical_storage_soc = max(min(electrical_storage_soc, norm_max),norm_min)
dhw_storage_soc = b.dhw_storage.soc[b.time_step]
dhw_storage_soc = max(min(dhw_storage_soc, norm_max), norm_min)
indoor_temperature = b.indoor_dry_bulb_temperature[b.time_step]
indoor_temperature_delta = indoor_temperature - b.energy_simulation.indoor_dry_bulb_temperature_set_point[b.time_step]
space_limits.append(b.non_periodic_normalized_observation_space_limits)

# render
rbuilding = RenderBuilding(index=i, canvas_size=canvas_size, num_buildings=num_buildings, line_color=color)
rbuilding.draw_line(canvas, draw_obj, energy=energy, color=color)
rbuilding.draw_building(canvas, charge=electrical_storage_soc)

# time series data
nec = self.net_electricity_consumption[-profile_time_steps:]
nec_wo_storage = self.net_electricity_consumption_without_storage[-profile_time_steps:]
nec_wo_storage_and_partial_load = self.net_electricity_consumption_without_storage_and_partial_load[-profile_time_steps:]
nec_wo_storage_and_partial_load_and_pv = self.net_electricity_consumption_without_storage_and_partial_load_and_pv[-profile_time_steps:]
values = [nec, nec_wo_storage, nec_wo_storage_and_partial_load, nec_wo_storage_and_partial_load_and_pv]

# time series data y limits
nec_y_lim = (
sum(s[0]['net_electricity_consumption'] for s in space_limits),
sum(s[1]['net_electricity_consumption'] for s in space_limits)
)
nec_wo_storage_y_lim = (
sum(s[0]['net_electricity_consumption_without_storage'] for s in space_limits),
sum(s[1]['net_electricity_consumption_without_storage'] for s in space_limits)
)
nec_wo_storage_and_partial_load_y_lim = (
sum(s[0]['net_electricity_consumption_without_storage_and_partial_load'] for s in space_limits),
sum(s[1]['net_electricity_consumption_without_storage_and_partial_load'] for s in space_limits)
)
nec_wo_storage_and_partial_load_and_pv_y_lim = (
sum(s[0]['net_electricity_consumption_without_storage_and_partial_load_and_pv'] for s in space_limits),
sum(s[1]['net_electricity_consumption_without_storage_and_partial_load_and_pv'] for s in space_limits)
)
limits = [nec_y_lim, nec_wo_storage_y_lim, nec_wo_storage_and_partial_load_y_lim, nec_wo_storage_and_partial_load_and_pv_y_lim]
plot_image = get_plots(values, limits)
graphic_image = np.asarray(canvas)

return np.concatenate([graphic_image, plot_image], axis=1)

def next_time_step(self):
r"""Advance all buildings to next `time_step`."""
Expand All @@ -1132,18 +1091,37 @@ def next_time_step(self):

super().next_time_step()

def reset(self) -> List[List[float]]:
def reset(self, seed: int = None, options: Mapping[str, Any] = None) -> Tuple[List[List[float]], dict]:
r"""Reset `CityLearnEnv` to initial state.
Parameters
----------
seed: int, optional
Use to updated :code:`citylearn.CityLearnEnv.random_seed` if value is provided.
options: Mapping[str, Any], optional
Use to pass additional data to environment on reset. Not used in this base class
but included to conform to gymnasium interface.
Returns
-------
observations: List[List[float]]
:attr:`observations`.
info: dict
A dictionary that may contain additional information regarding the reason for a `terminated` signal.
`info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging).
Override :meth"`get_info` to get custom key-value pairs in `info`.
"""

# object reset
super().reset()

# update seed
if seed is not None:
self.random_seed = seed

else:
pass

# update time steps for time series
self.episode_tracker.next_episode(
self.episode_time_steps,
Expand All @@ -1165,7 +1143,7 @@ def reset(self) -> List[List[float]]:
self.__net_electricity_consumption_emission = []
self.update_variables()

return self.observations
return self.observations, self.get_info()

def update_variables(self):
# net electricity consumption
Expand Down
Loading

0 comments on commit 29b33f7

Please sign in to comment.