Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Infer model name from model_uri and check AWS credential #250

Merged
merged 5 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/source/features/runtime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ First Define the necessary environment variables for the HuggingFace model.
.. code-block:: bash

# General settings
export DOWNLOADER_MODEL_NAME=deepseek-ai/deepseek-coder-6.7b-instruct/
export DOWNLOADER_ALLOW_FILE_SUFFIX=json, safetensors
export DOWNLOADER_NUM_THREADS=16
# HuggingFace settings
Expand All @@ -70,7 +69,6 @@ First Define the necessary environment variables for the S3 model.
.. code-block:: bash

# General settings
export DOWNLOADER_MODEL_NAME=deepseek-ai/deepseek-coder-6.7b-instruct/
export DOWNLOADER_ALLOW_FILE_SUFFIX=json, safetensors
export DOWNLOADER_NUM_THREADS=16
# AWS settings
Expand All @@ -96,7 +94,6 @@ First Define the necessary environment variables for the TOS model.
.. code-block:: bash

# General settings
export DOWNLOADER_MODEL_NAME=deepseek-ai/deepseek-coder-6.7b-instruct/
export DOWNLOADER_ALLOW_FILE_SUFFIX=json, safetensors
export DOWNLOADER_NUM_THREADS=16
# AWS settings
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorial/runtime/runtime-hf-download.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ spec:
- --port
- "8000"
- --model
- /models/deepseek-ai/deepseek-coder-6.7b-instruct
- /models/deepseek-coder-6.7b-instruct
- --served-model-name
- deepseek-ai/deepseek-coder-6.7b-instruct
- --distributed-executor-backend
Expand Down Expand Up @@ -96,9 +96,9 @@ spec:
- deepseek-ai/deepseek-coder-6.7b-instruct
- --local-dir
- /models/
- --model-name
- deepseek-coder-6.7b-instruct
env:
- name: DOWNLOADER_MODEL_NAME
value: deepseek-ai/deepseek-coder-6.7b-instruct
- name: DOWNLOADER_ALLOW_FILE_SUFFIX
value: json, safetensors
- name: HF_TOKEN
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorial/runtime/runtime-s3-download.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ spec:
- --port
- "8000"
- --model
- /models/deepseek-ai/deepseek-coder-6.7b-instruct
- /models/deepseek-coder-6.7b-instruct
- --served-model-name
- deepseek-ai/deepseek-coder-6.7b-instruct
- --distributed-executor-backend
Expand Down Expand Up @@ -96,9 +96,9 @@ spec:
- s3://<input your s3 bucket name>/<input your s3 bucket path>
- --local-dir
- /models/
- --model-name
- deepseek-coder-6.7b-instruct
env:
- name: DOWNLOADER_MODEL_NAME
value: deepseek-ai/deepseek-coder-6.7b-instruct
- name: DOWNLOADER_ALLOW_FILE_SUFFIX
value: json, safetensors
- name: AWS_ACCESS_KEY_ID
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorial/runtime/runtime-tos-download.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ spec:
- --port
- "8000"
- --model
- /models/deepseek-ai/deepseek-coder-6.7b-instruct
- /models/deepseek-coder-6.7b-instruct
- --served-model-name
- deepseek-ai/deepseek-coder-6.7b-instruct
- --distributed-executor-backend
Expand Down Expand Up @@ -96,9 +96,9 @@ spec:
- tos://<input your tos bucket name>/<input your tos bucket path>
- --local-dir
- /models/
- --model-name
- deepseek-coder-6.7b-instruct
env:
- name: DOWNLOADER_MODEL_NAME
value: deepseek-ai/deepseek-coder-6.7b-instruct
- name: DOWNLOADER_ALLOW_FILE_SUFFIX
value: json, safetensors
- name: TOS_ACCESS_KEY
Expand Down
6 changes: 4 additions & 2 deletions python/aibrix/aibrix/downloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
from aibrix.downloader.base import get_downloader


def download_model(model_uri: str, local_path: Optional[str] = None):
def download_model(
model_uri: str, local_path: Optional[str] = None, model_name: Optional[str] = None
):
"""Download model from model_uri to local_path.

Args:
model_uri (str): model uri.
local_path (str): local path to save model.
"""

downloader = get_downloader(model_uri)
downloader = get_downloader(model_uri, model_name)
return downloader.download_model(local_path)


Expand Down
10 changes: 8 additions & 2 deletions python/aibrix/aibrix/downloader/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@ def main():
"--local-dir",
type=str,
default=None,
help="dir to save model files",
help="base dir of the model file. If not set, it will used with env `DOWNLOADER_LOCAL_DIR`",
)
parser.add_argument(
"--model-name",
type=str,
default=None,
help="subdir of the base dir to save model files",
)
args = parser.parse_args()
download_model(args.model_uri, args.local_dir)
download_model(args.model_uri, args.local_dir, args.model_name)


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions python/aibrix/aibrix/downloader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,17 @@ def __hash__(self):
return hash(tuple(self.__dict__))


def get_downloader(model_uri: str) -> BaseDownloader:
def get_downloader(model_uri: str, model_name: Optional[str] = None) -> BaseDownloader:
"""Get downloader for model_uri."""
if re.match(envs.DOWNLOADER_S3_REGEX, model_uri):
from aibrix.downloader.s3 import S3Downloader

return S3Downloader(model_uri)
return S3Downloader(model_uri, model_name)
elif re.match(envs.DOWNLOADER_TOS_REGEX, model_uri):
from aibrix.downloader.tos import TOSDownloader

return TOSDownloader(model_uri)
return TOSDownloader(model_uri, model_name)
else:
from aibrix.downloader.huggingface import HuggingFaceDownloader

return HuggingFaceDownloader(model_uri)
return HuggingFaceDownloader(model_uri, model_name)
6 changes: 2 additions & 4 deletions python/aibrix/aibrix/downloader/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ def _parse_model_name_from_uri(model_uri: str) -> str:
class HuggingFaceDownloader(BaseDownloader):
def __init__(self, model_uri: str, model_name: Optional[str] = None):
if model_name is None:
if envs.DOWNLOADER_MODEL_NAME is not None:
model_name = envs.DOWNLOADER_MODEL_NAME
else:
model_name = _parse_model_name_from_uri(model_uri)
model_name = _parse_model_name_from_uri(model_uri)
logger.info(f"model_name is not set, using `{model_name}` as model_name")

self.hf_token = envs.DOWNLOADER_HF_TOKEN
self.hf_endpoint = envs.DOWNLOADER_HF_ENDPOINT
Expand Down
19 changes: 15 additions & 4 deletions python/aibrix/aibrix/downloader/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@

from aibrix import envs
from aibrix.downloader.base import BaseDownloader
from aibrix.downloader.utils import meta_file, need_to_download, save_meta_data
from aibrix.downloader.utils import (
infer_model_name,
meta_file,
need_to_download,
save_meta_data,
)
from aibrix.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -37,14 +42,20 @@ def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]:


class S3Downloader(BaseDownloader):
def __init__(self, model_uri):
model_name = envs.DOWNLOADER_MODEL_NAME
def __init__(self, model_uri, model_name: Optional[str] = None):
if model_name is None:
model_name = infer_model_name(model_uri)
logger.info(f"model_name is not set, using `{model_name}` as model_name")

ak = envs.DOWNLOADER_AWS_ACCESS_KEY_ID
sk = envs.DOWNLOADER_AWS_SECRET_ACCESS_KEY
endpoint = envs.DOWNLOADER_AWS_ENDPOINT_URL
region = envs.DOWNLOADER_AWS_REGION
bucket_name, bucket_path = _parse_bucket_info_from_uri(model_uri)

assert ak is not None and ak != "", "`AWS_ACCESS_KEY_ID` is not set."
assert sk is not None and sk != "", "`AWS_SECRET_ACCESS_KEY` is not set."

# Avoid warning log "Connection pool is full"
# Refs: https://github.com/boto/botocore/issues/619#issuecomment-583511406
max_pool_connections = (
Expand Down Expand Up @@ -75,7 +86,7 @@ def __init__(self, model_uri):
def _valid_config(self):
assert (
self.model_name is not None and self.model_name != ""
), "S3 model name is not set, please set env variable DOWNLOADER_MODEL_NAME."
), "S3 model name is not set, please check `--model-name`."
assert (
self.bucket_name is not None and self.bucket_name != ""
), "S3 bucket name is not set."
Expand Down
16 changes: 12 additions & 4 deletions python/aibrix/aibrix/downloader/tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@

from aibrix import envs
from aibrix.downloader.base import BaseDownloader
from aibrix.downloader.utils import meta_file, need_to_download, save_meta_data
from aibrix.downloader.utils import (
infer_model_name,
meta_file,
need_to_download,
save_meta_data,
)
from aibrix.logger import init_logger

tos_logger = logging.getLogger("tos")
Expand All @@ -39,8 +44,11 @@ def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]:


class TOSDownloader(BaseDownloader):
def __init__(self, model_uri):
model_name = envs.DOWNLOADER_MODEL_NAME
def __init__(self, model_uri, model_name: Optional[str] = None):
if model_name is None:
model_name = infer_model_name(model_uri)
logger.info(f"model_name is not set, using `{model_name}` as model_name")

ak = envs.DOWNLOADER_TOS_ACCESS_KEY or ""
sk = envs.DOWNLOADER_TOS_SECRET_KEY or ""
endpoint = envs.DOWNLOADER_TOS_ENDPOINT or ""
Expand All @@ -62,7 +70,7 @@ def __init__(self, model_uri):
def _valid_config(self):
assert (
self.model_name is not None and self.model_name != ""
), "TOS model name is not set, please set env variable DOWNLOADER_MODEL_NAME."
), "TOS model name is not set, please check `--model-name`."
assert (
self.bucket_name is not None and self.bucket_name != ""
), "TOS bucket name is not set."
Expand Down
7 changes: 7 additions & 0 deletions python/aibrix/aibrix/downloader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,10 @@ def need_to_download(
f"DOWNLOADER_CHECK_FILE_EXIST={envs.DOWNLOADER_CHECK_FILE_EXIST}"
)
return True


def infer_model_name(uri: str):
if uri is None or uri == "":
raise ValueError("Model uri is empty.")

return uri.strip().strip("/").split("/")[-1]
1 change: 0 additions & 1 deletion python/aibrix/aibrix/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def _parse_int_or_none(value: Optional[str]) -> Optional[int]:
DOWNLOADER_LOCAL_DIR = os.getenv("DOWNLOADER_LOCAL_DIR", "/tmp/aibrix/models/")


DOWNLOADER_MODEL_NAME = os.getenv("DOWNLOADER_MODEL_NAME")
DOWNLOADER_NUM_THREADS = int(os.getenv("DOWNLOADER_NUM_THREADS", "4"))
DOWNLOADER_PART_THRESHOLD = _parse_int_or_none(os.getenv("DOWNLOADER_PART_THRESHOLD"))
DOWNLOADER_PART_CHUNKSIZE = _parse_int_or_none(os.getenv("DOWNLOADER_PART_CHUNKSIZE"))
Expand Down
34 changes: 25 additions & 9 deletions python/aibrix/tests/downloader/test_downloader_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ def mock_exsit_boto3(mock_boto3):


env_group = mock.Mock()
env_group.DOWNLOADER_MODEL_NAME = "model_name"
env_group.DOWNLOADER_NUM_THREADS = 4
env_group.DOWNLOADER_AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID"
env_group.DOWNLOADER_AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY"

env_no_ak = mock.Mock()
env_no_ak.DOWNLOADER_NUM_THREADS = 4
env_no_ak.DOWNLOADER_AWS_ACCESS_KEY_ID = ""
env_no_ak.DOWNLOADER_AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY"

env_group_no_model_name = mock.Mock()
env_group_no_model_name.DOWNLOADER_MODEL_NAME = None
env_group_no_model_name.DOWNLOADER_NUM_THREADS = 4
env_no_sk = mock.Mock()
env_no_sk.DOWNLOADER_NUM_THREADS = 4
env_no_sk.DOWNLOADER_AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID"
env_no_sk.DOWNLOADER_AWS_SECRET_ACCESS_KEY = ""


@mock.patch(ENVS_MODULE, env_group)
Expand Down Expand Up @@ -87,11 +93,21 @@ def test_get_downloader_s3_path_empty_path(mock_boto3):
assert "S3 bucket path is not set." in str(exception.value)


@mock.patch(ENVS_MODULE, env_group_no_model_name)
@mock.patch(ENVS_MODULE, env_no_ak)
@mock.patch(S3_BOTO3_MODULE)
def test_get_downloader_s3_no_model_name(mock_tos):
mock_exsit_boto3(mock_tos)
def test_get_downloader_s3_no_ak(mock_boto3):
mock_exsit_boto3(mock_boto3)

with pytest.raises(AssertionError) as exception:
get_downloader("s3://bucket/path")
assert "S3 model name is not set" in str(exception.value)
get_downloader("s3://bucket/")
assert "`AWS_ACCESS_KEY_ID` is not set." in str(exception.value)


@mock.patch(ENVS_MODULE, env_no_sk)
@mock.patch(S3_BOTO3_MODULE)
def test_get_downloader_s3_no_sk(mock_boto3):
mock_exsit_boto3(mock_boto3)

with pytest.raises(AssertionError) as exception:
get_downloader("s3://bucket/")
assert "`AWS_SECRET_ACCESS_KEY` is not set." in str(exception.value)
15 changes: 0 additions & 15 deletions python/aibrix/tests/downloader/test_downloader_tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@ def mock_exsit_tos(mock_tos):


env_group = mock.Mock()
env_group.DOWNLOADER_MODEL_NAME = "model_name"


env_group_no_model_name = mock.Mock()
env_group_no_model_name.DOWNLOADER_MODEL_NAME = None


@mock.patch(ENVS_MODULE, env_group)
Expand Down Expand Up @@ -83,13 +78,3 @@ def test_get_downloader_tos_path_empty_path(mock_tos):
with pytest.raises(AssertionError) as exception:
get_downloader("tos://bucket/")
assert "TOS bucket path is not set." in str(exception.value)


@mock.patch(ENVS_MODULE, env_group_no_model_name)
@mock.patch(TOS_MODULE)
def test_get_downloader_tos_no_model_name(mock_tos):
mock_exsit_tos(mock_tos)

with pytest.raises(AssertionError) as exception:
get_downloader("tos://bucket/path")
assert "TOS model name is not set" in str(exception.value)
13 changes: 13 additions & 0 deletions python/aibrix/tests/downloader/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
from aibrix.config import DOWNLOAD_CACHE_DIR
from aibrix.downloader.utils import (
check_file_exist,
infer_model_name,
load_meta_data,
meta_file,
need_to_download,
save_meta_data,
)
import pytest


def prepare_file_and_meta_data(file_path, meta_path, file_size, etag):
Expand Down Expand Up @@ -136,3 +138,14 @@ def test_need_to_download(mock_check: mock.Mock):
# recover envs
envs.DOWNLOADER_FORCE_DOWNLOAD = origin_force_download_env
envs.DOWNLOADER_CHECK_FILE_EXIST = origin_check_file_exist


def test_infer_model_name():
with pytest.raises(ValueError):
infer_model_name("")

with pytest.raises(ValueError):
infer_model_name(None)

model_name = infer_model_name("s3://bucket/path/to/model")
assert model_name == "model"
Loading