forked from instadeepai/Mava
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_maddpg_scaling.py
107 lines (92 loc) · 3.1 KB
/
run_maddpg_scaling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example running MADDPG on debug MPE environments, while using 4 executors."""
import functools
from datetime import datetime
from typing import Any
import launchpad as lp
import sonnet as snt
from absl import app, flags
from launchpad.nodes.python.local_multi_processing import PythonProcess
from mava.systems.tf import maddpg
from mava.utils import lp_utils
from mava.utils.environments import debugging_utils
from mava.utils.loggers import logger_utils
FLAGS = flags.FLAGS
flags.DEFINE_string(
"env_name",
"simple_spread",
"Debugging environment name (str).",
)
flags.DEFINE_string(
"action_space",
"continuous",
"Environment action space type (str).",
)
flags.DEFINE_string(
"mava_id",
str(datetime.now()),
"Experiment identifier that can be used to continue experiments.",
)
flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.")
def main(_: Any) -> None:
# Environment.
environment_factory = functools.partial(
debugging_utils.make_environment,
env_name=FLAGS.env_name,
action_space=FLAGS.action_space,
)
# Networks.
network_factory = lp_utils.partial_kwargs(maddpg.make_default_networks)
# Checkpointer appends "Checkpoints" to checkpoint_dir.
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"
# Log every [log_every] seconds.
log_every = 10
logger_factory = functools.partial(
logger_utils.make_logger,
directory=FLAGS.base_dir,
to_terminal=True,
to_tensorboard=True,
time_stamp=FLAGS.mava_id,
time_delta=log_every,
)
# Distributed program.
program = maddpg.MADDPG(
environment_factory=environment_factory,
network_factory=network_factory,
logger_factory=logger_factory,
num_executors=4,
policy_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
checkpoint_subpath=checkpoint_dir,
max_gradient_norm=40.0,
).build()
# Ensure only trainer runs on gpu, while other processes run on cpu.
gpu_id = -1
env_vars = {"CUDA_VISIBLE_DEVICES": str(gpu_id)}
local_resources = {
"trainer": [],
"evaluator": PythonProcess(env=env_vars),
"executor": PythonProcess(env=env_vars),
}
# Launch.
lp.launch(
program,
lp.LaunchType.LOCAL_MULTI_PROCESSING,
terminal="current_terminal",
local_resources=local_resources,
)
if __name__ == "__main__":
app.run(main)