|
1 | | -# Copyright (c) 2022. |
| 1 | +# Copyright (c) 2022-2023. |
2 | 2 | # ProrokLab (https://www.proroklab.org/) |
3 | 3 | # All rights reserved. |
4 | 4 |
|
@@ -73,6 +73,9 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): |
73 | 73 |
|
74 | 74 | self.create_passage_map(world) |
75 | 75 |
|
| 76 | + self.pos_rew = torch.zeros(batch_dim, device=device, dtype=torch.float32) |
| 77 | + self.collision_rew = self.pos_rew.clone() |
| 78 | + |
76 | 79 | return world |
77 | 80 |
|
78 | 81 | def reset_world_at(self, env_index: int = None): |
@@ -213,8 +216,8 @@ def reward(self, agent: Agent): |
213 | 216 | self.rew = torch.zeros( |
214 | 217 | self.world.batch_dim, device=self.world.device, dtype=torch.float32 |
215 | 218 | ) |
216 | | - self.pos_rew = self.rew.clone() |
217 | | - self.collision_rew = self.rew.clone() |
| 219 | + self.pos_rew[:] = 0 |
| 220 | + self.collision_rew[:] = 0 |
218 | 221 |
|
219 | 222 | ball_passed = self.ball.state.pos[:, Y] > 0 |
220 | 223 |
|
@@ -328,32 +331,35 @@ def removed(i): |
328 | 331 | def spawn_passage_map(self, env_index): |
329 | 332 | if not self.fixed_passage: |
330 | 333 | order = torch.randperm(len(self.passages)).tolist() |
331 | | - self.passages = [self.passages[i] for i in order] |
332 | | - for i, passage in enumerate(self.passages): |
| 334 | + self.passages_to_place = [self.passages[i] for i in order] |
| 335 | + else: |
| 336 | + self.passages_to_place = self.passages |
| 337 | + for i, passage in enumerate(self.passages_to_place): |
333 | 338 | if not passage.collide: |
334 | 339 | passage.is_rendering[:] = False |
335 | 340 | passage.neighbour = False |
336 | 341 | try: |
337 | | - passage.neighbour += not self.passages[i - 1].collide |
| 342 | + passage.neighbour += not self.passages_to_place[i - 1].collide |
338 | 343 | except IndexError: |
339 | 344 | pass |
340 | 345 | try: |
341 | | - passage.neighbour += not self.passages[i + 1].collide |
| 346 | + passage.neighbour += not self.passages_to_place[i + 1].collide |
342 | 347 | except IndexError: |
343 | 348 | pass |
| 349 | + pos = torch.tensor( |
| 350 | + [ |
| 351 | + -1 |
| 352 | + - self.agent_radius |
| 353 | + + self.passage_length / 2 |
| 354 | + + self.passage_length * i, |
| 355 | + 0.0, |
| 356 | + ], |
| 357 | + dtype=torch.float32, |
| 358 | + device=self.world.device, |
| 359 | + ) |
344 | 360 | passage.neighbour *= passage.collide |
345 | 361 | passage.set_pos( |
346 | | - torch.tensor( |
347 | | - [ |
348 | | - -1 |
349 | | - - self.agent_radius |
350 | | - + self.passage_length / 2 |
351 | | - + self.passage_length * i, |
352 | | - 0.0, |
353 | | - ], |
354 | | - dtype=torch.float32, |
355 | | - device=self.world.device, |
356 | | - ), |
| 362 | + pos, |
357 | 363 | batch_index=env_index, |
358 | 364 | ) |
359 | 365 |
|
|
0 commit comments