Skip to content

Commit

Permalink
refact: refact tos model uri use tos protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
brosoul committed Aug 27, 2024
1 parent 8c5b9a1 commit 19a45ac
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 19 deletions.
4 changes: 2 additions & 2 deletions python/aibrix/aibrix/downloader/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from functools import lru_cache
from pathlib import Path
from typing import List
from typing import List, Tuple
from urllib.parse import urlparse

import boto3
Expand All @@ -24,7 +24,7 @@
from aibrix.downloader.base import BaseDownloader


def _parse_bucket_info_from_uri(uri):
def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]:
parsed = urlparse(uri, scheme="s3")
bucket_name = parsed.netloc
bucket_path = parsed.path.lstrip("/")
Expand Down
29 changes: 13 additions & 16 deletions python/aibrix/aibrix/downloader/tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from functools import lru_cache
from pathlib import Path
from typing import List, Tuple
Expand All @@ -29,31 +28,29 @@
tos_logger.setLevel(logging.WARNING)


def _parse_bucket_info_from_uri(model_uri: str) -> Tuple[str, str, str, str]:
if not re.match(envs.DOWNLOADER_TOS_REGEX, model_uri):
raise ValueError(f"TOS uri {model_uri} not valid format.")
parsed = urlparse(model_uri)
bucket_name, endpoint = parsed.netloc.split(".", 1)
def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]:
parsed = urlparse(uri, scheme="tos")
bucket_name = parsed.netloc
bucket_path = parsed.path.lstrip("/")
matched = re.match(r"tos-(.+?).volces.com", endpoint)
if matched is None:
raise ValueError(f"TOS endpoint {endpoint} not valid format.")

region = matched.group(1)
return bucket_name, bucket_path, endpoint, region
return bucket_name, bucket_path


class TOSDownloader(BaseDownloader):
def __init__(self, model_uri):
model_name = envs.DOWNLOADER_MODEL_NAME
ak = envs.DOWNLOADER_TOS_ACCESS_KEY
sk = envs.DOWNLOADER_TOS_SECRET_KEY
bucket_name, bucket_path, endpoint, region = _parse_bucket_info_from_uri(
model_uri
)
endpoint = envs.DOWNLOADER_TOS_ENDPOINT or ''
region = envs.DOWNLOADER_TOS_REGION or ''
bucket_name, bucket_path = _parse_bucket_info_from_uri(model_uri)
self.bucket_name = bucket_name
self.bucket_path = bucket_path
self.client = tos.TosClientV2(ak, sk, endpoint, region)

self.client = tos.TosClientV2(ak=ak,
sk=sk,
endpoint=endpoint,
region=region)

super().__init__(model_uri=model_uri, model_name=model_name) # type: ignore

def _valid_config(self):
Expand Down
2 changes: 1 addition & 1 deletion python/aibrix/aibrix/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _parse_int_or_none(value: Optional[str]) -> Optional[int]:

# Downloader Regex
DOWNLOADER_S3_REGEX = r"^s3://"
DOWNLOADER_TOS_REGEX = r"https://(.+?).tos-(.+?).volces.com/(.+)"
DOWNLOADER_TOS_REGEX = r"^tos://"

# Downloader HuggingFace Envs
DOWNLOADER_HF_TOKEN = os.getenv("HF_TOKEN")
Expand Down

0 comments on commit 19a45ac

Please sign in to comment.