Skip to content

Commit

Permalink
[Misc] Check AI Runtime download env settings (#221)
Browse files Browse the repository at this point in the history
* fix: assert DOWNLOADER_MODEL_NAME setting during download from tos or s3

* test: add test case about model name not set

* test: fix tos func name

* style
  • Loading branch information
brosoul authored Sep 25, 2024
1 parent 9b47a9a commit 9c3432c
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 6 deletions.
3 changes: 3 additions & 0 deletions python/aibrix/aibrix/downloader/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def __init__(self, model_uri):
) # type: ignore

def _valid_config(self):
assert (
self.model_name is not None and self.model_name != ""
), "S3 model name is not set, please set env variable DOWNLOADER_MODEL_NAME."
assert (
self.bucket_name is not None and self.bucket_name != ""
), "S3 bucket name is not set."
Expand Down
7 changes: 5 additions & 2 deletions python/aibrix/aibrix/downloader/tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]:
class TOSDownloader(BaseDownloader):
def __init__(self, model_uri):
model_name = envs.DOWNLOADER_MODEL_NAME
ak = envs.DOWNLOADER_TOS_ACCESS_KEY
sk = envs.DOWNLOADER_TOS_SECRET_KEY
ak = envs.DOWNLOADER_TOS_ACCESS_KEY or ""
sk = envs.DOWNLOADER_TOS_SECRET_KEY or ""
endpoint = envs.DOWNLOADER_TOS_ENDPOINT or ""
region = envs.DOWNLOADER_TOS_REGION or ""
enable_crc = envs.DOWNLOADER_TOS_ENABLE_CRC
Expand All @@ -60,6 +60,9 @@ def __init__(self, model_uri):
) # type: ignore

def _valid_config(self):
assert (
self.model_name is not None and self.model_name != ""
), "TOS model name is not set, please set env variable DOWNLOADER_MODEL_NAME."
assert (
self.bucket_name is not None and self.bucket_name != ""
), "TOS bucket name is not set."
Expand Down
25 changes: 25 additions & 0 deletions python/aibrix/tests/downloader/test_downloader_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from aibrix.downloader.s3 import S3Downloader

S3_BOTO3_MODULE = "aibrix.downloader.s3.boto3"
ENVS_MODULE = "aibrix.downloader.s3.envs"


def mock_not_exsit_boto3(mock_boto3):
Expand All @@ -34,6 +35,17 @@ def mock_exsit_boto3(mock_boto3):
mock_client.head_bucket.return_value = mock.Mock()


env_group = mock.Mock()
env_group.DOWNLOADER_MODEL_NAME = "model_name"
env_group.DOWNLOADER_NUM_THREADS = 4


env_group_no_model_name = mock.Mock()
env_group_no_model_name.DOWNLOADER_MODEL_NAME = None
env_group_no_model_name.DOWNLOADER_NUM_THREADS = 4


@mock.patch(ENVS_MODULE, env_group)
@mock.patch(S3_BOTO3_MODULE)
def test_get_downloader_s3(mock_boto3):
mock_exsit_boto3(mock_boto3)
Expand All @@ -42,6 +54,7 @@ def test_get_downloader_s3(mock_boto3):
assert isinstance(downloader, S3Downloader)


@mock.patch(ENVS_MODULE, env_group)
@mock.patch(S3_BOTO3_MODULE)
def test_get_downloader_s3_path_not_exist(mock_boto3):
mock_not_exsit_boto3(mock_boto3)
Expand All @@ -51,6 +64,7 @@ def test_get_downloader_s3_path_not_exist(mock_boto3):
assert "not exist" in str(exception.value)


@mock.patch(ENVS_MODULE, env_group)
@mock.patch(S3_BOTO3_MODULE)
def test_get_downloader_s3_path_empty(mock_boto3):
mock_exsit_boto3(mock_boto3)
Expand All @@ -62,6 +76,7 @@ def test_get_downloader_s3_path_empty(mock_boto3):
assert "S3 bucket name is not set." in str(exception.value)


@mock.patch(ENVS_MODULE, env_group)
@mock.patch(S3_BOTO3_MODULE)
def test_get_downloader_s3_path_empty_path(mock_boto3):
mock_exsit_boto3(mock_boto3)
Expand All @@ -70,3 +85,13 @@ def test_get_downloader_s3_path_empty_path(mock_boto3):
with pytest.raises(AssertionError) as exception:
get_downloader("s3://bucket/")
assert "S3 bucket path is not set." in str(exception.value)


@mock.patch(ENVS_MODULE, env_group_no_model_name)
@mock.patch(S3_BOTO3_MODULE)
def test_get_downloader_s3_no_model_name(mock_tos):
mock_exsit_boto3(mock_tos)

with pytest.raises(AssertionError) as exception:
get_downloader("s3://bucket/path")
assert "S3 model name is not set" in str(exception.value)
31 changes: 27 additions & 4 deletions python/aibrix/tests/downloader/test_downloader_tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from aibrix.downloader.tos import TOSDownloader

TOS_MODULE = "aibrix.downloader.tos.tos"
ENVS_MODULE = "aibrix.downloader.tos.envs"


def mock_not_exsit_tos(mock_tos):
Expand All @@ -34,25 +35,36 @@ def mock_exsit_tos(mock_tos):
mock_client.head_bucket.return_value = mock.Mock()


env_group = mock.Mock()
env_group.DOWNLOADER_MODEL_NAME = "model_name"


env_group_no_model_name = mock.Mock()
env_group_no_model_name.DOWNLOADER_MODEL_NAME = None


@mock.patch(ENVS_MODULE, env_group)
@mock.patch(TOS_MODULE)
def test_get_downloader_s3(mock_tos):
def test_get_downloader_tos(mock_tos):
mock_exsit_tos(mock_tos)

downloader = get_downloader("tos://bucket/path")
assert isinstance(downloader, TOSDownloader)


@mock.patch(ENVS_MODULE, env_group)
@mock.patch(TOS_MODULE)
def test_get_downloader_s3_path_not_exist(mock_tos):
def test_get_downloader_tos_path_not_exist(mock_tos):
mock_not_exsit_tos(mock_tos)

with pytest.raises(AssertionError) as exception:
get_downloader("tos://bucket/not_exsit_path")
assert "not exist" in str(exception.value)


@mock.patch(ENVS_MODULE, env_group)
@mock.patch(TOS_MODULE)
def test_get_downloader_s3_path_empty(mock_tos):
def test_get_downloader_tos_path_empty(mock_tos):
mock_exsit_tos(mock_tos)

# Bucket name and path both are empty,
Expand All @@ -62,11 +74,22 @@ def test_get_downloader_s3_path_empty(mock_tos):
assert "TOS bucket name is not set." in str(exception.value)


@mock.patch(ENVS_MODULE, env_group)
@mock.patch(TOS_MODULE)
def test_get_downloader_s3_path_empty_path(mock_tos):
def test_get_downloader_tos_path_empty_path(mock_tos):
mock_exsit_tos(mock_tos)

# bucket path is empty
with pytest.raises(AssertionError) as exception:
get_downloader("tos://bucket/")
assert "TOS bucket path is not set." in str(exception.value)


@mock.patch(ENVS_MODULE, env_group_no_model_name)
@mock.patch(TOS_MODULE)
def test_get_downloader_tos_no_model_name(mock_tos):
mock_exsit_tos(mock_tos)

with pytest.raises(AssertionError) as exception:
get_downloader("tos://bucket/path")
assert "TOS model name is not set" in str(exception.value)

0 comments on commit 9c3432c

Please sign in to comment.