Skip to content

Commit

Permalink
Add AI Runtime exist model check (#198)
Browse files Browse the repository at this point in the history
* add filename desc in download progress bar

* feat: add file exist check and force download options

* style: lint and format

* add default envs value

* refact: extract need_to_download from tos and s3

* test: add test case about download utils
  • Loading branch information
brosoul authored Sep 21, 2024
1 parent 2ce0ce2 commit 81541ab
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 10 deletions.
2 changes: 2 additions & 0 deletions python/aibrix/aibrix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@


DEFAULT_METRIC_COLLECTOR_TIMEOUT = 1

DOWNLOAD_CACHE_DIR = ".cache"
2 changes: 2 additions & 0 deletions python/aibrix/aibrix/downloader/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def download(
token=self.hf_token,
endpoint=self.hf_endpoint,
local_dir_use_symlinks=False,
force_download=envs.DOWNLOADER_FORCE_DOWNLOAD,
)

def download_directory(self, local_path: Path):
Expand All @@ -111,4 +112,5 @@ def download_directory(self, local_path: Path):
max_workers=envs.DOWNLOADER_NUM_THREADS,
endpoint=self.hf_endpoint,
local_dir_use_symlinks=False,
force_download=envs.DOWNLOADER_FORCE_DOWNLOAD,
)
25 changes: 20 additions & 5 deletions python/aibrix/aibrix/downloader/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

from aibrix import envs
from aibrix.downloader.base import BaseDownloader
from aibrix.downloader.utils import meta_file, need_to_download, save_meta_data
from aibrix.logger import init_logger

logger = init_logger(__name__)


def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]:
Expand Down Expand Up @@ -110,15 +114,21 @@ def download(
bucket_name: Optional[str] = None,
enable_range: bool = True,
):
# check if file exist
try:
meta_data = self.client.head_object(Bucket=bucket_name, Key=bucket_path)
except Exception as e:
raise ValueError(f"S3 bucket path {bucket_path} not exist for {e}.")

_file_name = bucket_path.split("/")[-1]
# S3 client does not support Path, convert it to str
local_file = str(local_path.joinpath(_file_name).absolute())
local_file = local_path.joinpath(_file_name).absolute()

# check if file exist
etag = meta_data.get("ETag", "")
file_size = meta_data.get("ContentLength", 0)
meta_data_file = meta_file(local_path=local_path, file_name=_file_name)

if not need_to_download(local_file, meta_data_file, file_size, etag):
return

# construct TransferConfig
config_kwargs = {
Expand All @@ -134,15 +144,20 @@ def download(

# download file
total_length = int(meta_data.get("ContentLength", 0))
with tqdm(total=total_length, unit="b", unit_scale=True) as pbar:
with tqdm(
desc=_file_name, total=total_length, unit="b", unit_scale=True
) as pbar:

def download_progress(bytes_transferred):
pbar.update(bytes_transferred)

self.client.download_file(
Bucket=bucket_name,
Key=bucket_path,
Filename=local_file,
Filename=str(
local_file
), # S3 client does not support Path, convert it to str
Config=config,
Callback=download_progress,
)
save_meta_data(meta_data_file, etag)
25 changes: 20 additions & 5 deletions python/aibrix/aibrix/downloader/tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@

from aibrix import envs
from aibrix.downloader.base import BaseDownloader
from aibrix.downloader.utils import meta_file, need_to_download, save_meta_data
from aibrix.logger import init_logger

tos_logger = logging.getLogger("tos")
tos_logger.setLevel(logging.WARNING)
logger = init_logger(__name__)


def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]:
Expand Down Expand Up @@ -100,15 +103,22 @@ def download(
bucket_name: Optional[str] = None,
enable_range: bool = True,
):
# check if file exist
try:
meta_data = self.client.head_object(bucket=bucket_name, key=bucket_path)
except Exception as e:
raise ValueError(f"TOS bucket path {bucket_path} not exist for {e}.")

_file_name = bucket_path.split("/")[-1]
# TOS client does not support Path, convert it to str
local_file = str(local_path.joinpath(_file_name).absolute())
local_file = local_path.joinpath(_file_name).absolute()

# check if file exist
etag = meta_data.etag
file_size = meta_data.content_length
meta_data_file = meta_file(local_path=local_path, file_name=_file_name)

if not need_to_download(local_file, meta_data_file, file_size, etag):
return

task_num = envs.DOWNLOADER_NUM_THREADS if enable_range else 1

download_kwargs = {}
Expand All @@ -117,7 +127,9 @@ def download(

# download file
total_length = meta_data.content_length
with tqdm(total=total_length, unit="b", unit_scale=True) as pbar:
with tqdm(
desc=_file_name, total=total_length, unit="b", unit_scale=True
) as pbar:

def download_progress(
consumed_bytes, total_bytes, rw_once_bytes, type: DataTransferType
Expand All @@ -127,8 +139,11 @@ def download_progress(
self.client.download_file(
bucket=bucket_name,
key=bucket_path,
file_path=local_file,
file_path=str(
local_file
), # TOS client does not support Path, convert it to str
task_num=task_num,
data_transfer_listener=download_progress,
**download_kwargs,
)
save_meta_data(meta_data_file, etag)
95 changes: 95 additions & 0 deletions python/aibrix/aibrix/downloader/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
from typing import Union

from aibrix import envs
from aibrix.config import DOWNLOAD_CACHE_DIR
from aibrix.logger import init_logger


logger = init_logger(__name__)


def meta_file(local_path: Union[Path, str], file_name: str) -> Path:
return (
Path(local_path)
.joinpath(DOWNLOAD_CACHE_DIR)
.joinpath(f"{file_name}.metadata")
.absolute()
)


def save_meta_data(file_path: Union[Path, str], etag: str):
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "w") as f:
f.write(etag)


def load_meta_data(file_path: Union[Path, str]):
if Path(file_path).exists():
with open(file_path, "r") as f:
return f.read()
return None


def check_file_exist(
local_file: Union[Path, str],
meta_file: Union[Path, str],
expected_file_size: int,
expected_etag: str,
) -> bool:
if expected_file_size is None or expected_file_size <= 0:
return False

if expected_etag is None or expected_etag == "":
return False

if not Path(local_file).exists():
return False

file_size = os.path.getsize(local_file)
if file_size != expected_file_size:
return False

if not Path(meta_file).exists():
return False

etag = load_meta_data(meta_file)
return etag == expected_etag


def need_to_download(
local_file: Union[Path, str],
meta_data_file: Union[Path, str],
expected_file_size: int,
expected_etag: str,
) -> bool:
_file_name = Path(local_file).name
if not envs.DOWNLOADER_FORCE_DOWNLOAD and envs.DOWNLOADER_CHECK_FILE_EXIST:
if check_file_exist(
local_file, meta_data_file, expected_file_size, expected_etag
):
logger.info(f"File {_file_name} exist in local, skip download.")
return False
else:
logger.info(f"File {_file_name} not exist in local, start to download...")
else:
logger.info(
f"File {_file_name} start downloading directly "
f"for DOWNLOADER_FORCE_DOWNLOAD={envs.DOWNLOADER_FORCE_DOWNLOAD}, "
f"DOWNLOADER_CHECK_FILE_EXIST={envs.DOWNLOADER_CHECK_FILE_EXIST}"
)
return True
3 changes: 3 additions & 0 deletions python/aibrix/aibrix/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def _parse_int_or_none(value: Optional[str]) -> Optional[int]:
os.getenv("DOWNLOADER_ALLOW_FILE_SUFFIX")
)

DOWNLOADER_FORCE_DOWNLOAD = _is_true(os.getenv("DOWNLOADER_FORCE_DOWNLOAD", "0"))
DOWNLOADER_CHECK_FILE_EXIST = _is_true(os.getenv("DOWNLOADER_CHECK_FILE_EXIST", "1"))

# Downloader Regex
DOWNLOADER_S3_REGEX = r"^s3://"
DOWNLOADER_TOS_REGEX = r"^tos://"
Expand Down
138 changes: 138 additions & 0 deletions python/aibrix/tests/downloader/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
from pathlib import Path
import tempfile
from unittest import mock

from aibrix import envs
from aibrix.config import DOWNLOAD_CACHE_DIR
from aibrix.downloader.utils import (
check_file_exist,
load_meta_data,
meta_file,
need_to_download,
save_meta_data,
)


def prepare_file_and_meta_data(file_path, meta_path, file_size, etag):
save_meta_data(meta_path, etag)
# create file
with open(file_path, "wb") as f:
f.write(os.urandom(file_size))


def test_meta_file():
with tempfile.TemporaryDirectory() as tmp_dir:
meta_file_path = meta_file(tmp_dir, "test")
assert str(meta_file_path).endswith(f"{DOWNLOAD_CACHE_DIR}/test.metadata")


def test_save_load_meta_data():
with tempfile.TemporaryDirectory() as tmp_dir:
file_path = Path(tmp_dir).joinpath("test.metadata")
etag = "here_is_etag_value_xyz"
save_meta_data(file_path, etag)
assert file_path.exists()

load_etag = load_meta_data(file_path)
assert etag == load_etag

not_exist_file = Path(tmp_dir).joinpath("not_exist.metadata")
not_exist_etag = load_meta_data(not_exist_file)
assert not_exist_etag is None


def test_check_file_exist():
with tempfile.TemporaryDirectory() as tmp_dir:
# prepare file and meta data
file_size = 10
file_name = "test"
etag = "here_is_etag_value_xyz"
file_path = Path(tmp_dir).joinpath(file_name)
# create meta data file
meta_path = meta_file(tmp_dir, file_name)
prepare_file_and_meta_data(file_path, meta_path, file_size, etag)

assert check_file_exist(file_path, meta_path, file_size, etag)
assert not check_file_exist(file_path, meta_path, None, etag)
assert not check_file_exist(file_path, meta_path, -1, etag)
assert not check_file_exist(file_path, meta_path, file_size, None)
# The remote etag has been updated
assert not check_file_exist(file_path, meta_path, file_size, "new_etag")

not_exist_file_name = "not_exist"
not_exist_file_path = Path(tmp_dir).joinpath(not_exist_file_name)
not_exist_meta_path = meta_file(tmp_dir, not_exist_file_name)
assert not check_file_exist(
not_exist_file_path, not_exist_meta_path, file_size, etag
)
# meta data not exist
assert not check_file_exist(file_path, not_exist_meta_path, file_size, etag)


CHECK_FUNC = "aibrix.downloader.utils.check_file_exist"


@mock.patch(CHECK_FUNC)
def test_need_to_download_not_call(mock_check: mock.Mock):
origin_force_download_env = envs.DOWNLOADER_FORCE_DOWNLOAD
origin_check_file_exist = envs.DOWNLOADER_CHECK_FILE_EXIST

envs.DOWNLOADER_FORCE_DOWNLOAD = True
envs.DOWNLOADER_CHECK_FILE_EXIST = False
need_to_download("file", "file.metadata", 10, "etag")
mock_check.assert_not_called()

envs.DOWNLOADER_FORCE_DOWNLOAD = True
envs.DOWNLOADER_CHECK_FILE_EXIST = True
need_to_download("file", "file.metadata", 10, "etag")
mock_check.assert_not_called()

envs.DOWNLOADER_FORCE_DOWNLOAD = False
envs.DOWNLOADER_CHECK_FILE_EXIST = False
need_to_download("file", "file.metadata", 10, "etag")
mock_check.assert_not_called()

# recover envs
envs.DOWNLOADER_FORCE_DOWNLOAD = origin_force_download_env
envs.DOWNLOADER_CHECK_FILE_EXIST = origin_check_file_exist


@mock.patch(CHECK_FUNC)
def test_need_to_download(mock_check: mock.Mock):
origin_force_download_env = envs.DOWNLOADER_FORCE_DOWNLOAD
origin_check_file_exist = envs.DOWNLOADER_CHECK_FILE_EXIST

envs.DOWNLOADER_FORCE_DOWNLOAD = False
envs.DOWNLOADER_CHECK_FILE_EXIST = True
file_exist = True
# file exist, no need to download
mock_check.return_value = file_exist
assert not need_to_download("file", "file.metadata", 10, "etag")
mock_check.assert_called_once()
mock_check.reset_mock()

# file not exist, need to download
mock_check.return_value = not file_exist
assert need_to_download("file", "file.metadata", 10, "etag")
mock_check.assert_called_once()
mock_check.reset_mock()

# recover envs
envs.DOWNLOADER_FORCE_DOWNLOAD = origin_force_download_env
envs.DOWNLOADER_CHECK_FILE_EXIST = origin_check_file_exist

0 comments on commit 81541ab

Please sign in to comment.