Skip to content

Improve file path validation #1110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datachain/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def remove(self, file: "File") -> None:
async def download(
self, file: "File", client: "Client", callback: Optional[Callback] = None
) -> None:
from_path = f"{file.source}/{file.path}"
from dvc_objects.fs.utils import tmp_fname

from_path = file.get_uri()
odb_fs = self.odb.fs
tmp_info = odb_fs.join(self.odb.tmp_dir, tmp_fname()) # type: ignore[arg-type]
size = file.size
Expand Down
8 changes: 5 additions & 3 deletions src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,14 @@ def url(self, path: str, expires: int = 3600, **kwargs) -> str:
)

async def get_current_etag(self, file: "File") -> str:
file_path = file.get_path_normalized()
kwargs = {}
if self._is_version_aware():
kwargs["version_id"] = file.version
info = await self.fs._info(
self.get_full_path(file.path, file.version), **kwargs
self.get_full_path(file_path, file.version), **kwargs
)
return self.info_to_file(info, file.path).etag
return self.info_to_file(info, file_path).etag

def get_file_info(self, path: str, version_id: Optional[str] = None) -> "File":
info = self.fs.info(self.get_full_path(path, version_id), version_id=version_id)
Expand Down Expand Up @@ -385,7 +386,8 @@ def open_object(
return open(cache_path, mode="rb")
assert not file.location
return FileWrapper(
self.fs.open(self.get_full_path(file.path, file.version)), cb
self.fs.open(self.get_full_path(file.get_path_normalized(), file.version)),
cb,
) # type: ignore[return-value]

def upload(self, data: bytes, path: str) -> "File":
Expand Down
8 changes: 4 additions & 4 deletions src/datachain/client/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def from_source(
)

async def get_current_etag(self, file: "File") -> str:
info = self.fs.info(self.get_full_path(file.path))
info = self.fs.info(self.get_full_path(file.get_path_normalized()))
return self.info_to_file(info, "").etag

async def get_size(self, path: str, version_id: Optional[str] = None) -> int:
Expand Down Expand Up @@ -138,8 +138,8 @@ def fetch_nodes(
if not self.use_symlinks:
super().fetch_nodes(nodes, shared_progress_bar)

def do_instantiate_object(self, uid, dst):
def do_instantiate_object(self, file: File, dst: str) -> None:
if self.use_symlinks:
os.symlink(Path(self.name, uid.path), dst)
os.symlink(Path(self.name, file.path), dst)
else:
super().do_instantiate_object(uid, dst)
super().do_instantiate_object(file, dst)
4 changes: 2 additions & 2 deletions src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def process(self, file: File):
fs_path = file.path
fs = ReferenceFileSystem({fs_path: [cache_path]})
else:
fs, fs_path = file.get_fs(), file.get_path()
fs, fs_path = file.get_fs(), file.get_fs_path()

kwargs = self.kwargs
if format := kwargs.get("format"):
Expand Down Expand Up @@ -161,7 +161,7 @@ def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:

schemas = []
for file in chain.collect("file"):
ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
ds = dataset(file.get_fs_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
schemas.append(ds.schema)
if not schemas:
raise ValueError(
Expand Down
100 changes: 77 additions & 23 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import logging
import os
import posixpath
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterator
from contextlib import contextmanager
from datetime import datetime
from functools import partial
from io import BytesIO
from pathlib import Path, PurePosixPath
from pathlib import Path, PurePath, PurePosixPath
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union
from urllib.parse import unquote, urlparse
from urllib.request import url2pathname
Expand Down Expand Up @@ -69,7 +70,7 @@ def done_task(self, done):
for task in done:
task.result()

def do_task(self, file):
def do_task(self, file: "File"):
file.export(
self.output,
self.placement,
Expand Down Expand Up @@ -256,8 +257,8 @@ def validate_location(cls, v):

@field_validator("path", mode="before")
@classmethod
def validate_path(cls, path):
return Path(path).as_posix() if path else ""
def validate_path(cls, path: str) -> str:
return PurePath(path).as_posix() if path else ""

def model_dump_custom(self):
res = self.model_dump()
Expand Down Expand Up @@ -319,11 +320,11 @@ def _from_row(cls, row: "RowDict") -> "Self":
return cls(**{key: row[key] for key in cls._datachain_column_types})

@property
def name(self):
def name(self) -> str:
return PurePosixPath(self.path).name

@property
def parent(self):
def parent(self) -> str:
return str(PurePosixPath(self.path).parent)

@contextmanager
Expand Down Expand Up @@ -366,7 +367,7 @@ def save(self, destination: str, client_config: Optional[dict] = None):

client.upload(self.read(), destination)

def _symlink_to(self, destination: str):
def _symlink_to(self, destination: str) -> None:
if self.location:
raise OSError(errno.ENOTSUP, "Symlinking virtual file is not supported")

Expand All @@ -375,7 +376,7 @@ def _symlink_to(self, destination: str):
source = self.get_local_path()
assert source, "File was not cached"
elif self.source.startswith("file://"):
source = self.get_path()
source = self.get_fs_path()
else:
raise OSError(errno.EXDEV, "can't link across filesystems")

Expand Down Expand Up @@ -452,27 +453,62 @@ def get_file_suffix(self):

def get_file_ext(self):
"""Returns last part of file name without `.`."""
return PurePosixPath(self.path).suffix.strip(".")
return PurePosixPath(self.path).suffix.lstrip(".")

def get_file_stem(self):
"""Returns file name without extension."""
return PurePosixPath(self.path).stem

def get_full_name(self):
"""Returns name with parent directories."""
"""
[DEPRECATED] Use `file.path` directly instead.

Returns name with parent directories.
"""
warnings.warn(
"file.get_full_name() is deprecated and will be removed "
"in a future version. Use `file.path` directly.",
DeprecationWarning,
stacklevel=2,
)
return self.path

def get_uri(self):
def get_path_normalized(self) -> str:
if not self.path:
raise FileError("path must not be empty", self.source, self.path)

if self.path.endswith("/"):
raise FileError("path must not be a directory", self.source, self.path)

normpath = os.path.normpath(self.path)
normpath = PurePath(normpath).as_posix()

if normpath == ".":
raise FileError("path must not be a directory", self.source, self.path)

if any(part == ".." for part in PurePath(normpath).parts):
raise FileError("path must not contain '..'", self.source, self.path)

return normpath

def get_uri(self) -> str:
"""Returns file URI."""
return f"{self.source}/{self.get_full_name()}"
return f"{self.source}/{self.get_path_normalized()}"

def get_fs_path(self) -> str:
"""
Returns file path with respect to the filescheme.

If `normalize` is True, the path is normalized to remove any redundant
separators and up-level references.

def get_path(self) -> str:
"""Returns file path."""
If the file scheme is "file", the path is converted to a local file path
using `url2pathname`. Otherwise, the original path with scheme is returned.
"""
path = unquote(self.get_uri())
source = urlparse(self.source)
if source.scheme == "file":
path = urlparse(path).path
path = url2pathname(path)
path_parsed = urlparse(path)
if path_parsed.scheme == "file":
path = url2pathname(path_parsed.path)
return path

def get_destination_path(
Expand All @@ -487,7 +523,7 @@ def get_destination_path(
elif placement == "etag":
path = f"{self.etag}{self.get_file_suffix()}"
elif placement == "fullpath":
path = unquote(self.get_full_name())
path = unquote(self.get_path_normalized())
source = urlparse(self.source)
if source.scheme and source.scheme != "file":
path = posixpath.join(source.netloc, path)
Expand Down Expand Up @@ -525,8 +561,9 @@ def resolve(self) -> "Self":
) from e

try:
info = client.fs.info(client.get_full_path(self.path))
converted_info = client.info_to_file(info, self.path)
normalized_path = self.get_path_normalized()
info = client.fs.info(client.get_full_path(normalized_path))
converted_info = client.info_to_file(info, normalized_path)
return type(self)(
path=self.path,
source=self.source,
Expand All @@ -537,8 +574,17 @@ def resolve(self) -> "Self":
last_modified=converted_info.last_modified,
location=self.location,
)
except FileError as e:
logger.warning(
"File error when resolving %s/%s: %s", self.source, self.path, str(e)
)
except (FileNotFoundError, PermissionError, OSError) as e:
logger.warning("File system error when resolving %s: %s", self.path, str(e))
logger.warning(
"File system error when resolving %s/%s: %s",
self.source,
self.path,
str(e),
)

return type(self)(
path=self.path,
Expand All @@ -554,6 +600,8 @@ def resolve(self) -> "Self":

def resolve(file: File) -> File:
"""
[DEPRECATED] Use `file.resolve()` directly instead.

Resolve a File object by checking its existence and updating its metadata.

This function is a wrapper around the File.resolve() method, designed to be
Expand All @@ -569,6 +617,12 @@ def resolve(file: File) -> File:
RuntimeError: If the file's catalog is not set or if
the file source protocol is unsupported.
"""
warnings.warn(
"resolve() is deprecated and will be removed "
"in a future version. Use file.resolve() directly.",
DeprecationWarning,
stacklevel=2,
)
return file.resolve()


Expand Down Expand Up @@ -916,7 +970,7 @@ def open(self):
ds = dataset(path, **self.kwargs)

else:
path = self.file.get_path()
path = self.file.get_fs_path()
ds = dataset(path, filesystem=self.file.get_fs(), **self.kwargs)

return ds.take([self.index]).to_reader()
Expand Down
3 changes: 1 addition & 2 deletions src/datachain/lib/tar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@


def build_tar_member(parent: File, info: tarfile.TarInfo) -> File:
new_parent = parent.get_full_name()
etag_string = "-".join([parent.etag, info.name, str(info.mtime)])
etag = hashlib.md5(etag_string.encode(), usedforsecurity=False).hexdigest()
return File(
source=parent.source,
path=f"{new_parent}/{info.name}",
path=f"{parent.name}/{info.name}",
version=parent.version,
size=info.size,
etag=etag,
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/lib/webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

class WDSError(DataChainError):
def __init__(self, tar_stream, message: str):
super().__init__(f"WebDataset error '{tar_stream.get_full_name()}': {message}")
super().__init__(f"WebDataset error '{tar_stream.name}': {message}")


class CoreFileDuplicationError(WDSError):
Expand Down
5 changes: 0 additions & 5 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,6 @@ def test_to_storage(
ctc = cloud_test_catalog
df = dc.read_storage(ctc.src_uri, type=file_type, session=test_session)
if use_map:
df.settings(cache=use_cache).to_storage(
tmp_dir / "output",
placement=placement,
num_threads=num_threads,
)
df.settings(cache=use_cache).map(
res=lambda file: file.export(tmp_dir / "output", placement=placement)
).exec()
Expand Down
Loading
Loading