Skip to content

Commit 0e4d76b

Browse files
authored
[Feat] Add runtime model management api (#540)
* refact: add download extra config into downloader * refact: replace assert with Exception * feat: add model management api * fix: test cases * fix allow_file_suffix * fix style
1 parent b9bb65a commit 0e4d76b

File tree

16 files changed

+553
-135
lines changed

16 files changed

+553
-135
lines changed

python/aibrix/aibrix/app.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import shutil
44
import time
55
from pathlib import Path
6+
from typing import Optional
67
from urllib.parse import urljoin
78

89
import uvicorn
@@ -24,8 +25,11 @@
2425
REGISTRY,
2526
)
2627
from aibrix.openapi.engine.base import InferenceEngine, get_inference_engine
28+
from aibrix.openapi.model import ModelManager
2729
from aibrix.openapi.protocol import (
30+
DownloadModelRequest,
2831
ErrorResponse,
32+
ListModelRequest,
2933
LoadLoraAdapterRequest,
3034
UnloadLoraAdapterRequest,
3135
)
@@ -120,6 +124,24 @@ async def unload_lora_adapter(request: UnloadLoraAdapterRequest, raw_request: Re
120124
return Response(status_code=200, content=response)
121125

122126

127+
@router.post("/v1/model/download")
128+
async def download_model(request: DownloadModelRequest):
129+
response = await ModelManager.model_download(request)
130+
if isinstance(response, ErrorResponse):
131+
return JSONResponse(content=response.model_dump(), status_code=response.code)
132+
133+
return JSONResponse(status_code=200, content=response.model_dump())
134+
135+
136+
@router.get("/v1/model/list")
137+
async def list_model(request: Optional[ListModelRequest] = None):
138+
response = await ModelManager.model_list(request)
139+
if isinstance(response, ErrorResponse):
140+
return JSONResponse(content=response.model_dump(), status_code=response.code)
141+
142+
return JSONResponse(status_code=200, content=response.model_dump())
143+
144+
123145
@router.get("/healthz")
124146
async def liveness_check():
125147
# Simply return a 200 status for liveness check
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 The Aibrix Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

python/aibrix/aibrix/common/errors.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2024 The Aibrix Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from typing import Optional
17+
18+
19+
class InvalidArgumentError(ValueError):
20+
pass
21+
22+
23+
class ArgNotCongiuredError(InvalidArgumentError):
24+
def __init__(self, arg_name: str, arg_source: Optional[str] = None):
25+
self.arg_name = arg_name
26+
self.message = f"Argument `{arg_name}` is not configured" + (
27+
f" please check {arg_source}" if arg_source else ""
28+
)
29+
super().__init__(self.message)
30+
31+
def __str__(self):
32+
return self.message
33+
34+
35+
class ArgNotFormatError(InvalidArgumentError):
36+
def __init__(self, arg_name: str, expected_format: str):
37+
self.arg_name = arg_name
38+
self.message = (
39+
f"Argument `{arg_name}` is not in the expected format: {expected_format}"
40+
)
41+
super().__init__(self.message)
42+
43+
def __str__(self):
44+
return self.message
45+
46+
47+
class ModelNotFoundError(Exception):
48+
def __init__(self, model_uri: str, detail_msg: Optional[str] = None):
49+
self.model_uri = model_uri
50+
self.message = f"Model not found at URI: {model_uri}" + (
51+
f"\nDetails: {detail_msg}" if detail_msg else ""
52+
)
53+
super().__init__(self.message)
54+
55+
def __str__(self):
56+
return self.message

python/aibrix/aibrix/downloader/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
15+
from typing import Dict, Optional
1616

1717
from aibrix.downloader.base import get_downloader
1818

@@ -21,6 +21,7 @@ def download_model(
2121
model_uri: str,
2222
local_path: Optional[str] = None,
2323
model_name: Optional[str] = None,
24+
download_extra_config: Optional[Dict] = None,
2425
enable_progress_bar: bool = False,
2526
):
2627
"""Download model from model_uri to local_path.
@@ -30,7 +31,9 @@ def download_model(
3031
local_path (str): local path to save model.
3132
"""
3233

33-
downloader = get_downloader(model_uri, model_name, enable_progress_bar)
34+
downloader = get_downloader(
35+
model_uri, model_name, download_extra_config, enable_progress_bar
36+
)
3437
return downloader.download_model(local_path)
3538

3639

python/aibrix/aibrix/downloader/__main__.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
import argparse
2+
import json
3+
from typing import Dict, Optional
24

35
from aibrix.downloader import download_model
46

57

8+
def str_to_dict(s) -> Optional[Dict]:
9+
if s is None:
10+
return None
11+
try:
12+
return json.loads(s)
13+
except Exception as e:
14+
raise ValueError(f"Invalid json string {s}") from e
15+
16+
617
def main():
718
parser = argparse.ArgumentParser(description="Download model from HuggingFace")
819
parser.add_argument(
@@ -30,9 +41,19 @@ def main():
3041
default=False,
3142
help="Enable download progress bar during downloading from TOS or S3",
3243
)
44+
parser.add_argument(
45+
"--download-extra-config",
46+
type=str_to_dict,
47+
default=None,
48+
help="Extra config for download, like auth config, parallel config, etc.",
49+
)
3350
args = parser.parse_args()
3451
download_model(
35-
args.model_uri, args.local_dir, args.model_name, args.enable_progress_bar
52+
args.model_uri,
53+
args.local_dir,
54+
args.model_name,
55+
args.download_extra_config,
56+
args.enable_progress_bar,
3657
)
3758

3859

python/aibrix/aibrix/downloader/base.py

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,45 @@
1818
from concurrent.futures import ThreadPoolExecutor, wait
1919
from dataclasses import dataclass, field
2020
from pathlib import Path
21-
from typing import List, Optional
21+
from typing import ClassVar, Dict, List, Optional
2222

2323
from aibrix import envs
24+
from aibrix.downloader.entity import RemoteSource
2425
from aibrix.logger import init_logger
2526

2627
logger = init_logger(__name__)
2728

2829

30+
@dataclass
31+
class DownloadExtraConfig:
32+
"""Downloader extra config."""
33+
34+
# Auth config for s3 or tos
35+
ak: Optional[str] = None
36+
sk: Optional[str] = None
37+
endpoint: Optional[str] = None
38+
region: Optional[str] = None
39+
40+
# Auth config for huggingface
41+
hf_endpoint: Optional[str] = None
42+
hf_token: Optional[str] = None
43+
hf_revision: Optional[str] = None
44+
45+
# parrallel config
46+
num_threads: Optional[int] = None
47+
max_io_queue: Optional[int] = None
48+
io_chunksize: Optional[int] = None
49+
part_threshold: Optional[int] = None
50+
part_chunksize: Optional[int] = None
51+
52+
# other config
53+
allow_file_suffix: Optional[List[str]] = None
54+
force_download: Optional[bool] = None
55+
56+
57+
DEFAULT_DOWNLOADER_EXTRA_CONFIG = DownloadExtraConfig()
58+
59+
2960
@dataclass
3061
class BaseDownloader(ABC):
3162
"""Base class for downloader."""
@@ -34,15 +65,27 @@ class BaseDownloader(ABC):
3465
model_name: str
3566
bucket_path: str
3667
bucket_name: Optional[str]
37-
enable_progress_bar: bool = False
38-
allow_file_suffix: Optional[List[str]] = field(
39-
default_factory=lambda: envs.DOWNLOADER_ALLOW_FILE_SUFFIX
68+
download_extra_config: DownloadExtraConfig = field(
69+
default_factory=DownloadExtraConfig
4070
)
71+
enable_progress_bar: bool = False
72+
_source: ClassVar[RemoteSource] = RemoteSource.UNKNOWN
4173

4274
def __post_init__(self):
4375
# valid downloader config
4476
self._valid_config()
4577
self.model_name_path = self.model_name
78+
self.allow_file_suffix = (
79+
self.download_extra_config.allow_file_suffix
80+
or envs.DOWNLOADER_ALLOW_FILE_SUFFIX
81+
)
82+
self.force_download = (
83+
self.download_extra_config.force_download or envs.DOWNLOADER_FORCE_DOWNLOAD
84+
)
85+
86+
@property
87+
def source(self) -> RemoteSource:
88+
return self._source
4689

4790
@abstractmethod
4891
def _valid_config(self):
@@ -81,7 +124,7 @@ def download_directory(self, local_path: Path):
81124
# filter the directory path
82125
files = [file for file in directory_list if not file.endswith("/")]
83126

84-
if self.allow_file_suffix is None:
127+
if self.allow_file_suffix is None or len(self.allow_file_suffix) == 0:
85128
logger.info(f"All files from {self.bucket_path} will be downloaded.")
86129
filtered_files = files
87130
else:
@@ -93,7 +136,9 @@ def download_directory(self, local_path: Path):
93136

94137
if not self._support_range_download():
95138
# download using multi threads
96-
num_threads = envs.DOWNLOADER_NUM_THREADS
139+
num_threads = (
140+
self.download_extra_config.num_threads or envs.DOWNLOADER_NUM_THREADS
141+
)
97142
logger.info(
98143
f"Downloader {self.__class__.__name__} download "
99144
f"{len(filtered_files)} files from {self.model_uri} "
@@ -157,23 +202,38 @@ def __hash__(self):
157202

158203

159204
def get_downloader(
160-
model_uri: str, model_name: Optional[str] = None, enable_progress_bar: bool = False
205+
model_uri: str,
206+
model_name: Optional[str] = None,
207+
download_extra_config: Optional[Dict] = None,
208+
enable_progress_bar: bool = False,
161209
) -> BaseDownloader:
162210
"""Get downloader for model_uri."""
211+
download_config: DownloadExtraConfig = (
212+
DEFAULT_DOWNLOADER_EXTRA_CONFIG
213+
if download_extra_config is None
214+
else DownloadExtraConfig(**download_extra_config)
215+
)
216+
163217
if re.match(envs.DOWNLOADER_S3_REGEX, model_uri):
164218
from aibrix.downloader.s3 import S3Downloader
165219

166-
return S3Downloader(model_uri, model_name, enable_progress_bar)
220+
return S3Downloader(model_uri, model_name, download_config, enable_progress_bar)
167221
elif re.match(envs.DOWNLOADER_TOS_REGEX, model_uri):
168222
if envs.DOWNLOADER_TOS_VERSION == "v1":
169223
from aibrix.downloader.tos import TOSDownloaderV1
170224

171-
return TOSDownloaderV1(model_uri, model_name, enable_progress_bar)
225+
return TOSDownloaderV1(
226+
model_uri, model_name, download_config, enable_progress_bar
227+
)
172228
else:
173229
from aibrix.downloader.tos import TOSDownloaderV2
174230

175-
return TOSDownloaderV2(model_uri, model_name, enable_progress_bar)
231+
return TOSDownloaderV2(
232+
model_uri, model_name, download_config, enable_progress_bar
233+
)
176234
else:
177235
from aibrix.downloader.huggingface import HuggingFaceDownloader
178236

179-
return HuggingFaceDownloader(model_uri, model_name, enable_progress_bar)
237+
return HuggingFaceDownloader(
238+
model_uri, model_name, download_config, enable_progress_bar
239+
)

python/aibrix/aibrix/downloader/entity.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ class RemoteSource(Enum):
3131
S3 = "s3"
3232
TOS = "tos"
3333
HUGGINGFACE = "huggingface"
34+
UNKNOWN = "unknown"
35+
36+
def __str__(self):
37+
return self.value
3438

3539

3640
class FileDownloadStatus(Enum):
@@ -39,13 +43,20 @@ class FileDownloadStatus(Enum):
3943
NO_OPERATION = "no_operation" # Interrupted from downloading
4044
UNKNOWN = "unknown"
4145

46+
def __str__(self):
47+
return self.value
48+
4249

4350
class ModelDownloadStatus(Enum):
51+
NOT_EXIST = "not_exist"
4452
DOWNLOADING = "downloading"
4553
DOWNLOADED = "downloaded"
4654
NO_OPERATION = "no_operation" # Interrupted from downloading
4755
UNKNOWN = "unknown"
4856

57+
def __str__(self):
58+
return self.value
59+
4960

5061
@dataclass
5162
class DownloadFile:
@@ -125,13 +136,22 @@ def status(self):
125136

126137
return ModelDownloadStatus.UNKNOWN
127138

139+
@property
140+
def model_root_path(self) -> Path:
141+
return Path(self.local_path).joinpath(self.model_name)
142+
128143
@classmethod
129144
def infer_from_model_path(
130145
cls, local_path: Path, model_name: str, source: RemoteSource
131146
) -> Optional["DownloadModel"]:
132147
assert source is not None
133148

134149
model_base_dir = Path(local_path).joinpath(model_name)
150+
151+
# model not exists
152+
if not model_base_dir.exists():
153+
return None
154+
135155
cache_sub_dir = (DOWNLOAD_CACHE_DIR % source.value).strip("/")
136156
cache_dir = Path(model_base_dir).joinpath(cache_sub_dir)
137157
lock_files = list(Path(cache_dir).glob("*.lock"))

0 commit comments

Comments
 (0)