diff --git a/python/aibrix/aibrix/config.py b/python/aibrix/aibrix/config.py index bd0708c3..3376b234 100644 --- a/python/aibrix/aibrix/config.py +++ b/python/aibrix/aibrix/config.py @@ -14,3 +14,5 @@ DEFAULT_METRIC_COLLECTOR_TIMEOUT = 1 + +DOWNLOAD_CACHE_DIR = ".cache" diff --git a/python/aibrix/aibrix/downloader/huggingface.py b/python/aibrix/aibrix/downloader/huggingface.py index 618c3802..a8e5c437 100644 --- a/python/aibrix/aibrix/downloader/huggingface.py +++ b/python/aibrix/aibrix/downloader/huggingface.py @@ -99,6 +99,7 @@ def download( token=self.hf_token, endpoint=self.hf_endpoint, local_dir_use_symlinks=False, + force_download=envs.DOWNLOADER_FORCE_DOWNLOAD, ) def download_directory(self, local_path: Path): @@ -111,4 +112,5 @@ def download_directory(self, local_path: Path): max_workers=envs.DOWNLOADER_NUM_THREADS, endpoint=self.hf_endpoint, local_dir_use_symlinks=False, + force_download=envs.DOWNLOADER_FORCE_DOWNLOAD, ) diff --git a/python/aibrix/aibrix/downloader/s3.py b/python/aibrix/aibrix/downloader/s3.py index 60dd4a2e..bde725d0 100644 --- a/python/aibrix/aibrix/downloader/s3.py +++ b/python/aibrix/aibrix/downloader/s3.py @@ -23,6 +23,10 @@ from aibrix import envs from aibrix.downloader.base import BaseDownloader +from aibrix.downloader.utils import meta_file, need_to_download, save_meta_data +from aibrix.logger import init_logger + +logger = init_logger(__name__) def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]: @@ -110,15 +114,21 @@ def download( bucket_name: Optional[str] = None, enable_range: bool = True, ): - # check if file exist try: meta_data = self.client.head_object(Bucket=bucket_name, Key=bucket_path) except Exception as e: raise ValueError(f"S3 bucket path {bucket_path} not exist for {e}.") _file_name = bucket_path.split("/")[-1] - # S3 client does not support Path, convert it to str - local_file = str(local_path.joinpath(_file_name).absolute()) + local_file = local_path.joinpath(_file_name).absolute() + + # check if file exist + etag = meta_data.get("ETag", "") + file_size = meta_data.get("ContentLength", 0) + meta_data_file = meta_file(local_path=local_path, file_name=_file_name) + + if not need_to_download(local_file, meta_data_file, file_size, etag): + return # construct TransferConfig config_kwargs = { @@ -134,7 +144,9 @@ def download( # download file total_length = int(meta_data.get("ContentLength", 0)) - with tqdm(total=total_length, unit="b", unit_scale=True) as pbar: + with tqdm( + desc=_file_name, total=total_length, unit="b", unit_scale=True + ) as pbar: def download_progress(bytes_transferred): pbar.update(bytes_transferred) @@ -142,7 +154,10 @@ def download_progress(bytes_transferred): self.client.download_file( Bucket=bucket_name, Key=bucket_path, - Filename=local_file, + Filename=str( + local_file + ), # S3 client does not support Path, convert it to str Config=config, Callback=download_progress, ) + save_meta_data(meta_data_file, etag) diff --git a/python/aibrix/aibrix/downloader/tos.py b/python/aibrix/aibrix/downloader/tos.py index f0ff8cdf..4d573def 100644 --- a/python/aibrix/aibrix/downloader/tos.py +++ b/python/aibrix/aibrix/downloader/tos.py @@ -23,9 +23,12 @@ from aibrix import envs from aibrix.downloader.base import BaseDownloader +from aibrix.downloader.utils import meta_file, need_to_download, save_meta_data +from aibrix.logger import init_logger tos_logger = logging.getLogger("tos") tos_logger.setLevel(logging.WARNING) +logger = init_logger(__name__) def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]: @@ -100,15 +103,22 @@ def download( bucket_name: Optional[str] = None, enable_range: bool = True, ): - # check if file exist try: meta_data = self.client.head_object(bucket=bucket_name, key=bucket_path) except Exception as e: raise ValueError(f"TOS bucket path {bucket_path} not exist for {e}.") _file_name = bucket_path.split("/")[-1] - # TOS client does not support Path, convert it to str - local_file = str(local_path.joinpath(_file_name).absolute()) + local_file = local_path.joinpath(_file_name).absolute() + + # check if file exist + etag = meta_data.etag + file_size = meta_data.content_length + meta_data_file = meta_file(local_path=local_path, file_name=_file_name) + + if not need_to_download(local_file, meta_data_file, file_size, etag): + return + task_num = envs.DOWNLOADER_NUM_THREADS if enable_range else 1 download_kwargs = {} @@ -117,7 +127,9 @@ def download( # download file total_length = meta_data.content_length - with tqdm(total=total_length, unit="b", unit_scale=True) as pbar: + with tqdm( + desc=_file_name, total=total_length, unit="b", unit_scale=True + ) as pbar: def download_progress( consumed_bytes, total_bytes, rw_once_bytes, type: DataTransferType @@ -127,8 +139,11 @@ def download_progress( self.client.download_file( bucket=bucket_name, key=bucket_path, - file_path=local_file, + file_path=str( + local_file + ), # TOS client does not support Path, convert it to str task_num=task_num, data_transfer_listener=download_progress, **download_kwargs, ) + save_meta_data(meta_data_file, etag) diff --git a/python/aibrix/aibrix/downloader/utils.py b/python/aibrix/aibrix/downloader/utils.py new file mode 100644 index 00000000..c008af95 --- /dev/null +++ b/python/aibrix/aibrix/downloader/utils.py @@ -0,0 +1,95 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path +from typing import Union + +from aibrix import envs +from aibrix.config import DOWNLOAD_CACHE_DIR +from aibrix.logger import init_logger + + +logger = init_logger(__name__) + + +def meta_file(local_path: Union[Path, str], file_name: str) -> Path: + return ( + Path(local_path) + .joinpath(DOWNLOAD_CACHE_DIR) + .joinpath(f"{file_name}.metadata") + .absolute() + ) + + +def save_meta_data(file_path: Union[Path, str], etag: str): + Path(file_path).parent.mkdir(parents=True, exist_ok=True) + with open(file_path, "w") as f: + f.write(etag) + + +def load_meta_data(file_path: Union[Path, str]): + if Path(file_path).exists(): + with open(file_path, "r") as f: + return f.read() + return None + + +def check_file_exist( + local_file: Union[Path, str], + meta_file: Union[Path, str], + expected_file_size: int, + expected_etag: str, +) -> bool: + if expected_file_size is None or expected_file_size <= 0: + return False + + if expected_etag is None or expected_etag == "": + return False + + if not Path(local_file).exists(): + return False + + file_size = os.path.getsize(local_file) + if file_size != expected_file_size: + return False + + if not Path(meta_file).exists(): + return False + + etag = load_meta_data(meta_file) + return etag == expected_etag + + +def need_to_download( + local_file: Union[Path, str], + meta_data_file: Union[Path, str], + expected_file_size: int, + expected_etag: str, +) -> bool: + _file_name = Path(local_file).name + if not envs.DOWNLOADER_FORCE_DOWNLOAD and envs.DOWNLOADER_CHECK_FILE_EXIST: + if check_file_exist( + local_file, meta_data_file, expected_file_size, expected_etag + ): + logger.info(f"File {_file_name} exist in local, skip download.") + return False + else: + logger.info(f"File {_file_name} not exist in local, start to download...") + else: + logger.info( + f"File {_file_name} start downloading directly " + f"for DOWNLOADER_FORCE_DOWNLOAD={envs.DOWNLOADER_FORCE_DOWNLOAD}, " + f"DOWNLOADER_CHECK_FILE_EXIST={envs.DOWNLOADER_CHECK_FILE_EXIST}" + ) + return True diff --git a/python/aibrix/aibrix/envs.py b/python/aibrix/aibrix/envs.py index 8d28889b..ba4db25b 100644 --- a/python/aibrix/aibrix/envs.py +++ b/python/aibrix/aibrix/envs.py @@ -53,6 +53,9 @@ def _parse_int_or_none(value: Optional[str]) -> Optional[int]: os.getenv("DOWNLOADER_ALLOW_FILE_SUFFIX") ) +DOWNLOADER_FORCE_DOWNLOAD = _is_true(os.getenv("DOWNLOADER_FORCE_DOWNLOAD", "0")) +DOWNLOADER_CHECK_FILE_EXIST = _is_true(os.getenv("DOWNLOADER_CHECK_FILE_EXIST", "1")) + # Downloader Regex DOWNLOADER_S3_REGEX = r"^s3://" DOWNLOADER_TOS_REGEX = r"^tos://" diff --git a/python/aibrix/tests/downloader/test_utils.py b/python/aibrix/tests/downloader/test_utils.py new file mode 100644 index 00000000..f24e0c3d --- /dev/null +++ b/python/aibrix/tests/downloader/test_utils.py @@ -0,0 +1,138 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from pathlib import Path +import tempfile +from unittest import mock + +from aibrix import envs +from aibrix.config import DOWNLOAD_CACHE_DIR +from aibrix.downloader.utils import ( + check_file_exist, + load_meta_data, + meta_file, + need_to_download, + save_meta_data, +) + + +def prepare_file_and_meta_data(file_path, meta_path, file_size, etag): + save_meta_data(meta_path, etag) + # create file + with open(file_path, "wb") as f: + f.write(os.urandom(file_size)) + + +def test_meta_file(): + with tempfile.TemporaryDirectory() as tmp_dir: + meta_file_path = meta_file(tmp_dir, "test") + assert str(meta_file_path).endswith(f"{DOWNLOAD_CACHE_DIR}/test.metadata") + + +def test_save_load_meta_data(): + with tempfile.TemporaryDirectory() as tmp_dir: + file_path = Path(tmp_dir).joinpath("test.metadata") + etag = "here_is_etag_value_xyz" + save_meta_data(file_path, etag) + assert file_path.exists() + + load_etag = load_meta_data(file_path) + assert etag == load_etag + + not_exist_file = Path(tmp_dir).joinpath("not_exist.metadata") + not_exist_etag = load_meta_data(not_exist_file) + assert not_exist_etag is None + + +def test_check_file_exist(): + with tempfile.TemporaryDirectory() as tmp_dir: + # prepare file and meta data + file_size = 10 + file_name = "test" + etag = "here_is_etag_value_xyz" + file_path = Path(tmp_dir).joinpath(file_name) + # create meta data file + meta_path = meta_file(tmp_dir, file_name) + prepare_file_and_meta_data(file_path, meta_path, file_size, etag) + + assert check_file_exist(file_path, meta_path, file_size, etag) + assert not check_file_exist(file_path, meta_path, None, etag) + assert not check_file_exist(file_path, meta_path, -1, etag) + assert not check_file_exist(file_path, meta_path, file_size, None) + # The remote etag has been updated + assert not check_file_exist(file_path, meta_path, file_size, "new_etag") + + not_exist_file_name = "not_exist" + not_exist_file_path = Path(tmp_dir).joinpath(not_exist_file_name) + not_exist_meta_path = meta_file(tmp_dir, not_exist_file_name) + assert not check_file_exist( + not_exist_file_path, not_exist_meta_path, file_size, etag + ) + # meta data not exist + assert not check_file_exist(file_path, not_exist_meta_path, file_size, etag) + + +CHECK_FUNC = "aibrix.downloader.utils.check_file_exist" + + +@mock.patch(CHECK_FUNC) +def test_need_to_download_not_call(mock_check: mock.Mock): + origin_force_download_env = envs.DOWNLOADER_FORCE_DOWNLOAD + origin_check_file_exist = envs.DOWNLOADER_CHECK_FILE_EXIST + + envs.DOWNLOADER_FORCE_DOWNLOAD = True + envs.DOWNLOADER_CHECK_FILE_EXIST = False + need_to_download("file", "file.metadata", 10, "etag") + mock_check.assert_not_called() + + envs.DOWNLOADER_FORCE_DOWNLOAD = True + envs.DOWNLOADER_CHECK_FILE_EXIST = True + need_to_download("file", "file.metadata", 10, "etag") + mock_check.assert_not_called() + + envs.DOWNLOADER_FORCE_DOWNLOAD = False + envs.DOWNLOADER_CHECK_FILE_EXIST = False + need_to_download("file", "file.metadata", 10, "etag") + mock_check.assert_not_called() + + # recover envs + envs.DOWNLOADER_FORCE_DOWNLOAD = origin_force_download_env + envs.DOWNLOADER_CHECK_FILE_EXIST = origin_check_file_exist + + +@mock.patch(CHECK_FUNC) +def test_need_to_download(mock_check: mock.Mock): + origin_force_download_env = envs.DOWNLOADER_FORCE_DOWNLOAD + origin_check_file_exist = envs.DOWNLOADER_CHECK_FILE_EXIST + + envs.DOWNLOADER_FORCE_DOWNLOAD = False + envs.DOWNLOADER_CHECK_FILE_EXIST = True + file_exist = True + # file exist, no need to download + mock_check.return_value = file_exist + assert not need_to_download("file", "file.metadata", 10, "etag") + mock_check.assert_called_once() + mock_check.reset_mock() + + # file not exist, need to download + mock_check.return_value = not file_exist + assert need_to_download("file", "file.metadata", 10, "etag") + mock_check.assert_called_once() + mock_check.reset_mock() + + # recover envs + envs.DOWNLOADER_FORCE_DOWNLOAD = origin_force_download_env + envs.DOWNLOADER_CHECK_FILE_EXIST = origin_check_file_exist