diff --git a/navsim/planning/training/dataset.py b/navsim/planning/training/dataset.py index 76aa49f..b36bd86 100644 --- a/navsim/planning/training/dataset.py +++ b/navsim/planning/training/dataset.py @@ -120,13 +120,18 @@ def _load_scene_with_token(self, idx) -> Tuple[Dict[str, torch.Tensor], Dict[str """ token = self.tokens[idx] token_path = self._valid_cache_paths[token] - if 'training_cache' in str(token_path): - pdm_token_path = str(token_path).replace("training_cache", "train_pdm_cache") - pdm_token_path_parts = pdm_token_path.split('/') - pdm_token_path_parts.insert(-1, 'unknown') - pdm_token_path = '/'.join(pdm_token_path_parts) + "/metric_cache.pkl" - else: - pdm_token_path = token_path + token_path_str = str(token_path) + if 'training_cache' in token_path_str: + pdm_token_path = token_path_str.replace( + "training_cache", "train_pdm_cache" + ) + else: + pdm_token_path = token_path_str + + pdm_token_path_parts = pdm_token_path.split('/') + pdm_token_path_parts.insert(-1, 'unknown') + pdm_token_path = '/'.join(pdm_token_path_parts) + "/metric_cache.pkl" + features: Dict[str, torch.Tensor] = {} for builder in self._feature_builders: