Skip to content

Commit

Permalink
[Feat] Add runtime model management api (#540)
Browse files Browse the repository at this point in the history
* refact: add download extra config into downloader

* refact: replace assert with Exception

* feat: add model management api

* fix: test cases

* fix allow_file_suffix

* fix style
  • Loading branch information
brosoul authored Dec 24, 2024
1 parent b9bb65a commit 0e4d76b
Show file tree
Hide file tree
Showing 16 changed files with 553 additions and 135 deletions.
22 changes: 22 additions & 0 deletions python/aibrix/aibrix/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import time
from pathlib import Path
from typing import Optional
from urllib.parse import urljoin

import uvicorn
Expand All @@ -24,8 +25,11 @@
REGISTRY,
)
from aibrix.openapi.engine.base import InferenceEngine, get_inference_engine
from aibrix.openapi.model import ModelManager
from aibrix.openapi.protocol import (
DownloadModelRequest,
ErrorResponse,
ListModelRequest,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest,
)
Expand Down Expand Up @@ -120,6 +124,24 @@ async def unload_lora_adapter(request: UnloadLoraAdapterRequest, raw_request: Re
return Response(status_code=200, content=response)


@router.post("/v1/model/download")
async def download_model(request: DownloadModelRequest):
response = await ModelManager.model_download(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), status_code=response.code)

return JSONResponse(status_code=200, content=response.model_dump())


@router.get("/v1/model/list")
async def list_model(request: Optional[ListModelRequest] = None):
response = await ModelManager.model_list(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), status_code=response.code)

return JSONResponse(status_code=200, content=response.model_dump())


@router.get("/healthz")
async def liveness_check():
# Simply return a 200 status for liveness check
Expand Down
13 changes: 13 additions & 0 deletions python/aibrix/aibrix/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
56 changes: 56 additions & 0 deletions python/aibrix/aibrix/common/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Optional


class InvalidArgumentError(ValueError):
pass


class ArgNotCongiuredError(InvalidArgumentError):
def __init__(self, arg_name: str, arg_source: Optional[str] = None):
self.arg_name = arg_name
self.message = f"Argument `{arg_name}` is not configured" + (
f" please check {arg_source}" if arg_source else ""
)
super().__init__(self.message)

def __str__(self):
return self.message


class ArgNotFormatError(InvalidArgumentError):
def __init__(self, arg_name: str, expected_format: str):
self.arg_name = arg_name
self.message = (
f"Argument `{arg_name}` is not in the expected format: {expected_format}"
)
super().__init__(self.message)

def __str__(self):
return self.message


class ModelNotFoundError(Exception):
def __init__(self, model_uri: str, detail_msg: Optional[str] = None):
self.model_uri = model_uri
self.message = f"Model not found at URI: {model_uri}" + (
f"\nDetails: {detail_msg}" if detail_msg else ""
)
super().__init__(self.message)

def __str__(self):
return self.message
7 changes: 5 additions & 2 deletions python/aibrix/aibrix/downloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Dict, Optional

from aibrix.downloader.base import get_downloader

Expand All @@ -21,6 +21,7 @@ def download_model(
model_uri: str,
local_path: Optional[str] = None,
model_name: Optional[str] = None,
download_extra_config: Optional[Dict] = None,
enable_progress_bar: bool = False,
):
"""Download model from model_uri to local_path.
Expand All @@ -30,7 +31,9 @@ def download_model(
local_path (str): local path to save model.
"""

downloader = get_downloader(model_uri, model_name, enable_progress_bar)
downloader = get_downloader(
model_uri, model_name, download_extra_config, enable_progress_bar
)
return downloader.download_model(local_path)


Expand Down
23 changes: 22 additions & 1 deletion python/aibrix/aibrix/downloader/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import argparse
import json
from typing import Dict, Optional

from aibrix.downloader import download_model


def str_to_dict(s) -> Optional[Dict]:
if s is None:
return None
try:
return json.loads(s)
except Exception as e:
raise ValueError(f"Invalid json string {s}") from e


def main():
parser = argparse.ArgumentParser(description="Download model from HuggingFace")
parser.add_argument(
Expand Down Expand Up @@ -30,9 +41,19 @@ def main():
default=False,
help="Enable download progress bar during downloading from TOS or S3",
)
parser.add_argument(
"--download-extra-config",
type=str_to_dict,
default=None,
help="Extra config for download, like auth config, parallel config, etc.",
)
args = parser.parse_args()
download_model(
args.model_uri, args.local_dir, args.model_name, args.enable_progress_bar
args.model_uri,
args.local_dir,
args.model_name,
args.download_extra_config,
args.enable_progress_bar,
)


Expand Down
82 changes: 71 additions & 11 deletions python/aibrix/aibrix/downloader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,45 @@
from concurrent.futures import ThreadPoolExecutor, wait
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
from typing import ClassVar, Dict, List, Optional

from aibrix import envs
from aibrix.downloader.entity import RemoteSource
from aibrix.logger import init_logger

logger = init_logger(__name__)


@dataclass
class DownloadExtraConfig:
"""Downloader extra config."""

# Auth config for s3 or tos
ak: Optional[str] = None
sk: Optional[str] = None
endpoint: Optional[str] = None
region: Optional[str] = None

# Auth config for huggingface
hf_endpoint: Optional[str] = None
hf_token: Optional[str] = None
hf_revision: Optional[str] = None

# parrallel config
num_threads: Optional[int] = None
max_io_queue: Optional[int] = None
io_chunksize: Optional[int] = None
part_threshold: Optional[int] = None
part_chunksize: Optional[int] = None

# other config
allow_file_suffix: Optional[List[str]] = None
force_download: Optional[bool] = None


DEFAULT_DOWNLOADER_EXTRA_CONFIG = DownloadExtraConfig()


@dataclass
class BaseDownloader(ABC):
"""Base class for downloader."""
Expand All @@ -34,15 +65,27 @@ class BaseDownloader(ABC):
model_name: str
bucket_path: str
bucket_name: Optional[str]
enable_progress_bar: bool = False
allow_file_suffix: Optional[List[str]] = field(
default_factory=lambda: envs.DOWNLOADER_ALLOW_FILE_SUFFIX
download_extra_config: DownloadExtraConfig = field(
default_factory=DownloadExtraConfig
)
enable_progress_bar: bool = False
_source: ClassVar[RemoteSource] = RemoteSource.UNKNOWN

def __post_init__(self):
# valid downloader config
self._valid_config()
self.model_name_path = self.model_name
self.allow_file_suffix = (
self.download_extra_config.allow_file_suffix
or envs.DOWNLOADER_ALLOW_FILE_SUFFIX
)
self.force_download = (
self.download_extra_config.force_download or envs.DOWNLOADER_FORCE_DOWNLOAD
)

@property
def source(self) -> RemoteSource:
return self._source

@abstractmethod
def _valid_config(self):
Expand Down Expand Up @@ -81,7 +124,7 @@ def download_directory(self, local_path: Path):
# filter the directory path
files = [file for file in directory_list if not file.endswith("/")]

if self.allow_file_suffix is None:
if self.allow_file_suffix is None or len(self.allow_file_suffix) == 0:
logger.info(f"All files from {self.bucket_path} will be downloaded.")
filtered_files = files
else:
Expand All @@ -93,7 +136,9 @@ def download_directory(self, local_path: Path):

if not self._support_range_download():
# download using multi threads
num_threads = envs.DOWNLOADER_NUM_THREADS
num_threads = (
self.download_extra_config.num_threads or envs.DOWNLOADER_NUM_THREADS
)
logger.info(
f"Downloader {self.__class__.__name__} download "
f"{len(filtered_files)} files from {self.model_uri} "
Expand Down Expand Up @@ -157,23 +202,38 @@ def __hash__(self):


def get_downloader(
model_uri: str, model_name: Optional[str] = None, enable_progress_bar: bool = False
model_uri: str,
model_name: Optional[str] = None,
download_extra_config: Optional[Dict] = None,
enable_progress_bar: bool = False,
) -> BaseDownloader:
"""Get downloader for model_uri."""
download_config: DownloadExtraConfig = (
DEFAULT_DOWNLOADER_EXTRA_CONFIG
if download_extra_config is None
else DownloadExtraConfig(**download_extra_config)
)

if re.match(envs.DOWNLOADER_S3_REGEX, model_uri):
from aibrix.downloader.s3 import S3Downloader

return S3Downloader(model_uri, model_name, enable_progress_bar)
return S3Downloader(model_uri, model_name, download_config, enable_progress_bar)
elif re.match(envs.DOWNLOADER_TOS_REGEX, model_uri):
if envs.DOWNLOADER_TOS_VERSION == "v1":
from aibrix.downloader.tos import TOSDownloaderV1

return TOSDownloaderV1(model_uri, model_name, enable_progress_bar)
return TOSDownloaderV1(
model_uri, model_name, download_config, enable_progress_bar
)
else:
from aibrix.downloader.tos import TOSDownloaderV2

return TOSDownloaderV2(model_uri, model_name, enable_progress_bar)
return TOSDownloaderV2(
model_uri, model_name, download_config, enable_progress_bar
)
else:
from aibrix.downloader.huggingface import HuggingFaceDownloader

return HuggingFaceDownloader(model_uri, model_name, enable_progress_bar)
return HuggingFaceDownloader(
model_uri, model_name, download_config, enable_progress_bar
)
20 changes: 20 additions & 0 deletions python/aibrix/aibrix/downloader/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class RemoteSource(Enum):
S3 = "s3"
TOS = "tos"
HUGGINGFACE = "huggingface"
UNKNOWN = "unknown"

def __str__(self):
return self.value


class FileDownloadStatus(Enum):
Expand All @@ -39,13 +43,20 @@ class FileDownloadStatus(Enum):
NO_OPERATION = "no_operation" # Interrupted from downloading
UNKNOWN = "unknown"

def __str__(self):
return self.value


class ModelDownloadStatus(Enum):
NOT_EXIST = "not_exist"
DOWNLOADING = "downloading"
DOWNLOADED = "downloaded"
NO_OPERATION = "no_operation" # Interrupted from downloading
UNKNOWN = "unknown"

def __str__(self):
return self.value


@dataclass
class DownloadFile:
Expand Down Expand Up @@ -125,13 +136,22 @@ def status(self):

return ModelDownloadStatus.UNKNOWN

@property
def model_root_path(self) -> Path:
return Path(self.local_path).joinpath(self.model_name)

@classmethod
def infer_from_model_path(
cls, local_path: Path, model_name: str, source: RemoteSource
) -> Optional["DownloadModel"]:
assert source is not None

model_base_dir = Path(local_path).joinpath(model_name)

# model not exists
if not model_base_dir.exists():
return None

cache_sub_dir = (DOWNLOAD_CACHE_DIR % source.value).strip("/")
cache_dir = Path(model_base_dir).joinpath(cache_sub_dir)
lock_files = list(Path(cache_dir).glob("*.lock"))
Expand Down
Loading

0 comments on commit 0e4d76b

Please sign in to comment.