diff --git a/python/aibrix/aibrix/downloader/base.py b/python/aibrix/aibrix/downloader/base.py index 14c660d6..c584582d 100644 --- a/python/aibrix/aibrix/downloader/base.py +++ b/python/aibrix/aibrix/downloader/base.py @@ -78,13 +78,16 @@ def download_directory(self, local_path: Path): used to download the directory. """ directory_list = self._directory_list(self.bucket_path) + # filter the directory path + files = [file for file in directory_list if not file.endswith("/")] + if self.allow_file_suffix is None: logger.info(f"All files from {self.bucket_path} will be downloaded.") - filtered_files = directory_list + filtered_files = files else: filtered_files = [ file - for file in directory_list + for file in files if any(file.endswith(suffix) for suffix in self.allow_file_suffix) ] diff --git a/python/aibrix/aibrix/downloader/s3.py b/python/aibrix/aibrix/downloader/s3.py index 48ddb3b8..1841eea3 100644 --- a/python/aibrix/aibrix/downloader/s3.py +++ b/python/aibrix/aibrix/downloader/s3.py @@ -121,9 +121,7 @@ def _directory_list(self, path: str) -> List[str]: Bucket=self.bucket_name, Delimiter="/", Prefix=path ) contents = objects_out.get("Contents", []) - files = [content.get("Key") for content in contents] - # filter the directory path - return [file for file in files if not file.endswith("/")] + return [content.get("Key") for content in contents] def _support_range_download(self) -> bool: return True