18
18
from concurrent .futures import ThreadPoolExecutor , wait
19
19
from dataclasses import dataclass , field
20
20
from pathlib import Path
21
- from typing import List , Optional
21
+ from typing import ClassVar , Dict , List , Optional
22
22
23
23
from aibrix import envs
24
+ from aibrix .downloader .entity import RemoteSource
24
25
from aibrix .logger import init_logger
25
26
26
27
logger = init_logger (__name__ )
27
28
28
29
30
+ @dataclass
31
+ class DownloadExtraConfig :
32
+ """Downloader extra config."""
33
+
34
+ # Auth config for s3 or tos
35
+ ak : Optional [str ] = None
36
+ sk : Optional [str ] = None
37
+ endpoint : Optional [str ] = None
38
+ region : Optional [str ] = None
39
+
40
+ # Auth config for huggingface
41
+ hf_endpoint : Optional [str ] = None
42
+ hf_token : Optional [str ] = None
43
+ hf_revision : Optional [str ] = None
44
+
45
+ # parrallel config
46
+ num_threads : Optional [int ] = None
47
+ max_io_queue : Optional [int ] = None
48
+ io_chunksize : Optional [int ] = None
49
+ part_threshold : Optional [int ] = None
50
+ part_chunksize : Optional [int ] = None
51
+
52
+ # other config
53
+ allow_file_suffix : Optional [List [str ]] = None
54
+ force_download : Optional [bool ] = None
55
+
56
+
57
+ DEFAULT_DOWNLOADER_EXTRA_CONFIG = DownloadExtraConfig ()
58
+
59
+
29
60
@dataclass
30
61
class BaseDownloader (ABC ):
31
62
"""Base class for downloader."""
@@ -34,15 +65,27 @@ class BaseDownloader(ABC):
34
65
model_name : str
35
66
bucket_path : str
36
67
bucket_name : Optional [str ]
37
- enable_progress_bar : bool = False
38
- allow_file_suffix : Optional [List [str ]] = field (
39
- default_factory = lambda : envs .DOWNLOADER_ALLOW_FILE_SUFFIX
68
+ download_extra_config : DownloadExtraConfig = field (
69
+ default_factory = DownloadExtraConfig
40
70
)
71
+ enable_progress_bar : bool = False
72
+ _source : ClassVar [RemoteSource ] = RemoteSource .UNKNOWN
41
73
42
74
def __post_init__ (self ):
43
75
# valid downloader config
44
76
self ._valid_config ()
45
77
self .model_name_path = self .model_name
78
+ self .allow_file_suffix = (
79
+ self .download_extra_config .allow_file_suffix
80
+ or envs .DOWNLOADER_ALLOW_FILE_SUFFIX
81
+ )
82
+ self .force_download = (
83
+ self .download_extra_config .force_download or envs .DOWNLOADER_FORCE_DOWNLOAD
84
+ )
85
+
86
+ @property
87
+ def source (self ) -> RemoteSource :
88
+ return self ._source
46
89
47
90
@abstractmethod
48
91
def _valid_config (self ):
@@ -81,7 +124,7 @@ def download_directory(self, local_path: Path):
81
124
# filter the directory path
82
125
files = [file for file in directory_list if not file .endswith ("/" )]
83
126
84
- if self .allow_file_suffix is None :
127
+ if self .allow_file_suffix is None or len ( self . allow_file_suffix ) == 0 :
85
128
logger .info (f"All files from { self .bucket_path } will be downloaded." )
86
129
filtered_files = files
87
130
else :
@@ -93,7 +136,9 @@ def download_directory(self, local_path: Path):
93
136
94
137
if not self ._support_range_download ():
95
138
# download using multi threads
96
- num_threads = envs .DOWNLOADER_NUM_THREADS
139
+ num_threads = (
140
+ self .download_extra_config .num_threads or envs .DOWNLOADER_NUM_THREADS
141
+ )
97
142
logger .info (
98
143
f"Downloader { self .__class__ .__name__ } download "
99
144
f"{ len (filtered_files )} files from { self .model_uri } "
@@ -157,23 +202,38 @@ def __hash__(self):
157
202
158
203
159
204
def get_downloader (
160
- model_uri : str , model_name : Optional [str ] = None , enable_progress_bar : bool = False
205
+ model_uri : str ,
206
+ model_name : Optional [str ] = None ,
207
+ download_extra_config : Optional [Dict ] = None ,
208
+ enable_progress_bar : bool = False ,
161
209
) -> BaseDownloader :
162
210
"""Get downloader for model_uri."""
211
+ download_config : DownloadExtraConfig = (
212
+ DEFAULT_DOWNLOADER_EXTRA_CONFIG
213
+ if download_extra_config is None
214
+ else DownloadExtraConfig (** download_extra_config )
215
+ )
216
+
163
217
if re .match (envs .DOWNLOADER_S3_REGEX , model_uri ):
164
218
from aibrix .downloader .s3 import S3Downloader
165
219
166
- return S3Downloader (model_uri , model_name , enable_progress_bar )
220
+ return S3Downloader (model_uri , model_name , download_config , enable_progress_bar )
167
221
elif re .match (envs .DOWNLOADER_TOS_REGEX , model_uri ):
168
222
if envs .DOWNLOADER_TOS_VERSION == "v1" :
169
223
from aibrix .downloader .tos import TOSDownloaderV1
170
224
171
- return TOSDownloaderV1 (model_uri , model_name , enable_progress_bar )
225
+ return TOSDownloaderV1 (
226
+ model_uri , model_name , download_config , enable_progress_bar
227
+ )
172
228
else :
173
229
from aibrix .downloader .tos import TOSDownloaderV2
174
230
175
- return TOSDownloaderV2 (model_uri , model_name , enable_progress_bar )
231
+ return TOSDownloaderV2 (
232
+ model_uri , model_name , download_config , enable_progress_bar
233
+ )
176
234
else :
177
235
from aibrix .downloader .huggingface import HuggingFaceDownloader
178
236
179
- return HuggingFaceDownloader (model_uri , model_name , enable_progress_bar )
237
+ return HuggingFaceDownloader (
238
+ model_uri , model_name , download_config , enable_progress_bar
239
+ )
0 commit comments