Skip to content

Commit 9ff3490

Browse files
authored
Merge pull request #3 from ibm-skills-network/prepare_overwrite
Add overwrite option to prepare
2 parents c67af18 + ada2a5e commit 9ff3490

File tree

2 files changed

+55
-10
lines changed

2 files changed

+55
-10
lines changed

doc/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
author = "Bradley Steinfeld, Sam Prokopchuk, James Reeve"
2424

2525
# The full version, including alpha/beta/rc tags
26-
release = "0.20.3"
26+
release = "0.20.4"
2727

2828

2929
# -- General configuration ---------------------------------------------------

skillsnetwork/core.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,27 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None]
114114
raise Exception(f"Failed to read dataset at {url}") from None
115115

116116

117-
def _verify_files_dont_exist(paths: Iterable[Union[str, Path]]) -> None:
117+
def _verify_files_dont_exist(
118+
paths: Iterable[Union[str, Path]], remove_if_exist: bool = False
119+
) -> None:
118120
"""
119121
Verifies all paths in 'paths' don't exist.
120122
:param paths: A iterable of strs or pathlib.Paths.
123+
:param remove_if_exist=False: Removes file at path if they already exist.
121124
:returns: None
122125
:raises FileExistsError: On the first path found that already exists.
123126
"""
124127
for path in paths:
125-
if Path(path).exists():
126-
raise FileExistsError(f"Error: File '{path}' already exists.")
128+
path = Path(path)
129+
if path.exists():
130+
if remove_if_exist:
131+
if path.is_symlink():
132+
realpath = path.resolve()
133+
path.unlink(realpath)
134+
else:
135+
shutil.rmtree(path)
136+
else:
137+
raise FileExistsError(f"Error: File '{path}' already exists.")
127138

128139

129140
def _is_file_to_symlink(path: Path) -> bool:
@@ -188,7 +199,9 @@ async def read(url: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> bytes:
188199
return b"".join([chunk async for chunk in _get_chunks(url, chunk_size)])
189200

190201

191-
async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) -> None:
202+
async def prepare(
203+
url: str, path: Optional[str] = None, verbose: bool = True, overwrite: bool = False
204+
) -> None:
192205
"""
193206
Prepares a dataset for learners. Downloads a dataset from the given url,
194207
decompresses it if necessary. If not using jupyterlite, will extract to
@@ -200,6 +213,8 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
200213
201214
:param url: The URL to download the dataset from.
202215
:param path: The path the dataset will be available at. Current working directory by default.
216+
:param verbose=True: Prints saved path if True.
217+
:param overwrite=False: Overwrites any existing files at destination if they exist.
203218
:raise InvalidURLException: When URL is invalid.
204219
:raise FileExistsError: it raises this when a file to be symlinked already exists.
205220
:raise ValueError: When requested path is in /tmp, or cannot be saved to path.
@@ -239,7 +254,8 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
239254
path / child.name
240255
for child in map(Path, tf.getnames())
241256
if len(child.parents) == 1 and _is_file_to_symlink(child)
242-
]
257+
],
258+
overwrite,
243259
) # Only check if top-level fileobject
244260
pbar = tqdm(iterable=tf.getmembers(), total=len(tf.getmembers()))
245261
pbar.set_description(f"Extracting {filename}")
@@ -253,15 +269,16 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
253269
path / child.name
254270
for child in map(Path, zf.namelist())
255271
if len(child.parents) == 1 and _is_file_to_symlink(child)
256-
]
272+
],
273+
overwrite,
257274
)
258275
pbar = tqdm(iterable=zf.infolist(), total=len(zf.infolist()))
259276
pbar.set_description(f"Extracting {filename}")
260277
for member in pbar:
261278
zf.extract(member=member, path=extract_dir)
262279
tmp_download_file.unlink()
263280
else:
264-
_verify_files_dont_exist([path / filename])
281+
_verify_files_dont_exist([path / filename], overwrite)
265282
shutil.move(tmp_download_file, extract_dir / filename)
266283

267284
# If in jupyterlite environment, the extract_dir = path, so the files are already there.
@@ -274,8 +291,36 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
274291
print(f"Saved to '{relpath(path.resolve())}'")
275292

276293

277-
if _is_jupyterlite():
278-
tqdm.monitor_interval = 0
294+
def setup() -> None:
295+
if _is_jupyterlite():
296+
tqdm.monitor_interval = 0
297+
298+
try:
299+
import sys # pyright: ignore
300+
301+
ipython = get_ipython()
302+
303+
def hide_traceback(
304+
exc_tuple=None,
305+
filename=None,
306+
tb_offset=None,
307+
exception_only=False,
308+
running_compiled_code=False,
309+
):
310+
etype, value, tb = sys.exc_info()
311+
value.__cause__ = None # suppress chained exceptions
312+
return ipython._showtraceback(
313+
etype, value, ipython.InteractiveTB.get_exception_only(etype, value)
314+
)
315+
316+
ipython.showtraceback = hide_traceback
317+
318+
except NameError:
319+
pass
320+
321+
322+
setup()
323+
279324

280325
# For backwards compatibility
281326
download_dataset = download

0 commit comments

Comments
 (0)