Skip to content

Commit 794d785

Browse files
committed
Rewards init moved to make_world in debug scenarios
1 parent e0a2c05 commit 794d785

File tree

6 files changed

+29
-12
lines changed

6 files changed

+29
-12
lines changed

vmas/scenarios/debug/asym_joint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022.
1+
# Copyright (c) 2022-2023.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
import math
@@ -141,6 +141,9 @@ def mass_collision_filter(e):
141141
)
142142
world.add_joint(joint)
143143

144+
self.rot_rew = torch.zeros(batch_dim, device=device)
145+
self.energy_rew = self.rot_rew.clone()
146+
144147
return world
145148

146149
def reset_world_at(self, env_index: int = None):
@@ -227,10 +230,7 @@ def reward(self, agent: Agent):
227230
is_first = agent == self.world.agents[0]
228231

229232
if is_first:
230-
self.rew = torch.zeros(
231-
self.world.batch_dim, device=self.world.device, dtype=torch.float32
232-
)
233-
self.rot_rew = self.rew.clone()
233+
self.rot_rew[:] = 0
234234

235235
# Rot shaping
236236
joint_dist_to_90_rot = get_line_angle_dist_0_180(

vmas/scenarios/debug/circle_trajectory.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022.
1+
# Copyright (c) 2022-2023.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
from typing import Dict
@@ -53,6 +53,9 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
5353
)
5454
world.add_agent(self.agent)
5555

56+
self.pos_rew = torch.zeros(batch_dim, device=device)
57+
self.dot_product = self.pos_rew.clone()
58+
5659
return world
5760

5861
def process_action(self, agent: Agent):

vmas/scenarios/debug/goal.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022.
1+
# Copyright (c) 2022-2023.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
import math
@@ -77,8 +77,12 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
7777
agent, world, controller_params, "standard"
7878
)
7979
agent.goal = self.goal
80+
agent.energy_rew = torch.zeros(batch_dim, device=device)
8081
world.add_agent(agent)
8182

83+
self.pos_rew = torch.zeros(batch_dim, device=device)
84+
self.time_rew = self.pos_rew.clone()
85+
8286
return world
8387

8488
def reset_world_at(self, env_index: int = None):
@@ -189,8 +193,8 @@ def reward(self, agent: Agent):
189193
is_first = agent == self.world.agents[0]
190194

191195
if is_first:
192-
self.pos_rew = torch.zeros(self.world.batch_dim, device=self.world.device)
193-
self.time_rew = torch.zeros(self.world.batch_dim, device=self.world.device)
196+
self.pos_rew[:] = 0
197+
self.time_rew[:] = 0
194198

195199
# Pos shaping
196200
goal_dist = torch.stack(

vmas/scenarios/debug/het_mass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022.
1+
# Copyright (c) 2022-2023.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
import math
@@ -7,6 +7,7 @@
77
import numpy as np
88
import torch
99
from torch import Tensor
10+
1011
from vmas import render_interactively
1112
from vmas.simulator.core import Agent, World
1213
from vmas.simulator.scenario import BaseScenario
@@ -37,6 +38,9 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
3738
)
3839
world.add_agent(self.blue_agent)
3940

41+
self.max_speed = torch.zeros(batch_dim, device=device)
42+
self.energy_expenditure = self.max_speed.clone()
43+
4044
return world
4145

4246
def reset_world_at(self, env_index: int = None):

vmas/scenarios/debug/line_trajectory.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022.
1+
# Copyright (c) 2022-2023.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
from typing import Dict
@@ -39,6 +39,10 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
3939
self.tangent = torch.zeros((world.batch_dim, world.dim_p), device=world.device)
4040
self.tangent[:, Y] = 1
4141

42+
self.pos_rew = torch.zeros(batch_dim, device=device)
43+
self.dot_product = self.pos_rew.clone()
44+
self.steady_rew = self.pos_rew.clone()
45+
4246
return world
4347

4448
def process_action(self, agent: Agent):

vmas/scenarios/debug/vel_control.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022.
1+
# Copyright (c) 2022-2023.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
from typing import Dict
@@ -84,6 +84,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
8484
self.landmark = Landmark("landmark 0", collide=False, movable=True)
8585
world.add_landmark(self.landmark)
8686

87+
self.energy_expenditure = torch.zeros(batch_dim, device=device)
88+
8789
return world
8890

8991
def reset_world_at(self, env_index: int = None):

0 commit comments

Comments
 (0)