diff --git a/python/aibrix/aibrix/downloader/s3.py b/python/aibrix/aibrix/downloader/s3.py index 43e87c8c..b3112ac3 100644 --- a/python/aibrix/aibrix/downloader/s3.py +++ b/python/aibrix/aibrix/downloader/s3.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import lru_cache from pathlib import Path -from typing import List +from typing import List, Tuple from urllib.parse import urlparse import boto3 @@ -24,7 +24,7 @@ from aibrix.downloader.base import BaseDownloader -def _parse_bucket_info_from_uri(uri): +def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]: parsed = urlparse(uri, scheme="s3") bucket_name = parsed.netloc bucket_path = parsed.path.lstrip("/") diff --git a/python/aibrix/aibrix/downloader/tos.py b/python/aibrix/aibrix/downloader/tos.py index 4ff25bbe..a77b9db3 100644 --- a/python/aibrix/aibrix/downloader/tos.py +++ b/python/aibrix/aibrix/downloader/tos.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import re from functools import lru_cache from pathlib import Path from typing import List, Tuple @@ -29,18 +28,11 @@ tos_logger.setLevel(logging.WARNING) -def _parse_bucket_info_from_uri(model_uri: str) -> Tuple[str, str, str, str]: - if not re.match(envs.DOWNLOADER_TOS_REGEX, model_uri): - raise ValueError(f"TOS uri {model_uri} not valid format.") - parsed = urlparse(model_uri) - bucket_name, endpoint = parsed.netloc.split(".", 1) +def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]: + parsed = urlparse(uri, scheme="tos") + bucket_name = parsed.netloc bucket_path = parsed.path.lstrip("/") - matched = re.match(r"tos-(.+?).volces.com", endpoint) - if matched is None: - raise ValueError(f"TOS endpoint {endpoint} not valid format.") - - region = matched.group(1) - return bucket_name, bucket_path, endpoint, region + return bucket_name, bucket_path class TOSDownloader(BaseDownloader): @@ -48,12 +40,17 @@ def __init__(self, model_uri): model_name = envs.DOWNLOADER_MODEL_NAME ak = envs.DOWNLOADER_TOS_ACCESS_KEY sk = envs.DOWNLOADER_TOS_SECRET_KEY - bucket_name, bucket_path, endpoint, region = _parse_bucket_info_from_uri( - model_uri - ) + endpoint = envs.DOWNLOADER_TOS_ENDPOINT or '' + region = envs.DOWNLOADER_TOS_REGION or '' + bucket_name, bucket_path = _parse_bucket_info_from_uri(model_uri) self.bucket_name = bucket_name self.bucket_path = bucket_path - self.client = tos.TosClientV2(ak, sk, endpoint, region) + + self.client = tos.TosClientV2(ak=ak, + sk=sk, + endpoint=endpoint, + region=region) + super().__init__(model_uri=model_uri, model_name=model_name) # type: ignore def _valid_config(self): diff --git a/python/aibrix/aibrix/envs.py b/python/aibrix/aibrix/envs.py index 90465da0..eef792ee 100644 --- a/python/aibrix/aibrix/envs.py +++ b/python/aibrix/aibrix/envs.py @@ -52,7 +52,7 @@ def _parse_int_or_none(value: Optional[str]) -> Optional[int]: # Downloader Regex DOWNLOADER_S3_REGEX = r"^s3://" -DOWNLOADER_TOS_REGEX = r"https://(.+?).tos-(.+?).volces.com/(.+)" +DOWNLOADER_TOS_REGEX = r"^tos://" # Downloader HuggingFace Envs DOWNLOADER_HF_TOKEN = os.getenv("HF_TOKEN")