forked from IntelLabs/networkgym
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathoffline_train.py
156 lines (132 loc) · 5.89 KB
/
offline_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#Copyright(C) 2023 Intel Corporation
#SPDX-License-Identifier: Apache-2.0
#File : start_client_sequential_training.py
# In this example, we perform sequential training for two environment sessions.
# the first session lasts for 3 episodes and the second session lasts for 1 episodes.
import numpy as np
import torch
import random
import sys
import fire
import copy
import wandb
from utils.buffer import ReplayBuffer
from CleanRL_agents import SACAgent
from CORL_agents import CQL
from tqdm import tqdm
from network_gym_client import load_config_file
from network_gym_client import Env as NetworkGymEnv
from gymnasium.wrappers import NormalizeObservation
from utils.utils import *
sys.path.append('../')
sys.path.append('../../')
slice_lists = slice_lists = [
[
{"num_users":6,"dedicated_rbg":0,"prioritized_rbg":12,"shared_rbg":25},
{"num_users":20,"dedicated_rbg":0,"prioritized_rbg":13,"shared_rbg":25},
{"num_users":5,"dedicated_rbg":0,"prioritized_rbg":0,"shared_rbg":25}
],
[
{"num_users":11,"dedicated_rbg":0,"prioritized_rbg":12,"shared_rbg":25},
{"num_users":15,"dedicated_rbg":0,"prioritized_rbg":13,"shared_rbg":25},
{"num_users":5,"dedicated_rbg":0,"prioritized_rbg":0,"shared_rbg":25}
],
[
{"num_users":13,"dedicated_rbg":0,"prioritized_rbg":12,"shared_rbg":25},
{"num_users":13,"dedicated_rbg":0,"prioritized_rbg":13,"shared_rbg":25},
{"num_users":5,"dedicated_rbg":0,"prioritized_rbg":0,"shared_rbg":25}
],
[
{"num_users":15,"dedicated_rbg":0,"prioritized_rbg":12,"shared_rbg":25},
{"num_users":11,"dedicated_rbg":0,"prioritized_rbg":13,"shared_rbg":25},
{"num_users":5,"dedicated_rbg":0,"prioritized_rbg":0,"shared_rbg":25}
],
[
{"num_users":20,"dedicated_rbg":0,"prioritized_rbg":12,"shared_rbg":25},
{"num_users":6,"dedicated_rbg":0,"prioritized_rbg":13,"shared_rbg":25},
{"num_users":5,"dedicated_rbg":0,"prioritized_rbg":0,"shared_rbg":25}
],
]
MODEL_SAVE_FREQ = 2000
LOG_INTERVAL = 10
NUM_OF_EVALUATE_EPISODES = 10
EVAL_EPI_PER_SESSION = 1
def main(agent_type:str,
env_name:str,
dataset:str,
num_steps = 60000,
client_id = 0,
hidden_dim = 64,
steps_per_episode = 10,
episode_per_session = 1,
random_seed = 1240,
):
# client_id = 1
# env_name = "network_slicing"
storage_ver = 0
config_json = load_config_file(env_name)
config_json["env_config"]["random_seed"] = random_seed
train_random_seed = random_seed
config_json["rl_config"]["agent"] = agent_type
config_json["env_config"]["steps_per_episode"] = steps_per_episode
config_json["env_config"]["episodes_per_session"] = episode_per_session
buffer = ReplayBuffer(max_size=1000000, obs_shape=15, n_actions=2)
if dataset == "mixed":
for d in ["sac", "baseline", "baseline_delay"]:
buffer.load_buffer(f"./dataset/{d}_buffer_new.h5")
else:
buffer.load_buffer(f"./dataset/{dataset}_buffer_new.h5")
# buffer.nomarlize_states()
# Create the environment
target_entropy = -np.prod((2,)).item()
# breakpoint()
agent = CQL(state_dim=15, action_dim=2, hidden_dim=hidden_dim, target_entropy=target_entropy,
q_n_hidden_layers=1, max_action=1, qf_lr=3e-4, policy_lr=5e-5,device="cuda:0", bc_steps=0)
run = wandb.init(project="network-slicing-offline",
name=f"{agent_type}-nn-{dataset}-ver{storage_ver}",
config=config_json)
num_episodes = 0
progress_bar = tqdm(range(num_steps))
best_eval_reward = -np.inf
# Training loop
# Evaluate every 500 steps, same as model saving frequency
for step in progress_bar:
batch = buffer.sample(64)
# breakpoint()
new_reward = vary_rewards(batch[0], [0.5, 0.5, 0], 1, 4)
batch = (batch[0], batch[1], new_reward, batch[3], batch[4])
train_info = agent.learn(*batch)
wandb.log(train_info)
if (step + 1) % MODEL_SAVE_FREQ == 0:
print("Step: {}, Saving model...".format(step))
agent.save("./models/cql_dataset_{}_{}_ver{}_res.pt".format(dataset,train_random_seed, storage_ver))
eval_agent = copy.deepcopy(agent)
eval_agent.actor.eval()
config_json["env_config"]["steps_per_episode"] = 52
config_json["env_config"]["episodes_per_session"] = EVAL_EPI_PER_SESSION
random_seed = 1
avg_reward = 0
for slice_list in slice_lists:
print("evaluating env: {}".format(slice_list))
config_json["env_config"]["slice_list"] = slice_list
config_json["env_config"]["random_seed"] = random_seed
eval_env = NetworkGymEnv(client_id, config_json, log=False)
# normalized_eval_env = NormalizeObservation(eval_env)
env_reward, eval_dict = evaluate(eval_agent, eval_env, n_episodes=1)
avg_reward += env_reward
wandb.log(eval_dict)
avg_reward /= len(slice_lists)
art = wandb.Artifact(f"{agent_type}-nn-{wandb.run.id}", type="model")
art.add_file("./models/cql_dataset_{}_{}_ver{}_res.pt".format(dataset,train_random_seed, storage_ver))
if avg_reward > best_eval_reward:
best_eval_reward = avg_reward
wandb.log_artifact(art, aliases=["latest", "best"])
else:
wandb.log_artifact(art)
print("Step: {}, Eval Reward: {}".format(step, avg_reward))
wandb.log({"eval_avg_reward": avg_reward})
# buffer.save_buffer("./dataset/offline_data_heavy_traffic.h5")
storage_ver += 1
progress_bar.set_description("Step: {}, Eval Reward: {}".format(step, avg_reward))
if __name__ == "__main__":
fire.Fire(main)