From ae1351f3cde9063a4a6699450f694fa0d7d38374 Mon Sep 17 00:00:00 2001 From: brosoul Date: Mon, 23 Dec 2024 17:23:16 +0800 Subject: [PATCH 1/6] refact: add download extra config into downloader --- python/aibrix/aibrix/downloader/__init__.py | 7 +- python/aibrix/aibrix/downloader/__main__.py | 23 +++++- python/aibrix/aibrix/downloader/base.py | 74 ++++++++++++++++--- .../aibrix/aibrix/downloader/huggingface.py | 27 +++++-- python/aibrix/aibrix/downloader/s3.py | 45 +++++++---- python/aibrix/aibrix/downloader/tos.py | 46 ++++++++---- 6 files changed, 173 insertions(+), 49 deletions(-) diff --git a/python/aibrix/aibrix/downloader/__init__.py b/python/aibrix/aibrix/downloader/__init__.py index 01468ae3..351fb6ac 100644 --- a/python/aibrix/aibrix/downloader/__init__.py +++ b/python/aibrix/aibrix/downloader/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Dict, Optional from aibrix.downloader.base import get_downloader @@ -21,6 +21,7 @@ def download_model( model_uri: str, local_path: Optional[str] = None, model_name: Optional[str] = None, + download_extra_config: Optional[Dict] = None, enable_progress_bar: bool = False, ): """Download model from model_uri to local_path. @@ -30,7 +31,9 @@ def download_model( local_path (str): local path to save model. """ - downloader = get_downloader(model_uri, model_name, enable_progress_bar) + downloader = get_downloader( + model_uri, model_name, download_extra_config, enable_progress_bar + ) return downloader.download_model(local_path) diff --git a/python/aibrix/aibrix/downloader/__main__.py b/python/aibrix/aibrix/downloader/__main__.py index 9c8e83c0..7effd60b 100644 --- a/python/aibrix/aibrix/downloader/__main__.py +++ b/python/aibrix/aibrix/downloader/__main__.py @@ -1,8 +1,19 @@ import argparse +import json +from typing import Dict, Optional from aibrix.downloader import download_model +def str_to_dict(s) -> Optional[Dict]: + if s is None: + return None + try: + return json.loads(s) + except Exception as e: + raise ValueError(f"Invalid json string {s}") from e + + def main(): parser = argparse.ArgumentParser(description="Download model from HuggingFace") parser.add_argument( @@ -30,9 +41,19 @@ def main(): default=False, help="Enable download progress bar during downloading from TOS or S3", ) + parser.add_argument( + "--download-extra-config", + type=str_to_dict, + default=None, + help="Extra config for download, like auth config, parallel config, etc.", + ) args = parser.parse_args() download_model( - args.model_uri, args.local_dir, args.model_name, args.enable_progress_bar + args.model_uri, + args.local_dir, + args.model_name, + args.download_extra_config, + args.enable_progress_bar, ) diff --git a/python/aibrix/aibrix/downloader/base.py b/python/aibrix/aibrix/downloader/base.py index 732a451a..649e2e27 100644 --- a/python/aibrix/aibrix/downloader/base.py +++ b/python/aibrix/aibrix/downloader/base.py @@ -18,7 +18,7 @@ from concurrent.futures import ThreadPoolExecutor, wait from dataclasses import dataclass, field from pathlib import Path -from typing import List, Optional +from typing import Dict, List, Optional from aibrix import envs from aibrix.logger import init_logger @@ -26,6 +26,36 @@ logger = init_logger(__name__) +@dataclass +class DownloadExtraConfig: + """Downloader extra config.""" + + # Auth config for s3 or tos + ak: Optional[str] = None + sk: Optional[str] = None + endpoint: Optional[str] = None + region: Optional[str] = None + + # Auth config for huggingface + hf_endpoint: Optional[str] = None + hf_token: Optional[str] = None + hf_revision: Optional[str] = None + + # parrallel config + num_threads: Optional[int] = None + max_io_queue: Optional[int] = None + io_chunksize: Optional[int] = None + part_threshold: Optional[int] = None + part_chunksize: Optional[int] = None + + # other config + allow_file_suffix: Optional[List[str]] = None + force_download: Optional[bool] = None + + +DEFAULT_DOWNLOADER_EXTRA_CONFIG = DownloadExtraConfig() + + @dataclass class BaseDownloader(ABC): """Base class for downloader.""" @@ -34,15 +64,22 @@ class BaseDownloader(ABC): model_name: str bucket_path: str bucket_name: Optional[str] - enable_progress_bar: bool = False - allow_file_suffix: Optional[List[str]] = field( - default_factory=lambda: envs.DOWNLOADER_ALLOW_FILE_SUFFIX + download_extra_config: DownloadExtraConfig = field( + default_factory=DownloadExtraConfig ) + enable_progress_bar: bool = False def __post_init__(self): # valid downloader config self._valid_config() self.model_name_path = self.model_name + self.allow_file_suffix = ( + self.download_extra_config.allow_file_suffix + or envs.DOWNLOADER_ALLOW_FILE_SUFFIX + ) + self.force_download = ( + self.download_extra_config.force_download or envs.DOWNLOADER_FORCE_DOWNLOAD + ) @abstractmethod def _valid_config(self): @@ -93,7 +130,9 @@ def download_directory(self, local_path: Path): if not self._support_range_download(): # download using multi threads - num_threads = envs.DOWNLOADER_NUM_THREADS + num_threads = ( + self.download_extra_config.num_threads or envs.DOWNLOADER_NUM_THREADS + ) logger.info( f"Downloader {self.__class__.__name__} download " f"{len(filtered_files)} files from {self.model_uri} " @@ -157,23 +196,38 @@ def __hash__(self): def get_downloader( - model_uri: str, model_name: Optional[str] = None, enable_progress_bar: bool = False + model_uri: str, + model_name: Optional[str] = None, + download_extra_config: Optional[Dict] = None, + enable_progress_bar: bool = False, ) -> BaseDownloader: """Get downloader for model_uri.""" + download_config: DownloadExtraConfig = ( + DEFAULT_DOWNLOADER_EXTRA_CONFIG + if download_extra_config is None + else DownloadExtraConfig(**download_extra_config) + ) + if re.match(envs.DOWNLOADER_S3_REGEX, model_uri): from aibrix.downloader.s3 import S3Downloader - return S3Downloader(model_uri, model_name, enable_progress_bar) + return S3Downloader(model_uri, model_name, download_config, enable_progress_bar) elif re.match(envs.DOWNLOADER_TOS_REGEX, model_uri): if envs.DOWNLOADER_TOS_VERSION == "v1": from aibrix.downloader.tos import TOSDownloaderV1 - return TOSDownloaderV1(model_uri, model_name, enable_progress_bar) + return TOSDownloaderV1( + model_uri, model_name, download_config, enable_progress_bar + ) else: from aibrix.downloader.tos import TOSDownloaderV2 - return TOSDownloaderV2(model_uri, model_name, enable_progress_bar) + return TOSDownloaderV2( + model_uri, model_name, download_config, enable_progress_bar + ) else: from aibrix.downloader.huggingface import HuggingFaceDownloader - return HuggingFaceDownloader(model_uri, model_name, enable_progress_bar) + return HuggingFaceDownloader( + model_uri, model_name, download_config, enable_progress_bar + ) diff --git a/python/aibrix/aibrix/downloader/huggingface.py b/python/aibrix/aibrix/downloader/huggingface.py index dc4de73d..3761ff2c 100644 --- a/python/aibrix/aibrix/downloader/huggingface.py +++ b/python/aibrix/aibrix/downloader/huggingface.py @@ -18,7 +18,11 @@ from huggingface_hub import HfApi, hf_hub_download, snapshot_download from aibrix import envs -from aibrix.downloader.base import BaseDownloader +from aibrix.downloader.base import ( + DEFAULT_DOWNLOADER_EXTRA_CONFIG, + BaseDownloader, + DownloadExtraConfig, +) from aibrix.logger import init_logger logger = init_logger(__name__) @@ -33,15 +37,20 @@ def __init__( self, model_uri: str, model_name: Optional[str] = None, + download_extra_config: DownloadExtraConfig = DEFAULT_DOWNLOADER_EXTRA_CONFIG, enable_progress_bar: bool = False, ): if model_name is None: model_name = _parse_model_name_from_uri(model_uri) logger.info(f"model_name is not set, using `{model_name}` as model_name") - self.hf_token = envs.DOWNLOADER_HF_TOKEN - self.hf_endpoint = envs.DOWNLOADER_HF_ENDPOINT - self.hf_revision = envs.DOWNLOADER_HF_REVISION + self.hf_token = self.download_extra_config.hf_token or envs.DOWNLOADER_HF_TOKEN + self.hf_endpoint = ( + self.download_extra_config.hf_endpoint or envs.DOWNLOADER_HF_ENDPOINT + ) + self.hf_revision = ( + self.download_extra_config.hf_revision or envs.DOWNLOADER_HF_REVISION + ) self.hf_api = HfApi(endpoint=self.hf_endpoint, token=self.hf_token) super().__init__( @@ -49,6 +58,7 @@ def __init__( model_name=model_name, bucket_path=model_uri, bucket_name=None, + download_extra_config=download_extra_config, enable_progress_bar=enable_progress_bar, ) # type: ignore @@ -103,18 +113,21 @@ def download( token=self.hf_token, endpoint=self.hf_endpoint, local_dir_use_symlinks=False, - force_download=envs.DOWNLOADER_FORCE_DOWNLOAD, + force_download=self.force_download, ) def download_directory(self, local_path: Path): + max_workers = ( + self.download_extra_config.num_threads or envs.DOWNLOADER_NUM_THREADS + ) snapshot_download( self.model_uri, local_dir=local_path, revision=self.hf_revision, token=self.hf_token, allow_patterns=self.allow_patterns, - max_workers=envs.DOWNLOADER_NUM_THREADS, + max_workers=max_workers, endpoint=self.hf_endpoint, local_dir_use_symlinks=False, - force_download=envs.DOWNLOADER_FORCE_DOWNLOAD, + force_download=self.force_download, ) diff --git a/python/aibrix/aibrix/downloader/s3.py b/python/aibrix/aibrix/downloader/s3.py index b46868d0..975e51f5 100644 --- a/python/aibrix/aibrix/downloader/s3.py +++ b/python/aibrix/aibrix/downloader/s3.py @@ -24,7 +24,11 @@ from tqdm import tqdm from aibrix import envs -from aibrix.downloader.base import BaseDownloader +from aibrix.downloader.base import ( + DEFAULT_DOWNLOADER_EXTRA_CONFIG, + BaseDownloader, + DownloadExtraConfig, +) from aibrix.downloader.entity import RemoteSource, get_local_download_paths from aibrix.downloader.utils import ( infer_model_name, @@ -52,6 +56,7 @@ def __init__( scheme: str, model_uri: str, model_name: Optional[str] = None, + download_extra_config: DownloadExtraConfig = DEFAULT_DOWNLOADER_EXTRA_CONFIG, enable_progress_bar: bool = False, ): if model_name is None: @@ -63,10 +68,12 @@ def __init__( # Avoid warning log "Connection pool is full" # Refs: https://github.com/boto/botocore/issues/619#issuecomment-583511406 + _num_threads = ( + self.download_extra_config.num_threads or envs.DOWNLOADER_NUM_THREADS + ) + max_pool_connections = ( - envs.DOWNLOADER_NUM_THREADS - if envs.DOWNLOADER_NUM_THREADS > MAX_POOL_CONNECTIONS - else MAX_POOL_CONNECTIONS + _num_threads if _num_threads > MAX_POOL_CONNECTIONS else MAX_POOL_CONNECTIONS ) client_config = Config( s3={"addressing_style": "virtual"}, @@ -82,6 +89,7 @@ def __init__( model_name=model_name, bucket_path=bucket_path, bucket_name=bucket_name, + extra_download_config=download_extra_config, enable_progress_bar=enable_progress_bar, ) # type: ignore @@ -156,15 +164,18 @@ def download( # construct TransferConfig config_kwargs = { - "max_concurrency": envs.DOWNLOADER_NUM_THREADS, + "max_concurrency": self.download_extra_config.num_threads + or envs.DOWNLOADER_NUM_THREADS, "use_threads": enable_range, - "max_io_queue": envs.DOWNLOADER_S3_MAX_IO_QUEUE, - "io_chunksize": envs.DOWNLOADER_S3_IO_CHUNKSIZE, + "max_io_queue": self.download_extra_config.max_io_queue + or envs.DOWNLOADER_S3_MAX_IO_QUEUE, + "io_chunksize": self.download_extra_config.io_chunksize + or envs.DOWNLOADER_S3_IO_CHUNKSIZE, + "multipart_threshold": self.download_extra_config.part_threshold + or envs.DOWNLOADER_PART_THRESHOLD, + "multipart_chunksize": self.download_extra_config.part_chunksize + or envs.DOWNLOADER_PART_CHUNKSIZE, } - if envs.DOWNLOADER_PART_THRESHOLD is not None: - config_kwargs["multipart_threshold"] = envs.DOWNLOADER_PART_THRESHOLD - if envs.DOWNLOADER_PART_CHUNKSIZE is not None: - config_kwargs["multipart_chunksize"] = envs.DOWNLOADER_PART_CHUNKSIZE config = TransferConfig(**config_kwargs) @@ -200,12 +211,14 @@ def __init__( self, model_uri, model_name: Optional[str] = None, + download_extra_config: DownloadExtraConfig = DEFAULT_DOWNLOADER_EXTRA_CONFIG, enable_progress_bar: bool = False, ): super().__init__( scheme="s3", model_uri=model_uri, model_name=model_name, + download_extra_config=download_extra_config, enable_progress_bar=enable_progress_bar, ) # type: ignore @@ -226,15 +239,17 @@ def _valid_config(self): def _get_auth_config(self) -> Dict[str, Optional[str]]: ak, sk = ( - envs.DOWNLOADER_AWS_ACCESS_KEY_ID, - envs.DOWNLOADER_AWS_SECRET_ACCESS_KEY, + self.download_extra_config.ak or envs.DOWNLOADER_AWS_ACCESS_KEY_ID, + self.download_extra_config.sk or envs.DOWNLOADER_AWS_SECRET_ACCESS_KEY, ) assert ak is not None and ak != "", "`AWS_ACCESS_KEY_ID` is not set." assert sk is not None and sk != "", "`AWS_SECRET_ACCESS_KEY` is not set." return { - "region_name": envs.DOWNLOADER_AWS_REGION, - "endpoint_url": envs.DOWNLOADER_AWS_ENDPOINT_URL, + "region_name": self.download_extra_config.region + or envs.DOWNLOADER_AWS_REGION, + "endpoint_url": self.download_extra_config.endpoint + or envs.DOWNLOADER_AWS_ENDPOINT_URL, "aws_access_key_id": ak, "aws_secret_access_key": sk, } diff --git a/python/aibrix/aibrix/downloader/tos.py b/python/aibrix/aibrix/downloader/tos.py index 12b5b5f3..6979257d 100644 --- a/python/aibrix/aibrix/downloader/tos.py +++ b/python/aibrix/aibrix/downloader/tos.py @@ -23,7 +23,11 @@ from tqdm import tqdm from aibrix import envs -from aibrix.downloader.base import BaseDownloader +from aibrix.downloader.base import ( + DEFAULT_DOWNLOADER_EXTRA_CONFIG, + BaseDownloader, + DownloadExtraConfig, +) from aibrix.downloader.entity import RemoteSource, get_local_download_paths from aibrix.downloader.s3 import S3BaseDownloader from aibrix.downloader.utils import ( @@ -53,16 +57,19 @@ def __init__( self, model_uri, model_name: Optional[str] = None, + download_extra_config: DownloadExtraConfig = DEFAULT_DOWNLOADER_EXTRA_CONFIG, enable_progress_bar: bool = False, ): if model_name is None: model_name = infer_model_name(model_uri) logger.info(f"model_name is not set, using `{model_name}` as model_name") - ak = envs.DOWNLOADER_TOS_ACCESS_KEY or "" - sk = envs.DOWNLOADER_TOS_SECRET_KEY or "" - endpoint = envs.DOWNLOADER_TOS_ENDPOINT or "" - region = envs.DOWNLOADER_TOS_REGION or "" + ak = self.download_extra_config.ak or envs.DOWNLOADER_TOS_ACCESS_KEY or "" + sk = self.download_extra_config.sk or envs.DOWNLOADER_TOS_SECRET_KEY or "" + endpoint = ( + self.download_extra_config.endpoint or envs.DOWNLOADER_TOS_ENDPOINT or "" + ) + region = self.download_extra_config.region or envs.DOWNLOADER_TOS_REGION or "" enable_crc = envs.DOWNLOADER_TOS_ENABLE_CRC bucket_name, bucket_path = _parse_bucket_info_from_uri(model_uri) @@ -75,6 +82,7 @@ def __init__( model_name=model_name, bucket_path=bucket_path, bucket_name=bucket_name, + download_extra_config=download_extra_config, enable_progress_bar=enable_progress_bar, ) # type: ignore @@ -142,12 +150,12 @@ def download( if not need_to_download(local_file, meta_data_file, file_size, etag): return + num_threads = ( + self.download_extra_config.num_threads or envs.DOWNLOADER_NUM_THREADS + ) + task_num = num_threads if enable_range else 1 - task_num = envs.DOWNLOADER_NUM_THREADS if enable_range else 1 - - download_kwargs = {} - if envs.DOWNLOADER_PART_CHUNKSIZE is not None: - download_kwargs["part_size"] = envs.DOWNLOADER_PART_CHUNKSIZE + download_kwargs = {"part_size": self.download_extra_config.part_chunksize} # download file total_length = meta_data.content_length @@ -187,12 +195,14 @@ def __init__( self, model_uri, model_name: Optional[str] = None, + download_extra_config: DownloadExtraConfig = DEFAULT_DOWNLOADER_EXTRA_CONFIG, enable_progress_bar: bool = False, ): super().__init__( scheme="tos", model_uri=model_uri, model_name=model_name, + download_extra_config=download_extra_config, enable_progress_bar=enable_progress_bar, ) # type: ignore @@ -213,8 +223,16 @@ def _valid_config(self): def _get_auth_config(self) -> Dict[str, Optional[str]]: return { - "region_name": envs.DOWNLOADER_TOS_REGION or "", - "endpoint_url": envs.DOWNLOADER_TOS_ENDPOINT or "", - "aws_access_key_id": envs.DOWNLOADER_TOS_ACCESS_KEY or "", - "aws_secret_access_key": envs.DOWNLOADER_TOS_SECRET_KEY or "", + "region_name": self.download_extra_config.region + or envs.DOWNLOADER_TOS_REGION + or "", + "endpoint_url": self.download_extra_config.endpoint + or envs.DOWNLOADER_TOS_ENDPOINT + or "", + "aws_access_key_id": self.download_extra_config.ak + or envs.DOWNLOADER_TOS_ACCESS_KEY + or "", + "aws_secret_access_key": self.download_extra_config.sk + or envs.DOWNLOADER_TOS_SECRET_KEY + or "", } From c2dfe6b68617d25d7cf6c591f1f7caedd22af148 Mon Sep 17 00:00:00 2001 From: brosoul Date: Mon, 23 Dec 2024 19:47:19 +0800 Subject: [PATCH 2/6] refact: replace assert with Exception --- python/aibrix/aibrix/common/__init__.py | 13 +++++ python/aibrix/aibrix/common/errors.py | 56 +++++++++++++++++++ .../aibrix/aibrix/downloader/huggingface.py | 21 ++++--- python/aibrix/aibrix/downloader/s3.py | 48 ++++++++++------ python/aibrix/aibrix/downloader/tos.py | 37 ++++-------- 5 files changed, 124 insertions(+), 51 deletions(-) create mode 100644 python/aibrix/aibrix/common/__init__.py create mode 100644 python/aibrix/aibrix/common/errors.py diff --git a/python/aibrix/aibrix/common/__init__.py b/python/aibrix/aibrix/common/__init__.py new file mode 100644 index 00000000..6461ec1a --- /dev/null +++ b/python/aibrix/aibrix/common/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/aibrix/aibrix/common/errors.py b/python/aibrix/aibrix/common/errors.py new file mode 100644 index 00000000..b3b151ad --- /dev/null +++ b/python/aibrix/aibrix/common/errors.py @@ -0,0 +1,56 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + + +class InvalidArgumentError(ValueError): + pass + + +class ArgNotCongiuredError(InvalidArgumentError): + def __init__(self, arg_name: str, arg_source: Optional[str] = None): + self.arg_name = arg_name + self.message = f"Argument `{arg_name}` is not configured" + ( + f" please check {arg_source}" if arg_source else "" + ) + super().__init__(self.message) + + def __str__(self): + return self.message + + +class ArgNotFormatError(InvalidArgumentError): + def __init__(self, arg_name: str, expected_format: str): + self.arg_name = arg_name + self.message = ( + f"Argument `{arg_name}` is not in the expected format: {expected_format}" + ) + super().__init__(self.message) + + def __str__(self): + return self.message + + +class ModelNotFoundError(Exception): + def __init__(self, model_uri: str, detail_msg: Optional[str] = None): + self.model_uri = model_uri + self.message = f"Model not found at URI: {model_uri}" + ( + f"\nDetails: {detail_msg}" if detail_msg else "" + ) + super().__init__(self.message) + + def __str__(self): + return self.message diff --git a/python/aibrix/aibrix/downloader/huggingface.py b/python/aibrix/aibrix/downloader/huggingface.py index 3761ff2c..e3f0a226 100644 --- a/python/aibrix/aibrix/downloader/huggingface.py +++ b/python/aibrix/aibrix/downloader/huggingface.py @@ -18,6 +18,11 @@ from huggingface_hub import HfApi, hf_hub_download, snapshot_download from aibrix import envs +from aibrix.common.errors import ( + ArgNotCongiuredError, + ArgNotFormatError, + ModelNotFoundError, +) from aibrix.downloader.base import ( DEFAULT_DOWNLOADER_EXTRA_CONFIG, BaseDownloader, @@ -77,14 +82,14 @@ def __init__( ) def _valid_config(self): - assert ( - len(self.model_uri.split("/")) == 2 - ), "Model uri must be in `repo/name` format." - assert self.bucket_name is None, "Bucket name is empty in HuggingFace." - assert self.model_name is not None, "Model name is not set." - assert self.hf_api.repo_exists( - repo_id=self.model_uri - ), f"Model {self.model_uri} not exist." + if len(self.model_uri.split("/")) != 2: + raise ArgNotFormatError(arg_name="model_uri", expected_format="repo/name") + + if self.model_name is None: + raise ArgNotCongiuredError(arg_name="model_name", arg_source="--model-name") + + if not self.hf_api.repo_exists(repo_id=self.model_uri): + raise ModelNotFoundError(model_uri=self.model_uri) def _is_directory(self) -> bool: """Check if model_uri is a directory. diff --git a/python/aibrix/aibrix/downloader/s3.py b/python/aibrix/aibrix/downloader/s3.py index 975e51f5..880e76e6 100644 --- a/python/aibrix/aibrix/downloader/s3.py +++ b/python/aibrix/aibrix/downloader/s3.py @@ -24,6 +24,7 @@ from tqdm import tqdm from aibrix import envs +from aibrix.common.errors import ArgNotCongiuredError, ModelNotFoundError from aibrix.downloader.base import ( DEFAULT_DOWNLOADER_EXTRA_CONFIG, BaseDownloader, @@ -73,7 +74,9 @@ def __init__( ) max_pool_connections = ( - _num_threads if _num_threads > MAX_POOL_CONNECTIONS else MAX_POOL_CONNECTIONS + _num_threads + if _num_threads > MAX_POOL_CONNECTIONS + else MAX_POOL_CONNECTIONS ) client_config = Config( s3={"addressing_style": "virtual"}, @@ -93,6 +96,24 @@ def __init__( enable_progress_bar=enable_progress_bar, ) # type: ignore + def _valid_config(self): + if self.model_name is None or self.model_name == "": + raise ArgNotCongiuredError(arg_name="model_name", arg_source="--model-name") + + if self.bucket_name is None or self.bucket_name == "": + raise ArgNotCongiuredError(arg_name="bucket_name", arg_source="--model-uri") + + if self.bucket_path is None or self.bucket_path == "": + raise ArgNotCongiuredError(arg_name="bucket_path", arg_source="--model-uri") + + try: + self.client.head_bucket(Bucket=self.bucket_name) + except Exception as e: + logger.error( + f"Bucket {self.bucket_name} not exist in {self.model_uri}\nFor {e}" + ) + raise ModelNotFoundError(model_uri=self.model_uri, detail_msg=str(e)) + @abstractmethod def _get_auth_config(self) -> Dict[str, Optional[str]]: """Get auth config for S3 client. @@ -222,28 +243,19 @@ def __init__( enable_progress_bar=enable_progress_bar, ) # type: ignore - def _valid_config(self): - assert ( - self.model_name is not None and self.model_name != "" - ), "S3 model name is not set, please check `--model-name`." - assert ( - self.bucket_name is not None and self.bucket_name != "" - ), "S3 bucket name is not set." - assert ( - self.bucket_path is not None and self.bucket_path != "" - ), "S3 bucket path is not set." - try: - self.client.head_bucket(Bucket=self.bucket_name) - except Exception as e: - assert False, f"S3 bucket {self.bucket_name} not exist for {e}." - def _get_auth_config(self) -> Dict[str, Optional[str]]: ak, sk = ( self.download_extra_config.ak or envs.DOWNLOADER_AWS_ACCESS_KEY_ID, self.download_extra_config.sk or envs.DOWNLOADER_AWS_SECRET_ACCESS_KEY, ) - assert ak is not None and ak != "", "`AWS_ACCESS_KEY_ID` is not set." - assert sk is not None and sk != "", "`AWS_SECRET_ACCESS_KEY` is not set." + if ak is None or ak == "": + raise ArgNotCongiuredError( + arg_name="ak", arg_source="--download-extra-config" + ) + if sk is None or sk == "": + raise ArgNotCongiuredError( + arg_name="sk", arg_source="--download-extra-config" + ) return { "region_name": self.download_extra_config.region diff --git a/python/aibrix/aibrix/downloader/tos.py b/python/aibrix/aibrix/downloader/tos.py index 6979257d..87597fa9 100644 --- a/python/aibrix/aibrix/downloader/tos.py +++ b/python/aibrix/aibrix/downloader/tos.py @@ -23,6 +23,7 @@ from tqdm import tqdm from aibrix import envs +from aibrix.common.errors import ArgNotCongiuredError, ModelNotFoundError from aibrix.downloader.base import ( DEFAULT_DOWNLOADER_EXTRA_CONFIG, BaseDownloader, @@ -87,19 +88,20 @@ def __init__( ) # type: ignore def _valid_config(self): - assert ( - self.model_name is not None and self.model_name != "" - ), "TOS model name is not set, please check `--model-name`." - assert ( - self.bucket_name is not None and self.bucket_name != "" - ), "TOS bucket name is not set." - assert ( - self.bucket_path is not None and self.bucket_path != "" - ), "TOS bucket path is not set." + if self.model_name is None or self.model_name == "": + raise ArgNotCongiuredError(arg_name="model_name", arg_source="--model-name") + + if self.bucket_name is None or self.bucket_name == "": + raise ArgNotCongiuredError(arg_name="bucket_name", arg_source="--model-uri") + + if self.bucket_path is None or self.bucket_path == "": + raise ArgNotCongiuredError(arg_name="bucket_path", arg_source="--model-uri") + try: self.client.head_bucket(self.bucket_name) except Exception as e: - assert False, f"TOS bucket {self.bucket_name} not exist for {e}." + logger.error(f"TOS bucket {self.bucket_name} not exist for {e}") + raise ModelNotFoundError(model_uri=self.model_uri, detail_msg=str(e)) @lru_cache() def _is_directory(self) -> bool: @@ -206,21 +208,6 @@ def __init__( enable_progress_bar=enable_progress_bar, ) # type: ignore - def _valid_config(self): - assert ( - self.model_name is not None and self.model_name != "" - ), "TOS model name is not set, please check `--model-name`." - assert ( - self.bucket_name is not None and self.bucket_name != "" - ), "TOS bucket name is not set." - assert ( - self.bucket_path is not None and self.bucket_path != "" - ), "TOS bucket path is not set." - try: - self.client.head_bucket(Bucket=self.bucket_name) - except Exception as e: - assert False, f"TOS bucket {self.bucket_name} not exist for {e}." - def _get_auth_config(self) -> Dict[str, Optional[str]]: return { "region_name": self.download_extra_config.region From f081bfa93caf9aa6b9ec7d9b52a8b152763944e2 Mon Sep 17 00:00:00 2001 From: brosoul Date: Tue, 24 Dec 2024 00:37:30 +0800 Subject: [PATCH 3/6] feat: add model management api --- python/aibrix/aibrix/app.py | 22 ++++ python/aibrix/aibrix/downloader/base.py | 8 +- python/aibrix/aibrix/downloader/entity.py | 20 ++++ python/aibrix/aibrix/downloader/s3.py | 9 +- python/aibrix/aibrix/downloader/tos.py | 7 +- python/aibrix/aibrix/openapi/model.py | 121 ++++++++++++++++++++++ python/aibrix/aibrix/openapi/protocol.py | 29 +++++- 7 files changed, 207 insertions(+), 9 deletions(-) create mode 100644 python/aibrix/aibrix/openapi/model.py diff --git a/python/aibrix/aibrix/app.py b/python/aibrix/aibrix/app.py index d5d07c7a..3a8f51c8 100644 --- a/python/aibrix/aibrix/app.py +++ b/python/aibrix/aibrix/app.py @@ -3,6 +3,7 @@ import shutil import time from pathlib import Path +from typing import Optional from urllib.parse import urljoin import uvicorn @@ -24,8 +25,11 @@ REGISTRY, ) from aibrix.openapi.engine.base import InferenceEngine, get_inference_engine +from aibrix.openapi.model import ModelManager from aibrix.openapi.protocol import ( + DownloadModelRequest, ErrorResponse, + ListModelRequest, LoadLoraAdapterRequest, UnloadLoraAdapterRequest, ) @@ -120,6 +124,24 @@ async def unload_lora_adapter(request: UnloadLoraAdapterRequest, raw_request: Re return Response(status_code=200, content=response) +@router.post("/v1/model/download") +async def download_model(request: DownloadModelRequest): + response = await ModelManager.model_download(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), status_code=response.code) + + return JSONResponse(status_code=200, content=response.model_dump()) + + +@router.get("/v1/model/list") +async def list_model(request: Optional[ListModelRequest] = None): + response = await ModelManager.model_list(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), status_code=response.code) + + return JSONResponse(status_code=200, content=response.model_dump()) + + @router.get("/healthz") async def liveness_check(): # Simply return a 200 status for liveness check diff --git a/python/aibrix/aibrix/downloader/base.py b/python/aibrix/aibrix/downloader/base.py index 649e2e27..cbb221ec 100644 --- a/python/aibrix/aibrix/downloader/base.py +++ b/python/aibrix/aibrix/downloader/base.py @@ -18,9 +18,10 @@ from concurrent.futures import ThreadPoolExecutor, wait from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, List, Optional +from typing import ClassVar, Dict, List, Optional from aibrix import envs +from aibrix.downloader.entity import RemoteSource from aibrix.logger import init_logger logger = init_logger(__name__) @@ -68,6 +69,7 @@ class BaseDownloader(ABC): default_factory=DownloadExtraConfig ) enable_progress_bar: bool = False + _source: ClassVar[RemoteSource] = RemoteSource.UNKNOWN def __post_init__(self): # valid downloader config @@ -81,6 +83,10 @@ def __post_init__(self): self.download_extra_config.force_download or envs.DOWNLOADER_FORCE_DOWNLOAD ) + @property + def source(self) -> RemoteSource: + return self._source + @abstractmethod def _valid_config(self): pass diff --git a/python/aibrix/aibrix/downloader/entity.py b/python/aibrix/aibrix/downloader/entity.py index e76cdb1f..2e700b06 100644 --- a/python/aibrix/aibrix/downloader/entity.py +++ b/python/aibrix/aibrix/downloader/entity.py @@ -31,6 +31,10 @@ class RemoteSource(Enum): S3 = "s3" TOS = "tos" HUGGINGFACE = "huggingface" + UNKNOWN = "unknown" + + def __str__(self): + return self.value class FileDownloadStatus(Enum): @@ -39,13 +43,20 @@ class FileDownloadStatus(Enum): NO_OPERATION = "no_operation" # Interrupted from downloading UNKNOWN = "unknown" + def __str__(self): + return self.value + class ModelDownloadStatus(Enum): + NOT_EXIST = "not_exist" DOWNLOADING = "downloading" DOWNLOADED = "downloaded" NO_OPERATION = "no_operation" # Interrupted from downloading UNKNOWN = "unknown" + def __str__(self): + return self.value + @dataclass class DownloadFile: @@ -125,6 +136,10 @@ def status(self): return ModelDownloadStatus.UNKNOWN + @property + def model_root_path(self) -> Path: + return Path(self.local_path).joinpath(self.model_name) + @classmethod def infer_from_model_path( cls, local_path: Path, model_name: str, source: RemoteSource @@ -132,6 +147,11 @@ def infer_from_model_path( assert source is not None model_base_dir = Path(local_path).joinpath(model_name) + + # model not exists + if not model_base_dir.exists(): + return None + cache_sub_dir = (DOWNLOAD_CACHE_DIR % source.value).strip("/") cache_dir = Path(model_base_dir).joinpath(cache_sub_dir) lock_files = list(Path(cache_dir).glob("*.lock")) diff --git a/python/aibrix/aibrix/downloader/s3.py b/python/aibrix/aibrix/downloader/s3.py index 880e76e6..1621cc7c 100644 --- a/python/aibrix/aibrix/downloader/s3.py +++ b/python/aibrix/aibrix/downloader/s3.py @@ -15,7 +15,7 @@ from contextlib import nullcontext from functools import lru_cache from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import ClassVar, Dict, List, Optional, Tuple from urllib.parse import urlparse import boto3 @@ -50,7 +50,7 @@ def _parse_bucket_info_from_uri(uri: str, scheme: str = "s3") -> Tuple[str, str] class S3BaseDownloader(BaseDownloader): - _source: RemoteSource = RemoteSource.S3 + _source: ClassVar[RemoteSource] = RemoteSource.S3 def __init__( self, @@ -64,6 +64,7 @@ def __init__( model_name = infer_model_name(model_uri) logger.info(f"model_name is not set, using `{model_name}` as model_name") + self.download_extra_config = download_extra_config auth_config = self._get_auth_config() bucket_name, bucket_path = _parse_bucket_info_from_uri(model_uri, scheme=scheme) @@ -92,7 +93,7 @@ def __init__( model_name=model_name, bucket_path=bucket_path, bucket_name=bucket_name, - extra_download_config=download_extra_config, + download_extra_config=download_extra_config, enable_progress_bar=enable_progress_bar, ) # type: ignore @@ -226,7 +227,7 @@ def download_progress(bytes_transferred): class S3Downloader(S3BaseDownloader): - _source: RemoteSource = RemoteSource.S3 + _source: ClassVar[RemoteSource] = RemoteSource.S3 def __init__( self, diff --git a/python/aibrix/aibrix/downloader/tos.py b/python/aibrix/aibrix/downloader/tos.py index 87597fa9..99a3e646 100644 --- a/python/aibrix/aibrix/downloader/tos.py +++ b/python/aibrix/aibrix/downloader/tos.py @@ -15,7 +15,7 @@ from contextlib import nullcontext from functools import lru_cache from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import ClassVar, Dict, List, Optional, Tuple from urllib.parse import urlparse import tos @@ -52,7 +52,7 @@ def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]: class TOSDownloaderV1(BaseDownloader): - _source: RemoteSource = RemoteSource.TOS + _source: ClassVar[RemoteSource] = RemoteSource.TOS def __init__( self, @@ -65,6 +65,7 @@ def __init__( model_name = infer_model_name(model_uri) logger.info(f"model_name is not set, using `{model_name}` as model_name") + self.download_extra_config = download_extra_config ak = self.download_extra_config.ak or envs.DOWNLOADER_TOS_ACCESS_KEY or "" sk = self.download_extra_config.sk or envs.DOWNLOADER_TOS_SECRET_KEY or "" endpoint = ( @@ -191,7 +192,7 @@ def download_progress( class TOSDownloaderV2(S3BaseDownloader): - _source: RemoteSource = RemoteSource.TOS + _source: ClassVar[RemoteSource] = RemoteSource.TOS def __init__( self, diff --git a/python/aibrix/aibrix/openapi/model.py b/python/aibrix/aibrix/openapi/model.py new file mode 100644 index 00000000..c89e1d3c --- /dev/null +++ b/python/aibrix/aibrix/openapi/model.py @@ -0,0 +1,121 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from multiprocessing import Process +from pathlib import Path +from typing import Optional, Union + +from aibrix import envs +from aibrix.common.errors import ( + InvalidArgumentError, + ModelNotFoundError, +) +from aibrix.downloader import download_model +from aibrix.downloader.base import get_downloader +from aibrix.downloader.entity import ( + DownloadModel, + ModelDownloadStatus, +) +from aibrix.openapi.protocol import ( + DownloadModelRequest, + ErrorResponse, + ListModelRequest, + ListModelResponse, + ModelStatusCard, +) + + +class ModelManager: + @staticmethod + async def model_download( + request: DownloadModelRequest, + ) -> Union[ErrorResponse, ModelStatusCard]: + model_uri = request.model_uri + model_name = request.model_name + local_dir = request.local_dir or envs.DOWNLOADER_LOCAL_DIR + download_extra_config = request.download_extra_config + try: + downloader = get_downloader( + model_uri, model_name, download_extra_config, False + ) + except InvalidArgumentError as e: + return ErrorResponse( + message=str(e), + type="InvalidArgumentError", + code=HTTPStatus.UNPROCESSABLE_ENTITY.value, + ) + except ModelNotFoundError as e: + return ErrorResponse( + message=str(e), + type="ModelNotFoundError", + code=HTTPStatus.NOT_FOUND.value, + ) + + except Exception as e: + return ErrorResponse( + message=str(e), + type="InternalError", + code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + # Infer the model_name from model_uri and source type + model_name = downloader.model_name + local_path = Path(local_dir) + model = DownloadModel.infer_from_model_path( + local_path=local_path, + model_name=model_name, + source=downloader.source, + ) + + model_path = local_path.joinpath(downloader.model_name_path) + model_status = model.status if model else ModelDownloadStatus.NOT_EXIST + if model_status in [ + ModelDownloadStatus.DOWNLOADED, + ModelDownloadStatus.DOWNLOADING, + ]: + return ModelStatusCard( + model_name=model_name, + model_root_path=str(model_path), + model_status=str(model_status), + source=str(downloader.source), + ) + else: + # Start to download the model in background + Process(target=download_model, args=(model_uri, local_dir, model_name, download_extra_config)).start() + return ModelStatusCard( + model_name=model_name, + model_root_path=str(model_path), + model_status=str(ModelDownloadStatus.DOWNLOADING), + source=str(downloader.source), + ) + + @staticmethod + async def model_list( + request: Optional[ListModelRequest], + ) -> Union[ErrorResponse, ListModelResponse]: + local_dir = envs.DOWNLOADER_LOCAL_DIR if request is None else request.local_dir + local_path = Path(local_dir) + models = DownloadModel.infer_from_local_path(local_path) + cards = [] + cards = [ + ModelStatusCard( + model_name=model.model_name, + model_root_path=str(model.model_root_path), + model_status=str(model.status), + source=str(model.model_source), + ) for model in models + ] + response = ListModelResponse(data=cards) + return response diff --git a/python/aibrix/aibrix/openapi/protocol.py b/python/aibrix/aibrix/openapi/protocol.py index 44fa5620..fe200ef6 100644 --- a/python/aibrix/aibrix/openapi/protocol.py +++ b/python/aibrix/aibrix/openapi/protocol.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Dict, List, Optional from pydantic import BaseModel, ConfigDict, Field @@ -21,6 +21,10 @@ class NoExtraBaseModel(BaseModel): # The class does not allow extra fields model_config = ConfigDict(extra="forbid") +class NoProtectdBaseModel(BaseModel): + # The class does not allow extra fields + model_config = ConfigDict(extra="forbid", protected_namespaces=()) + class ErrorResponse(NoExtraBaseModel): object: str = "error" @@ -38,3 +42,26 @@ class LoadLoraAdapterRequest(NoExtraBaseModel): class UnloadLoraAdapterRequest(NoExtraBaseModel): lora_name: str lora_int_id: Optional[int] = Field(default=None) + + +class DownloadModelRequest(NoProtectdBaseModel): + model_uri: str + local_dir: Optional[str] = None + model_name: Optional[str] = None + download_extra_config: Optional[Dict] = None + + +class ModelStatusCard(NoProtectdBaseModel): + model_name: str + model_root_path: str + source: str + model_status: str + + +class ListModelRequest(NoExtraBaseModel): + local_dir: str + + +class ListModelResponse(NoExtraBaseModel): + object: str = "list" + data: List[ModelStatusCard] = Field(default_factory=list) From 82b6c7460303e5be5381328d11435ba75f6a7563 Mon Sep 17 00:00:00 2001 From: brosoul Date: Tue, 24 Dec 2024 01:04:17 +0800 Subject: [PATCH 4/6] fix: test cases --- .../aibrix/aibrix/downloader/huggingface.py | 2 ++ .../tests/downloader/test_downloader_hf.py | 16 ++++++++----- .../tests/downloader/test_downloader_s3.py | 24 +++++++++++-------- .../tests/downloader/test_downloader_tos.py | 16 ++++++++----- .../downloader/test_downloader_tos_v1.py | 16 ++++++++----- 5 files changed, 46 insertions(+), 28 deletions(-) diff --git a/python/aibrix/aibrix/downloader/huggingface.py b/python/aibrix/aibrix/downloader/huggingface.py index e3f0a226..ea946a4b 100644 --- a/python/aibrix/aibrix/downloader/huggingface.py +++ b/python/aibrix/aibrix/downloader/huggingface.py @@ -48,6 +48,8 @@ def __init__( if model_name is None: model_name = _parse_model_name_from_uri(model_uri) logger.info(f"model_name is not set, using `{model_name}` as model_name") + + self.download_extra_config = download_extra_config self.hf_token = self.download_extra_config.hf_token or envs.DOWNLOADER_HF_TOKEN self.hf_endpoint = ( diff --git a/python/aibrix/tests/downloader/test_downloader_hf.py b/python/aibrix/tests/downloader/test_downloader_hf.py index fe0118bc..fb0e815d 100644 --- a/python/aibrix/tests/downloader/test_downloader_hf.py +++ b/python/aibrix/tests/downloader/test_downloader_hf.py @@ -14,6 +14,10 @@ import pytest +from aibrix.common.errors import ( + ArgNotFormatError, + ModelNotFoundError, +) from aibrix.downloader.base import get_downloader from aibrix.downloader.huggingface import HuggingFaceDownloader @@ -24,16 +28,16 @@ def test_get_downloader_hf(): def test_get_downloader_hf_not_exist(): - with pytest.raises(AssertionError) as exception: + with pytest.raises(ModelNotFoundError) as exception: get_downloader("not_exsit_path/model") - assert "not exist" in str(exception.value) + assert "Model not found" in str(exception.value) def test_get_downloader_hf_invalid_uri(): - with pytest.raises(AssertionError) as exception: + with pytest.raises(ArgNotFormatError) as exception: get_downloader("single_field") - assert "Model uri must be in `repo/name` format." in str(exception.value) + assert "not in the expected format: repo/name" in str(exception.value) - with pytest.raises(AssertionError) as exception: + with pytest.raises(ArgNotFormatError) as exception: get_downloader("multi/filed/repo") - assert "Model uri must be in `repo/name` format." in str(exception.value) + assert "not in the expected format: repo/name" in str(exception.value) diff --git a/python/aibrix/tests/downloader/test_downloader_s3.py b/python/aibrix/tests/downloader/test_downloader_s3.py index ec6501da..2071f985 100644 --- a/python/aibrix/tests/downloader/test_downloader_s3.py +++ b/python/aibrix/tests/downloader/test_downloader_s3.py @@ -16,6 +16,10 @@ import pytest +from aibrix.common.errors import ( + ArgNotCongiuredError, + ModelNotFoundError, +) from aibrix.downloader.base import get_downloader from aibrix.downloader.s3 import S3Downloader @@ -65,9 +69,9 @@ def test_get_downloader_s3(mock_boto3): def test_get_downloader_s3_path_not_exist(mock_boto3): mock_not_exsit_boto3(mock_boto3) - with pytest.raises(AssertionError) as exception: + with pytest.raises(ModelNotFoundError) as exception: get_downloader("s3://bucket/not_exsit_path") - assert "not exist" in str(exception.value) + assert "Model not found" in str(exception.value) @mock.patch(ENVS_MODULE, env_group) @@ -77,9 +81,9 @@ def test_get_downloader_s3_path_empty(mock_boto3): # Bucket name and path both are empty, # will first assert the name - with pytest.raises(AssertionError) as exception: + with pytest.raises(ArgNotCongiuredError) as exception: get_downloader("s3://") - assert "S3 bucket name is not set." in str(exception.value) + assert "`bucket_name` is not configured" in str(exception.value) @mock.patch(ENVS_MODULE, env_group) @@ -88,9 +92,9 @@ def test_get_downloader_s3_path_empty_path(mock_boto3): mock_exsit_boto3(mock_boto3) # bucket path is empty - with pytest.raises(AssertionError) as exception: + with pytest.raises(ArgNotCongiuredError) as exception: get_downloader("s3://bucket/") - assert "S3 bucket path is not set." in str(exception.value) + assert "`bucket_path` is not configured" in str(exception.value) @mock.patch(ENVS_MODULE, env_no_ak) @@ -98,9 +102,9 @@ def test_get_downloader_s3_path_empty_path(mock_boto3): def test_get_downloader_s3_no_ak(mock_boto3): mock_exsit_boto3(mock_boto3) - with pytest.raises(AssertionError) as exception: + with pytest.raises(ArgNotCongiuredError) as exception: get_downloader("s3://bucket/") - assert "`AWS_ACCESS_KEY_ID` is not set." in str(exception.value) + assert "`ak` is not configured" in str(exception.value) @mock.patch(ENVS_MODULE, env_no_sk) @@ -108,6 +112,6 @@ def test_get_downloader_s3_no_ak(mock_boto3): def test_get_downloader_s3_no_sk(mock_boto3): mock_exsit_boto3(mock_boto3) - with pytest.raises(AssertionError) as exception: + with pytest.raises(ArgNotCongiuredError) as exception: get_downloader("s3://bucket/") - assert "`AWS_SECRET_ACCESS_KEY` is not set." in str(exception.value) + assert "`sk` is not configured" in str(exception.value) diff --git a/python/aibrix/tests/downloader/test_downloader_tos.py b/python/aibrix/tests/downloader/test_downloader_tos.py index 5398f70b..8e74a8bd 100644 --- a/python/aibrix/tests/downloader/test_downloader_tos.py +++ b/python/aibrix/tests/downloader/test_downloader_tos.py @@ -16,6 +16,10 @@ import pytest +from aibrix.common.errors import ( + ArgNotCongiuredError, + ModelNotFoundError, +) from aibrix.downloader.base import get_downloader from aibrix.downloader.tos import TOSDownloaderV2 @@ -54,9 +58,9 @@ def test_get_downloader_tos(mock_boto3): def test_get_downloader_tos_path_not_exist(mock_boto3): mock_not_exsit_tos(mock_boto3) - with pytest.raises(AssertionError) as exception: + with pytest.raises(ModelNotFoundError) as exception: get_downloader("tos://bucket/not_exsit_path") - assert "not exist" in str(exception.value) + assert "Model not found" in str(exception.value) @mock.patch(ENVS_MODULE, env_group) @@ -66,9 +70,9 @@ def test_get_downloader_tos_path_empty(mock_boto3): # Bucket name and path both are empty, # will first assert the name - with pytest.raises(AssertionError) as exception: + with pytest.raises(ArgNotCongiuredError) as exception: get_downloader("tos://") - assert "TOS bucket name is not set." in str(exception.value) + assert "`bucket_name` is not configured" in str(exception.value) @mock.patch(ENVS_MODULE, env_group) @@ -77,6 +81,6 @@ def test_get_downloader_tos_path_empty_path(mock_boto3): mock_exsit_tos(mock_boto3) # bucket path is empty - with pytest.raises(AssertionError) as exception: + with pytest.raises(ArgNotCongiuredError) as exception: get_downloader("tos://bucket/") - assert "TOS bucket path is not set." in str(exception.value) + assert "`bucket_path` is not configured" in str(exception.value) diff --git a/python/aibrix/tests/downloader/test_downloader_tos_v1.py b/python/aibrix/tests/downloader/test_downloader_tos_v1.py index f71900a1..18528226 100644 --- a/python/aibrix/tests/downloader/test_downloader_tos_v1.py +++ b/python/aibrix/tests/downloader/test_downloader_tos_v1.py @@ -16,6 +16,10 @@ import pytest +from aibrix.common.errors import ( + ArgNotCongiuredError, + ModelNotFoundError, +) from aibrix.downloader.base import get_downloader from aibrix.downloader.tos import TOSDownloaderV1 @@ -56,9 +60,9 @@ def test_get_downloader_tos(mock_tos): def test_get_downloader_tos_path_not_exist(mock_tos): mock_not_exsit_tos(mock_tos) - with pytest.raises(AssertionError) as exception: + with pytest.raises(ModelNotFoundError) as exception: get_downloader("tos://bucket/not_exsit_path") - assert "not exist" in str(exception.value) + assert "Model not found" in str(exception.value) @mock.patch(ENVS_DOWNLOADER_TOS_VERSION, DOWNLOADER_TOS_VERSION) @@ -69,9 +73,9 @@ def test_get_downloader_tos_path_empty(mock_tos): # Bucket name and path both are empty, # will first assert the name - with pytest.raises(AssertionError) as exception: + with pytest.raises(ArgNotCongiuredError) as exception: get_downloader("tos://") - assert "TOS bucket name is not set." in str(exception.value) + assert "`bucket_name` is not configured" in str(exception.value) @mock.patch(ENVS_DOWNLOADER_TOS_VERSION, DOWNLOADER_TOS_VERSION) @@ -81,6 +85,6 @@ def test_get_downloader_tos_path_empty_path(mock_tos): mock_exsit_tos(mock_tos) # bucket path is empty - with pytest.raises(AssertionError) as exception: + with pytest.raises(ArgNotCongiuredError) as exception: get_downloader("tos://bucket/") - assert "TOS bucket path is not set." in str(exception.value) + assert "`bucket_path` is not configured" in str(exception.value) From 8e9b148a8ad0e0180fb8ba2002d9756902f9af22 Mon Sep 17 00:00:00 2001 From: brosoul Date: Tue, 24 Dec 2024 01:29:39 +0800 Subject: [PATCH 5/6] fix allow_file_suffix --- python/aibrix/aibrix/downloader/base.py | 2 +- python/aibrix/aibrix/downloader/huggingface.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/aibrix/aibrix/downloader/base.py b/python/aibrix/aibrix/downloader/base.py index cbb221ec..5b9d37e5 100644 --- a/python/aibrix/aibrix/downloader/base.py +++ b/python/aibrix/aibrix/downloader/base.py @@ -124,7 +124,7 @@ def download_directory(self, local_path: Path): # filter the directory path files = [file for file in directory_list if not file.endswith("/")] - if self.allow_file_suffix is None: + if self.allow_file_suffix is None or len(self.allow_file_suffix) == 0: logger.info(f"All files from {self.bucket_path} will be downloaded.") filtered_files = files else: diff --git a/python/aibrix/aibrix/downloader/huggingface.py b/python/aibrix/aibrix/downloader/huggingface.py index ea946a4b..2b749ab9 100644 --- a/python/aibrix/aibrix/downloader/huggingface.py +++ b/python/aibrix/aibrix/downloader/huggingface.py @@ -73,7 +73,7 @@ def __init__( # so place it after the super().__init__() call. self.allow_patterns = ( None - if self.allow_file_suffix is None + if self.allow_file_suffix is None or len(self.allow_file_suffix) == 0 else [f"*.{suffix}" for suffix in self.allow_file_suffix] ) logger.debug( From 48a13243fcd82708dc8480aaed64f489d4006403 Mon Sep 17 00:00:00 2001 From: brosoul Date: Tue, 24 Dec 2024 02:05:03 +0800 Subject: [PATCH 6/6] fix style --- python/aibrix/aibrix/downloader/huggingface.py | 2 +- python/aibrix/aibrix/openapi/model.py | 16 ++++++++++------ python/aibrix/aibrix/openapi/protocol.py | 1 + 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/python/aibrix/aibrix/downloader/huggingface.py b/python/aibrix/aibrix/downloader/huggingface.py index 2b749ab9..55939141 100644 --- a/python/aibrix/aibrix/downloader/huggingface.py +++ b/python/aibrix/aibrix/downloader/huggingface.py @@ -48,7 +48,7 @@ def __init__( if model_name is None: model_name = _parse_model_name_from_uri(model_uri) logger.info(f"model_name is not set, using `{model_name}` as model_name") - + self.download_extra_config = download_extra_config self.hf_token = self.download_extra_config.hf_token or envs.DOWNLOADER_HF_TOKEN diff --git a/python/aibrix/aibrix/openapi/model.py b/python/aibrix/aibrix/openapi/model.py index c89e1d3c..e2b0b9dc 100644 --- a/python/aibrix/aibrix/openapi/model.py +++ b/python/aibrix/aibrix/openapi/model.py @@ -93,7 +93,10 @@ async def model_download( ) else: # Start to download the model in background - Process(target=download_model, args=(model_uri, local_dir, model_name, download_extra_config)).start() + Process( + target=download_model, + args=(model_uri, local_dir, model_name, download_extra_config), + ).start() return ModelStatusCard( model_name=model_name, model_root_path=str(model_path), @@ -111,11 +114,12 @@ async def model_list( cards = [] cards = [ ModelStatusCard( - model_name=model.model_name, - model_root_path=str(model.model_root_path), - model_status=str(model.status), - source=str(model.model_source), - ) for model in models + model_name=model.model_name, + model_root_path=str(model.model_root_path), + model_status=str(model.status), + source=str(model.model_source), + ) + for model in models ] response = ListModelResponse(data=cards) return response diff --git a/python/aibrix/aibrix/openapi/protocol.py b/python/aibrix/aibrix/openapi/protocol.py index fe200ef6..cdb58e72 100644 --- a/python/aibrix/aibrix/openapi/protocol.py +++ b/python/aibrix/aibrix/openapi/protocol.py @@ -21,6 +21,7 @@ class NoExtraBaseModel(BaseModel): # The class does not allow extra fields model_config = ConfigDict(extra="forbid") + class NoProtectdBaseModel(BaseModel): # The class does not allow extra fields model_config = ConfigDict(extra="forbid", protected_namespaces=())