Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add download status into runtime downloader #539

Merged
merged 8 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions python/aibrix/aibrix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@


DEFAULT_METRIC_COLLECTOR_TIMEOUT = 1

DOWNLOAD_CACHE_DIR = ".cache"
"""
DOWNLOAD CACHE DIR would be like:
.
└── .cache
└── huggingface | s3 | tos
├── .gitignore
└── download
"""
DOWNLOAD_CACHE_DIR = ".cache/%s/download"
DOWNLOAD_FILE_LOCK_CHECK_TIMEOUT = 10

EXCLUDE_METRICS_HTTP_ENDPOINTS = ["/metrics/"]
198 changes: 198 additions & 0 deletions python/aibrix/aibrix/downloader/entity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import contextlib
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Generator, List, Optional

from filelock import BaseFileLock, FileLock, Timeout

from aibrix.config import DOWNLOAD_CACHE_DIR, DOWNLOAD_FILE_LOCK_CHECK_TIMEOUT
from aibrix.logger import init_logger

logger = init_logger(__name__)


class RemoteSource(Enum):
S3 = "s3"
TOS = "tos"
HUGGINGFACE = "huggingface"


class DownloadStatus(Enum):
DOWNLOADING = "downloading"
DOWNLOADED = "downloaded"
NO_OPERATION = "no_operation" # Interrupted from downloading
UNKNOWN = "unknown"


@dataclass
class DownloadFile:
file_path: Path
lock_path: Path
metadata_path: Path

@property
def status(self):
if self.file_path.exists() and self.metadata_path.exists():
return DownloadStatus.DOWNLOADED

try:
# Downloading process will acquire the lock,
# and other process will raise Timeout Exception
lock = FileLock(self.lock_path)
lock.acquire(blocking=False)
except Timeout:
return DownloadStatus.DOWNLOADING
except Exception as e:
logger.warning(f"Failed to acquire lock failed for error: {e}")
return DownloadStatus.UNKNOWN
else:
return DownloadStatus.NO_OPERATION

@contextlib.contextmanager
def download_lock(self) -> Generator[BaseFileLock, None, None]:
"""A filelock that download process should be acquired.
Same implementation as WeakFileLock in huggingface_hub."""
lock = FileLock(self.lock_path, timeout=DOWNLOAD_FILE_LOCK_CHECK_TIMEOUT)
while True:
try:
lock.acquire()
except Timeout:
logger.info(
f"Still waiting to acquire download lock on {self.lock_path}"
)
else:
break

yield lock

try:
return lock.release()
except OSError:
try:
Path(self.lock_path).unlink()
except OSError:
pass


@dataclass
class DownloadModel:
model_source: RemoteSource
local_path: Path
model_name: str
download_files: List[DownloadFile]

def __post_init__(self):
if len(self.download_files) == 0:
logger.warning(f"No download files found for model {self.model_name}")

@property
def status(self):
all_status = [file.status for file in self.download_files]
if all(status == DownloadStatus.DOWNLOADED for status in all_status):
return DownloadStatus.DOWNLOADED

if any(status == DownloadStatus.DOWNLOADING for status in all_status):
return DownloadStatus.DOWNLOADING

if all(
status in [DownloadStatus.DOWNLOADED, DownloadStatus.NO_OPERATION]
for status in all_status
):
return DownloadStatus.NO_OPERATION

return DownloadStatus.UNKNOWN

@classmethod
def infer_from_model_path(
cls, local_path: Path, model_name: str, source: RemoteSource
) -> Optional["DownloadModel"]:
assert source is not None

model_base_dir = Path(local_path).joinpath(model_name)
cache_sub_dir = (DOWNLOAD_CACHE_DIR % source.value).strip("/")
cache_dir = Path(model_base_dir).joinpath(cache_sub_dir)
lock_files = list(Path(cache_dir).glob("*.lock"))

download_files = []
for lock_file in lock_files:
lock_name = lock_file.name
lock_suffix = ".lock"
if lock_name.endswith(lock_suffix):
filename = lock_name[: -len(lock_suffix)]
download_file = get_local_download_paths(
model_base_dir=model_base_dir, filename=filename, source=source
)
download_files.append(download_file)

return cls(
model_source=source,
local_path=local_path,
model_name=model_name,
download_files=download_files,
)

@classmethod
def infer_from_local_path(cls, local_path: Path) -> List["DownloadModel"]:
models: List["DownloadModel"] = []

def remove_suffix(input_string, suffix):
if not input_string.endswith(suffix):
return input_string
return input_string[: -len(suffix)]

for source in RemoteSource:
cache_sub_dir = (DOWNLOAD_CACHE_DIR % source.value).strip("/")
cache_dirs = list(Path(local_path).glob(f"**/{cache_sub_dir}"))
if not cache_dirs:
continue

for cache_dir in cache_dirs:
relative_path = cache_dir.relative_to(local_path)
relative_str = str(relative_path).strip("/")
model_name = remove_suffix(relative_str, cache_sub_dir).strip("/")

download_model = cls.infer_from_model_path(
local_path, model_name, source
)
if download_model is None:
continue
models.append(download_model)

return models

def __str__(self):
return (
"DownloadModel(\n"
+ f"\tmodel_source={self.model_source.value},\n"
+ f"\tmodel_name={self.model_name},\n"
+ f"\tstatus={self.status.value},\n"
+ f"\tlocal_path={self.local_path},\n"
+ f"\tdownload_files_count={len(self.download_files)}\n)"
)


def get_local_download_paths(
model_base_dir: Path, filename: str, source: RemoteSource
) -> DownloadFile:
file_path = model_base_dir.joinpath(filename)
sub_cache_dir = (DOWNLOAD_CACHE_DIR % source.value).strip("/")
cache_dir = model_base_dir.joinpath(sub_cache_dir)
lock_path = cache_dir.joinpath(f"{filename}.lock")
metadata_path = cache_dir.joinpath(f"{filename}.metadata")
return DownloadFile(file_path, lock_path, metadata_path)
31 changes: 21 additions & 10 deletions python/aibrix/aibrix/downloader/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from aibrix import envs
from aibrix.downloader.base import BaseDownloader
from aibrix.downloader.entity import RemoteSource, get_local_download_paths
from aibrix.downloader.utils import (
infer_model_name,
meta_file,
Expand All @@ -44,6 +45,8 @@ def _parse_bucket_info_from_uri(uri: str, scheme: str = "s3") -> Tuple[str, str]


class S3BaseDownloader(BaseDownloader):
_source: RemoteSource = RemoteSource.S3

def __init__(
self,
scheme: str,
Expand Down Expand Up @@ -144,7 +147,9 @@ def download(
# check if file exist
etag = meta_data.get("ETag", "")
file_size = meta_data.get("ContentLength", 0)
meta_data_file = meta_file(local_path=local_path, file_name=_file_name)
meta_data_file = meta_file(
local_path=local_path, file_name=_file_name, source=self._source.value
)

if not need_to_download(local_file, meta_data_file, file_size, etag):
return
Expand Down Expand Up @@ -172,19 +177,25 @@ def download(
def download_progress(bytes_transferred):
pbar.update(bytes_transferred)

self.client.download_file(
Bucket=bucket_name,
Key=bucket_path,
Filename=str(
local_file
), # S3 client does not support Path, convert it to str
Config=config,
Callback=download_progress if self.enable_progress_bar else None,
download_file = get_local_download_paths(
local_path, _file_name, self._source
)
save_meta_data(meta_data_file, etag)
with download_file.download_lock():
self.client.download_file(
Bucket=bucket_name,
Key=bucket_path,
Filename=str(
local_file
), # S3 client does not support Path, convert it to str
Config=config,
Callback=download_progress if self.enable_progress_bar else None,
)
save_meta_data(meta_data_file, etag)


class S3Downloader(S3BaseDownloader):
_source: RemoteSource = RemoteSource.S3

def __init__(
self,
model_uri,
Expand Down
40 changes: 25 additions & 15 deletions python/aibrix/aibrix/downloader/tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from aibrix import envs
from aibrix.downloader.base import BaseDownloader
from aibrix.downloader.entity import RemoteSource, get_local_download_paths
from aibrix.downloader.s3 import S3BaseDownloader
from aibrix.downloader.utils import (
infer_model_name,
Expand All @@ -46,6 +47,8 @@ def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]:


class TOSDownloaderV1(BaseDownloader):
_source: RemoteSource = RemoteSource.TOS

def __init__(
self,
model_uri,
Expand Down Expand Up @@ -133,7 +136,9 @@ def download(
# check if file exist
etag = meta_data.etag
file_size = meta_data.content_length
meta_data_file = meta_file(local_path=local_path, file_name=_file_name)
meta_data_file = meta_file(
local_path=local_path, file_name=_file_name, source=self._source.value
)

if not need_to_download(local_file, meta_data_file, file_size, etag):
return
Expand All @@ -147,7 +152,6 @@ def download(
# download file
total_length = meta_data.content_length

nullcontext
with tqdm(
desc=_file_name, total=total_length, unit="b", unit_scale=True
) if self.enable_progress_bar else nullcontext() as pbar:
Expand All @@ -157,30 +161,36 @@ def download_progress(
):
pbar.update(rw_once_bytes)

self.client.download_file(
bucket=bucket_name,
key=bucket_path,
file_path=str(
local_file
), # TOS client does not support Path, convert it to str
task_num=task_num,
data_transfer_listener=download_progress
if self.enable_progress_bar
else None,
**download_kwargs,
download_file = get_local_download_paths(
local_path, _file_name, self._source
)
save_meta_data(meta_data_file, etag)
with download_file.download_lock():
self.client.download_file(
bucket=bucket_name,
key=bucket_path,
file_path=str(
local_file
), # TOS client does not support Path, convert it to str
task_num=task_num,
data_transfer_listener=download_progress
if self.enable_progress_bar
else None,
**download_kwargs,
)
save_meta_data(meta_data_file, etag)


class TOSDownloaderV2(S3BaseDownloader):
_source: RemoteSource = RemoteSource.TOS

def __init__(
self,
model_uri,
model_name: Optional[str] = None,
enable_progress_bar: bool = False,
):
super().__init__(
scheme="s3",
scheme="tos",
model_uri=model_uri,
model_name=model_name,
enable_progress_bar=enable_progress_bar,
Expand Down
4 changes: 2 additions & 2 deletions python/aibrix/aibrix/downloader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
logger = init_logger(__name__)


def meta_file(local_path: Union[Path, str], file_name: str) -> Path:
def meta_file(local_path: Union[Path, str], file_name: str, source: str) -> Path:
return (
Path(local_path)
.joinpath(DOWNLOAD_CACHE_DIR)
.joinpath(DOWNLOAD_CACHE_DIR % source)
.joinpath(f"{file_name}.metadata")
.absolute()
)
Expand Down
4 changes: 2 additions & 2 deletions python/aibrix/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions python/aibrix/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ incdbscan = "^0.1.0"
aiohttp = "^3.11.7"
dash = "^2.18.2"
matplotlib = "^3.9.2"
filelock = "^3.16.1"


[tool.poetry.group.dev.dependencies]
Expand Down
Loading
Loading