forked from automl-edu/advanced-topics-in-deep-rl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
462 additions
and
72 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,26 @@ | ||
# adrl | ||
# Advances Topics in Deep RL | ||
|
||
This repository contains the lecture materials. There are two main directories: | ||
- adrl contains constructors for our main three settings and small examples | ||
- lecture contains the lecture PDFs and you will add your seminar contributions via PR there as well | ||
|
||
|
||
You should install the repository as below to run experiments in our settings. You should be able to interact with the continual environment as with any other env. For the multi-agent interface see [Petting Zoo](https://pettingzoo.farama.org/) and for offline [Minari](https://minari.farama.org/main/). | ||
|
||
## Installation | ||
|
||
Ideally, you'll follow these instructions to create a fresh conda environment and then install for usage. That should allow you to run the examples and use the constructor functions for all three settings. | ||
The dev option simply enables formatting in case you're interested in using that. | ||
|
||
``` | ||
git clone https://github.com/automl/adrl.git | ||
cd adrl | ||
conda create -n adrl python=3.8 | ||
conda create -n adrl python=3.10 | ||
conda activate adrl | ||
# Install for usage | ||
pip install . | ||
make install | ||
# Install for development | ||
make install-dev | ||
``` | ||
|
||
## Minimal Example | ||
|
||
``` | ||
# Your code here | ||
``` | ||
|
||
TODOs: | ||
- verständnisampel s. fast.ai | ||
- intuition als prio | ||
- großes visualisierung aller themen | ||
- beispielprojekt | ||
- klar machen dass komponenten beliebig kombinierbar sind |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import numpy as np | ||
from gymnasium import Wrapper | ||
from carl.envs import CARLLunarLander | ||
|
||
|
||
class GravityChangeWrapper(Wrapper): | ||
def __init__(self, env): | ||
super().__init__(env) | ||
self.n_steps = 0 | ||
self.n_switches = 0 | ||
|
||
def step(self, action): | ||
self.n_steps += 1 | ||
state, reward, terminated, truncated, info = self.env.step(action) | ||
if self.n_steps >= 10000: | ||
truncated = True | ||
return state, reward, terminated, truncated, info | ||
|
||
def reset(self): | ||
self.env.reset() | ||
if self.n_steps // 10000 <= self.n_switches: | ||
change_kind = np.random.choice(["flip", "random"]) | ||
if change_kind == "flip": | ||
gravity = -self.env.context["GRAVITY_Y"] | ||
else: | ||
gravity = np.random.uniform(-20, 0) | ||
self.env.contexts[0] = {"GRAVITY_Y": gravity} | ||
self.env.context["GRAVITY_Y"] = gravity | ||
self.n_switches += 1 | ||
return self.env.reset() | ||
|
||
|
||
def make_continual_rl_env(): | ||
contexts = {0: {"GRAVITY_Y": -10}} | ||
env = CARLLunarLander(contexts=contexts) | ||
env = GravityChangeWrapper(env) | ||
return env | ||
|
||
|
||
if __name__ == "__main__": | ||
env = make_continual_rl_env() | ||
env.reset() | ||
for i in range(50000): | ||
_, _, te, tr, _ = env.step(env.action_space.sample()) | ||
if te or tr: | ||
env.reset() | ||
if env.n_steps % 10000 == 0: | ||
print(f"Gravity is {env.env.context['GRAVITY_Y']}") | ||
env.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import csv | ||
from dacbench.benchmarks import SigmoidBenchmark | ||
import pathlib | ||
|
||
|
||
def make_multi_agent_env(): | ||
bench = SigmoidBenchmark() | ||
bench.config.instance_set = {} | ||
with open(pathlib.Path(__file__).parent.resolve() / "sigmoid_train.csv", "r") as f: | ||
reader = csv.reader(f) | ||
for row in reader: | ||
f = [] | ||
inst_id = None | ||
for i in range(len(row)): | ||
if i == 0: | ||
try: | ||
inst_id = int(row[i]) | ||
except Exception: | ||
continue | ||
else: | ||
try: | ||
f.append(float(row[i])) | ||
except Exception: | ||
continue | ||
if not len(f) == 0: | ||
bench.config.instance_set[inst_id] = f | ||
|
||
bench.config.test_set = {} | ||
with open(pathlib.Path(__file__).parent.resolve() / "sigmoid_test.csv", "r") as f: | ||
reader = csv.reader(f) | ||
for row in reader: | ||
f = [] | ||
inst_id = None | ||
for i in range(len(row)): | ||
if i == 0: | ||
try: | ||
inst_id = int(row[i]) | ||
except Exception: | ||
continue | ||
else: | ||
try: | ||
f.append(float(row[i])) | ||
except Exception: | ||
continue | ||
if not len(f) == 0: | ||
bench.config.test_set[inst_id] = f | ||
|
||
bench.config["multi_agent"] = True | ||
env = bench.get_environment() | ||
return env | ||
|
||
|
||
if __name__ == "__main__": | ||
env = make_multi_agent_env() | ||
|
||
# Add one agent per action dimension | ||
env.register_agent(agent_id=0) | ||
env.register_agent(agent_id=1) | ||
|
||
env.reset() | ||
total_reward = 0 | ||
terminated, truncated = False, False | ||
while not (terminated or truncated): | ||
for agent in [0, 1]: | ||
observation, reward, terminated, truncated, info = env.last() | ||
action = env.action_spaces[agent].sample() | ||
env.step(action) | ||
observation, reward, terminated, truncated, info = env.last() | ||
total_reward += reward | ||
|
||
print(f"The final reward was {total_reward}.") | ||
env.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import numpy as np | ||
import minari | ||
import torch | ||
import torch.nn as nn | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
def collate_fn(batch): | ||
return { | ||
"id": torch.Tensor([x.id for x in batch]), | ||
"seed": torch.Tensor([x.seed for x in batch]), | ||
"total_steps": torch.Tensor([x.total_timesteps for x in batch]), | ||
"observations": torch.nn.utils.rnn.pad_sequence( | ||
[torch.as_tensor(x.observations["observation"]) for x in batch], | ||
batch_first=True, | ||
), | ||
"actions": torch.nn.utils.rnn.pad_sequence( | ||
[torch.as_tensor(x.actions) for x in batch], batch_first=True | ||
), | ||
"rewards": torch.nn.utils.rnn.pad_sequence( | ||
[torch.as_tensor(x.rewards) for x in batch], batch_first=True | ||
), | ||
"terminations": torch.nn.utils.rnn.pad_sequence( | ||
[torch.as_tensor(x.terminations) for x in batch], batch_first=True | ||
), | ||
"truncations": torch.nn.utils.rnn.pad_sequence( | ||
[torch.as_tensor(x.truncations) for x in batch], batch_first=True | ||
), | ||
} | ||
|
||
|
||
class PolicyNetwork(nn.Module): | ||
def __init__(self, input_dim, output_dim): | ||
super().__init__() | ||
self.fc1 = nn.Linear(input_dim, 256) | ||
self.fc2 = nn.Linear(256, 128) | ||
self.fc3 = nn.Linear(128, output_dim) | ||
|
||
def forward(self, x): | ||
x = torch.tensor(x).float() | ||
x = torch.relu(self.fc1(x)) | ||
x = torch.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x | ||
|
||
|
||
def make_offline_rl_dataset(): | ||
dataset = minari.load_dataset("antmaze-umaze-v0", download=True) | ||
dataloader = DataLoader( | ||
dataset, batch_size=256, shuffle=True, collate_fn=collate_fn | ||
) | ||
env = dataset.recover_environment() | ||
return dataloader, env | ||
|
||
|
||
if __name__ == "__main__": | ||
num_epochs = 3 | ||
dataloader, env = make_offline_rl_dataset() | ||
|
||
observation_space = env.observation_space["observation"] | ||
action_space = env.action_space | ||
policy_net = PolicyNetwork(np.prod(observation_space.shape), action_space.shape[0]) | ||
optimizer = torch.optim.Adam(policy_net.parameters()) | ||
loss_fn = nn.CrossEntropyLoss() | ||
|
||
for epoch in range(num_epochs): | ||
for batch in dataloader: | ||
a_pred = policy_net(batch["observations"][:, :-1]) | ||
loss = loss_fn(a_pred, a_pred) | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
print(f"Epoch: {epoch}/{num_epochs}, Loss: {loss.item()}") | ||
|
||
state = env.reset()[0] | ||
te, tr = False, False | ||
r = 0 | ||
while not (te or tr): | ||
action = policy_net(state["observation"]) | ||
state, reward, te, tr, _ = env.step(action.detach().numpy()) | ||
r += reward | ||
print(f"Total online evaluation reward: {reward}") | ||
env.close() |
Empty file.
Oops, something went wrong.