Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AI Runtime exist model check #198

Merged
merged 6 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/aibrix/aibrix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@


DEFAULT_METRIC_COLLECTOR_TIMEOUT = 1

DOWNLOAD_CACHE_DIR = ".cache"
2 changes: 2 additions & 0 deletions python/aibrix/aibrix/downloader/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def download(
token=self.hf_token,
endpoint=self.hf_endpoint,
local_dir_use_symlinks=False,
force_download=envs.DOWNLOADER_FORCE_DOWNLOAD,
)

def download_directory(self, local_path: Path):
Expand All @@ -111,4 +112,5 @@ def download_directory(self, local_path: Path):
max_workers=envs.DOWNLOADER_NUM_THREADS,
endpoint=self.hf_endpoint,
local_dir_use_symlinks=False,
force_download=envs.DOWNLOADER_FORCE_DOWNLOAD,
)
25 changes: 20 additions & 5 deletions python/aibrix/aibrix/downloader/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

from aibrix import envs
from aibrix.downloader.base import BaseDownloader
from aibrix.downloader.utils import meta_file, need_to_download, save_meta_data
from aibrix.logger import init_logger

logger = init_logger(__name__)


def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]:
Expand Down Expand Up @@ -110,15 +114,21 @@ def download(
bucket_name: Optional[str] = None,
enable_range: bool = True,
):
# check if file exist
try:
meta_data = self.client.head_object(Bucket=bucket_name, Key=bucket_path)
except Exception as e:
raise ValueError(f"S3 bucket path {bucket_path} not exist for {e}.")

_file_name = bucket_path.split("/")[-1]
# S3 client does not support Path, convert it to str
local_file = str(local_path.joinpath(_file_name).absolute())
local_file = local_path.joinpath(_file_name).absolute()

# check if file exist
etag = meta_data.get("ETag", "")
file_size = meta_data.get("ContentLength", 0)
meta_data_file = meta_file(local_path=local_path, file_name=_file_name)

if not need_to_download(local_file, meta_data_file, file_size, etag):
return

# construct TransferConfig
config_kwargs = {
Expand All @@ -134,15 +144,20 @@ def download(

# download file
total_length = int(meta_data.get("ContentLength", 0))
with tqdm(total=total_length, unit="b", unit_scale=True) as pbar:
with tqdm(
desc=_file_name, total=total_length, unit="b", unit_scale=True
) as pbar:

def download_progress(bytes_transferred):
pbar.update(bytes_transferred)

self.client.download_file(
Bucket=bucket_name,
Key=bucket_path,
Filename=local_file,
Filename=str(
local_file
), # S3 client does not support Path, convert it to str
Config=config,
Callback=download_progress,
)
save_meta_data(meta_data_file, etag)
25 changes: 20 additions & 5 deletions python/aibrix/aibrix/downloader/tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@

from aibrix import envs
from aibrix.downloader.base import BaseDownloader
from aibrix.downloader.utils import meta_file, need_to_download, save_meta_data
from aibrix.logger import init_logger

tos_logger = logging.getLogger("tos")
tos_logger.setLevel(logging.WARNING)
logger = init_logger(__name__)


def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]:
Expand Down Expand Up @@ -100,15 +103,22 @@ def download(
bucket_name: Optional[str] = None,
enable_range: bool = True,
):
# check if file exist
try:
meta_data = self.client.head_object(bucket=bucket_name, key=bucket_path)
except Exception as e:
raise ValueError(f"TOS bucket path {bucket_path} not exist for {e}.")

_file_name = bucket_path.split("/")[-1]
# TOS client does not support Path, convert it to str
local_file = str(local_path.joinpath(_file_name).absolute())
local_file = local_path.joinpath(_file_name).absolute()

# check if file exist
etag = meta_data.etag
file_size = meta_data.content_length
meta_data_file = meta_file(local_path=local_path, file_name=_file_name)

if not need_to_download(local_file, meta_data_file, file_size, etag):
return

task_num = envs.DOWNLOADER_NUM_THREADS if enable_range else 1

download_kwargs = {}
Expand All @@ -117,7 +127,9 @@ def download(

# download file
total_length = meta_data.content_length
with tqdm(total=total_length, unit="b", unit_scale=True) as pbar:
with tqdm(
desc=_file_name, total=total_length, unit="b", unit_scale=True
) as pbar:

def download_progress(
consumed_bytes, total_bytes, rw_once_bytes, type: DataTransferType
Expand All @@ -127,8 +139,11 @@ def download_progress(
self.client.download_file(
bucket=bucket_name,
key=bucket_path,
file_path=local_file,
file_path=str(
local_file
), # TOS client does not support Path, convert it to str
task_num=task_num,
data_transfer_listener=download_progress,
**download_kwargs,
)
save_meta_data(meta_data_file, etag)
95 changes: 95 additions & 0 deletions python/aibrix/aibrix/downloader/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
from typing import Union

from aibrix import envs
from aibrix.config import DOWNLOAD_CACHE_DIR
from aibrix.logger import init_logger


logger = init_logger(__name__)


def meta_file(local_path: Union[Path, str], file_name: str) -> Path:
return (
Path(local_path)
.joinpath(DOWNLOAD_CACHE_DIR)
.joinpath(f"{file_name}.metadata")
.absolute()
)


def save_meta_data(file_path: Union[Path, str], etag: str):
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "w") as f:
f.write(etag)


def load_meta_data(file_path: Union[Path, str]):
if Path(file_path).exists():
with open(file_path, "r") as f:
return f.read()
return None


def check_file_exist(
local_file: Union[Path, str],
meta_file: Union[Path, str],
expected_file_size: int,
expected_etag: str,
) -> bool:
if expected_file_size is None or expected_file_size <= 0:
return False

if expected_etag is None or expected_etag == "":
return False

if not Path(local_file).exists():
return False

file_size = os.path.getsize(local_file)
if file_size != expected_file_size:
return False

if not Path(meta_file).exists():
return False

etag = load_meta_data(meta_file)
return etag == expected_etag


def need_to_download(
local_file: Union[Path, str],
meta_data_file: Union[Path, str],
expected_file_size: int,
expected_etag: str,
) -> bool:
_file_name = Path(local_file).name
if not envs.DOWNLOADER_FORCE_DOWNLOAD and envs.DOWNLOADER_CHECK_FILE_EXIST:
if check_file_exist(
local_file, meta_data_file, expected_file_size, expected_etag
):
logger.info(f"File {_file_name} exist in local, skip download.")
return False
else:
logger.info(f"File {_file_name} not exist in local, start to download...")
else:
logger.info(
f"File {_file_name} start downloading directly "
f"for DOWNLOADER_FORCE_DOWNLOAD={envs.DOWNLOADER_FORCE_DOWNLOAD}, "
f"DOWNLOADER_CHECK_FILE_EXIST={envs.DOWNLOADER_CHECK_FILE_EXIST}"
)
return True
3 changes: 3 additions & 0 deletions python/aibrix/aibrix/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def _parse_int_or_none(value: Optional[str]) -> Optional[int]:
os.getenv("DOWNLOADER_ALLOW_FILE_SUFFIX")
)

DOWNLOADER_FORCE_DOWNLOAD = _is_true(os.getenv("DOWNLOADER_FORCE_DOWNLOAD", "0"))
DOWNLOADER_CHECK_FILE_EXIST = _is_true(os.getenv("DOWNLOADER_CHECK_FILE_EXIST", "1"))

# Downloader Regex
DOWNLOADER_S3_REGEX = r"^s3://"
DOWNLOADER_TOS_REGEX = r"^tos://"
Expand Down
138 changes: 138 additions & 0 deletions python/aibrix/tests/downloader/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
from pathlib import Path
import tempfile
from unittest import mock

from aibrix import envs
from aibrix.config import DOWNLOAD_CACHE_DIR
from aibrix.downloader.utils import (
check_file_exist,
load_meta_data,
meta_file,
need_to_download,
save_meta_data,
)


def prepare_file_and_meta_data(file_path, meta_path, file_size, etag):
save_meta_data(meta_path, etag)
# create file
with open(file_path, "wb") as f:
f.write(os.urandom(file_size))


def test_meta_file():
with tempfile.TemporaryDirectory() as tmp_dir:
meta_file_path = meta_file(tmp_dir, "test")
assert str(meta_file_path).endswith(f"{DOWNLOAD_CACHE_DIR}/test.metadata")


def test_save_load_meta_data():
with tempfile.TemporaryDirectory() as tmp_dir:
file_path = Path(tmp_dir).joinpath("test.metadata")
etag = "here_is_etag_value_xyz"
save_meta_data(file_path, etag)
assert file_path.exists()

load_etag = load_meta_data(file_path)
assert etag == load_etag

not_exist_file = Path(tmp_dir).joinpath("not_exist.metadata")
not_exist_etag = load_meta_data(not_exist_file)
assert not_exist_etag is None


def test_check_file_exist():
with tempfile.TemporaryDirectory() as tmp_dir:
# prepare file and meta data
file_size = 10
file_name = "test"
etag = "here_is_etag_value_xyz"
file_path = Path(tmp_dir).joinpath(file_name)
# create meta data file
meta_path = meta_file(tmp_dir, file_name)
prepare_file_and_meta_data(file_path, meta_path, file_size, etag)

assert check_file_exist(file_path, meta_path, file_size, etag)
assert not check_file_exist(file_path, meta_path, None, etag)
assert not check_file_exist(file_path, meta_path, -1, etag)
assert not check_file_exist(file_path, meta_path, file_size, None)
# The remote etag has been updated
assert not check_file_exist(file_path, meta_path, file_size, "new_etag")

not_exist_file_name = "not_exist"
not_exist_file_path = Path(tmp_dir).joinpath(not_exist_file_name)
not_exist_meta_path = meta_file(tmp_dir, not_exist_file_name)
assert not check_file_exist(
not_exist_file_path, not_exist_meta_path, file_size, etag
)
# meta data not exist
assert not check_file_exist(file_path, not_exist_meta_path, file_size, etag)


CHECK_FUNC = "aibrix.downloader.utils.check_file_exist"


@mock.patch(CHECK_FUNC)
def test_need_to_download_not_call(mock_check: mock.Mock):
origin_force_download_env = envs.DOWNLOADER_FORCE_DOWNLOAD
origin_check_file_exist = envs.DOWNLOADER_CHECK_FILE_EXIST

envs.DOWNLOADER_FORCE_DOWNLOAD = True
envs.DOWNLOADER_CHECK_FILE_EXIST = False
need_to_download("file", "file.metadata", 10, "etag")
mock_check.assert_not_called()

envs.DOWNLOADER_FORCE_DOWNLOAD = True
envs.DOWNLOADER_CHECK_FILE_EXIST = True
need_to_download("file", "file.metadata", 10, "etag")
mock_check.assert_not_called()

envs.DOWNLOADER_FORCE_DOWNLOAD = False
envs.DOWNLOADER_CHECK_FILE_EXIST = False
need_to_download("file", "file.metadata", 10, "etag")
mock_check.assert_not_called()

# recover envs
envs.DOWNLOADER_FORCE_DOWNLOAD = origin_force_download_env
envs.DOWNLOADER_CHECK_FILE_EXIST = origin_check_file_exist


@mock.patch(CHECK_FUNC)
def test_need_to_download(mock_check: mock.Mock):
origin_force_download_env = envs.DOWNLOADER_FORCE_DOWNLOAD
origin_check_file_exist = envs.DOWNLOADER_CHECK_FILE_EXIST

envs.DOWNLOADER_FORCE_DOWNLOAD = False
envs.DOWNLOADER_CHECK_FILE_EXIST = True
file_exist = True
# file exist, no need to download
mock_check.return_value = file_exist
assert not need_to_download("file", "file.metadata", 10, "etag")
mock_check.assert_called_once()
mock_check.reset_mock()

# file not exist, need to download
mock_check.return_value = not file_exist
assert need_to_download("file", "file.metadata", 10, "etag")
mock_check.assert_called_once()
mock_check.reset_mock()

# recover envs
envs.DOWNLOADER_FORCE_DOWNLOAD = origin_force_download_env
envs.DOWNLOADER_CHECK_FILE_EXIST = origin_check_file_exist
Loading