Skip to content

Commit e0a2c05

Browse files
committed
Rewards init moved to make_world in main scenarios
1 parent 4b9771a commit e0a2c05

File tree

12 files changed

+128
-100
lines changed

12 files changed

+128
-100
lines changed

vmas/scenarios/balance.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022.
1+
# Copyright (c) 2022-2023.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44

@@ -72,6 +72,9 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
7272
)
7373
world.add_landmark(floor)
7474

75+
self.pos_rew = torch.zeros(batch_dim, device=device, dtype=torch.float32)
76+
self.ground_rew = self.pos_rew.clone()
77+
7578
return world
7679

7780
def reset_world_at(self, env_index: int = None):
@@ -201,12 +204,8 @@ def reward(self, agent: Agent):
201204
is_first = agent == self.world.agents[0]
202205

203206
if is_first:
204-
self.pos_rew = torch.zeros(
205-
self.world.batch_dim, device=self.world.device, dtype=torch.float32
206-
)
207-
self.ground_rew = torch.zeros(
208-
self.world.batch_dim, device=self.world.device, dtype=torch.float32
209-
)
207+
self.pos_rew[:] = 0
208+
self.ground_rew[:] = 0
210209

211210
self.on_the_ground = (
212211
self.package.state.pos[:, Y] <= -self.world.y_semidim

vmas/scenarios/ball_passage.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022.
1+
# Copyright (c) 2022-2023.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44

@@ -73,6 +73,9 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
7373

7474
self.create_passage_map(world)
7575

76+
self.pos_rew = torch.zeros(batch_dim, device=device, dtype=torch.float32)
77+
self.collision_rew = self.pos_rew.clone()
78+
7679
return world
7780

7881
def reset_world_at(self, env_index: int = None):
@@ -213,8 +216,8 @@ def reward(self, agent: Agent):
213216
self.rew = torch.zeros(
214217
self.world.batch_dim, device=self.world.device, dtype=torch.float32
215218
)
216-
self.pos_rew = self.rew.clone()
217-
self.collision_rew = self.rew.clone()
219+
self.pos_rew[:] = 0
220+
self.collision_rew[:] = 0
218221

219222
ball_passed = self.ball.state.pos[:, Y] > 0
220223

@@ -328,32 +331,35 @@ def removed(i):
328331
def spawn_passage_map(self, env_index):
329332
if not self.fixed_passage:
330333
order = torch.randperm(len(self.passages)).tolist()
331-
self.passages = [self.passages[i] for i in order]
332-
for i, passage in enumerate(self.passages):
334+
self.passages_to_place = [self.passages[i] for i in order]
335+
else:
336+
self.passages_to_place = self.passages
337+
for i, passage in enumerate(self.passages_to_place):
333338
if not passage.collide:
334339
passage.is_rendering[:] = False
335340
passage.neighbour = False
336341
try:
337-
passage.neighbour += not self.passages[i - 1].collide
342+
passage.neighbour += not self.passages_to_place[i - 1].collide
338343
except IndexError:
339344
pass
340345
try:
341-
passage.neighbour += not self.passages[i + 1].collide
346+
passage.neighbour += not self.passages_to_place[i + 1].collide
342347
except IndexError:
343348
pass
349+
pos = torch.tensor(
350+
[
351+
-1
352+
- self.agent_radius
353+
+ self.passage_length / 2
354+
+ self.passage_length * i,
355+
0.0,
356+
],
357+
dtype=torch.float32,
358+
device=self.world.device,
359+
)
344360
passage.neighbour *= passage.collide
345361
passage.set_pos(
346-
torch.tensor(
347-
[
348-
-1
349-
- self.agent_radius
350-
+ self.passage_length / 2
351-
+ self.passage_length * i,
352-
0.0,
353-
],
354-
dtype=torch.float32,
355-
device=self.world.device,
356-
),
362+
pos,
357363
batch_index=env_index,
358364
)
359365

vmas/scenarios/ball_trajectory.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022.
1+
# Copyright (c) 2022-2023.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
from typing import Dict
@@ -72,6 +72,10 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
7272
)
7373
world.add_joint(self.joints[i])
7474

75+
self.pos_rew = torch.zeros(batch_dim, device=device, dtype=torch.float32)
76+
self.speed_rew = self.pos_rew.clone()
77+
self.dist_rew = self.pos_rew.clone()
78+
7579
return world
7680

7781
def reset_world_at(self, env_index: int = None):

vmas/scenarios/buzz_wire.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022.
1+
# Copyright (c) 2022-2023.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
from typing import Dict
@@ -90,6 +90,9 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
9090

9191
self.build_path_line(world)
9292

93+
self.pos_rew = torch.zeros(batch_dim, device=device, dtype=torch.float32)
94+
self.collision_rew = self.pos_rew.clone()
95+
9396
return world
9497

9598
def reset_world_at(self, env_index: int = None):
@@ -204,8 +207,8 @@ def reward(self, agent: Agent):
204207
self.rew = torch.zeros(
205208
self.world.batch_dim, device=self.world.device, dtype=torch.float32
206209
)
207-
self.pos_rew = self.rew.clone()
208-
self.collision_rew = self.rew.clone()
210+
self.pos_rew[:] = 0
211+
self.collision_rew[:] = 0
209212
self.collided = torch.full(
210213
(self.world.batch_dim,), False, device=self.world.device
211214
)

vmas/scenarios/discovery.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022.
1+
# Copyright (c) 2022-2023.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44

@@ -85,6 +85,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
8585
),
8686
],
8787
)
88+
agent.collision_rew = torch.zeros(batch_dim, device=device)
89+
agent.covering_reward = agent.collision_rew.clone()
8890
world.add_agent(agent)
8991

9092
self._targets = []
@@ -99,6 +101,9 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
99101
world.add_landmark(target)
100102
self._targets.append(target)
101103

104+
self.covered_targets = torch.zeros(batch_dim, self.n_targets, device=device)
105+
self.shared_covering_rew = torch.zeros(batch_dim, device=device)
106+
102107
return world
103108

104109
def reset_world_at(self, env_index: int = None):
@@ -140,17 +145,13 @@ def reward(self, agent: Agent):
140145
)
141146
self.covered_targets = self.agents_per_target >= self._agents_per_target
142147

143-
self.shared_covering_rew = torch.zeros(
144-
self.world.batch_dim, device=self.world.device
145-
)
148+
self.shared_covering_rew[:] = 0
146149
for a in self.world.agents:
147150
self.shared_covering_rew += self.agent_reward(a)
148151
self.shared_covering_rew[self.shared_covering_rew != 0] /= 2
149152

150153
# Avoid collisions with each other
151-
agent.collision_rew = torch.zeros(
152-
self.world.batch_dim, device=self.world.device
153-
)
154+
agent.collision_rew[:] = 0
154155
for a in self.world.agents:
155156
if a != agent:
156157
agent.collision_rew[
@@ -206,9 +207,7 @@ def get_outside_pos(self, env_index):
206207
def agent_reward(self, agent):
207208
agent_index = self.world.agents.index(agent)
208209

209-
agent.covering_reward = torch.zeros(
210-
self.world.batch_dim, device=self.world.device
211-
)
210+
agent.covering_reward[:] = 0
212211
targets_covered_by_agent = (
213212
self.agents_targets_dists[:, agent_index] < self._covering_range
214213
)

vmas/scenarios/dropout.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
# Copyright (c) 2022.
1+
# Copyright (c) 2022-2023.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
import math
55
from typing import Dict
66

77
import torch
88
from torch import Tensor
9+
910
from vmas import render_interactively
1011
from vmas.simulator.core import Agent, Landmark, Sphere, World
1112
from vmas.simulator.scenario import BaseScenario
@@ -37,6 +38,9 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
3738
)
3839
world.add_landmark(goal)
3940

41+
self.pos_rew = torch.zeros(batch_dim, device=device)
42+
self.energy_rew = self.pos_rew.clone()
43+
4044
return world
4145

4246
def reset_world_at(self, env_index: int = None):
@@ -98,7 +102,7 @@ def reward(self, agent: Agent):
98102
dim=-1,
99103
)
100104

101-
self.pos_rew = torch.zeros(self.world.batch_dim, device=self.world.device)
105+
self.pos_rew[:] = 0
102106
self.pos_rew[self.any_eaten * ~self.world.landmarks[0].eaten] = 1
103107

104108
if is_last:
@@ -132,11 +136,7 @@ def observation(self, agent: Agent):
132136
)
133137

134138
def info(self, agent: Agent) -> Dict[str, Tensor]:
135-
try:
136-
info = {"pos_rew": self.pos_rew, "energy_rew": self.energy_rew}
137-
# When reset is called before reward()
138-
except AttributeError:
139-
info = {}
139+
info = {"pos_rew": self.pos_rew, "energy_rew": self.energy_rew}
140140
return info
141141

142142
def done(self):

vmas/scenarios/flocking.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022.
1+
# Copyright (c) 2022-2023.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
from typing import Dict, Callable
@@ -61,6 +61,11 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
6161
)
6262
world.add_landmark(self._target)
6363

64+
self.collision_rew = torch.zeros(batch_dim, device=device)
65+
self.velocity_rew = self.collision_rew.clone()
66+
self.separation_rew = self.collision_rew.clone()
67+
self.cohesion_rew = self.collision_rew.clone()
68+
6469
return world
6570

6671
def reset_world_at(self, env_index: int = None):
@@ -75,7 +80,7 @@ def reset_world_at(self, env_index: int = None):
7580

7681
def reward(self, agent: Agent):
7782
# Avoid collisions with each other
78-
self.collision_rew = torch.zeros(self.world.batch_dim, device=self.world.device)
83+
self.collision_rew[:] = 0
7984
for a in self.world.agents:
8085
if a != agent:
8186
self.collision_rew[self.world.is_overlapping(a, agent)] -= 1.0
@@ -112,16 +117,14 @@ def observation(self, agent: Agent):
112117
)
113118

114119
def info(self, agent: Agent) -> Dict[str, Tensor]:
115-
try:
116-
info = {
117-
"collision_rew": self.collision_rew,
118-
"velocity_rew": self.velocity_rew,
119-
"separation_rew": self.separation_rew,
120-
"cohesion_rew": self.cohesion_rew,
121-
}
122-
# When reset is called before reward()
123-
except AttributeError:
124-
info = {}
120+
121+
info = {
122+
"collision_rew": self.collision_rew,
123+
"velocity_rew": self.velocity_rew,
124+
"separation_rew": self.separation_rew,
125+
"cohesion_rew": self.cohesion_rew,
126+
}
127+
125128
return info
126129

127130

vmas/scenarios/give_way.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022.
1+
# Copyright (c) 2022-2023.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
import math
@@ -121,6 +121,14 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
121121

122122
self.spawn_map(world)
123123

124+
for agent in world.agents:
125+
agent.energy_rew = torch.zeros(batch_dim, device=device)
126+
agent.agent_collision_rew = agent.energy_rew.clone()
127+
agent.obstacle_collision_rew = agent.agent_collision_rew.clone()
128+
129+
self.pos_rew = torch.zeros(batch_dim, device=device)
130+
self.final_rew = self.pos_rew.clone()
131+
124132
return world
125133

126134
def reset_world_at(self, env_index: int = None):
@@ -222,10 +230,8 @@ def reward(self, agent: Agent):
222230
green_agent = self.world.agents[-1]
223231

224232
if is_first:
225-
self.pos_rew = torch.zeros(
226-
self.world.batch_dim, device=self.world.device, dtype=torch.float32
227-
)
228-
self.final_rew = torch.zeros(self.world.batch_dim, device=self.world.device)
233+
self.pos_rew[:] = 0
234+
self.final_rew[:] = 0
229235

230236
self.blue_distance = torch.linalg.vector_norm(
231237
blue_agent.state.pos - blue_agent.goal.state.pos,
@@ -253,12 +259,8 @@ def reward(self, agent: Agent):
253259
self.final_rew[self.goal_reached] = self.final_reward
254260
self.reached_goal += self.goal_reached
255261

256-
agent.agent_collision_rew = torch.zeros(
257-
(self.world.batch_dim,), device=self.world.device
258-
)
259-
agent.obstacle_collision_rew = torch.zeros(
260-
(self.world.batch_dim,), device=self.world.device
261-
)
262+
agent.agent_collision_rew[:] = 0
263+
agent.obstacle_collision_rew[:] = 0
262264
for a in self.world.agents:
263265
if a != agent:
264266
agent.agent_collision_rew[

0 commit comments

Comments
 (0)