Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Add Downloader implementation for runtime #96

Merged
merged 12 commits into from
Aug 29, 2024
1 change: 1 addition & 0 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
run: |
cd python/aibrix
python -m ruff check .
python -m ruff format --check .
- name: Run isort
run: |
cd python/aibrix
Expand Down
38 changes: 26 additions & 12 deletions python/aibrix/aibrix/downloader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ class BaseDownloader(ABC):

model_uri: str
model_name: str
required_envs: List[str] = field(default_factory=list)
optional_envs: List[str] = field(default_factory=list)
allow_file_suffix: List[str] = field(default_factory=list)
allow_file_suffix: Optional[List[str]] = field(
default_factory=lambda: envs.DOWNLOADER_ALLOW_FILE_SUFFIX
)

def __post_init__(self):
# ensure downloader required envs are set
self._check_config()
# valid downloader config
self._valid_config()
self.model_name_path = self.model_name.replace("/", "_")

@abstractmethod
def _check_config(self):
def _valid_config(self):
pass

@abstractmethod
Expand All @@ -59,7 +59,7 @@ def _support_range_download(self) -> bool:
pass

@abstractmethod
def download(self, path: str, local_path: Path, enable_range: bool = True):
def download(self, filename: str, local_path: Path, enable_range: bool = True):
pass

def download_directory(self, local_path: Path):
Expand All @@ -69,8 +69,8 @@ def download_directory(self, local_path: Path):
used to download the directory.
"""
directory_list = self._directory_list(self.model_uri)
if len(self.allow_file_suffix) == 0:
logger.info("All files from {self.model_uri} will be downloaded.")
if self.allow_file_suffix is None:
logger.info(f"All files from {self.model_uri} will be downloaded.")
filtered_files = directory_list
else:
filtered_files = [
Expand All @@ -91,7 +91,10 @@ def download_directory(self, local_path: Path):
executor = ThreadPoolExecutor(num_threads)
futures = [
executor.submit(
self.download, path=file, local_path=local_path, enable_range=False
self.download,
filename=file,
local_path=local_path,
enable_range=False,
)
for file in filtered_files
]
Expand All @@ -117,7 +120,7 @@ def download_directory(self, local_path: Path):
f"duration: {duration:.2f} seconds."
)

def download_model(self, local_path: Optional[str]):
def download_model(self, local_path: Optional[str] = None):
if local_path is None:
local_path = envs.DOWNLOADER_LOCAL_DIR
Path(local_path).mkdir(parents=True, exist_ok=True)
Expand All @@ -131,9 +134,20 @@ def download_model(self, local_path: Optional[str]):
if self._is_directory():
self.download_directory(model_path)
else:
self.download(self.model_uri, model_path)
self.download(
filename=self.model_uri,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the naming is a little bit weird. file should be subset of the model_uri.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I also feel a bit awkward here. For TOS, bucket_name is in host_name of model_uri . For AWS, bucket_name is in the path of the model_uri. It is difficult to distinguish which subset of the model_uri is obtained in the base class. Or we should rename filename

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean TOS uri is like bucket-name-tos-beijing/path and S3 is s3://bucket-name/path? for the internal methods, I suggest to have two variables like bucket and path so we can construct the path as needed. but for the user's input. they probably give the URI directly, right? is there a way to split the string into bucket and path?

local_path=model_path,
enable_range=self._support_range_download(),
)

# create completed file to indicate download completed
completed_file = model_path.joinpath(".completed")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do you use it for completion check?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I want is to first execute an entrypoint script in the main container to check if the file .completed exists in the model directory, and then execute the subsequent model startup logic. This can ensure the order of download and startup.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's say we use the vLLM, in that case, we need to change the entrypoint or add a wrapper?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's like using script content like the following as an entrypoint

#!/bin/bash

MODEL_PATH="/path/to/model"

while true; do
    if [ -e "$FILEPATH/.completed" ]; then
        echo "File exists. Continuing execution..."
        break
    else
        sleep 1
    fi
done

python3 -m vllm.entrypoints.openai.api_server $@ --model $MODEL_PATH

completed_file.touch()
return model_path

def __hash__(self):
return hash(tuple(self.__dict__))


def get_downloader(model_uri: str) -> BaseDownloader:
"""Get downloader for model_uri."""
Expand Down
78 changes: 67 additions & 11 deletions python/aibrix/aibrix/downloader/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,83 @@
# limitations under the License.

from pathlib import Path
from typing import List
from typing import List, Optional

from huggingface_hub import HfApi, hf_hub_download, snapshot_download

from aibrix import envs
from aibrix.downloader.base import BaseDownloader
from aibrix.logger import init_logger

logger = init_logger(__name__)


def _parse_model_name_from_uri(model_uri: str) -> str:
return model_uri


class HuggingFaceDownloader(BaseDownloader):
def __init__(self, model_uri):
super().__init__(model_uri)
def __init__(self, model_uri: str, model_name: Optional[str] = None):
if model_name is None and envs.DOWNLOADER_MODEL_NAME is None:
model_name = _parse_model_name_from_uri(model_uri)

super().__init__(model_uri=model_uri, model_name=model_name) # type: ignore

self.hf_token = envs.DOWNLOADER_HF_TOKEN
self.hf_endpoint = envs.DOWNLOADER_HF_ENDPOINT
self.hf_revision = envs.DOWNLOADER_HF_REVISION

self.allow_patterns = (
None
if self.allow_file_suffix is None
else [f"*.{suffix}" for suffix in self.allow_file_suffix]
)
logger.debug(
f"Downloader {self.__class__.__name__} initialized."
f"HF Settings are followed: \n"
f"hf_token={self.hf_token}, \n"
f"hf_endpoint={self.hf_endpoint}"
)

def _check_config(self):
pass
def _valid_config(self):
assert (
len(self.model_uri.split("/")) == 2
), "Model uri must be in `repo/name` format."

assert self.model_name is not None, "Model name is not set."

def _is_directory(self) -> bool:
"""Check if model_uri is a directory."""
return False
"""Check if model_uri is a directory.
model_uri in `repo/name` format must be a directory.
"""
return True

def _directory_list(self, path: str) -> List[str]:
return []
hf_api = HfApi(endpoint=self.hf_endpoint, token=self.hf_token)
return hf_api.list_repo_files(repo_id=self.model_uri)

def _support_range_download(self) -> bool:
return True
return False

def download(self, filename: str, local_path: Path, enable_range: bool = True):
hf_hub_download(
repo_id=self.model_uri,
filename=filename,
local_dir=local_path,
revision=self.hf_revision,
token=self.hf_token,
endpoint=self.hf_endpoint,
local_dir_use_symlinks=False,
)

def download(self, path: str, local_path: Path, enable_range: bool = True):
pass
def download_directory(self, local_path: Path):
snapshot_download(
self.model_uri,
local_dir=local_path,
revision=self.hf_revision,
token=self.hf_token,
allow_patterns=self.allow_patterns,
max_workers=envs.DOWNLOADER_NUM_THREADS,
endpoint=self.hf_endpoint,
local_dir_use_symlinks=False,
)
107 changes: 99 additions & 8 deletions python/aibrix/aibrix/downloader/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,120 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import lru_cache
from pathlib import Path
from typing import List
from urllib.parse import urlparse

import boto3
from boto3.s3.transfer import TransferConfig
from tqdm import tqdm

from aibrix import envs
from aibrix.downloader.base import BaseDownloader


def _parse_bucket_info_from_uri(uri):
parsed = urlparse(uri, scheme="s3")
bucket_name = parsed.netloc
bucket_path = parsed.path.lstrip("/")
return bucket_name, bucket_path


class S3Downloader(BaseDownloader):
def __init__(self, model_uri):
super().__init__(model_uri)
model_name = envs.DOWNLOADER_MODEL_NAME
ak = envs.DOWNLOADER_AWS_ACCESS_KEY
sk = envs.DOWNLOADER_AWS_SECRET_KEY
endpoint = envs.DOWNLOADER_AWS_ENDPOINT
region = envs.DOWNLOADER_AWS_REGION
bucket_name, bucket_path = _parse_bucket_info_from_uri(model_uri)
self.bucket_name = bucket_name
self.bucket_path = bucket_path

def _check_config(self):
pass
self.client = boto3.client(
service_name="s3",
region_name=region,
endpoint_url=endpoint,
aws_access_key_id=ak,
aws_secret_access_key=sk,
)

super().__init__(model_uri=model_uri, model_name=model_name) # type: ignore

def _valid_config(self):
assert (
self.bucket_name is not None or self.bucket_name == ""
), "S3 bucket name is not set."
assert (
self.bucket_path is not None or self.bucket_path == ""
), "S3 bucket path is not set."
try:
self.client.head_bucket(Bucket=self.bucket_name)
except Exception as e:
assert False, f"S3 bucket {self.bucket_name} not exist for {e}."

@lru_cache()
def _is_directory(self) -> bool:
"""Check if model_uri is a directory."""
return False
if self.bucket_path.endswith("/"):
return True
objects_out = self.client.list_objects_v2(
Bucket=self.bucket_name, Delimiter="/", Prefix=self.bucket_path
)
contents = objects_out.get("Contents", [])
if len(contents) == 1 and contents[0].get("Key") == self.bucket_path:
return False
return True

def _directory_list(self, path: str) -> List[str]:
return []
objects_out = self.client.list_objects_v2(
Bucket=self.bucket_name, Delimiter="/", Prefix=self.bucket_path
)
contents = objects_out.get("Contents", [])
return [content.get("Key") for content in contents]

def _support_range_download(self) -> bool:
return True

def download(self, path: str, local_path: Path, enable_range: bool = True):
pass
def download(self, filename: str, local_path: Path, enable_range: bool = True):
# filename should extract from model_uri when it is not a directory
if not self._is_directory():
filename = self.bucket_path

# check if file exist
try:
meta_data = self.client.head_object(Bucket=self.bucket_name, Key=filename)
except Exception as e:
raise ValueError(f"TOS file {filename} not exist for {e}.")

_file_name = filename.split("/")[-1]
# S3 client does not support Path, convert it to str
local_file = str(local_path.joinpath(_file_name).absolute())

# construct TransferConfig
config_kwargs = {
"max_concurrency": envs.DOWNLOADER_NUM_THREADS,
"use_threads": enable_range,
}
if envs.DOWNLOADER_PART_THRESHOLD is not None:
config_kwargs["multipart_threshold"] = envs.DOWNLOADER_PART_THRESHOLD
if envs.DOWNLOADER_PART_CHUNKSIZE is not None:
config_kwargs["multipart_chunksize"] = envs.DOWNLOADER_PART_CHUNKSIZE

config = TransferConfig(**config_kwargs)

# download file
total_length = int(meta_data.get("ContentLength", 0))
with tqdm(total=total_length, unit="b", unit_scale=True) as pbar:

def download_progress(bytes_transferred):
pbar.update(bytes_transferred)

self.client.download_file(
Bucket=self.bucket_name,
Key=filename,
Filename=local_file,
Config=config,
Callback=download_progress,
)
Loading
Loading