diff --git a/src/datachain/cache.py b/src/datachain/cache.py index 3bfd305b5..166584007 100644 --- a/src/datachain/cache.py +++ b/src/datachain/cache.py @@ -76,9 +76,9 @@ def remove(self, file: "File") -> None: async def download( self, file: "File", client: "Client", callback: Optional[Callback] = None ) -> None: - from_path = f"{file.source}/{file.path}" from dvc_objects.fs.utils import tmp_fname + from_path = file.get_uri() odb_fs = self.odb.fs tmp_info = odb_fs.join(self.odb.tmp_dir, tmp_fname()) # type: ignore[arg-type] size = file.size diff --git a/src/datachain/client/fsspec.py b/src/datachain/client/fsspec.py index 1b8ca1776..71a619849 100644 --- a/src/datachain/client/fsspec.py +++ b/src/datachain/client/fsspec.py @@ -207,13 +207,14 @@ def url(self, path: str, expires: int = 3600, **kwargs) -> str: ) async def get_current_etag(self, file: "File") -> str: + file_path = file.get_path_normalized() kwargs = {} if self._is_version_aware(): kwargs["version_id"] = file.version info = await self.fs._info( - self.get_full_path(file.path, file.version), **kwargs + self.get_full_path(file_path, file.version), **kwargs ) - return self.info_to_file(info, file.path).etag + return self.info_to_file(info, file_path).etag def get_file_info(self, path: str, version_id: Optional[str] = None) -> "File": info = self.fs.info(self.get_full_path(path, version_id), version_id=version_id) @@ -385,7 +386,8 @@ def open_object( return open(cache_path, mode="rb") assert not file.location return FileWrapper( - self.fs.open(self.get_full_path(file.path, file.version)), cb + self.fs.open(self.get_full_path(file.get_path_normalized(), file.version)), + cb, ) # type: ignore[return-value] def upload(self, data: bytes, path: str) -> "File": diff --git a/src/datachain/client/local.py b/src/datachain/client/local.py index bcb3dfa73..ec4bab82a 100644 --- a/src/datachain/client/local.py +++ b/src/datachain/client/local.py @@ -99,7 +99,7 @@ def from_source( ) async def get_current_etag(self, file: "File") -> str: - info = self.fs.info(self.get_full_path(file.path)) + info = self.fs.info(self.get_full_path(file.get_path_normalized())) return self.info_to_file(info, "").etag async def get_size(self, path: str, version_id: Optional[str] = None) -> int: @@ -138,8 +138,8 @@ def fetch_nodes( if not self.use_symlinks: super().fetch_nodes(nodes, shared_progress_bar) - def do_instantiate_object(self, uid, dst): + def do_instantiate_object(self, file: File, dst: str) -> None: if self.use_symlinks: - os.symlink(Path(self.name, uid.path), dst) + os.symlink(Path(self.name, file.path), dst) else: - super().do_instantiate_object(uid, dst) + super().do_instantiate_object(file, dst) diff --git a/src/datachain/lib/arrow.py b/src/datachain/lib/arrow.py index 442d80cba..210214e07 100644 --- a/src/datachain/lib/arrow.py +++ b/src/datachain/lib/arrow.py @@ -76,7 +76,7 @@ def process(self, file: File): fs_path = file.path fs = ReferenceFileSystem({fs_path: [cache_path]}) else: - fs, fs_path = file.get_fs(), file.get_path() + fs, fs_path = file.get_fs(), file.get_fs_path() kwargs = self.kwargs if format := kwargs.get("format"): @@ -161,7 +161,7 @@ def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema: schemas = [] for file in chain.collect("file"): - ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr] + ds = dataset(file.get_fs_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr] schemas.append(ds.schema) if not schemas: raise ValueError( diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index d69ed2e21..9af9ccc8a 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -5,13 +5,14 @@ import logging import os import posixpath +import warnings from abc import ABC, abstractmethod from collections.abc import Iterator from contextlib import contextmanager from datetime import datetime from functools import partial from io import BytesIO -from pathlib import Path, PurePosixPath +from pathlib import Path, PurePath, PurePosixPath from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union from urllib.parse import unquote, urlparse from urllib.request import url2pathname @@ -69,7 +70,7 @@ def done_task(self, done): for task in done: task.result() - def do_task(self, file): + def do_task(self, file: "File"): file.export( self.output, self.placement, @@ -256,8 +257,8 @@ def validate_location(cls, v): @field_validator("path", mode="before") @classmethod - def validate_path(cls, path): - return Path(path).as_posix() if path else "" + def validate_path(cls, path: str) -> str: + return PurePath(path).as_posix() if path else "" def model_dump_custom(self): res = self.model_dump() @@ -319,11 +320,11 @@ def _from_row(cls, row: "RowDict") -> "Self": return cls(**{key: row[key] for key in cls._datachain_column_types}) @property - def name(self): + def name(self) -> str: return PurePosixPath(self.path).name @property - def parent(self): + def parent(self) -> str: return str(PurePosixPath(self.path).parent) @contextmanager @@ -366,7 +367,7 @@ def save(self, destination: str, client_config: Optional[dict] = None): client.upload(self.read(), destination) - def _symlink_to(self, destination: str): + def _symlink_to(self, destination: str) -> None: if self.location: raise OSError(errno.ENOTSUP, "Symlinking virtual file is not supported") @@ -375,7 +376,7 @@ def _symlink_to(self, destination: str): source = self.get_local_path() assert source, "File was not cached" elif self.source.startswith("file://"): - source = self.get_path() + source = self.get_fs_path() else: raise OSError(errno.EXDEV, "can't link across filesystems") @@ -452,27 +453,62 @@ def get_file_suffix(self): def get_file_ext(self): """Returns last part of file name without `.`.""" - return PurePosixPath(self.path).suffix.strip(".") + return PurePosixPath(self.path).suffix.lstrip(".") def get_file_stem(self): """Returns file name without extension.""" return PurePosixPath(self.path).stem def get_full_name(self): - """Returns name with parent directories.""" + """ + [DEPRECATED] Use `file.path` directly instead. + + Returns name with parent directories. + """ + warnings.warn( + "file.get_full_name() is deprecated and will be removed " + "in a future version. Use `file.path` directly.", + DeprecationWarning, + stacklevel=2, + ) return self.path - def get_uri(self): + def get_path_normalized(self) -> str: + if not self.path: + raise FileError("path must not be empty", self.source, self.path) + + if self.path.endswith("/"): + raise FileError("path must not be a directory", self.source, self.path) + + normpath = os.path.normpath(self.path) + normpath = PurePath(normpath).as_posix() + + if normpath == ".": + raise FileError("path must not be a directory", self.source, self.path) + + if any(part == ".." for part in PurePath(normpath).parts): + raise FileError("path must not contain '..'", self.source, self.path) + + return normpath + + def get_uri(self) -> str: """Returns file URI.""" - return f"{self.source}/{self.get_full_name()}" + return f"{self.source}/{self.get_path_normalized()}" + + def get_fs_path(self) -> str: + """ + Returns file path with respect to the filescheme. + + If `normalize` is True, the path is normalized to remove any redundant + separators and up-level references. - def get_path(self) -> str: - """Returns file path.""" + If the file scheme is "file", the path is converted to a local file path + using `url2pathname`. Otherwise, the original path with scheme is returned. + """ path = unquote(self.get_uri()) - source = urlparse(self.source) - if source.scheme == "file": - path = urlparse(path).path - path = url2pathname(path) + path_parsed = urlparse(path) + if path_parsed.scheme == "file": + path = url2pathname(path_parsed.path) return path def get_destination_path( @@ -487,7 +523,7 @@ def get_destination_path( elif placement == "etag": path = f"{self.etag}{self.get_file_suffix()}" elif placement == "fullpath": - path = unquote(self.get_full_name()) + path = unquote(self.get_path_normalized()) source = urlparse(self.source) if source.scheme and source.scheme != "file": path = posixpath.join(source.netloc, path) @@ -525,8 +561,9 @@ def resolve(self) -> "Self": ) from e try: - info = client.fs.info(client.get_full_path(self.path)) - converted_info = client.info_to_file(info, self.path) + normalized_path = self.get_path_normalized() + info = client.fs.info(client.get_full_path(normalized_path)) + converted_info = client.info_to_file(info, normalized_path) return type(self)( path=self.path, source=self.source, @@ -537,8 +574,17 @@ def resolve(self) -> "Self": last_modified=converted_info.last_modified, location=self.location, ) + except FileError as e: + logger.warning( + "File error when resolving %s/%s: %s", self.source, self.path, str(e) + ) except (FileNotFoundError, PermissionError, OSError) as e: - logger.warning("File system error when resolving %s: %s", self.path, str(e)) + logger.warning( + "File system error when resolving %s/%s: %s", + self.source, + self.path, + str(e), + ) return type(self)( path=self.path, @@ -554,6 +600,8 @@ def resolve(self) -> "Self": def resolve(file: File) -> File: """ + [DEPRECATED] Use `file.resolve()` directly instead. + Resolve a File object by checking its existence and updating its metadata. This function is a wrapper around the File.resolve() method, designed to be @@ -569,6 +617,12 @@ def resolve(file: File) -> File: RuntimeError: If the file's catalog is not set or if the file source protocol is unsupported. """ + warnings.warn( + "resolve() is deprecated and will be removed " + "in a future version. Use file.resolve() directly.", + DeprecationWarning, + stacklevel=2, + ) return file.resolve() @@ -916,7 +970,7 @@ def open(self): ds = dataset(path, **self.kwargs) else: - path = self.file.get_path() + path = self.file.get_fs_path() ds = dataset(path, filesystem=self.file.get_fs(), **self.kwargs) return ds.take([self.index]).to_reader() diff --git a/src/datachain/lib/tar.py b/src/datachain/lib/tar.py index cd64abd4d..f693da048 100644 --- a/src/datachain/lib/tar.py +++ b/src/datachain/lib/tar.py @@ -6,12 +6,11 @@ def build_tar_member(parent: File, info: tarfile.TarInfo) -> File: - new_parent = parent.get_full_name() etag_string = "-".join([parent.etag, info.name, str(info.mtime)]) etag = hashlib.md5(etag_string.encode(), usedforsecurity=False).hexdigest() return File( source=parent.source, - path=f"{new_parent}/{info.name}", + path=f"{parent.name}/{info.name}", version=parent.version, size=info.size, etag=etag, diff --git a/src/datachain/lib/webdataset.py b/src/datachain/lib/webdataset.py index 365462a02..69a029127 100644 --- a/src/datachain/lib/webdataset.py +++ b/src/datachain/lib/webdataset.py @@ -35,7 +35,7 @@ class WDSError(DataChainError): def __init__(self, tar_stream, message: str): - super().__init__(f"WebDataset error '{tar_stream.get_full_name()}': {message}") + super().__init__(f"WebDataset error '{tar_stream.name}': {message}") class CoreFileDuplicationError(WDSError): diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index b479aeaf4..910dbd370 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -345,11 +345,6 @@ def test_to_storage( ctc = cloud_test_catalog df = dc.read_storage(ctc.src_uri, type=file_type, session=test_session) if use_map: - df.settings(cache=use_cache).to_storage( - tmp_dir / "output", - placement=placement, - num_threads=num_threads, - ) df.settings(cache=use_cache).map( res=lambda file: file.export(tmp_dir / "output", placement=placement) ).exec() diff --git a/tests/func/test_file.py b/tests/func/test_file.py index 2306c2ade..9365a4fca 100644 --- a/tests/func/test_file.py +++ b/tests/func/test_file.py @@ -3,18 +3,19 @@ import datachain as dc from datachain.data_storage.sqlite import SQLiteWarehouse -from datachain.lib.file import File +from datachain.lib.file import File, FileError from datachain.utils import TIME_ZERO @pytest.mark.parametrize("cloud_type", ["s3"], indirect=True) def test_get_path_cloud(cloud_test_catalog): file = File(path="dir/file", source="s3://") - file._catalog = cloud_test_catalog.catalog - assert file.get_path().strip("/") == "s3:///dir/file" + file._set_stream(catalog=cloud_test_catalog.catalog) + assert file.get_fs_path().strip("/") == "s3:///dir/file" -def test_resolve_file(cloud_test_catalog): +@pytest.mark.parametrize("caching_enabled", [True, False]) +def test_resolve_file(cloud_test_catalog, caching_enabled): ctc = cloud_test_catalog is_sqlite = isinstance(cloud_test_catalog.catalog.warehouse, SQLiteWarehouse) @@ -25,7 +26,7 @@ def test_resolve_file(cloud_test_catalog): source=orig_file.source, path=orig_file.path, ) - file._catalog = ctc.catalog + file._set_stream(catalog=ctc.catalog, caching_enabled=caching_enabled) resolved_file = file.resolve() if not is_sqlite: resolved_file.last_modified = resolved_file.last_modified.replace( @@ -33,18 +34,43 @@ def test_resolve_file(cloud_test_catalog): ) assert orig_file == resolved_file + file.ensure_cached() + def test_resolve_file_no_exist(cloud_test_catalog): ctc = cloud_test_catalog non_existent_file = File(source=ctc.src_uri, path="non_existent_file.txt") - non_existent_file._catalog = ctc.catalog + non_existent_file._set_stream(catalog=ctc.catalog) resolved_non_existent = non_existent_file.resolve() assert resolved_non_existent.size == 0 assert resolved_non_existent.etag == "" assert resolved_non_existent.last_modified == TIME_ZERO +@pytest.mark.parametrize("path", ["", ".", "..", "/", "dir/../../file.txt"]) +def test_resolve_file_wrong_path(cloud_test_catalog, path): + ctc = cloud_test_catalog + + wrong_file = File(source=ctc.src_uri, path=path) + wrong_file._set_stream(catalog=ctc.catalog) + resolved_wrong = wrong_file.resolve() + assert resolved_wrong.size == 0 + assert resolved_wrong.etag == "" + assert resolved_wrong.last_modified == TIME_ZERO + + +@pytest.mark.parametrize("caching_enabled", [True, False]) +@pytest.mark.parametrize("path", ["", ".", "..", "/", "dir/../../file.txt"]) +def test_cache_file_wrong_path(cloud_test_catalog, path, caching_enabled): + ctc = cloud_test_catalog + + wrong_file = File(source=ctc.src_uri, path=path) + wrong_file._set_stream(catalog=ctc.catalog, caching_enabled=caching_enabled) + with pytest.raises(FileError): + wrong_file.ensure_cached() + + def test_upload(cloud_test_catalog): ctc = cloud_test_catalog diff --git a/tests/unit/lib/test_file.py b/tests/unit/lib/test_file.py index c4da6e76d..3dec21b5f 100644 --- a/tests/unit/lib/test_file.py +++ b/tests/unit/lib/test_file.py @@ -7,7 +7,7 @@ from PIL import Image from datachain.catalog import Catalog -from datachain.lib.file import File, ImageFile, TextFile, resolve +from datachain.lib.file import File, FileError, ImageFile, TextFile, resolve def create_file(source: str): @@ -33,12 +33,6 @@ def test_file_suffix(): assert s.get_file_suffix() == ".txt" -@pytest.mark.parametrize("name", [".file.jpg.txt", "dir1/dir2/name"]) -def test_full_name(name): - f = File(path=name) - assert f.get_full_name() == name - - def test_cache_get_path(catalog: Catalog): stream = File(path="test.txt1", source="s3://mybkt") stream._set_stream(catalog) @@ -244,7 +238,7 @@ def test_file_info_jsons(): def test_get_path_local(catalog): file = File(path="dir/file", source="file:///") file._catalog = catalog - assert file.get_path().replace("\\", "/").strip("/") == "dir/file" + assert file.get_fs_path().replace("\\", "/").strip("/") == "dir/file" def test_get_fs(catalog): @@ -367,3 +361,51 @@ def test_export_with_symlink(tmp_path, catalog, use_cache): dst = Path(file.get_local_path()) if use_cache else path assert (tmp_path / "dir" / "myfile.txt").resolve() == dst + + +@pytest.mark.parametrize( + "path,expected", + [ + ("", ""), + (".", "."), + ("dir/file.txt", "dir/file.txt"), + ("../dir/file.txt", "../dir/file.txt"), + ("/dir/file.txt", "/dir/file.txt"), + ], +) +def test_path_validation(path, expected): + assert File(path=path, source="file:///").path == expected + + +@pytest.mark.parametrize( + "path,expected,raises", + [ + ("", None, "must not be empty"), + ("/", None, "must not be a directory"), + (".", None, "must not be a directory"), + ("dir/..", None, "must not be a directory"), + ("dir/file.txt", "dir/file.txt", None), + ("dir//file.txt", "dir/file.txt", None), + ("./dir/file.txt", "dir/file.txt", None), + ("dir/./file.txt", "dir/file.txt", None), + ("dir/../file.txt", "file.txt", None), + ("dir/foo/../file.txt", "dir/file.txt", None), + ("./dir/./foo/.././../file.txt", "file.txt", None), + ("./dir", "dir", None), + ("dir/.", "dir", None), + ("./dir/.", "dir", None), + ("/dir", "/dir", None), + ("/dir/file.txt", "/dir/file.txt", None), + ("/dir/../file.txt", "/file.txt", None), + ("..", None, "must not contain '..'"), + ("../file.txt", None, "must not contain '..'"), + ("dir/../../file.txt", None, "must not contain '..'"), + ], +) +def test_path_normalized(path, expected, raises): + file = File(path=path, source="s3://bucket") + if raises: + with pytest.raises(FileError, match=raises): + file.get_path_normalized() + else: + assert file.get_path_normalized() == expected