From be7964844b2d5be714c577cecaf8672a461a7139 Mon Sep 17 00:00:00 2001 From: brosoul Date: Sat, 21 Dec 2024 17:08:22 +0800 Subject: [PATCH 1/8] refact: refact tos cache dir same with huggingface --- python/aibrix/aibrix/config.py | 11 +++++++++-- python/aibrix/aibrix/downloader/s3.py | 4 +++- python/aibrix/aibrix/downloader/tos.py | 6 ++++-- python/aibrix/aibrix/downloader/utils.py | 4 ++-- python/aibrix/tests/downloader/test_utils.py | 10 ++++++---- 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/python/aibrix/aibrix/config.py b/python/aibrix/aibrix/config.py index db99a37a..fdc4f85a 100644 --- a/python/aibrix/aibrix/config.py +++ b/python/aibrix/aibrix/config.py @@ -14,7 +14,14 @@ DEFAULT_METRIC_COLLECTOR_TIMEOUT = 1 - -DOWNLOAD_CACHE_DIR = ".cache" +""" +DOWNLOAD CACHE DIR would be like: +. +└── .cache + └── huggingface | s3 | tos + ├── .gitignore + └── download +""" +DOWNLOAD_CACHE_DIR = ".cache/%s/download" EXCLUDE_METRICS_HTTP_ENDPOINTS = ["/metrics/"] diff --git a/python/aibrix/aibrix/downloader/s3.py b/python/aibrix/aibrix/downloader/s3.py index bb5cd351..75baa041 100644 --- a/python/aibrix/aibrix/downloader/s3.py +++ b/python/aibrix/aibrix/downloader/s3.py @@ -44,6 +44,7 @@ def _parse_bucket_info_from_uri(uri: str, scheme: str = "s3") -> Tuple[str, str] class S3BaseDownloader(BaseDownloader): + _source = "s3_base" def __init__( self, scheme: str, @@ -144,7 +145,7 @@ def download( # check if file exist etag = meta_data.get("ETag", "") file_size = meta_data.get("ContentLength", 0) - meta_data_file = meta_file(local_path=local_path, file_name=_file_name) + meta_data_file = meta_file(local_path=local_path, file_name=_file_name, source=self._source) if not need_to_download(local_file, meta_data_file, file_size, etag): return @@ -185,6 +186,7 @@ def download_progress(bytes_transferred): class S3Downloader(S3BaseDownloader): + _source = "s3" def __init__( self, model_uri, diff --git a/python/aibrix/aibrix/downloader/tos.py b/python/aibrix/aibrix/downloader/tos.py index c2d4abb3..dbb30c95 100644 --- a/python/aibrix/aibrix/downloader/tos.py +++ b/python/aibrix/aibrix/downloader/tos.py @@ -46,6 +46,7 @@ def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]: class TOSDownloaderV1(BaseDownloader): + _source = "tos" def __init__( self, model_uri, @@ -133,7 +134,7 @@ def download( # check if file exist etag = meta_data.etag file_size = meta_data.content_length - meta_data_file = meta_file(local_path=local_path, file_name=_file_name) + meta_data_file = meta_file(local_path=local_path, file_name=_file_name, source=self._source) if not need_to_download(local_file, meta_data_file, file_size, etag): return @@ -173,6 +174,7 @@ def download_progress( class TOSDownloaderV2(S3BaseDownloader): + _source = "tos" def __init__( self, model_uri, @@ -180,7 +182,7 @@ def __init__( enable_progress_bar: bool = False, ): super().__init__( - scheme="s3", + scheme="tos", model_uri=model_uri, model_name=model_name, enable_progress_bar=enable_progress_bar, diff --git a/python/aibrix/aibrix/downloader/utils.py b/python/aibrix/aibrix/downloader/utils.py index 47a4d733..e3e64af7 100644 --- a/python/aibrix/aibrix/downloader/utils.py +++ b/python/aibrix/aibrix/downloader/utils.py @@ -22,10 +22,10 @@ logger = init_logger(__name__) -def meta_file(local_path: Union[Path, str], file_name: str) -> Path: +def meta_file(local_path: Union[Path, str], file_name: str, source: str) -> Path: return ( Path(local_path) - .joinpath(DOWNLOAD_CACHE_DIR) + .joinpath(DOWNLOAD_CACHE_DIR % source) .joinpath(f"{file_name}.metadata") .absolute() ) diff --git a/python/aibrix/tests/downloader/test_utils.py b/python/aibrix/tests/downloader/test_utils.py index 6fba0a10..4e0abc1d 100644 --- a/python/aibrix/tests/downloader/test_utils.py +++ b/python/aibrix/tests/downloader/test_utils.py @@ -41,8 +41,9 @@ def prepare_file_and_meta_data(file_path, meta_path, file_size, etag): def test_meta_file(): with tempfile.TemporaryDirectory() as tmp_dir: - meta_file_path = meta_file(tmp_dir, "test") - assert str(meta_file_path).endswith(f"{DOWNLOAD_CACHE_DIR}/test.metadata") + source = "s3" + meta_file_path = meta_file(tmp_dir, "test", source=source) + assert str(meta_file_path).endswith(f"{DOWNLOAD_CACHE_DIR % source}/test.metadata") def test_save_load_meta_data(): @@ -61,6 +62,7 @@ def test_save_load_meta_data(): def test_check_file_exist(): + source = "s3" with tempfile.TemporaryDirectory() as tmp_dir: # prepare file and meta data file_size = 10 @@ -68,7 +70,7 @@ def test_check_file_exist(): etag = "here_is_etag_value_xyz" file_path = Path(tmp_dir).joinpath(file_name) # create meta data file - meta_path = meta_file(tmp_dir, file_name) + meta_path = meta_file(tmp_dir, file_name, source=source) prepare_file_and_meta_data(file_path, meta_path, file_size, etag) assert check_file_exist(file_path, meta_path, file_size, etag) @@ -80,7 +82,7 @@ def test_check_file_exist(): not_exist_file_name = "not_exist" not_exist_file_path = Path(tmp_dir).joinpath(not_exist_file_name) - not_exist_meta_path = meta_file(tmp_dir, not_exist_file_name) + not_exist_meta_path = meta_file(tmp_dir, not_exist_file_name, source=source) assert not check_file_exist( not_exist_file_path, not_exist_meta_path, file_size, etag ) From d9a814ca0b7ecc6e4a0d01918f8b432535758bad Mon Sep 17 00:00:00 2001 From: brosoul Date: Sat, 21 Dec 2024 18:13:29 +0800 Subject: [PATCH 2/8] fix: style --- python/aibrix/aibrix/downloader/s3.py | 6 +++++- python/aibrix/aibrix/downloader/tos.py | 6 +++++- python/aibrix/tests/downloader/test_utils.py | 4 +++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/aibrix/aibrix/downloader/s3.py b/python/aibrix/aibrix/downloader/s3.py index 75baa041..484c87c0 100644 --- a/python/aibrix/aibrix/downloader/s3.py +++ b/python/aibrix/aibrix/downloader/s3.py @@ -45,6 +45,7 @@ def _parse_bucket_info_from_uri(uri: str, scheme: str = "s3") -> Tuple[str, str] class S3BaseDownloader(BaseDownloader): _source = "s3_base" + def __init__( self, scheme: str, @@ -145,7 +146,9 @@ def download( # check if file exist etag = meta_data.get("ETag", "") file_size = meta_data.get("ContentLength", 0) - meta_data_file = meta_file(local_path=local_path, file_name=_file_name, source=self._source) + meta_data_file = meta_file( + local_path=local_path, file_name=_file_name, source=self._source + ) if not need_to_download(local_file, meta_data_file, file_size, etag): return @@ -187,6 +190,7 @@ def download_progress(bytes_transferred): class S3Downloader(S3BaseDownloader): _source = "s3" + def __init__( self, model_uri, diff --git a/python/aibrix/aibrix/downloader/tos.py b/python/aibrix/aibrix/downloader/tos.py index dbb30c95..24bd79c8 100644 --- a/python/aibrix/aibrix/downloader/tos.py +++ b/python/aibrix/aibrix/downloader/tos.py @@ -47,6 +47,7 @@ def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]: class TOSDownloaderV1(BaseDownloader): _source = "tos" + def __init__( self, model_uri, @@ -134,7 +135,9 @@ def download( # check if file exist etag = meta_data.etag file_size = meta_data.content_length - meta_data_file = meta_file(local_path=local_path, file_name=_file_name, source=self._source) + meta_data_file = meta_file( + local_path=local_path, file_name=_file_name, source=self._source + ) if not need_to_download(local_file, meta_data_file, file_size, etag): return @@ -175,6 +178,7 @@ def download_progress( class TOSDownloaderV2(S3BaseDownloader): _source = "tos" + def __init__( self, model_uri, diff --git a/python/aibrix/tests/downloader/test_utils.py b/python/aibrix/tests/downloader/test_utils.py index 4e0abc1d..6f2c7801 100644 --- a/python/aibrix/tests/downloader/test_utils.py +++ b/python/aibrix/tests/downloader/test_utils.py @@ -43,7 +43,9 @@ def test_meta_file(): with tempfile.TemporaryDirectory() as tmp_dir: source = "s3" meta_file_path = meta_file(tmp_dir, "test", source=source) - assert str(meta_file_path).endswith(f"{DOWNLOAD_CACHE_DIR % source}/test.metadata") + assert str(meta_file_path).endswith( + f"{DOWNLOAD_CACHE_DIR % source}/test.metadata" + ) def test_save_load_meta_data(): From 7cb3a689936618104009644a5d9eae2e08c245f7 Mon Sep 17 00:00:00 2001 From: brosoul Date: Sat, 21 Dec 2024 18:46:34 +0800 Subject: [PATCH 3/8] feat: add donwload lock --- python/aibrix/aibrix/config.py | 1 + python/aibrix/aibrix/downloader/entity.py | 136 ++++++++++++++++++++++ python/aibrix/aibrix/downloader/s3.py | 31 ++--- python/aibrix/aibrix/downloader/tos.py | 36 +++--- python/aibrix/poetry.lock | 4 +- python/aibrix/pyproject.toml | 1 + 6 files changed, 176 insertions(+), 33 deletions(-) create mode 100644 python/aibrix/aibrix/downloader/entity.py diff --git a/python/aibrix/aibrix/config.py b/python/aibrix/aibrix/config.py index fdc4f85a..1e959677 100644 --- a/python/aibrix/aibrix/config.py +++ b/python/aibrix/aibrix/config.py @@ -23,5 +23,6 @@ └── download """ DOWNLOAD_CACHE_DIR = ".cache/%s/download" +DOWNLOAD_FILE_LOCK_CHECK_TIMEOUT = 10 EXCLUDE_METRICS_HTTP_ENDPOINTS = ["/metrics/"] diff --git a/python/aibrix/aibrix/downloader/entity.py b/python/aibrix/aibrix/downloader/entity.py new file mode 100644 index 00000000..07ad8b37 --- /dev/null +++ b/python/aibrix/aibrix/downloader/entity.py @@ -0,0 +1,136 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import contextlib +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Generator, List + +from filelock import BaseFileLock, FileLock, Timeout + +from aibrix.config import DOWNLOAD_CACHE_DIR, DOWNLOAD_FILE_LOCK_CHECK_TIMEOUT +from aibrix.logger import init_logger + +logger = init_logger(__name__) + + +class RemoteSource(Enum): + S3 = "s3" + TOS = "tos" + HUGGINGFACE = "huggingface" + + +class DownloadStatus(Enum): + DOWNLOADING = "downloading" + DOWNLOADED = "downloaded" + NO_OPETATION = "no_operation" # Interrupted from downloading + UNKNOWN = "unknown" + + +@dataclass +class DownloadFile: + file_path: Path + lock_path: Path + metadata_path: Path + + @property + def status(self): + if self.file_path.exists() and self.metadata_path.exists(): + return DownloadStatus.DOWNLOADED + + try: + # Downloading process will acquire the lock, + # and other process will raise Timeout Exception + lock = FileLock(self.lock_path) + lock.acquire(blocking=False) + except Timeout: + return DownloadStatus.DOWNLOADING + except Exception: + return DownloadStatus.UNKNOWN + else: + return DownloadStatus.NO_OPETATION + + @contextlib.contextmanager + def download_lock(self) -> Generator[BaseFileLock, None, None]: + """Download process should acquire the lock to prevent other process from downloading the same file. + Same implementation as WeakFileLock in huggingface_hub""" + lock = FileLock(self.lock_path, timeout=DOWNLOAD_FILE_LOCK_CHECK_TIMEOUT) + while True: + try: + lock.acquire() + except Timeout: + logger.info(f"still waiting to acquire lock on {self.lock_path}, status is {self.status}") + else: + break + + yield lock + + try: + return lock.release() + except OSError: + try: + Path(self.lock_path).unlink() + except OSError: + pass + + +@dataclass +class DownloadModel: + model_source: RemoteSource + local_dir: Path + model_name: str + download_files: List[DownloadFile] + + def __post_init__(self): + if len(self.download_files) == 0: + logger.warning(f"No download files found for model {self.model_name}") + + @property + def status(self): + all_status = [] + for file in self.download_files: + file_status = file.status + if file_status == DownloadStatus.DOWNLOADING: + return DownloadStatus.DOWNLOADING + elif file_status == DownloadStatus.NO_OPETATION: + return DownloadStatus.NO_OPETATION + elif file_status == DownloadStatus.UNKNOWN: + return DownloadStatus.UNKNOWN + else: + all_status.append(file.file_status) + if all(status == DownloadStatus.DOWNLOADED for status in all_status): + return DownloadStatus.DOWNLOADED + return DownloadStatus.UNKNOWN + + @classmethod + def infer_from_model_path(cls, local_path: Path, model_name: str) -> "DownloadModel": + # TODO, infer downloadfiles from cached dir + pass + + @classmethod + def infer_from_local_path(cls, remote_path: Path) -> List["DownloadModel"]: + # TODO + pass + + +def get_local_download_paths( + model_base_dir: Path, filename: str, source_type: RemoteSource +) -> DownloadFile: + file_path = model_base_dir.joinpath(filename) + cache_dir = model_base_dir.joinpath(DOWNLOAD_CACHE_DIR % source_type.value) + lock_path = cache_dir.joinpath(f"{filename}.lock") + metadata_path = cache_dir.joinpath(f"{filename}.metadata") + return DownloadFile(file_path, lock_path, metadata_path) diff --git a/python/aibrix/aibrix/downloader/s3.py b/python/aibrix/aibrix/downloader/s3.py index 484c87c0..2ac3bc3f 100644 --- a/python/aibrix/aibrix/downloader/s3.py +++ b/python/aibrix/aibrix/downloader/s3.py @@ -18,6 +18,7 @@ from typing import Dict, List, Optional, Tuple from urllib.parse import urlparse +from aibrix.downloader.entity import RemoteSource, get_local_download_paths import boto3 from boto3.s3.transfer import TransferConfig from botocore.config import MAX_POOL_CONNECTIONS, Config @@ -44,7 +45,7 @@ def _parse_bucket_info_from_uri(uri: str, scheme: str = "s3") -> Tuple[str, str] class S3BaseDownloader(BaseDownloader): - _source = "s3_base" + _source: RemoteSource = RemoteSource.S3 def __init__( self, @@ -147,7 +148,7 @@ def download( etag = meta_data.get("ETag", "") file_size = meta_data.get("ContentLength", 0) meta_data_file = meta_file( - local_path=local_path, file_name=_file_name, source=self._source + local_path=local_path, file_name=_file_name, source=self._source.value ) if not need_to_download(local_file, meta_data_file, file_size, etag): @@ -175,21 +176,23 @@ def download( def download_progress(bytes_transferred): pbar.update(bytes_transferred) - - self.client.download_file( - Bucket=bucket_name, - Key=bucket_path, - Filename=str( - local_file - ), # S3 client does not support Path, convert it to str - Config=config, - Callback=download_progress if self.enable_progress_bar else None, - ) - save_meta_data(meta_data_file, etag) + + download_file = get_local_download_paths(local_path, _file_name, self._source) + with download_file.download_lock(): + self.client.download_file( + Bucket=bucket_name, + Key=bucket_path, + Filename=str( + local_file + ), # S3 client does not support Path, convert it to str + Config=config, + Callback=download_progress if self.enable_progress_bar else None, + ) + save_meta_data(meta_data_file, etag) class S3Downloader(S3BaseDownloader): - _source = "s3" + _source: RemoteSource = RemoteSource.S3 def __init__( self, diff --git a/python/aibrix/aibrix/downloader/tos.py b/python/aibrix/aibrix/downloader/tos.py index 24bd79c8..e73f0cba 100644 --- a/python/aibrix/aibrix/downloader/tos.py +++ b/python/aibrix/aibrix/downloader/tos.py @@ -18,6 +18,7 @@ from typing import Dict, List, Optional, Tuple from urllib.parse import urlparse +from aibrix.downloader.entity import RemoteSource, get_local_download_paths import tos from tos import DataTransferType from tqdm import tqdm @@ -46,7 +47,7 @@ def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]: class TOSDownloaderV1(BaseDownloader): - _source = "tos" + _source: RemoteSource = RemoteSource.TOS def __init__( self, @@ -136,7 +137,7 @@ def download( etag = meta_data.etag file_size = meta_data.content_length meta_data_file = meta_file( - local_path=local_path, file_name=_file_name, source=self._source + local_path=local_path, file_name=_file_name, source=self._source.value ) if not need_to_download(local_file, meta_data_file, file_size, etag): @@ -151,7 +152,6 @@ def download( # download file total_length = meta_data.content_length - nullcontext with tqdm( desc=_file_name, total=total_length, unit="b", unit_scale=True ) if self.enable_progress_bar else nullcontext() as pbar: @@ -161,23 +161,25 @@ def download_progress( ): pbar.update(rw_once_bytes) - self.client.download_file( - bucket=bucket_name, - key=bucket_path, - file_path=str( - local_file - ), # TOS client does not support Path, convert it to str - task_num=task_num, - data_transfer_listener=download_progress - if self.enable_progress_bar - else None, - **download_kwargs, - ) - save_meta_data(meta_data_file, etag) + download_file = get_local_download_paths(local_path, _file_name, self._source) + with download_file.download_lock(): + self.client.download_file( + bucket=bucket_name, + key=bucket_path, + file_path=str( + local_file + ), # TOS client does not support Path, convert it to str + task_num=task_num, + data_transfer_listener=download_progress + if self.enable_progress_bar + else None, + **download_kwargs, + ) + save_meta_data(meta_data_file, etag) class TOSDownloaderV2(S3BaseDownloader): - _source = "tos" + _source: RemoteSource = RemoteSource.TOS def __init__( self, diff --git a/python/aibrix/poetry.lock b/python/aibrix/poetry.lock index 5ac84ae4..25e10fb7 100644 --- a/python/aibrix/poetry.lock +++ b/python/aibrix/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -3285,4 +3285,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "dc205fd0f31f54b35ed660d4ed566ea09a46000b7da11da526160196e1b9e686" \ No newline at end of file +content-hash = "5151c5c47bf53acfe735bb4fc92395b8fad2033a14bb40e14e82854e59ce631c" diff --git a/python/aibrix/pyproject.toml b/python/aibrix/pyproject.toml index 3c172a89..7bd9e7bc 100644 --- a/python/aibrix/pyproject.toml +++ b/python/aibrix/pyproject.toml @@ -56,6 +56,7 @@ incdbscan = "^0.1.0" aiohttp = "^3.11.7" dash = "^2.18.2" matplotlib = "^3.9.2" +filelock = "^3.16.1" [tool.poetry.group.dev.dependencies] From 1c55b9ddd9de8a806b4ccdc7fbd6712b594e4c92 Mon Sep 17 00:00:00 2001 From: brosoul Date: Sat, 21 Dec 2024 18:48:04 +0800 Subject: [PATCH 4/8] fix style --- python/aibrix/aibrix/downloader/entity.py | 12 ++++++++---- python/aibrix/aibrix/downloader/s3.py | 8 +++++--- python/aibrix/aibrix/downloader/tos.py | 6 ++++-- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/python/aibrix/aibrix/downloader/entity.py b/python/aibrix/aibrix/downloader/entity.py index 07ad8b37..455edbdb 100644 --- a/python/aibrix/aibrix/downloader/entity.py +++ b/python/aibrix/aibrix/downloader/entity.py @@ -72,7 +72,9 @@ def download_lock(self) -> Generator[BaseFileLock, None, None]: try: lock.acquire() except Timeout: - logger.info(f"still waiting to acquire lock on {self.lock_path}, status is {self.status}") + logger.info( + f"still waiting to acquire lock on {self.lock_path}, status is {self.status}" + ) else: break @@ -93,7 +95,7 @@ class DownloadModel: local_dir: Path model_name: str download_files: List[DownloadFile] - + def __post_init__(self): if len(self.download_files) == 0: logger.warning(f"No download files found for model {self.model_name}") @@ -116,10 +118,12 @@ def status(self): return DownloadStatus.UNKNOWN @classmethod - def infer_from_model_path(cls, local_path: Path, model_name: str) -> "DownloadModel": + def infer_from_model_path( + cls, local_path: Path, model_name: str + ) -> "DownloadModel": # TODO, infer downloadfiles from cached dir pass - + @classmethod def infer_from_local_path(cls, remote_path: Path) -> List["DownloadModel"]: # TODO diff --git a/python/aibrix/aibrix/downloader/s3.py b/python/aibrix/aibrix/downloader/s3.py index 2ac3bc3f..b46868d0 100644 --- a/python/aibrix/aibrix/downloader/s3.py +++ b/python/aibrix/aibrix/downloader/s3.py @@ -18,7 +18,6 @@ from typing import Dict, List, Optional, Tuple from urllib.parse import urlparse -from aibrix.downloader.entity import RemoteSource, get_local_download_paths import boto3 from boto3.s3.transfer import TransferConfig from botocore.config import MAX_POOL_CONNECTIONS, Config @@ -26,6 +25,7 @@ from aibrix import envs from aibrix.downloader.base import BaseDownloader +from aibrix.downloader.entity import RemoteSource, get_local_download_paths from aibrix.downloader.utils import ( infer_model_name, meta_file, @@ -176,8 +176,10 @@ def download( def download_progress(bytes_transferred): pbar.update(bytes_transferred) - - download_file = get_local_download_paths(local_path, _file_name, self._source) + + download_file = get_local_download_paths( + local_path, _file_name, self._source + ) with download_file.download_lock(): self.client.download_file( Bucket=bucket_name, diff --git a/python/aibrix/aibrix/downloader/tos.py b/python/aibrix/aibrix/downloader/tos.py index e73f0cba..12b5b5f3 100644 --- a/python/aibrix/aibrix/downloader/tos.py +++ b/python/aibrix/aibrix/downloader/tos.py @@ -18,13 +18,13 @@ from typing import Dict, List, Optional, Tuple from urllib.parse import urlparse -from aibrix.downloader.entity import RemoteSource, get_local_download_paths import tos from tos import DataTransferType from tqdm import tqdm from aibrix import envs from aibrix.downloader.base import BaseDownloader +from aibrix.downloader.entity import RemoteSource, get_local_download_paths from aibrix.downloader.s3 import S3BaseDownloader from aibrix.downloader.utils import ( infer_model_name, @@ -161,7 +161,9 @@ def download_progress( ): pbar.update(rw_once_bytes) - download_file = get_local_download_paths(local_path, _file_name, self._source) + download_file = get_local_download_paths( + local_path, _file_name, self._source + ) with download_file.download_lock(): self.client.download_file( bucket=bucket_name, From 46dd15b3cfc05c5407f983fa7ec56ab230ac4985 Mon Sep 17 00:00:00 2001 From: brosoul Date: Sat, 21 Dec 2024 23:52:06 +0800 Subject: [PATCH 5/8] feat: add DownloadModel and DownloadFile --- python/aibrix/aibrix/downloader/entity.py | 104 ++++++++++++++++------ 1 file changed, 77 insertions(+), 27 deletions(-) diff --git a/python/aibrix/aibrix/downloader/entity.py b/python/aibrix/aibrix/downloader/entity.py index 455edbdb..1a9181e3 100644 --- a/python/aibrix/aibrix/downloader/entity.py +++ b/python/aibrix/aibrix/downloader/entity.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import Generator, List +from typing import Generator, List, Optional from filelock import BaseFileLock, FileLock, Timeout @@ -36,7 +36,7 @@ class RemoteSource(Enum): class DownloadStatus(Enum): DOWNLOADING = "downloading" DOWNLOADED = "downloaded" - NO_OPETATION = "no_operation" # Interrupted from downloading + NO_OPERATION = "no_operation" # Interrupted from downloading UNKNOWN = "unknown" @@ -58,15 +58,18 @@ def status(self): lock.acquire(blocking=False) except Timeout: return DownloadStatus.DOWNLOADING - except Exception: + except Exception as e: + logger.warning( + f"failed to acquire lock failed for unknown error, error: {e}" + ) return DownloadStatus.UNKNOWN else: - return DownloadStatus.NO_OPETATION + return DownloadStatus.NO_OPERATION @contextlib.contextmanager def download_lock(self) -> Generator[BaseFileLock, None, None]: - """Download process should acquire the lock to prevent other process from downloading the same file. - Same implementation as WeakFileLock in huggingface_hub""" + """A filelock that download process should be acquired. + Same implementation as WeakFileLock in huggingface_hub.""" lock = FileLock(self.lock_path, timeout=DOWNLOAD_FILE_LOCK_CHECK_TIMEOUT) while True: try: @@ -92,7 +95,7 @@ def download_lock(self) -> Generator[BaseFileLock, None, None]: @dataclass class DownloadModel: model_source: RemoteSource - local_dir: Path + local_path: Path model_name: str download_files: List[DownloadFile] @@ -102,39 +105,86 @@ def __post_init__(self): @property def status(self): - all_status = [] - for file in self.download_files: - file_status = file.status - if file_status == DownloadStatus.DOWNLOADING: - return DownloadStatus.DOWNLOADING - elif file_status == DownloadStatus.NO_OPETATION: - return DownloadStatus.NO_OPETATION - elif file_status == DownloadStatus.UNKNOWN: - return DownloadStatus.UNKNOWN - else: - all_status.append(file.file_status) + all_status = [file.status for file in self.download_files] if all(status == DownloadStatus.DOWNLOADED for status in all_status): return DownloadStatus.DOWNLOADED + + if any(status == DownloadStatus.DOWNLOADING for status in all_status): + return DownloadStatus.DOWNLOADING + + if any(status == DownloadStatus.NO_OPERATION for status in all_status): + return DownloadStatus.NO_OPERATION + return DownloadStatus.UNKNOWN @classmethod def infer_from_model_path( - cls, local_path: Path, model_name: str - ) -> "DownloadModel": - # TODO, infer downloadfiles from cached dir - pass + cls, local_path: Path, model_name: str, source: RemoteSource + ) -> Optional["DownloadModel"]: + assert source is not None + + model_base_dir = Path(local_path).joinpath(model_name) + cache_sub_dir = (DOWNLOAD_CACHE_DIR % source.value).strip("/") + cache_dir = Path(model_base_dir).joinpath(cache_sub_dir) + lock_files = list(Path(cache_dir).glob("*.lock")) + + download_files = [] + for lock_file in lock_files: + lock_name = lock_file.name + lock_suffix = ".lock" + if lock_name.endswith(lock_suffix): + filename = lock_name[: -len(lock_suffix)] + download_file = get_local_download_paths( + model_base_dir=model_base_dir, filename=filename, source=source + ) + download_files.append(download_file) + + return cls( + model_source=source, + local_path=local_path, + model_name=model_name, + download_files=download_files, + ) @classmethod - def infer_from_local_path(cls, remote_path: Path) -> List["DownloadModel"]: - # TODO - pass + def infer_from_local_path(cls, local_path: Path) -> List["DownloadModel"]: + models: List["DownloadModel"] = [] + for source in RemoteSource: + cache_sub_dir = (DOWNLOAD_CACHE_DIR % source.value).strip("/") + cache_dirs = list(Path(local_path).glob(f"**/{cache_sub_dir}")) + if not cache_dirs: + continue + + for cache_dir in cache_dirs: + relative_path = cache_dir.relative_to(local_path) + relative_str = str(relative_path).strip("/") + model_name = relative_str.rstrip(cache_sub_dir).strip("/") + download_model = cls.infer_from_model_path( + local_path, model_name, source + ) + if download_model is None: + continue + models.append(download_model) + + return models + + def __str__(self): + return ( + "DownloadModel(\n" + + f"\tmodel_source={self.model_source.value},\n" + + f"\tmodel_name={self.model_name},\n" + + f"\tstatus={self.status.value},\n" + + f"\tlocal_path={self.local_path},\n" + + f"\tdownload_files_count={len(self.download_files)}\n)" + ) def get_local_download_paths( - model_base_dir: Path, filename: str, source_type: RemoteSource + model_base_dir: Path, filename: str, source: RemoteSource ) -> DownloadFile: file_path = model_base_dir.joinpath(filename) - cache_dir = model_base_dir.joinpath(DOWNLOAD_CACHE_DIR % source_type.value) + sub_cache_dir = (DOWNLOAD_CACHE_DIR % source.value).strip("/") + cache_dir = model_base_dir.joinpath(sub_cache_dir) lock_path = cache_dir.joinpath(f"{filename}.lock") metadata_path = cache_dir.joinpath(f"{filename}.metadata") return DownloadFile(file_path, lock_path, metadata_path) From f517d7e90fa7c4e8767c173d8ef84c3812c2dd4c Mon Sep 17 00:00:00 2001 From: brosoul Date: Sun, 22 Dec 2024 00:02:57 +0800 Subject: [PATCH 6/8] fix style --- python/aibrix/aibrix/downloader/entity.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/aibrix/aibrix/downloader/entity.py b/python/aibrix/aibrix/downloader/entity.py index 1a9181e3..62e9de7c 100644 --- a/python/aibrix/aibrix/downloader/entity.py +++ b/python/aibrix/aibrix/downloader/entity.py @@ -59,9 +59,7 @@ def status(self): except Timeout: return DownloadStatus.DOWNLOADING except Exception as e: - logger.warning( - f"failed to acquire lock failed for unknown error, error: {e}" - ) + logger.warning(f"Failed to acquire lock failed for error: {e}") return DownloadStatus.UNKNOWN else: return DownloadStatus.NO_OPERATION @@ -76,7 +74,7 @@ def download_lock(self) -> Generator[BaseFileLock, None, None]: lock.acquire() except Timeout: logger.info( - f"still waiting to acquire lock on {self.lock_path}, status is {self.status}" + f"Still waiting to acquire download lock on {self.lock_path}" ) else: break From 98dcb6ece7a925a11521403f85c9c351f89c16f6 Mon Sep 17 00:00:00 2001 From: brosoul Date: Sun, 22 Dec 2024 01:08:52 +0800 Subject: [PATCH 7/8] fix --- python/aibrix/aibrix/downloader/entity.py | 14 +- python/aibrix/tests/downloader/test_entity.py | 201 ++++++++++++++++++ 2 files changed, 213 insertions(+), 2 deletions(-) create mode 100644 python/aibrix/tests/downloader/test_entity.py diff --git a/python/aibrix/aibrix/downloader/entity.py b/python/aibrix/aibrix/downloader/entity.py index 62e9de7c..b84446ab 100644 --- a/python/aibrix/aibrix/downloader/entity.py +++ b/python/aibrix/aibrix/downloader/entity.py @@ -110,7 +110,10 @@ def status(self): if any(status == DownloadStatus.DOWNLOADING for status in all_status): return DownloadStatus.DOWNLOADING - if any(status == DownloadStatus.NO_OPERATION for status in all_status): + if all( + status in [DownloadStatus.DOWNLOADED, DownloadStatus.NO_OPERATION] + for status in all_status + ): return DownloadStatus.NO_OPERATION return DownloadStatus.UNKNOWN @@ -147,6 +150,12 @@ def infer_from_model_path( @classmethod def infer_from_local_path(cls, local_path: Path) -> List["DownloadModel"]: models: List["DownloadModel"] = [] + + def remove_suffix(input_string, suffix): + if not input_string.endswith(suffix): + return input_string + return input_string[: -len(suffix)] + for source in RemoteSource: cache_sub_dir = (DOWNLOAD_CACHE_DIR % source.value).strip("/") cache_dirs = list(Path(local_path).glob(f"**/{cache_sub_dir}")) @@ -156,7 +165,8 @@ def infer_from_local_path(cls, local_path: Path) -> List["DownloadModel"]: for cache_dir in cache_dirs: relative_path = cache_dir.relative_to(local_path) relative_str = str(relative_path).strip("/") - model_name = relative_str.rstrip(cache_sub_dir).strip("/") + model_name = remove_suffix(relative_str, cache_sub_dir).strip("/") + download_model = cls.infer_from_model_path( local_path, model_name, source ) diff --git a/python/aibrix/tests/downloader/test_entity.py b/python/aibrix/tests/downloader/test_entity.py new file mode 100644 index 00000000..dfbdc764 --- /dev/null +++ b/python/aibrix/tests/downloader/test_entity.py @@ -0,0 +1,201 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import tempfile +from pathlib import Path +from typing import Generator, List, Tuple + +from filelock import FileLock + +from aibrix.config import DOWNLOAD_CACHE_DIR +from aibrix.downloader.entity import ( + DownloadModel, + DownloadStatus, + RemoteSource, + get_local_download_paths, +) + + +@contextlib.contextmanager +def prepare_model_dir( + local_path: Path, + model_name: str, + source: RemoteSource, + files_with_status: List[Tuple[str, DownloadStatus]], +) -> Generator: + model_base_dir = local_path.joinpath(model_name) + cache_sub_dir = (DOWNLOAD_CACHE_DIR % source.value).strip("/") + model_cache_dir = model_base_dir.joinpath(cache_sub_dir) + + model_base_dir.mkdir(parents=True, exist_ok=True) + model_cache_dir.mkdir(parents=True, exist_ok=True) + + def _parepare_file( + model_base_dir: Path, cache_dir: Path, file_name: str, status: DownloadStatus + ): + file_path = model_base_dir.joinpath(file_name) + lock_path = cache_dir.joinpath(f"{file_name}.lock") + meta_path = cache_dir.joinpath(f"{file_name}.metadata") + + if status == DownloadStatus.DOWNLOADED: + file_path.touch() + lock_path.touch() + meta_path.touch() + elif status == DownloadStatus.DOWNLOADING: + lock_path.touch() + elif status == DownloadStatus.NO_OPERATION: + lock_path.touch() + + # acquire the donwloading lock + all_locks = [] + for file_name, status in files_with_status: + _parepare_file(model_base_dir, model_cache_dir, file_name, status) + if status == DownloadStatus.DOWNLOADING: + lock_path = model_cache_dir.joinpath(f"{file_name}.lock") + lock = FileLock(lock_path) + lock.acquire(blocking=False) + all_locks.append(lock) + + yield + + # release all the locks + for lock in all_locks: + lock.release(force=True) + + +def test_prepare_model_dir(): + with tempfile.TemporaryDirectory() as local_dir: + local_path = Path(local_dir) + model_name = "model/name" + source = RemoteSource.S3 + files_with_status = [ + ("file1", DownloadStatus.DOWNLOADED), + ("file2", DownloadStatus.DOWNLOADING), + ("file3", DownloadStatus.NO_OPERATION), + ] + with prepare_model_dir(local_path, model_name, source, files_with_status): + model_base_dir = local_path.joinpath(model_name) + # file1 and file1.lock, file1.metadata should be exist + assert len(list(model_base_dir.glob("file1"))) == 1 + assert len(list(model_base_dir.glob("**/file1.lock"))) == 1 + assert len(list(model_base_dir.glob("**/file1.metadata"))) == 1 + + # file2.lock should be exist + assert len(list(model_base_dir.glob("file2"))) == 0 + assert len(list(model_base_dir.glob("**/file2.lock"))) == 1 + assert len(list(model_base_dir.glob("**/file2.metadata"))) == 0 + + # file3.lock should be exist + assert len(list(model_base_dir.glob("file3"))) == 0 + assert len(list(model_base_dir.glob("**/file3.lock"))) == 1 + assert len(list(model_base_dir.glob("**/file2.metadata"))) == 0 + + +def test_get_local_download_paths(): + with tempfile.TemporaryDirectory() as local_dir: + local_path = Path(local_dir) + model_name = "model/name" + source = RemoteSource.S3 + files_with_status = [ + ("file1", DownloadStatus.DOWNLOADED), + ("file2", DownloadStatus.DOWNLOADING), + ("file3", DownloadStatus.NO_OPERATION), + ] + model_base_dir = local_path.joinpath(model_name) + with prepare_model_dir(local_path, model_name, source, files_with_status): + for filename, status in files_with_status: + download_file = get_local_download_paths( + model_base_dir=model_base_dir, + filename=filename, + source=source, + ) + assert download_file.status == status, f"{filename} status not match" + + +def test_infer_from_local_path(): + with tempfile.TemporaryDirectory() as local_dir: + local_path = Path(local_dir) + # S3 model with 3 files + s3_model_name = "s3_model/name" + s3_source = RemoteSource.S3 + s3_files_with_status = [ + ("s3_file1", DownloadStatus.DOWNLOADED), + ("s3_file2", DownloadStatus.DOWNLOADING), + ("s3_file3", DownloadStatus.NO_OPERATION), + ] + # TOS model with 2 files + tos_model_name = "tos_model_name" + tos_source = RemoteSource.TOS + tos_files_with_status = [ + ("tos_file1", DownloadStatus.DOWNLOADED), + ("tos_file2", DownloadStatus.DOWNLOADED), + ] + # HuggingFace with 1 file + hf_model_name = "hf/model_name" + hf_source = RemoteSource.HUGGINGFACE + hf_files_with_status = [ + ("hf_file1", DownloadStatus.DOWNLOADED), + ] + with prepare_model_dir( + local_path, s3_model_name, s3_source, s3_files_with_status + ), prepare_model_dir( + local_path, tos_model_name, tos_source, tos_files_with_status + ), prepare_model_dir( + local_path, hf_model_name, hf_source, hf_files_with_status + ): + download_models = DownloadModel.infer_from_local_path(local_path) + assert len(download_models) == 3 + for download_model in download_models: + if download_model.model_name == s3_model_name: + assert download_model.model_source == s3_source + assert len(download_model.download_files) == 3 + elif download_model.model_name == tos_model_name: + assert download_model.model_source == tos_source + assert len(download_model.download_files) == 2 + elif download_model.model_name == hf_model_name: + assert download_model.model_source == hf_source + assert len(download_model.download_files) == 1 + else: + raise ValueError(f"Unknown model name {download_model.model_name}") + + +def test_download_model_status(): + # All the file with downloaded will be model downloaded + with tempfile.TemporaryDirectory() as local_dir: + local_path = Path(local_dir) + model_name = "model/name" + source = RemoteSource.S3 + files_with_status = [ + ("file1", DownloadStatus.DOWNLOADED), + ("file2", DownloadStatus.DOWNLOADED), + ("file3", DownloadStatus.DOWNLOADED), + ] + with prepare_model_dir(local_path, model_name, source, files_with_status): + download_model = DownloadModel.infer_from_local_path(local_path)[0] + assert download_model.status == DownloadStatus.DOWNLOADED + + # The model will only be in the NO_OPERATION state + # if the file status is only in the DOWNLOADED or NO_OPERATION state + with tempfile.TemporaryDirectory() as local_dir: + local_path = Path(local_dir) + model_name = "model/name" + source = RemoteSource.S3 + files_with_status = [ + ("file1", DownloadStatus.DOWNLOADED), + ("file2", DownloadStatus.NO_OPERATION), + ("file3", DownloadStatus.DOWNLOADED), + ] + with prepare_model_dir(local_path, model_name, source, files_with_status): + download_model = DownloadModel.infer_from_local_path(local_path)[0] + assert download_model.status == DownloadStatus.NO_OPERATION From e7020475b5d3fc612629b9c3e00ae76eb4e97e5e Mon Sep 17 00:00:00 2001 From: brosoul Date: Mon, 23 Dec 2024 10:45:48 +0800 Subject: [PATCH 8/8] refact: extend DownloadStatus to FileDownloadStatus and ModelDownloadStatus --- python/aibrix/aibrix/downloader/entity.py | 31 ++++++---- python/aibrix/tests/downloader/test_entity.py | 58 ++++++++++--------- 2 files changed, 50 insertions(+), 39 deletions(-) diff --git a/python/aibrix/aibrix/downloader/entity.py b/python/aibrix/aibrix/downloader/entity.py index b84446ab..e76cdb1f 100644 --- a/python/aibrix/aibrix/downloader/entity.py +++ b/python/aibrix/aibrix/downloader/entity.py @@ -33,7 +33,14 @@ class RemoteSource(Enum): HUGGINGFACE = "huggingface" -class DownloadStatus(Enum): +class FileDownloadStatus(Enum): + DOWNLOADING = "downloading" + DOWNLOADED = "downloaded" + NO_OPERATION = "no_operation" # Interrupted from downloading + UNKNOWN = "unknown" + + +class ModelDownloadStatus(Enum): DOWNLOADING = "downloading" DOWNLOADED = "downloaded" NO_OPERATION = "no_operation" # Interrupted from downloading @@ -49,7 +56,7 @@ class DownloadFile: @property def status(self): if self.file_path.exists() and self.metadata_path.exists(): - return DownloadStatus.DOWNLOADED + return FileDownloadStatus.DOWNLOADED try: # Downloading process will acquire the lock, @@ -57,12 +64,12 @@ def status(self): lock = FileLock(self.lock_path) lock.acquire(blocking=False) except Timeout: - return DownloadStatus.DOWNLOADING + return FileDownloadStatus.DOWNLOADING except Exception as e: logger.warning(f"Failed to acquire lock failed for error: {e}") - return DownloadStatus.UNKNOWN + return FileDownloadStatus.UNKNOWN else: - return DownloadStatus.NO_OPERATION + return FileDownloadStatus.NO_OPERATION @contextlib.contextmanager def download_lock(self) -> Generator[BaseFileLock, None, None]: @@ -104,19 +111,19 @@ def __post_init__(self): @property def status(self): all_status = [file.status for file in self.download_files] - if all(status == DownloadStatus.DOWNLOADED for status in all_status): - return DownloadStatus.DOWNLOADED + if all(status == FileDownloadStatus.DOWNLOADED for status in all_status): + return ModelDownloadStatus.DOWNLOADED - if any(status == DownloadStatus.DOWNLOADING for status in all_status): - return DownloadStatus.DOWNLOADING + if any(status == FileDownloadStatus.DOWNLOADING for status in all_status): + return ModelDownloadStatus.DOWNLOADING if all( - status in [DownloadStatus.DOWNLOADED, DownloadStatus.NO_OPERATION] + status in [FileDownloadStatus.DOWNLOADED, FileDownloadStatus.NO_OPERATION] for status in all_status ): - return DownloadStatus.NO_OPERATION + return ModelDownloadStatus.NO_OPERATION - return DownloadStatus.UNKNOWN + return ModelDownloadStatus.UNKNOWN @classmethod def infer_from_model_path( diff --git a/python/aibrix/tests/downloader/test_entity.py b/python/aibrix/tests/downloader/test_entity.py index dfbdc764..ab1600ff 100644 --- a/python/aibrix/tests/downloader/test_entity.py +++ b/python/aibrix/tests/downloader/test_entity.py @@ -21,7 +21,8 @@ from aibrix.config import DOWNLOAD_CACHE_DIR from aibrix.downloader.entity import ( DownloadModel, - DownloadStatus, + FileDownloadStatus, + ModelDownloadStatus, RemoteSource, get_local_download_paths, ) @@ -32,7 +33,7 @@ def prepare_model_dir( local_path: Path, model_name: str, source: RemoteSource, - files_with_status: List[Tuple[str, DownloadStatus]], + files_with_status: List[Tuple[str, FileDownloadStatus]], ) -> Generator: model_base_dir = local_path.joinpath(model_name) cache_sub_dir = (DOWNLOAD_CACHE_DIR % source.value).strip("/") @@ -42,26 +43,29 @@ def prepare_model_dir( model_cache_dir.mkdir(parents=True, exist_ok=True) def _parepare_file( - model_base_dir: Path, cache_dir: Path, file_name: str, status: DownloadStatus + model_base_dir: Path, + cache_dir: Path, + file_name: str, + status: FileDownloadStatus, ): file_path = model_base_dir.joinpath(file_name) lock_path = cache_dir.joinpath(f"{file_name}.lock") meta_path = cache_dir.joinpath(f"{file_name}.metadata") - if status == DownloadStatus.DOWNLOADED: + if status == FileDownloadStatus.DOWNLOADED: file_path.touch() lock_path.touch() meta_path.touch() - elif status == DownloadStatus.DOWNLOADING: + elif status == FileDownloadStatus.DOWNLOADING: lock_path.touch() - elif status == DownloadStatus.NO_OPERATION: + elif status == FileDownloadStatus.NO_OPERATION: lock_path.touch() # acquire the donwloading lock all_locks = [] for file_name, status in files_with_status: _parepare_file(model_base_dir, model_cache_dir, file_name, status) - if status == DownloadStatus.DOWNLOADING: + if status == FileDownloadStatus.DOWNLOADING: lock_path = model_cache_dir.joinpath(f"{file_name}.lock") lock = FileLock(lock_path) lock.acquire(blocking=False) @@ -80,9 +84,9 @@ def test_prepare_model_dir(): model_name = "model/name" source = RemoteSource.S3 files_with_status = [ - ("file1", DownloadStatus.DOWNLOADED), - ("file2", DownloadStatus.DOWNLOADING), - ("file3", DownloadStatus.NO_OPERATION), + ("file1", FileDownloadStatus.DOWNLOADED), + ("file2", FileDownloadStatus.DOWNLOADING), + ("file3", FileDownloadStatus.NO_OPERATION), ] with prepare_model_dir(local_path, model_name, source, files_with_status): model_base_dir = local_path.joinpath(model_name) @@ -108,9 +112,9 @@ def test_get_local_download_paths(): model_name = "model/name" source = RemoteSource.S3 files_with_status = [ - ("file1", DownloadStatus.DOWNLOADED), - ("file2", DownloadStatus.DOWNLOADING), - ("file3", DownloadStatus.NO_OPERATION), + ("file1", FileDownloadStatus.DOWNLOADED), + ("file2", FileDownloadStatus.DOWNLOADING), + ("file3", FileDownloadStatus.NO_OPERATION), ] model_base_dir = local_path.joinpath(model_name) with prepare_model_dir(local_path, model_name, source, files_with_status): @@ -130,22 +134,22 @@ def test_infer_from_local_path(): s3_model_name = "s3_model/name" s3_source = RemoteSource.S3 s3_files_with_status = [ - ("s3_file1", DownloadStatus.DOWNLOADED), - ("s3_file2", DownloadStatus.DOWNLOADING), - ("s3_file3", DownloadStatus.NO_OPERATION), + ("s3_file1", FileDownloadStatus.DOWNLOADED), + ("s3_file2", FileDownloadStatus.DOWNLOADING), + ("s3_file3", FileDownloadStatus.NO_OPERATION), ] # TOS model with 2 files tos_model_name = "tos_model_name" tos_source = RemoteSource.TOS tos_files_with_status = [ - ("tos_file1", DownloadStatus.DOWNLOADED), - ("tos_file2", DownloadStatus.DOWNLOADED), + ("tos_file1", FileDownloadStatus.DOWNLOADED), + ("tos_file2", FileDownloadStatus.DOWNLOADED), ] # HuggingFace with 1 file hf_model_name = "hf/model_name" hf_source = RemoteSource.HUGGINGFACE hf_files_with_status = [ - ("hf_file1", DownloadStatus.DOWNLOADED), + ("hf_file1", FileDownloadStatus.DOWNLOADED), ] with prepare_model_dir( local_path, s3_model_name, s3_source, s3_files_with_status @@ -177,13 +181,13 @@ def test_download_model_status(): model_name = "model/name" source = RemoteSource.S3 files_with_status = [ - ("file1", DownloadStatus.DOWNLOADED), - ("file2", DownloadStatus.DOWNLOADED), - ("file3", DownloadStatus.DOWNLOADED), + ("file1", FileDownloadStatus.DOWNLOADED), + ("file2", FileDownloadStatus.DOWNLOADED), + ("file3", FileDownloadStatus.DOWNLOADED), ] with prepare_model_dir(local_path, model_name, source, files_with_status): download_model = DownloadModel.infer_from_local_path(local_path)[0] - assert download_model.status == DownloadStatus.DOWNLOADED + assert download_model.status == ModelDownloadStatus.DOWNLOADED # The model will only be in the NO_OPERATION state # if the file status is only in the DOWNLOADED or NO_OPERATION state @@ -192,10 +196,10 @@ def test_download_model_status(): model_name = "model/name" source = RemoteSource.S3 files_with_status = [ - ("file1", DownloadStatus.DOWNLOADED), - ("file2", DownloadStatus.NO_OPERATION), - ("file3", DownloadStatus.DOWNLOADED), + ("file1", FileDownloadStatus.DOWNLOADED), + ("file2", FileDownloadStatus.NO_OPERATION), + ("file3", FileDownloadStatus.DOWNLOADED), ] with prepare_model_dir(local_path, model_name, source, files_with_status): download_model = DownloadModel.infer_from_local_path(local_path)[0] - assert download_model.status == DownloadStatus.NO_OPERATION + assert download_model.status == ModelDownloadStatus.NO_OPERATION