-
Notifications
You must be signed in to change notification settings - Fork 387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(zjow): envpoool env example in new pipeline #746
Changes from 250 commits
eac9434
6fda31b
6b9def4
6f49d0a
cdafb55
b5788e2
b181597
b534448
1d91b6d
ef5f1d5
87822ba
c86d897
0f06aa6
1973d01
7a832ee
1c111c2
3a05437
54b1a09
1caefff
646b005
6a8d535
6fb3534
c54f220
557102e
95f995c
c5e9a52
01b82c7
0d60070
3f3fb68
49cab88
dc5aa8c
ccb2fcf
84bef89
c937f3b
27ff425
02bc7f0
6fb854f
bbf7e2d
a76408c
fd7f922
70009ae
772c354
a7513d8
12d6291
fec830a
471aff4
8f523e7
0f5015e
151079c
d5cdb1e
511dfad
d69b165
b95f340
b6be677
516780b
98e4d46
93008aa
83861f8
883ce54
c7f5ad6
36a7dfa
1941951
af7272a
eafeada
d21839c
68a738e
fdc6408
c67622d
5cf69d6
995e39c
7e03fc1
f222d42
ca63569
2e8978b
aa3367d
4b7aa50
134e3e5
4c08017
efc807e
eed925f
f37f65b
59cc61b
4b2ffcd
8b04a11
b1aab8d
0584404
f651f68
92bfff3
a72de14
4e59519
f104d81
308e25a
522b0ff
1e87b1d
06f4046
7fc7032
bb74395
5ea9233
59b7080
d96ce90
e6e100b
f69c448
98877de
dbec6a7
d2d7e8e
168fd41
de2d180
1a2d4dd
498c094
1221565
1018dda
f197844
2e578ad
0bc6923
5473706
db8176b
43b0c3e
44a3047
d16fa86
eb86c63
eab7912
98a9017
b52d8f1
8b15b52
c915f33
dc61317
ea5f1e7
35a21b4
f95e8eb
603fa5e
6b874dc
ddd6550
d25228e
df7963d
a7c3cf4
ea979e8
ff7f639
ed5b1a3
05f8c47
df20033
4cc8eac
5ecc9dc
1e5ec1a
c621c35
41786e3
ebcefb4
420ef72
d9d93dd
aa1f39d
653a00b
57e7325
0754dd9
0919f06
d958f49
ab0fdda
0c1f2b6
9336a0a
ced06f8
a8822fd
8d152e0
2e2db04
e063d77
afb6355
0863b0b
c934ef6
a1f3e94
da9d2c1
af3d101
92d9504
1f0704c
5a08ec7
1774224
65b9f08
02f90cf
c9e736a
4d125c9
d204c95
eea5573
0f807d8
984c8ba
584bd7a
b87eabc
795ec5d
948c99b
554edb2
d0047ed
8c57293
c7509cb
9ff7d4b
0a1a2cc
b798c2e
fb5045b
6a4d83e
52ded5a
9aa23f7
c6e90a4
3addb8b
1d9f9af
ef99434
27cb8bd
5a41f63
1b7cf2a
aab3847
83ece4f
1980d51
96c0bbf
8419a38
7adbc77
7b581eb
fe30fbe
dc0ea3a
c7b7645
2f9a41f
871fdc0
e6e6828
8d79f66
e1c137a
ab33001
340d50e
e5af078
20cbef1
532b5b8
b31a7ca
35069ae
a06bd3f
e5ea2fd
d7c4983
e22df12
c03a17b
48c1333
dda0ffc
878bbb3
73b73dc
c068721
7daf239
ed0f490
8981236
3a1d98c
83ca217
97360c0
25fab56
ab93b39
cd762b6
d3c9bf8
1bd96e0
35a2c67
fca097f
3687f8b
4fb85b0
48ee6da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,11 @@ | |
from easydict import EasyDict | ||
from copy import deepcopy | ||
import numpy as np | ||
import torch | ||
import treetensor.torch as ttorch | ||
import treetensor.numpy as tnp | ||
from collections import namedtuple | ||
import enum | ||
from typing import Any, Union, List, Tuple, Dict, Callable, Optional | ||
from ditk import logging | ||
try: | ||
|
@@ -17,17 +21,28 @@ | |
from ding.torch_utils import to_ndarray | ||
|
||
|
||
@ENV_MANAGER_REGISTRY.register('env_pool') | ||
class EnvState(enum.IntEnum): | ||
VOID = 0 | ||
INIT = 1 | ||
RUN = 2 | ||
RESET = 3 | ||
DONE = 4 | ||
ERROR = 5 | ||
NEED_RESET = 6 | ||
|
||
|
||
@ENV_MANAGER_REGISTRY.register('envpool') | ||
class PoolEnvManager: | ||
''' | ||
""" | ||
Overview: | ||
PoolEnvManager supports old pipeline of DI-engine. | ||
Envpool now supports Atari, Classic Control, Toy Text, ViZDoom. | ||
Here we list some commonly used env_ids as follows. | ||
For more examples, you can refer to <https://envpool.readthedocs.io/en/latest/api/atari.html>. | ||
|
||
- Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5" | ||
- Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1" | ||
''' | ||
""" | ||
|
||
@classmethod | ||
def default_config(cls) -> EasyDict: | ||
|
@@ -39,10 +54,17 @@ def default_config(cls) -> EasyDict: | |
# Async mode: batch_size < env_num | ||
env_num=8, | ||
batch_size=8, | ||
image_observation=True, | ||
episodic_life=False, | ||
reward_clip=False, | ||
gray_scale=True, | ||
stack_num=4, | ||
frame_skip=4, | ||
) | ||
|
||
def __init__(self, cfg: EasyDict) -> None: | ||
self._cfg = cfg | ||
self._cfg = self.default_config() | ||
self._cfg.update(cfg) | ||
self._env_num = cfg.env_num | ||
self._batch_size = cfg.batch_size | ||
self._ready_obs = {} | ||
|
@@ -55,6 +77,7 @@ def launch(self) -> None: | |
seed = 0 | ||
else: | ||
seed = self._seed | ||
|
||
self._envs = envpool.make( | ||
task_id=self._cfg.env_id, | ||
env_type="gym", | ||
|
@@ -65,8 +88,10 @@ def launch(self) -> None: | |
reward_clip=self._cfg.reward_clip, | ||
stack_num=self._cfg.stack_num, | ||
gray_scale=self._cfg.gray_scale, | ||
frame_skip=self._cfg.frame_skip | ||
frame_skip=self._cfg.frame_skip, | ||
) | ||
self._action_space = self._envs.action_space | ||
self._observation_space = self._envs.observation_space | ||
self._closed = False | ||
self.reset() | ||
|
||
|
@@ -77,6 +102,8 @@ def reset(self) -> None: | |
obs, _, _, info = self._envs.recv() | ||
env_id = info['env_id'] | ||
obs = obs.astype(np.float32) | ||
if self._cfg.image_observation: | ||
obs /= 255.0 | ||
self._ready_obs = deep_merge_dicts({i: o for i, o in zip(env_id, obs)}, self._ready_obs) | ||
if len(self._ready_obs) == self._env_num: | ||
break | ||
|
@@ -91,6 +118,8 @@ def step(self, action: dict) -> Dict[int, namedtuple]: | |
|
||
obs, rew, done, info = self._envs.recv() | ||
obs = obs.astype(np.float32) | ||
if self._cfg.image_observation: | ||
obs /= 255.0 | ||
rew = rew.astype(np.float32) | ||
env_id = info['env_id'] | ||
timesteps = {} | ||
|
@@ -124,3 +153,152 @@ def env_num(self) -> int: | |
@property | ||
def ready_obs(self) -> Dict[int, Any]: | ||
return self._ready_obs | ||
|
||
@property | ||
def observation_space(self) -> 'gym.spaces.Space': # noqa | ||
try: | ||
return self._observation_space | ||
except AttributeError: | ||
self.launch() | ||
self.close() | ||
return self._observation_space | ||
|
||
@property | ||
def action_space(self) -> 'gym.spaces.Space': # noqa | ||
try: | ||
return self._action_space | ||
except AttributeError: | ||
self.launch() | ||
self.close() | ||
return self._action_space | ||
|
||
|
||
@ENV_MANAGER_REGISTRY.register('envpool_v2') | ||
class PoolEnvManagerV2: | ||
""" | ||
Overview: | ||
PoolEnvManagerV2 supports new pipeline of DI-engine. | ||
Envpool now supports Atari, Classic Control, Toy Text, ViZDoom. | ||
Here we list some commonly used env_ids as follows. | ||
For more examples, you can refer to <https://envpool.readthedocs.io/en/latest/api/atari.html>. | ||
|
||
- Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5" | ||
- Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1" | ||
""" | ||
|
||
@classmethod | ||
def default_config(cls) -> EasyDict: | ||
return EasyDict(deepcopy(cls.config)) | ||
|
||
config = dict( | ||
type='envpool_v2', | ||
env_num=8, | ||
batch_size=8, | ||
image_observation=True, | ||
episodic_life=False, | ||
reward_clip=False, | ||
gray_scale=True, | ||
stack_num=4, | ||
frame_skip=4, | ||
) | ||
|
||
def __init__(self, cfg: EasyDict) -> None: | ||
super().__init__() | ||
self._cfg = self.default_config() | ||
self._cfg.update(cfg) | ||
self._env_num = cfg.env_num | ||
self._batch_size = cfg.batch_size | ||
|
||
self._closed = True | ||
self._seed = None | ||
|
||
def launch(self) -> None: | ||
assert self._closed, "Please first close the env manager" | ||
if self._seed is None: | ||
seed = 0 | ||
else: | ||
seed = self._seed | ||
|
||
self._envs = envpool.make( | ||
task_id=self._cfg.env_id, | ||
env_type="gym", | ||
num_envs=self._env_num, | ||
batch_size=self._batch_size, | ||
seed=seed, | ||
episodic_life=self._cfg.episodic_life, | ||
reward_clip=self._cfg.reward_clip, | ||
stack_num=self._cfg.stack_num, | ||
gray_scale=self._cfg.gray_scale, | ||
frame_skip=self._cfg.frame_skip, | ||
) | ||
self._action_space = self._envs.action_space | ||
self._observation_space = self._envs.observation_space | ||
self._closed = False | ||
return self.reset() | ||
|
||
def reset(self) -> None: | ||
self._envs.async_reset() | ||
ready_obs = {} | ||
while True: | ||
obs, _, _, info = self._envs.recv() | ||
env_id = info['env_id'] | ||
obs = obs.astype(np.float32) | ||
if self._cfg.image_observation: | ||
obs /= 255.0 | ||
ready_obs = deep_merge_dicts({i: o for i, o in zip(env_id, obs)}, ready_obs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. directly use the assignment operation, don't use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
if len(ready_obs) == self._env_num: | ||
break | ||
self._eval_episode_return = [0. for _ in range(self._env_num)] | ||
|
||
return ready_obs | ||
|
||
def send_action(self, action, env_id) -> Dict[int, namedtuple]: | ||
self._envs.send(action, env_id) | ||
|
||
def receive_data(self): | ||
next_obs, rew, done, info = self._envs.recv() | ||
next_obs = next_obs.astype(np.float32) | ||
if self._cfg.image_observation: | ||
next_obs /= 255.0 | ||
rew = rew.astype(np.float32) | ||
|
||
return next_obs, rew, done, info | ||
|
||
def close(self) -> None: | ||
if self._closed: | ||
return | ||
# Envpool has no `close` API | ||
self._closed = True | ||
|
||
@property | ||
def closed(self) -> None: | ||
return self._closed | ||
|
||
def seed(self, seed: int, dynamic_seed=False) -> None: | ||
# The i-th environment seed in Envpool will be set with i+seed, so we don't do extra transformation here | ||
self._seed = seed | ||
logging.warning("envpool doesn't support dynamic_seed in different episode") | ||
|
||
@property | ||
def env_num(self) -> int: | ||
return self._env_num | ||
|
||
@property | ||
def observation_space(self) -> 'gym.spaces.Space': # noqa | ||
try: | ||
return self._observation_space | ||
except AttributeError: | ||
self.launch() | ||
self.close() | ||
self._ready_obs = {} | ||
return self._observation_space | ||
|
||
@property | ||
def action_space(self) -> 'gym.spaces.Space': # noqa | ||
try: | ||
return self._action_space | ||
except AttributeError: | ||
self.launch() | ||
self.close() | ||
self._ready_obs = {} | ||
return self._action_space | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add envpooltest for this file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import datetime | ||
from easydict import EasyDict | ||
from ditk import logging | ||
from ding.model import DQN | ||
from ding.policy import DQNFastPolicy | ||
from ding.envs.env_manager.envpool_env_manager import PoolEnvManagerV2 | ||
from ding.data import DequeBuffer | ||
from ding.config import compile_config | ||
from ding.framework import task, ding_init | ||
from ding.framework.context import OnlineRLContext | ||
from ding.framework.middleware import envpool_evaluator, data_pusher, \ | ||
eps_greedy_handler, CkptSaver, ContextExchanger, ModelExchanger, online_logger, \ | ||
termination_checker, wandb_online_logger, epoch_timer, EnvpoolStepCollector, EnvpoolOffPolicyLearner | ||
from ding.utils import set_pkg_seed | ||
from dizoo.atari.config.serial import pong_dqn_envpool_config | ||
|
||
|
||
def main(cfg): | ||
logging.getLogger().setLevel(logging.INFO) | ||
cfg.exp_name = 'Pong-v5-DQN-envpool-' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why modify this |
||
|
||
collector_env_cfg = EasyDict( | ||
{ | ||
'env_id': cfg.env.env_id, | ||
'env_num': cfg.env.collector_env_num, | ||
'batch_size': cfg.env.collector_batch_size, | ||
# env wrappers | ||
'episodic_life': True, # collector: True | ||
'reward_clip': False, # collector: True | ||
'gray_scale': cfg.env.get('gray_scale', True), | ||
'stack_num': cfg.env.get('stack_num', 4), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move some keys to default config and user config |
||
} | ||
) | ||
cfg.env["collector_env_cfg"] = collector_env_cfg | ||
evaluator_env_cfg = EasyDict( | ||
{ | ||
'env_id': cfg.env.env_id, | ||
'env_num': cfg.env.evaluator_env_num, | ||
'batch_size': cfg.env.evaluator_batch_size, | ||
# env wrappers | ||
'episodic_life': False, # evaluator: False | ||
'reward_clip': False, # evaluator: False | ||
'gray_scale': cfg.env.get('gray_scale', True), | ||
'stack_num': cfg.env.get('stack_num', 4), | ||
} | ||
) | ||
cfg.env["evaluator_env_cfg"] = evaluator_env_cfg | ||
cfg = compile_config(cfg, PoolEnvManagerV2, DQNFastPolicy, save_cfg=task.router.node_id == 0) | ||
ding_init(cfg) | ||
with task.start(async_mode=False, ctx=OnlineRLContext()): | ||
collector_env = PoolEnvManagerV2(cfg.env.collector_env_cfg) | ||
evaluator_env = PoolEnvManagerV2(cfg.env.evaluator_env_cfg) | ||
collector_env.seed(cfg.seed) | ||
evaluator_env.seed(cfg.seed) | ||
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) | ||
|
||
model = DQN(**cfg.policy.model) | ||
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) | ||
policy = DQNFastPolicy(cfg.policy, model=model) | ||
|
||
# Consider the case with multiple processes | ||
if task.router.is_active: | ||
# You can use labels to distinguish between workers with different roles, | ||
# here we use node_id to distinguish. | ||
if task.router.node_id == 0: | ||
task.add_role(task.role.LEARNER) | ||
elif task.router.node_id == 1: | ||
task.add_role(task.role.EVALUATOR) | ||
else: | ||
task.add_role(task.role.COLLECTOR) | ||
|
||
# Sync their context and model between each worker. | ||
task.use(ContextExchanger(skip_n_iter=1)) | ||
task.use(ModelExchanger(model)) | ||
task.use(epoch_timer()) | ||
task.use(envpool_evaluator(cfg, policy.eval_mode, evaluator_env)) | ||
task.use(eps_greedy_handler(cfg)) | ||
task.use( | ||
EnvpoolStepCollector( | ||
cfg, | ||
policy.collect_mode, | ||
collector_env, | ||
random_collect_size=cfg.policy.random_collect_size \ | ||
if hasattr(cfg.policy, 'random_collect_size') else 0, | ||
) | ||
) | ||
task.use(data_pusher(cfg, buffer_)) | ||
task.use(EnvpoolOffPolicyLearner(cfg, policy, buffer_)) | ||
task.use(online_logger(train_show_freq=10)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why use two logger |
||
task.use( | ||
wandb_online_logger( | ||
metric_list=policy._monitor_vars_learn(), | ||
model=policy._model, | ||
exp_config=cfg, | ||
anonymous=True, | ||
project_name=cfg.exp_name, | ||
wandb_sweep=False, | ||
) | ||
) | ||
#task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000)) | ||
task.use(termination_checker(max_env_step=10000000)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use int(1e7) |
||
task.run() | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
import argparse | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--seed", type=int, default=0, help="random seed") | ||
parser.add_argument("--collector_env_num", type=int, default=8, help="collector env number") | ||
parser.add_argument("--collector_batch_size", type=int, default=8, help="collector batch size") | ||
arg = parser.parse_args() | ||
|
||
pong_dqn_envpool_config.env.collector_env_num = arg.collector_env_num | ||
pong_dqn_envpool_config.env.collector_batch_size = arg.collector_batch_size | ||
pong_dqn_envpool_config.seed = arg.seed | ||
|
||
main(pong_dqn_envpool_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why merge config here, we have already merged the config of env manager in
compile_config
functionThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two env manager config, one for evaluator and one for collector. It's too complicated to use compile_config with auto=True.
I suggest use compile_config with auto=False.