Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zjow): envpoool env example in new pipeline #746

Closed
wants to merge 308 commits into from
Closed
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
308 commits
Select commit Hold shift + click to select a range
eac9434
fix bug
zjowowen Mar 13, 2023
6fda31b
fix dtype error
zjowowen Mar 14, 2023
6b9def4
polish code
zjowowen Mar 15, 2023
6f49d0a
polish code
zjowowen Mar 15, 2023
cdafb55
Add dqn agent
zjowowen Mar 16, 2023
b5788e2
merge from main
zjowowen Mar 17, 2023
b181597
add config
zjowowen Mar 17, 2023
b534448
merge from main
zjowowen Mar 20, 2023
1d91b6d
add bonus/c51.py
zhangpaipai Mar 20, 2023
ef5f1d5
add c51 logit monitor
zhangpaipai Mar 27, 2023
87822ba
add sac dqn agent
zjowowen Mar 28, 2023
c86d897
add sac dqn agent demo in dizoo
zjowowen Mar 28, 2023
0f06aa6
merge from main
zjowowen Mar 28, 2023
1973d01
polish format
zjowowen Mar 28, 2023
7a832ee
pull zjow new-pipeline-agent
zhangpaipai Mar 28, 2023
1c111c2
polish code
zjowowen Mar 28, 2023
3a05437
polish code
zjowowen Mar 28, 2023
54b1a09
fix ddpg bug
zjowowen Mar 28, 2023
1caefff
merge nyz c51/dqn config and policy
zhangpaipai Mar 28, 2023
646b005
merge from main
zjowowen Mar 29, 2023
6a8d535
fix config
zjowowen Mar 29, 2023
6fb3534
remove mutistep_trainer
zhangpaipai Mar 29, 2023
c54f220
fix bug
zjowowen Mar 29, 2023
557102e
polish code
zjowowen Mar 29, 2023
95f995c
polish code
zjowowen Mar 29, 2023
c5e9a52
polish code
zjowowen Mar 30, 2023
01b82c7
add Hopper demo
zjowowen Mar 31, 2023
0d60070
polish code
zjowowen Mar 31, 2023
3f3fb68
add property best
zjowowen Apr 3, 2023
49cab88
merge from main
zjowowen Apr 3, 2023
dc5aa8c
add a2c pipeline
zjowowen Apr 3, 2023
ccb2fcf
add sac halfcheetah+walker2d
zhangpaipai Apr 5, 2023
84bef89
pull zjow new-pipeline-agent
zhangpaipai Apr 5, 2023
c937f3b
fix a2c pipeline bug
zjowowen Apr 6, 2023
27ff425
fix pipeline bug
zjowowen Apr 6, 2023
02bc7f0
fix bug
zjowowen Apr 6, 2023
6fb854f
change config
zjowowen Apr 6, 2023
bbf7e2d
merge from main
zjowowen Apr 7, 2023
a76408c
remove IMPALA pipeline
zjowowen Apr 7, 2023
fd7f922
format code
zjowowen Apr 7, 2023
70009ae
polish code
zjowowen Apr 7, 2023
772c354
polish c51 and add ddpg halfcheetah walker2d
zhangpaipai Apr 10, 2023
a7513d8
pull zjow new-pipeline-agent again
zhangpaipai Apr 10, 2023
12d6291
add dizoo/common for zjow to review
zhangpaipai Apr 11, 2023
fec830a
fix agent best method
zjowowen Apr 11, 2023
471aff4
Merge branch 'new-pipeline-agent' of https://github.com/zjowowen/DI-e…
zhangpaipai Apr 11, 2023
8f523e7
reset dizoo
zjowowen Apr 11, 2023
0f5015e
delete common
zhangpaipai Apr 11, 2023
151079c
Merge branch 'new-pipeline-agent' of https://github.com/zjowowen/DI-e…
zhangpaipai Apr 11, 2023
d5cdb1e
polish for zjow to review
zhangpaipai Apr 12, 2023
511dfad
merge from main
zjowowen Apr 13, 2023
d69b165
polish code
zjowowen Apr 13, 2023
b95f340
polish code
zjowowen Apr 13, 2023
b6be677
fix bug
zjowowen Apr 13, 2023
516780b
fix bug
zjowowen Apr 13, 2023
98e4d46
Merge branch 'new-pipeline-agent' of https://github.com/zjowowen/DI-e…
zhangpaipai Apr 14, 2023
93008aa
polish c51
zhangpaipai Apr 14, 2023
83861f8
merge from main
zjowowen Apr 25, 2023
883ce54
Merge branch 'new-pipeline-agent' of https://github.com/zjowowen/DI-e…
zhangpaipai Apr 26, 2023
c7f5ad6
add pg agent
zjowowen Apr 27, 2023
36a7dfa
merge from main
zjowowen Apr 27, 2023
1941951
add pendulum config
zhangpaipai Apr 28, 2023
af7272a
add c51_atari td3_pendulum,bipedalwalker ddpg_pendulum
zhangpaipai Apr 28, 2023
eafeada
polish code
zjowowen Apr 28, 2023
d21839c
merge from main
zjowowen Apr 28, 2023
68a738e
polish code
zjowowen Apr 28, 2023
fdc6408
polish code
zjowowen Apr 28, 2023
c67622d
merge zjow
zhangpaipai Apr 29, 2023
5cf69d6
add bipedalwalker_ddpg_config
zhangpaipai May 5, 2023
995e39c
merge from main
zjowowen May 9, 2023
7e03fc1
merge zjow
zhangpaipai May 9, 2023
f222d42
feature(zp): add c51
zjowowen May 9, 2023
ca63569
change config
zjowowen May 9, 2023
2e8978b
change bipedalwalker config and noframeskip
zhangpaipai May 11, 2023
aa3367d
polish c51-atari name
zhangpaipai May 15, 2023
4b7aa50
add pong spaceinvaders and qbert for dqn
ruoyuGao May 15, 2023
134e3e5
merge from main
zjowowen May 16, 2023
4c08017
git fetch
zjowowen May 16, 2023
efc807e
polish code
zjowowen May 16, 2023
eed925f
polish code; add env mode
zjowowen May 16, 2023
f37f65b
add rew_clip in ding_env_wrapper
zhangpaipai May 16, 2023
59cc61b
polish dqn atari
zhangpaipai May 19, 2023
4b2ffcd
merge from main
zjowowen May 23, 2023
8b04a11
merge from new-pipeline-agent
zjowowen May 23, 2023
b1aab8d
add a2c continuous action space
zjowowen May 23, 2023
0584404
add a2c continuous action space
zjowowen May 23, 2023
f651f68
add a2c continuous for mujoco
zjowowen May 23, 2023
92bfff3
add a2c continuous for mujoco
zjowowen May 23, 2023
a72de14
add a2c continuous for mujoco
zjowowen May 23, 2023
4e59519
add a2c mujoco config; add ppo atari config
zjowowen May 24, 2023
f104d81
add a2c mujoco config; add ppo atari config
zjowowen May 24, 2023
308e25a
fix a2c deploy bug
zjowowen May 24, 2023
522b0ff
Add bipedalwalker a2c
zjowowen May 25, 2023
1e87b1d
polish code
zjowowen May 25, 2023
06f4046
polish code
zjowowen May 25, 2023
7fc7032
polish code
zjowowen May 25, 2023
bb74395
polish code
zjowowen May 25, 2023
5ea9233
polish code
zjowowen May 29, 2023
59b7080
add pendulum a2c+pg
zhangpaipai May 30, 2023
d96ce90
Merge branch 'new-pipeline-agent' of https://github.com/zjowowen/DI-e…
zhangpaipai May 30, 2023
e6e100b
add pg bipedalwalker+mujoco
zhangpaipai May 30, 2023
f69c448
polish code for wandb sweep
zjowowen May 30, 2023
98877de
polish code for wandb sweep
zjowowen May 30, 2023
dbec6a7
polish code for wandb sweep
zjowowen May 30, 2023
d2d7e8e
polish code for a2c mujoco
zjowowen May 30, 2023
168fd41
add pg pendulum new pipeline
zhangpaipai May 31, 2023
de2d180
fix scalar action bug in random collect
zjowowen May 31, 2023
1a2d4dd
polish pg algorithm
zhangpaipai Jun 1, 2023
498c094
Merge branch 'new-pipeline-agent' of https://github.com/zjowowen/DI-e…
zhangpaipai Jun 1, 2023
1221565
add bonus pg config
zhangpaipai Jun 1, 2023
1018dda
polish pg config
zhangpaipai Jun 1, 2023
f197844
polish config
zjowowen Jun 1, 2023
2e578ad
merge from main
zjowowen Jun 1, 2023
0bc6923
polish code
zjowowen Jun 1, 2023
5473706
change pendulum pg config
zhangpaipai Jun 1, 2023
db8176b
fix continuous action dim=1 bug
zjowowen Jun 1, 2023
43b0c3e
merge from main
zjowowen Jun 1, 2023
44a3047
merge from origin main
zjowowen Jun 1, 2023
d16fa86
Add ppof lr scheduler
zjowowen Jun 5, 2023
eb86c63
polish config
zjowowen Jun 6, 2023
eab7912
fix random collect bug for dqn
zjowowen Jun 6, 2023
98a9017
polish ppo qbert spaceinvader config
zjowowen Jun 7, 2023
b52d8f1
remove mujoco wrapper
zjowowen Jun 9, 2023
8b15b52
polish a2c mujoco config; add ppo offpolicy agent pipeline
zjowowen Jun 9, 2023
c915f33
merge from main
zjowowen Jun 9, 2023
dc61317
Add wandb monitor evaluate return std
zjowowen Jun 9, 2023
ea5f1e7
polish deploy method
zjowowen Jun 9, 2023
35a21b4
format code
zjowowen Jun 9, 2023
f95e8eb
polish code
zjowowen Jun 13, 2023
603fa5e
polish pg pendulum+hopper config
zhangpaipai Jun 13, 2023
6b874dc
Merge branch 'new-pipeline-agent' of https://github.com/zjowowen/DI-e…
zhangpaipai Jun 13, 2023
ddd6550
fix data shape bug
zjowowen Jun 13, 2023
d25228e
merge from remote
zjowowen Jun 13, 2023
df7963d
merge from main
zjowowen Jun 13, 2023
a7c3cf4
fix ppo offpolicy deploy bug
zjowowen Jun 13, 2023
ea979e8
fix mujoco reward action env clip bug
zjowowen Jun 13, 2023
ff7f639
fix mujoco reward action env clip bug
zjowowen Jun 13, 2023
ed5b1a3
fix deploy env mode bug
zjowowen Jun 14, 2023
05f8c47
fix env reset bug for deployment and evaluation
zjowowen Jun 14, 2023
df20033
Add ppo offpolicy atari config
zjowowen Jun 25, 2023
4cc8eac
merge from main
zjowowen Jun 27, 2023
5ecc9dc
polish config
zjowowen Jun 29, 2023
1e5ec1a
merge from main
zjowowen Jul 9, 2023
c621c35
polish config code
zjowowen Jul 10, 2023
41786e3
polish code; add SQL
zjowowen Jul 10, 2023
ebcefb4
polish code
zjowowen Jul 10, 2023
420ef72
polish code
zjowowen Jul 10, 2023
d9d93dd
polish code
zjowowen Jul 10, 2023
aa1f39d
polish code
zjowowen Jul 10, 2023
653a00b
change config path
zjowowen Jul 11, 2023
57e7325
add compatibility fix for nstep
zjowowen Jul 11, 2023
0754dd9
polish code
zjowowen Jul 11, 2023
0919f06
Add ppo offpolicy continuous policy
zjowowen Jul 12, 2023
d958f49
polish config
zjowowen Jul 14, 2023
ab0fdda
add ppo offpolicy general action modeling
zjowowen Jul 17, 2023
0c1f2b6
add dependencies
zjowowen Jul 17, 2023
9336a0a
polish config
zjowowen Jul 18, 2023
ced06f8
polish deploy
zjowowen Jul 18, 2023
a8822fd
Add array video helper
zjowowen Jul 18, 2023
8d152e0
polish deploy
zjowowen Jul 18, 2023
2e2db04
merge from main
zjowowen Jul 19, 2023
e063d77
polish config
zjowowen Jul 19, 2023
afb6355
polish setup
zjowowen Jul 20, 2023
0863b0b
fix config bug
zjowowen Jul 22, 2023
c934ef6
polish code
zjowowen Jul 25, 2023
a1f3e94
polish code
zjowowen Jul 25, 2023
da9d2c1
polish code
zjowowen Jul 25, 2023
af3d101
merge from main
zjowowen Jul 26, 2023
92d9504
fix bug in evaluator
zjowowen Jul 26, 2023
1f0704c
polish code
zjowowen Jul 27, 2023
5a08ec7
polish code
zjowowen Jul 27, 2023
1774224
merge from main
zjowowen Aug 3, 2023
65b9f08
Add priority in collector
zjowowen Aug 8, 2023
02f90cf
merge from main
zjowowen Aug 8, 2023
c9e736a
polish code
zjowowen Aug 8, 2023
4d125c9
Merge branch 'new-pipeline-agent' of https://github.com/zjowowen/DI-e…
zjowowen Aug 8, 2023
d204c95
polish example
zjowowen Aug 8, 2023
eea5573
polish example
zjowowen Aug 8, 2023
0f807d8
add wandb logger
zjowowen Aug 8, 2023
984c8ba
polish code
zjowowen Aug 10, 2023
584bd7a
polish code
zjowowen Aug 11, 2023
b87eabc
merge from main
zjowowen Aug 15, 2023
795ec5d
merge from new-agent-pipeline
zjowowen Aug 15, 2023
948c99b
polish code
zjowowen Aug 15, 2023
554edb2
Merge branch 'envpool-pr' of https://github.com/zjowowen/DI-engine in…
zjowowen Aug 15, 2023
d0047ed
change timer gpu to false
zjowowen Aug 15, 2023
8c57293
Merge branch 'envpool-pr' of https://github.com/zjowowen/DI-engine in…
zjowowen Aug 15, 2023
c7509cb
polish config
zjowowen Aug 15, 2023
9ff7d4b
add sweep main file for new pipeline
zjowowen Aug 16, 2023
0a1a2cc
polish code
zjowowen Aug 16, 2023
b798c2e
Merge branch 'envpool-pr' of https://github.com/zjowowen/DI-engine in…
zjowowen Aug 16, 2023
fb5045b
polish code
zjowowen Aug 16, 2023
6a4d83e
polish code
zjowowen Aug 16, 2023
52ded5a
Add main file
zjowowen Aug 18, 2023
9aa23f7
add test
zjowowen Aug 18, 2023
c6e90a4
add test
zjowowen Aug 18, 2023
3addb8b
merge from main
zjowowen Aug 21, 2023
1d9f9af
Merge branch 'new-pipeline-agent' of https://github.com/zjowowen/DI-e…
zjowowen Aug 21, 2023
ef99434
add time logger
zjowowen Aug 21, 2023
27cb8bd
add new envmanager and collector
zjowowen Aug 22, 2023
5a41f63
fix bug in learner
zjowowen Aug 22, 2023
1b7cf2a
add nstep support for fast dqn
zjowowen Sep 6, 2023
aab3847
change data type
zjowowen Sep 7, 2023
83ece4f
polish code
zjowowen Sep 7, 2023
1980d51
add spaceinvaders envpool
zjowowen Sep 8, 2023
96c0bbf
fix import bug
zjowowen Sep 8, 2023
8419a38
merge file from main
zjowowen Oct 11, 2023
7adbc77
merge file from main
zjowowen Oct 11, 2023
7b581eb
merge file from main
zjowowen Oct 11, 2023
fe30fbe
merge file from main
zjowowen Oct 11, 2023
dc0ea3a
merge file from main
zjowowen Oct 11, 2023
c7b7645
merge file from main
zjowowen Oct 11, 2023
2f9a41f
merge file from main
zjowowen Oct 11, 2023
871fdc0
merge file from main
zjowowen Oct 11, 2023
e6e6828
change offline learner
zjowowen Oct 11, 2023
8d79f66
add dqn policy timer
zjowowen Oct 11, 2023
e1c137a
polish code
zjowowen Oct 11, 2023
ab33001
polish code
zjowowen Oct 12, 2023
340d50e
polish code
zjowowen Oct 12, 2023
e5af078
polish code
zjowowen Oct 12, 2023
20cbef1
add shrink model
zjowowen Oct 12, 2023
532b5b8
add large batch
zjowowen Oct 12, 2023
b31a7ca
add large batch
zjowowen Oct 12, 2023
35069ae
add large learning rate; add priority
zjowowen Oct 13, 2023
a06bd3f
Add update per collect 5 and target update 100
zjowowen Oct 16, 2023
e5ea2fd
Add qbert test 6 7
zjowowen Oct 25, 2023
d7c4983
polish qbert test 6 7
zjowowen Oct 25, 2023
e22df12
polish qbert test 6 7
zjowowen Oct 25, 2023
c03a17b
polish qbert test 8 9
zjowowen Oct 26, 2023
48c1333
polish qbert test 10~12
zjowowen Oct 26, 2023
dda0ffc
polish qbert test 13
zjowowen Oct 26, 2023
878bbb3
polish qbert test 14 15
zjowowen Oct 27, 2023
73b73dc
polish qbert test 16~18
zjowowen Oct 27, 2023
c068721
merge from main
zjowowen Nov 1, 2023
7daf239
polish code
zjowowen Nov 1, 2023
ed0f490
polish code
zjowowen Nov 1, 2023
8981236
polish code
zjowowen Nov 1, 2023
3a1d98c
polish code
zjowowen Nov 1, 2023
83ca217
polish code
zjowowen Nov 1, 2023
97360c0
polish code
zjowowen Nov 1, 2023
25fab56
polish code
zjowowen Nov 1, 2023
ab93b39
polish code
zjowowen Nov 1, 2023
cd762b6
polish code
zjowowen Nov 1, 2023
d3c9bf8
polish code
zjowowen Nov 1, 2023
1bd96e0
polish pr
zjowowen Nov 16, 2023
35a2c67
fix bug
zjowowen Nov 16, 2023
fca097f
Merge branch 'main' of https://github.com/zjowowen/DI-engine into dis…
zjowowen Nov 16, 2023
3687f8b
polish code
zjowowen Nov 20, 2023
4fb85b0
polish code
zjowowen Nov 23, 2023
48ee6da
polish code
zjowowen Nov 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions ding/entry/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Callable, List, Any
from typing import Optional, Callable, List, Any, Dict

from ding.policy import PolicyFactory
from ding.worker import IMetric, MetricSerialEvaluator
Expand Down Expand Up @@ -46,7 +46,8 @@ def random_collect(
collector_env: 'BaseEnvManager', # noqa
commander: 'BaseSerialCommander', # noqa
replay_buffer: 'IBuffer', # noqa
postprocess_data_fn: Optional[Callable] = None
postprocess_data_fn: Optional[Callable] = None,
collect_kwargs: Optional[Dict] = None,
) -> None: # noqa
assert policy_cfg.random_collect_size > 0
if policy_cfg.get('transition_with_policy_data', False):
Expand All @@ -55,7 +56,8 @@ def random_collect(
action_space = collector_env.action_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
if collect_kwargs is None:
collect_kwargs = commander.step()
if policy_cfg.collect.collector.type == 'episode':
new_data = collector.collect(n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs)
else:
Expand Down
188 changes: 183 additions & 5 deletions ding/envs/env_manager/envpool_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from easydict import EasyDict
from copy import deepcopy
import numpy as np
import torch
import treetensor.torch as ttorch
import treetensor.numpy as tnp
from collections import namedtuple
import enum
from typing import Any, Union, List, Tuple, Dict, Callable, Optional
from ditk import logging
try:
Expand All @@ -17,17 +21,28 @@
from ding.torch_utils import to_ndarray


@ENV_MANAGER_REGISTRY.register('env_pool')
class EnvState(enum.IntEnum):
VOID = 0
INIT = 1
RUN = 2
RESET = 3
DONE = 4
ERROR = 5
NEED_RESET = 6


@ENV_MANAGER_REGISTRY.register('envpool')
class PoolEnvManager:
'''
"""
Overview:
PoolEnvManager supports old pipeline of DI-engine.
Envpool now supports Atari, Classic Control, Toy Text, ViZDoom.
Here we list some commonly used env_ids as follows.
For more examples, you can refer to <https://envpool.readthedocs.io/en/latest/api/atari.html>.

- Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5"
- Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1"
'''
"""

@classmethod
def default_config(cls) -> EasyDict:
Expand All @@ -39,10 +54,17 @@ def default_config(cls) -> EasyDict:
# Async mode: batch_size < env_num
env_num=8,
batch_size=8,
image_observation=True,
episodic_life=False,
reward_clip=False,
gray_scale=True,
stack_num=4,
frame_skip=4,
)

def __init__(self, cfg: EasyDict) -> None:
self._cfg = cfg
self._cfg = self.default_config()
self._cfg.update(cfg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why merge config here, we have already merged the config of env manager in compile_config function

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two env manager config, one for evaluator and one for collector. It's too complicated to use compile_config with auto=True.
I suggest use compile_config with auto=False.

self._env_num = cfg.env_num
self._batch_size = cfg.batch_size
self._ready_obs = {}
Expand All @@ -55,6 +77,7 @@ def launch(self) -> None:
seed = 0
else:
seed = self._seed

self._envs = envpool.make(
task_id=self._cfg.env_id,
env_type="gym",
Expand All @@ -65,8 +88,10 @@ def launch(self) -> None:
reward_clip=self._cfg.reward_clip,
stack_num=self._cfg.stack_num,
gray_scale=self._cfg.gray_scale,
frame_skip=self._cfg.frame_skip
frame_skip=self._cfg.frame_skip,
)
self._action_space = self._envs.action_space
self._observation_space = self._envs.observation_space
self._closed = False
self.reset()

Expand All @@ -77,6 +102,8 @@ def reset(self) -> None:
obs, _, _, info = self._envs.recv()
env_id = info['env_id']
obs = obs.astype(np.float32)
if self._cfg.image_observation:
obs /= 255.0
self._ready_obs = deep_merge_dicts({i: o for i, o in zip(env_id, obs)}, self._ready_obs)
if len(self._ready_obs) == self._env_num:
break
Expand All @@ -91,6 +118,8 @@ def step(self, action: dict) -> Dict[int, namedtuple]:

obs, rew, done, info = self._envs.recv()
obs = obs.astype(np.float32)
if self._cfg.image_observation:
obs /= 255.0
rew = rew.astype(np.float32)
env_id = info['env_id']
timesteps = {}
Expand Down Expand Up @@ -124,3 +153,152 @@ def env_num(self) -> int:
@property
def ready_obs(self) -> Dict[int, Any]:
return self._ready_obs

@property
def observation_space(self) -> 'gym.spaces.Space': # noqa
try:
return self._observation_space
except AttributeError:
self.launch()
self.close()
return self._observation_space

@property
def action_space(self) -> 'gym.spaces.Space': # noqa
try:
return self._action_space
except AttributeError:
self.launch()
self.close()
return self._action_space


@ENV_MANAGER_REGISTRY.register('envpool_v2')
class PoolEnvManagerV2:
"""
Overview:
PoolEnvManagerV2 supports new pipeline of DI-engine.
Envpool now supports Atari, Classic Control, Toy Text, ViZDoom.
Here we list some commonly used env_ids as follows.
For more examples, you can refer to <https://envpool.readthedocs.io/en/latest/api/atari.html>.

- Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5"
- Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1"
"""

@classmethod
def default_config(cls) -> EasyDict:
return EasyDict(deepcopy(cls.config))

config = dict(
type='envpool_v2',
env_num=8,
batch_size=8,
image_observation=True,
episodic_life=False,
reward_clip=False,
gray_scale=True,
stack_num=4,
frame_skip=4,
)

def __init__(self, cfg: EasyDict) -> None:
super().__init__()
self._cfg = self.default_config()
self._cfg.update(cfg)
self._env_num = cfg.env_num
self._batch_size = cfg.batch_size

self._closed = True
self._seed = None

def launch(self) -> None:
assert self._closed, "Please first close the env manager"
if self._seed is None:
seed = 0
else:
seed = self._seed

self._envs = envpool.make(
task_id=self._cfg.env_id,
env_type="gym",
num_envs=self._env_num,
batch_size=self._batch_size,
seed=seed,
episodic_life=self._cfg.episodic_life,
reward_clip=self._cfg.reward_clip,
stack_num=self._cfg.stack_num,
gray_scale=self._cfg.gray_scale,
frame_skip=self._cfg.frame_skip,
)
self._action_space = self._envs.action_space
self._observation_space = self._envs.observation_space
self._closed = False
return self.reset()

def reset(self) -> None:
self._envs.async_reset()
ready_obs = {}
while True:
obs, _, _, info = self._envs.recv()
env_id = info['env_id']
obs = obs.astype(np.float32)
if self._cfg.image_observation:
obs /= 255.0
ready_obs = deep_merge_dicts({i: o for i, o in zip(env_id, obs)}, ready_obs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

directly use the assignment operation, don't use deep_merge_dicts too many times, it is a complex function

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

if len(ready_obs) == self._env_num:
break
self._eval_episode_return = [0. for _ in range(self._env_num)]

return ready_obs

def send_action(self, action, env_id) -> Dict[int, namedtuple]:
self._envs.send(action, env_id)

def receive_data(self):
next_obs, rew, done, info = self._envs.recv()
next_obs = next_obs.astype(np.float32)
if self._cfg.image_observation:
next_obs /= 255.0
rew = rew.astype(np.float32)

return next_obs, rew, done, info

def close(self) -> None:
if self._closed:
return
# Envpool has no `close` API
self._closed = True

@property
def closed(self) -> None:
return self._closed

def seed(self, seed: int, dynamic_seed=False) -> None:
# The i-th environment seed in Envpool will be set with i+seed, so we don't do extra transformation here
self._seed = seed
logging.warning("envpool doesn't support dynamic_seed in different episode")

@property
def env_num(self) -> int:
return self._env_num

@property
def observation_space(self) -> 'gym.spaces.Space': # noqa
try:
return self._observation_space
except AttributeError:
self.launch()
self.close()
self._ready_obs = {}
return self._observation_space

@property
def action_space(self) -> 'gym.spaces.Space': # noqa
try:
return self._action_space
except AttributeError:
self.launch()
self.close()
self._ready_obs = {}
return self._action_space
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add envpooltest for this file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

118 changes: 118 additions & 0 deletions ding/example/dqn_nstep_envpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import datetime
from easydict import EasyDict
from ditk import logging
from ding.model import DQN
from ding.policy import DQNFastPolicy
from ding.envs.env_manager.envpool_env_manager import PoolEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import envpool_evaluator, data_pusher, \
eps_greedy_handler, CkptSaver, ContextExchanger, ModelExchanger, online_logger, \
termination_checker, wandb_online_logger, epoch_timer, EnvpoolStepCollector, EnvpoolOffPolicyLearner
from ding.utils import set_pkg_seed
from dizoo.atari.config.serial import pong_dqn_envpool_config


def main(cfg):
logging.getLogger().setLevel(logging.INFO)
cfg.exp_name = 'Pong-v5-DQN-envpool-' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why modify this


collector_env_cfg = EasyDict(
{
'env_id': cfg.env.env_id,
'env_num': cfg.env.collector_env_num,
'batch_size': cfg.env.collector_batch_size,
# env wrappers
'episodic_life': True, # collector: True
'reward_clip': False, # collector: True
'gray_scale': cfg.env.get('gray_scale', True),
'stack_num': cfg.env.get('stack_num', 4),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move some keys to default config and user config

}
)
cfg.env["collector_env_cfg"] = collector_env_cfg
evaluator_env_cfg = EasyDict(
{
'env_id': cfg.env.env_id,
'env_num': cfg.env.evaluator_env_num,
'batch_size': cfg.env.evaluator_batch_size,
# env wrappers
'episodic_life': False, # evaluator: False
'reward_clip': False, # evaluator: False
'gray_scale': cfg.env.get('gray_scale', True),
'stack_num': cfg.env.get('stack_num', 4),
}
)
cfg.env["evaluator_env_cfg"] = evaluator_env_cfg
cfg = compile_config(cfg, PoolEnvManagerV2, DQNFastPolicy, save_cfg=task.router.node_id == 0)
ding_init(cfg)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = PoolEnvManagerV2(cfg.env.collector_env_cfg)
evaluator_env = PoolEnvManagerV2(cfg.env.evaluator_env_cfg)
collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = DQN(**cfg.policy.model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = DQNFastPolicy(cfg.policy, model=model)

# Consider the case with multiple processes
if task.router.is_active:
# You can use labels to distinguish between workers with different roles,
# here we use node_id to distinguish.
if task.router.node_id == 0:
task.add_role(task.role.LEARNER)
elif task.router.node_id == 1:
task.add_role(task.role.EVALUATOR)
else:
task.add_role(task.role.COLLECTOR)

# Sync their context and model between each worker.
task.use(ContextExchanger(skip_n_iter=1))
task.use(ModelExchanger(model))
task.use(epoch_timer())
task.use(envpool_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(eps_greedy_handler(cfg))
task.use(
EnvpoolStepCollector(
cfg,
policy.collect_mode,
collector_env,
random_collect_size=cfg.policy.random_collect_size \
if hasattr(cfg.policy, 'random_collect_size') else 0,
)
)
task.use(data_pusher(cfg, buffer_))
task.use(EnvpoolOffPolicyLearner(cfg, policy, buffer_))
task.use(online_logger(train_show_freq=10))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use two logger

task.use(
wandb_online_logger(
metric_list=policy._monitor_vars_learn(),
model=policy._model,
exp_config=cfg,
anonymous=True,
project_name=cfg.exp_name,
wandb_sweep=False,
)
)
#task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
task.use(termination_checker(max_env_step=10000000))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use int(1e7)

task.run()


if __name__ == "__main__":

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--collector_env_num", type=int, default=8, help="collector env number")
parser.add_argument("--collector_batch_size", type=int, default=8, help="collector batch size")
arg = parser.parse_args()

pong_dqn_envpool_config.env.collector_env_num = arg.collector_env_num
pong_dqn_envpool_config.env.collector_batch_size = arg.collector_batch_size
pong_dqn_envpool_config.seed = arg.seed

main(pong_dqn_envpool_config)
Loading
Loading