Skip to content

Commit

Permalink
[Core] Add Downloader implementation for runtime (#96)
Browse files Browse the repository at this point in the history
* feat: add huggingface downloader implement

* feat: add tos downloader implement

* fix: fix download single file from tos

* feat: add s3 downloader implement

* feat: add progress bar into s3 and tos downloader

* ci: add ruff format check

* style: remove unused import

* fix: huggingface downloader init model name

* refact: refact tos model uri use tos protocol

* refact: use bucket path

* style: fix code style

* refact: remove .complete file
  • Loading branch information
brosoul authored Aug 29, 2024
1 parent 5e306fa commit c834ae4
Show file tree
Hide file tree
Showing 8 changed files with 937 additions and 48 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
run: |
cd python/aibrix
python -m ruff check .
python -m ruff format --check .
- name: Run isort
run: |
cd python/aibrix
Expand Down
50 changes: 35 additions & 15 deletions python/aibrix/aibrix/downloader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,19 @@ class BaseDownloader(ABC):

model_uri: str
model_name: str
required_envs: List[str] = field(default_factory=list)
optional_envs: List[str] = field(default_factory=list)
allow_file_suffix: List[str] = field(default_factory=list)
bucket_path: str
bucket_name: Optional[str]
allow_file_suffix: Optional[List[str]] = field(
default_factory=lambda: envs.DOWNLOADER_ALLOW_FILE_SUFFIX
)

def __post_init__(self):
# ensure downloader required envs are set
self._check_config()
# valid downloader config
self._valid_config()
self.model_name_path = self.model_name.replace("/", "_")

@abstractmethod
def _check_config(self):
def _valid_config(self):
pass

@abstractmethod
Expand All @@ -59,7 +61,13 @@ def _support_range_download(self) -> bool:
pass

@abstractmethod
def download(self, path: str, local_path: Path, enable_range: bool = True):
def download(
self,
local_path: Path,
bucket_path: str,
bucket_name: Optional[str] = None,
enable_range: bool = True,
):
pass

def download_directory(self, local_path: Path):
Expand All @@ -68,9 +76,9 @@ def download_directory(self, local_path: Path):
directory method for ``Downloader``. Otherwise, the following logic will be
used to download the directory.
"""
directory_list = self._directory_list(self.model_uri)
if len(self.allow_file_suffix) == 0:
logger.info("All files from {self.model_uri} will be downloaded.")
directory_list = self._directory_list(self.bucket_path)
if self.allow_file_suffix is None:
logger.info(f"All files from {self.bucket_path} will be downloaded.")
filtered_files = directory_list
else:
filtered_files = [
Expand All @@ -91,7 +99,11 @@ def download_directory(self, local_path: Path):
executor = ThreadPoolExecutor(num_threads)
futures = [
executor.submit(
self.download, path=file, local_path=local_path, enable_range=False
self.download,
local_path=local_path,
bucket_path=file,
bucket_name=self.bucket_name,
enable_range=False,
)
for file in filtered_files
]
Expand All @@ -108,7 +120,7 @@ def download_directory(self, local_path: Path):
st = time.perf_counter()
for file in filtered_files:
# use range download to speedup download
self.download(file, local_path, True)
self.download(local_path, file, self.bucket_name, True)
duration = time.perf_counter() - st
logger.info(
f"Downloader {self.__class__.__name__} download "
Expand All @@ -117,7 +129,7 @@ def download_directory(self, local_path: Path):
f"duration: {duration:.2f} seconds."
)

def download_model(self, local_path: Optional[str]):
def download_model(self, local_path: Optional[str] = None):
if local_path is None:
local_path = envs.DOWNLOADER_LOCAL_DIR
Path(local_path).mkdir(parents=True, exist_ok=True)
Expand All @@ -129,11 +141,19 @@ def download_model(self, local_path: Optional[str]):
# TODO check local file exists

if self._is_directory():
self.download_directory(model_path)
self.download_directory(local_path=model_path)
else:
self.download(self.model_uri, model_path)
self.download(
local_path=model_path,
bucket_path=self.bucket_path,
bucket_name=self.bucket_name,
enable_range=self._support_range_download(),
)
return model_path

def __hash__(self):
return hash(tuple(self.__dict__))


def get_downloader(model_uri: str) -> BaseDownloader:
"""Get downloader for model_uri."""
Expand Down
94 changes: 83 additions & 11 deletions python/aibrix/aibrix/downloader/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,99 @@
# limitations under the License.

from pathlib import Path
from typing import List
from typing import List, Optional

from huggingface_hub import HfApi, hf_hub_download, snapshot_download

from aibrix import envs
from aibrix.downloader.base import BaseDownloader
from aibrix.logger import init_logger

logger = init_logger(__name__)


def _parse_model_name_from_uri(model_uri: str) -> str:
return model_uri


class HuggingFaceDownloader(BaseDownloader):
def __init__(self, model_uri):
super().__init__(model_uri)
def __init__(self, model_uri: str, model_name: Optional[str] = None):
if model_name is None:
if envs.DOWNLOADER_MODEL_NAME is not None:
model_name = envs.DOWNLOADER_MODEL_NAME
else:
model_name = _parse_model_name_from_uri(model_uri)

self.hf_token = envs.DOWNLOADER_HF_TOKEN
self.hf_endpoint = envs.DOWNLOADER_HF_ENDPOINT
self.hf_revision = envs.DOWNLOADER_HF_REVISION

super().__init__(
model_uri=model_uri,
model_name=model_name,
bucket_path=model_uri,
bucket_name=None,
) # type: ignore

def _check_config(self):
pass
# Dependent on the attributes generated in the base class,
# so place it after the super().__init__() call.
self.allow_patterns = (
None
if self.allow_file_suffix is None
else [f"*.{suffix}" for suffix in self.allow_file_suffix]
)
logger.debug(
f"Downloader {self.__class__.__name__} initialized."
f"HF Settings are followed: \n"
f"hf_token={self.hf_token}, \n"
f"hf_endpoint={self.hf_endpoint}"
)

def _valid_config(self):
assert (
len(self.model_uri.split("/")) == 2
), "Model uri must be in `repo/name` format."
assert self.bucket_name is None, "Bucket name is empty in HuggingFace."
assert self.model_name is not None, "Model name is not set."

def _is_directory(self) -> bool:
"""Check if model_uri is a directory."""
return False
"""Check if model_uri is a directory.
model_uri in `repo/name` format must be a directory.
"""
return True

def _directory_list(self, path: str) -> List[str]:
return []
hf_api = HfApi(endpoint=self.hf_endpoint, token=self.hf_token)
return hf_api.list_repo_files(repo_id=self.model_uri)

def _support_range_download(self) -> bool:
return True
return False

def download(
self,
local_path: Path,
bucket_path: str,
bucket_name: Optional[str] = None,
enable_range: bool = True,
):
hf_hub_download(
repo_id=self.model_uri,
filename=bucket_path,
local_dir=local_path,
revision=self.hf_revision,
token=self.hf_token,
endpoint=self.hf_endpoint,
local_dir_use_symlinks=False,
)

def download(self, path: str, local_path: Path, enable_range: bool = True):
pass
def download_directory(self, local_path: Path):
snapshot_download(
self.model_uri,
local_dir=local_path,
revision=self.hf_revision,
token=self.hf_token,
allow_patterns=self.allow_patterns,
max_workers=envs.DOWNLOADER_NUM_THREADS,
endpoint=self.hf_endpoint,
local_dir_use_symlinks=False,
)
114 changes: 105 additions & 9 deletions python/aibrix/aibrix/downloader/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,125 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import lru_cache
from pathlib import Path
from typing import List
from typing import List, Optional, Tuple
from urllib.parse import urlparse

import boto3
from boto3.s3.transfer import TransferConfig
from tqdm import tqdm

from aibrix import envs
from aibrix.downloader.base import BaseDownloader


def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]:
parsed = urlparse(uri, scheme="s3")
bucket_name = parsed.netloc
bucket_path = parsed.path.lstrip("/")
return bucket_name, bucket_path


class S3Downloader(BaseDownloader):
def __init__(self, model_uri):
super().__init__(model_uri)
model_name = envs.DOWNLOADER_MODEL_NAME
ak = envs.DOWNLOADER_AWS_ACCESS_KEY
sk = envs.DOWNLOADER_AWS_SECRET_KEY
endpoint = envs.DOWNLOADER_AWS_ENDPOINT
region = envs.DOWNLOADER_AWS_REGION
bucket_name, bucket_path = _parse_bucket_info_from_uri(model_uri)

def _check_config(self):
pass
self.client = boto3.client(
service_name="s3",
region_name=region,
endpoint_url=endpoint,
aws_access_key_id=ak,
aws_secret_access_key=sk,
)

super().__init__(
model_uri=model_uri,
model_name=model_name,
bucket_path=bucket_path,
bucket_name=bucket_name,
) # type: ignore

def _valid_config(self):
assert (
self.bucket_name is not None or self.bucket_name == ""
), "S3 bucket name is not set."
assert (
self.bucket_path is not None or self.bucket_path == ""
), "S3 bucket path is not set."
try:
self.client.head_bucket(Bucket=self.bucket_name)
except Exception as e:
assert False, f"S3 bucket {self.bucket_name} not exist for {e}."

@lru_cache()
def _is_directory(self) -> bool:
"""Check if model_uri is a directory."""
return False
if self.bucket_path.endswith("/"):
return True
objects_out = self.client.list_objects_v2(
Bucket=self.bucket_name, Delimiter="/", Prefix=self.bucket_path
)
contents = objects_out.get("Contents", [])
if len(contents) == 1 and contents[0].get("Key") == self.bucket_path:
return False
return True

def _directory_list(self, path: str) -> List[str]:
return []
objects_out = self.client.list_objects_v2(
Bucket=self.bucket_name, Delimiter="/", Prefix=path
)
contents = objects_out.get("Contents", [])
return [content.get("Key") for content in contents]

def _support_range_download(self) -> bool:
return True

def download(self, path: str, local_path: Path, enable_range: bool = True):
pass
def download(
self,
local_path: Path,
bucket_path: str,
bucket_name: Optional[str] = None,
enable_range: bool = True,
):
# check if file exist
try:
meta_data = self.client.head_object(Bucket=bucket_name, Key=bucket_path)
except Exception as e:
raise ValueError(f"S3 bucket path {bucket_path} not exist for {e}.")

_file_name = bucket_path.split("/")[-1]
# S3 client does not support Path, convert it to str
local_file = str(local_path.joinpath(_file_name).absolute())

# construct TransferConfig
config_kwargs = {
"max_concurrency": envs.DOWNLOADER_NUM_THREADS,
"use_threads": enable_range,
}
if envs.DOWNLOADER_PART_THRESHOLD is not None:
config_kwargs["multipart_threshold"] = envs.DOWNLOADER_PART_THRESHOLD
if envs.DOWNLOADER_PART_CHUNKSIZE is not None:
config_kwargs["multipart_chunksize"] = envs.DOWNLOADER_PART_CHUNKSIZE

config = TransferConfig(**config_kwargs)

# download file
total_length = int(meta_data.get("ContentLength", 0))
with tqdm(total=total_length, unit="b", unit_scale=True) as pbar:

def download_progress(bytes_transferred):
pbar.update(bytes_transferred)

self.client.download_file(
Bucket=bucket_name,
Key=bucket_path,
Filename=local_file,
Config=config,
Callback=download_progress,
)
Loading

0 comments on commit c834ae4

Please sign in to comment.