@@ -114,16 +114,27 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None]
114
114
raise Exception (f"Failed to read dataset at { url } " ) from None
115
115
116
116
117
- def _verify_files_dont_exist (paths : Iterable [Union [str , Path ]]) -> None :
117
+ def _verify_files_dont_exist (
118
+ paths : Iterable [Union [str , Path ]], remove_if_exist : bool = False
119
+ ) -> None :
118
120
"""
119
121
Verifies all paths in 'paths' don't exist.
120
122
:param paths: A iterable of strs or pathlib.Paths.
123
+ :param remove_if_exist=False: Removes file at path if they already exist.
121
124
:returns: None
122
125
:raises FileExistsError: On the first path found that already exists.
123
126
"""
124
127
for path in paths :
125
- if Path (path ).exists ():
126
- raise FileExistsError (f"Error: File '{ path } ' already exists." )
128
+ path = Path (path )
129
+ if path .exists ():
130
+ if remove_if_exist :
131
+ if path .is_symlink ():
132
+ realpath = path .resolve ()
133
+ path .unlink (realpath )
134
+ else :
135
+ shutil .rmtree (path )
136
+ else :
137
+ raise FileExistsError (f"Error: File '{ path } ' already exists." )
127
138
128
139
129
140
def _is_file_to_symlink (path : Path ) -> bool :
@@ -188,7 +199,9 @@ async def read(url: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> bytes:
188
199
return b"" .join ([chunk async for chunk in _get_chunks (url , chunk_size )])
189
200
190
201
191
- async def prepare (url : str , path : Optional [str ] = None , verbose : bool = True ) -> None :
202
+ async def prepare (
203
+ url : str , path : Optional [str ] = None , verbose : bool = True , overwrite : bool = False
204
+ ) -> None :
192
205
"""
193
206
Prepares a dataset for learners. Downloads a dataset from the given url,
194
207
decompresses it if necessary. If not using jupyterlite, will extract to
@@ -200,6 +213,8 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
200
213
201
214
:param url: The URL to download the dataset from.
202
215
:param path: The path the dataset will be available at. Current working directory by default.
216
+ :param verbose=True: Prints saved path if True.
217
+ :param overwrite=False: Overwrites any existing files at destination if they exist.
203
218
:raise InvalidURLException: When URL is invalid.
204
219
:raise FileExistsError: it raises this when a file to be symlinked already exists.
205
220
:raise ValueError: When requested path is in /tmp, or cannot be saved to path.
@@ -239,7 +254,8 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
239
254
path / child .name
240
255
for child in map (Path , tf .getnames ())
241
256
if len (child .parents ) == 1 and _is_file_to_symlink (child )
242
- ]
257
+ ],
258
+ overwrite ,
243
259
) # Only check if top-level fileobject
244
260
pbar = tqdm (iterable = tf .getmembers (), total = len (tf .getmembers ()))
245
261
pbar .set_description (f"Extracting { filename } " )
@@ -253,15 +269,16 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
253
269
path / child .name
254
270
for child in map (Path , zf .namelist ())
255
271
if len (child .parents ) == 1 and _is_file_to_symlink (child )
256
- ]
272
+ ],
273
+ overwrite ,
257
274
)
258
275
pbar = tqdm (iterable = zf .infolist (), total = len (zf .infolist ()))
259
276
pbar .set_description (f"Extracting { filename } " )
260
277
for member in pbar :
261
278
zf .extract (member = member , path = extract_dir )
262
279
tmp_download_file .unlink ()
263
280
else :
264
- _verify_files_dont_exist ([path / filename ])
281
+ _verify_files_dont_exist ([path / filename ], overwrite )
265
282
shutil .move (tmp_download_file , extract_dir / filename )
266
283
267
284
# If in jupyterlite environment, the extract_dir = path, so the files are already there.
@@ -274,8 +291,36 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
274
291
print (f"Saved to '{ relpath (path .resolve ())} '" )
275
292
276
293
277
- if _is_jupyterlite ():
278
- tqdm .monitor_interval = 0
294
+ def setup () -> None :
295
+ if _is_jupyterlite ():
296
+ tqdm .monitor_interval = 0
297
+
298
+ try :
299
+ import sys # pyright: ignore
300
+
301
+ ipython = get_ipython ()
302
+
303
+ def hide_traceback (
304
+ exc_tuple = None ,
305
+ filename = None ,
306
+ tb_offset = None ,
307
+ exception_only = False ,
308
+ running_compiled_code = False ,
309
+ ):
310
+ etype , value , tb = sys .exc_info ()
311
+ value .__cause__ = None # suppress chained exceptions
312
+ return ipython ._showtraceback (
313
+ etype , value , ipython .InteractiveTB .get_exception_only (etype , value )
314
+ )
315
+
316
+ ipython .showtraceback = hide_traceback
317
+
318
+ except NameError :
319
+ pass
320
+
321
+
322
+ setup ()
323
+
279
324
280
325
# For backwards compatibility
281
326
download_dataset = download
0 commit comments