Skip to content

Commit e009ea0

Browse files
authored
Merge pull request #4 from ibm-skills-network/prepare_overwrite
Add new test cases for overwrite=True
2 parents 9ff3490 + 6f6289c commit e009ea0

File tree

2 files changed

+138
-73
lines changed

2 files changed

+138
-73
lines changed

skillsnetwork/core.py

Lines changed: 65 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None]
8989
pbar.update(len(value))
9090
pbar.close()
9191
except JsException:
92-
raise Exception(f"Failed to read dataset at {url}") from None
92+
raise Exception(f"Failed to read dataset at '{url}'.") from None
9393
else:
9494
import requests # pyright: ignore
9595
from requests.exceptions import ConnectionError # pyright: ignore
@@ -99,7 +99,7 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None]
9999
# If requests.get fails, it will return readable error
100100
if response.status_code >= 400:
101101
raise Exception(
102-
f"received status code {response.status_code} from {url}"
102+
f"received status code {response.status_code} from '{url}'."
103103
)
104104
pbar = tqdm(
105105
miniters=1,
@@ -111,28 +111,36 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None]
111111
pbar.update(len(chunk))
112112
pbar.close()
113113
except ConnectionError:
114-
raise Exception(f"Failed to read dataset at {url}") from None
114+
raise Exception(f"Failed to read dataset at '{url}'.") from None
115+
116+
117+
def _rmrf(path: Path) -> None:
118+
if path.is_dir():
119+
shutil.rmtree(path)
120+
else:
121+
path.unlink()
115122

116123

117124
def _verify_files_dont_exist(
118-
paths: Iterable[Union[str, Path]], remove_if_exist: bool = False
125+
paths: Iterable[Path], remove_if_exist: bool = False
119126
) -> None:
120127
"""
121128
Verifies all paths in 'paths' don't exist.
122-
:param paths: A iterable of strs or pathlib.Paths.
123-
:param remove_if_exist=False: Removes file at path if they already exist.
129+
:param paths: A iterable of pathlib.Path s.
130+
:param remove_if_exist=False: Remove each file at each path in paths if they already exist.
124131
:returns: None
125-
:raises FileExistsError: On the first path found that already exists.
132+
:raises FileExistsError: On the first path found that already exists if remove_if_exist is False.
126133
"""
127134
for path in paths:
128-
path = Path(path)
129-
if path.exists():
135+
# Could be a broken symlink => path.exists() is False
136+
if path.exists() or path.is_symlink():
130137
if remove_if_exist:
131-
if path.is_symlink():
132-
realpath = path.resolve()
133-
path.unlink(realpath)
134-
else:
135-
shutil.rmtree(path)
138+
while path.is_symlink():
139+
temp = path.readlink()
140+
path.unlink(missing_ok=True)
141+
path = temp
142+
if path.exists():
143+
_rmrf(path)
136144
else:
137145
raise FileExistsError(f"Error: File '{path}' already exists.")
138146

@@ -224,14 +232,13 @@ async def prepare(
224232
path = Path.cwd() if path is None else Path(path)
225233
# Check if path contains /tmp
226234
if Path("/tmp") in path.parents:
227-
raise ValueError("path must not be in /tmp")
235+
raise ValueError("path must not be in /tmp.")
228236
elif path.is_file():
229-
raise ValueError("Datasets must be prepared to directories, not files")
237+
raise ValueError("Datasets must be prepared to directories, not files.")
230238
# Create the target path if it doesn't exist yet
231239
path.mkdir(exist_ok=True)
232240

233241
# For avoiding collisions with any other files the user may have downloaded to /tmp/
234-
235242
dname = f"skills-network-{hash(url)}"
236243
# The file to extract data to. If not jupyterlite, to be symlinked to as well
237244
extract_dir = path if _is_jupyterlite() else Path(f"/tmp/{dname}")
@@ -247,44 +254,52 @@ async def prepare(
247254
shutil.rmtree(extract_dir)
248255
extract_dir.mkdir()
249256

250-
if tarfile.is_tarfile(tmp_download_file):
251-
with tarfile.open(tmp_download_file) as tf:
252-
_verify_files_dont_exist(
253-
[
254-
path / child.name
255-
for child in map(Path, tf.getnames())
256-
if len(child.parents) == 1 and _is_file_to_symlink(child)
257-
],
258-
overwrite,
259-
) # Only check if top-level fileobject
260-
pbar = tqdm(iterable=tf.getmembers(), total=len(tf.getmembers()))
261-
pbar.set_description(f"Extracting {filename}")
262-
for member in pbar:
263-
tf.extract(member=member, path=extract_dir)
264-
tmp_download_file.unlink()
265-
elif zipfile.is_zipfile(tmp_download_file):
266-
with zipfile.ZipFile(tmp_download_file) as zf:
267-
_verify_files_dont_exist(
268-
[
269-
path / child.name
270-
for child in map(Path, zf.namelist())
271-
if len(child.parents) == 1 and _is_file_to_symlink(child)
272-
],
273-
overwrite,
274-
)
275-
pbar = tqdm(iterable=zf.infolist(), total=len(zf.infolist()))
276-
pbar.set_description(f"Extracting {filename}")
277-
for member in pbar:
278-
zf.extract(member=member, path=extract_dir)
279-
tmp_download_file.unlink()
280-
else:
281-
_verify_files_dont_exist([path / filename], overwrite)
282-
shutil.move(tmp_download_file, extract_dir / filename)
257+
try:
258+
if tarfile.is_tarfile(tmp_download_file):
259+
with tarfile.open(tmp_download_file) as tf:
260+
_verify_files_dont_exist(
261+
[
262+
path / child.name
263+
for child in map(Path, tf.getnames())
264+
if len(child.parents) == 1 and _is_file_to_symlink(child)
265+
], # Only check if top-level fileobject
266+
remove_if_exist=overwrite,
267+
)
268+
pbar = tqdm(iterable=tf.getmembers(), total=len(tf.getmembers()))
269+
pbar.set_description(f"Extracting {filename}")
270+
for member in pbar:
271+
tf.extract(member=member, path=extract_dir)
272+
tmp_download_file.unlink()
273+
elif zipfile.is_zipfile(tmp_download_file):
274+
with zipfile.ZipFile(tmp_download_file) as zf:
275+
_verify_files_dont_exist(
276+
[
277+
path / child.name
278+
for child in map(Path, zf.namelist())
279+
if len(child.parents) == 1 and _is_file_to_symlink(child)
280+
], # Only check if top-level fileobject
281+
remove_if_exist=overwrite,
282+
)
283+
pbar = tqdm(iterable=zf.infolist(), total=len(zf.infolist()))
284+
pbar.set_description(f"Extracting {filename}")
285+
for member in pbar:
286+
zf.extract(member=member, path=extract_dir)
287+
tmp_download_file.unlink()
288+
else:
289+
_verify_files_dont_exist([path / filename], remove_if_exist=overwrite)
290+
shutil.move(tmp_download_file, extract_dir / filename)
291+
except FileExistsError as e:
292+
raise FileExistsError(
293+
str(e)
294+
+ "\nIf you want to overwrite any existing files, use prepare(..., overwrite=True)."
295+
) from None
283296

284297
# If in jupyterlite environment, the extract_dir = path, so the files are already there.
285298
if not _is_jupyterlite():
286299
# If not in jupyterlite environment, symlink top-level file objects in extract_dir
287300
for child in filter(_is_file_to_symlink, extract_dir.iterdir()):
301+
if (path / child.name).is_symlink() and overwrite:
302+
(path / child.name).unlink()
288303
(path / child.name).symlink_to(child, target_is_directory=child.is_dir())
289304

290305
if verbose:
@@ -295,29 +310,6 @@ def setup() -> None:
295310
if _is_jupyterlite():
296311
tqdm.monitor_interval = 0
297312

298-
try:
299-
import sys # pyright: ignore
300-
301-
ipython = get_ipython()
302-
303-
def hide_traceback(
304-
exc_tuple=None,
305-
filename=None,
306-
tb_offset=None,
307-
exception_only=False,
308-
running_compiled_code=False,
309-
):
310-
etype, value, tb = sys.exc_info()
311-
value.__cause__ = None # suppress chained exceptions
312-
return ipython._showtraceback(
313-
etype, value, ipython.InteractiveTB.get_exception_only(etype, value)
314-
)
315-
316-
ipython.showtraceback = hide_traceback
317-
318-
except NameError:
319-
pass
320-
321313

322314
setup()
323315

tests/test_skillsnetwork.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,76 @@ async def test_prepare_non_compressed_dataset_with_path(httpserver):
134134
await skillsnetwork.prepare_dataset(httpserver.url_for(url), path=path)
135135
assert expected_path.exists()
136136
expected_path.unlink()
137+
138+
139+
@pytest.mark.asyncio
140+
async def test_prepare_non_compressed_dataset_no_path_with_overwrite(httpserver):
141+
url = "/test.csv"
142+
expected_path = Path("./test.csv")
143+
with open("tests/test.csv", "rb") as expected_data:
144+
httpserver.expect_request(url).respond_with_data(expected_data)
145+
await skillsnetwork.prepare_dataset(httpserver.url_for(url))
146+
assert expected_path.exists()
147+
httpserver.clear()
148+
with open("tests/test.csv", "rb") as expected_data:
149+
httpserver.expect_request(url).respond_with_data(expected_data)
150+
await skillsnetwork.prepare_dataset(httpserver.url_for(url), overwrite=True)
151+
assert expected_path.exists()
152+
assert Path(expected_path).stat().st_size == 540
153+
expected_path.unlink()
154+
155+
156+
@pytest.mark.asyncio
157+
async def test_prepare_dataset_tar_no_path_with_overwrite(httpserver):
158+
url = "/test.tar.gz"
159+
expected_directory = Path("test")
160+
try:
161+
shutil.rmtree(expected_directory) # clean up any previous test
162+
except FileNotFoundError as e:
163+
print(e)
164+
pass
165+
166+
with open("tests/test.tar.gz", "rb") as expected_data:
167+
httpserver.expect_request(url).respond_with_data(expected_data)
168+
await skillsnetwork.prepare_dataset(httpserver.url_for(url))
169+
170+
assert os.path.isdir(expected_directory)
171+
with open(expected_directory / "1.txt") as f:
172+
assert "I am the first test file" in f.read()
173+
httpserver.clear()
174+
175+
with open("tests/test.tar.gz", "rb") as expected_data:
176+
httpserver.expect_request(url).respond_with_data(expected_data)
177+
await skillsnetwork.prepare_dataset(httpserver.url_for(url), overwrite=True)
178+
assert os.path.isdir(expected_directory)
179+
with open(expected_directory / "1.txt") as f:
180+
assert "I am the first test file" in f.read()
181+
expected_directory.unlink()
182+
183+
184+
@pytest.mark.asyncio
185+
async def test_prepare_dataset_zip_no_path_with_overwrite(httpserver):
186+
url = "/test.zip"
187+
expected_directory = Path("test")
188+
try:
189+
shutil.rmtree(expected_directory) # clean up any previous test
190+
except FileNotFoundError as e:
191+
print(e)
192+
pass
193+
194+
with open("tests/test.zip", "rb") as expected_data:
195+
httpserver.expect_request(url).respond_with_data(expected_data)
196+
await skillsnetwork.prepare_dataset(httpserver.url_for(url))
197+
198+
assert os.path.isdir(expected_directory)
199+
with open(expected_directory / "1.txt") as f:
200+
assert "I am the first test file" in f.read()
201+
httpserver.clear()
202+
203+
with open("tests/test.zip", "rb") as expected_data:
204+
httpserver.expect_request(url).respond_with_data(expected_data)
205+
await skillsnetwork.prepare_dataset(httpserver.url_for(url), overwrite=True)
206+
assert os.path.isdir(expected_directory)
207+
with open(expected_directory / "1.txt") as f:
208+
assert "I am the first test file" in f.read()
209+
expected_directory.unlink()

0 commit comments

Comments
 (0)