Skip to content

Commit

Permalink
refactor pyarrow
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinjqliu committed Feb 6, 2024
1 parent fa15877 commit c85097e
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 72 deletions.
160 changes: 93 additions & 67 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,79 @@
T = TypeVar("T")


def _file(_: Properties) -> FileSystem:
return PyArrowLocalFileSystem()


def _s3(properties: Properties) -> FileSystem:
from pyarrow.fs import S3FileSystem

client_kwargs: Dict[str, Any] = {
"endpoint_override": properties.get(S3_ENDPOINT),
"access_key": properties.get(S3_ACCESS_KEY_ID),
"secret_key": properties.get(S3_SECRET_ACCESS_KEY),
"session_token": properties.get(S3_SESSION_TOKEN),
"region": properties.get(S3_REGION),
}

if proxy_uri := properties.get(S3_PROXY_URI):
client_kwargs["proxy_options"] = proxy_uri

if connect_timeout := properties.get(S3_CONNECT_TIMEOUT):
client_kwargs["connect_timeout"] = float(connect_timeout)

return S3FileSystem(**client_kwargs)


def _gs(properties: Properties) -> FileSystem:
from pyarrow.fs import GcsFileSystem

gcs_kwargs: Dict[str, Any] = {}
if access_token := properties.get(GCS_TOKEN):
gcs_kwargs["access_token"] = access_token
if expiration := properties.get(GCS_TOKEN_EXPIRES_AT_MS):
gcs_kwargs["credential_token_expiration"] = millis_to_datetime(int(expiration))
if bucket_location := properties.get(GCS_DEFAULT_LOCATION):
gcs_kwargs["default_bucket_location"] = bucket_location
if endpoint := properties.get(GCS_ENDPOINT):
url_parts = urlparse(endpoint)
gcs_kwargs["scheme"] = url_parts.scheme
gcs_kwargs["endpoint_override"] = url_parts.netloc

return GcsFileSystem(**gcs_kwargs)


def _hdfs(properties: Properties) -> FileSystem:
from pyarrow.fs import HadoopFileSystem

hdfs_kwargs: Dict[str, Any] = {}
# if netloc:
# return HadoopFileSystem.from_uri(f"hdfs://{netloc}")
if host := properties.get(HDFS_HOST):
hdfs_kwargs["host"] = host
if port := properties.get(HDFS_PORT):
# port should be an integer type
hdfs_kwargs["port"] = int(port)
if user := properties.get(HDFS_USER):
hdfs_kwargs["user"] = user
if kerb_ticket := properties.get(HDFS_KERB_TICKET):
hdfs_kwargs["kerb_ticket"] = kerb_ticket

return HadoopFileSystem(**hdfs_kwargs)


SCHEME_TO_FS = {
"": _file,
"file": _file,
"s3": _s3,
"s3a": _s3,
"s3n": _s3,
"gs": _gs,
"gcs": _gs,
"hdfs": _hdfs,
}


class PyArrowLocalFileSystem(pyarrow.fs.LocalFileSystem):
def open_output_stream(self, path: str, *args: Any, **kwargs: Any) -> pyarrow.NativeFile:
# In LocalFileSystem, parent directories must be first created before opening an output stream
Expand Down Expand Up @@ -316,10 +389,12 @@ def to_input_file(self) -> PyArrowFile:


class PyArrowFileIO(FileIO):
fs_by_scheme: Callable[[str, Optional[str]], FileSystem]
"""A FileIO implementation that uses pyarrow filesystems."""

def __init__(self, properties: Properties = EMPTY_DICT):
self.fs_by_scheme: Callable[[str, Optional[str]], FileSystem] = lru_cache(self._initialize_fs)
self._scheme_to_fs = {}
self._scheme_to_fs.update(SCHEME_TO_FS)
self.get_fs: Callable[[str], FileSystem] = lru_cache(self._get_fs)
super().__init__(properties=properties)

@staticmethod
Expand All @@ -333,63 +408,6 @@ def parse_location(location: str) -> Tuple[str, str, str]:
else:
return uri.scheme, uri.netloc, f"{uri.netloc}{uri.path}"

def _initialize_fs(self, scheme: str, netloc: Optional[str] = None) -> FileSystem:
if scheme in {"s3", "s3a", "s3n"}:
from pyarrow.fs import S3FileSystem

client_kwargs: Dict[str, Any] = {
"endpoint_override": self.properties.get(S3_ENDPOINT),
"access_key": self.properties.get(S3_ACCESS_KEY_ID),
"secret_key": self.properties.get(S3_SECRET_ACCESS_KEY),
"session_token": self.properties.get(S3_SESSION_TOKEN),
"region": self.properties.get(S3_REGION),
}

if proxy_uri := self.properties.get(S3_PROXY_URI):
client_kwargs["proxy_options"] = proxy_uri

if connect_timeout := self.properties.get(S3_CONNECT_TIMEOUT):
client_kwargs["connect_timeout"] = float(connect_timeout)

return S3FileSystem(**client_kwargs)
elif scheme == "hdfs":
from pyarrow.fs import HadoopFileSystem

hdfs_kwargs: Dict[str, Any] = {}
if netloc:
return HadoopFileSystem.from_uri(f"hdfs://{netloc}")
if host := self.properties.get(HDFS_HOST):
hdfs_kwargs["host"] = host
if port := self.properties.get(HDFS_PORT):
# port should be an integer type
hdfs_kwargs["port"] = int(port)
if user := self.properties.get(HDFS_USER):
hdfs_kwargs["user"] = user
if kerb_ticket := self.properties.get(HDFS_KERB_TICKET):
hdfs_kwargs["kerb_ticket"] = kerb_ticket

return HadoopFileSystem(**hdfs_kwargs)
elif scheme in {"gs", "gcs"}:
from pyarrow.fs import GcsFileSystem

gcs_kwargs: Dict[str, Any] = {}
if access_token := self.properties.get(GCS_TOKEN):
gcs_kwargs["access_token"] = access_token
if expiration := self.properties.get(GCS_TOKEN_EXPIRES_AT_MS):
gcs_kwargs["credential_token_expiration"] = millis_to_datetime(int(expiration))
if bucket_location := self.properties.get(GCS_DEFAULT_LOCATION):
gcs_kwargs["default_bucket_location"] = bucket_location
if endpoint := self.properties.get(GCS_ENDPOINT):
url_parts = urlparse(endpoint)
gcs_kwargs["scheme"] = url_parts.scheme
gcs_kwargs["endpoint_override"] = url_parts.netloc

return GcsFileSystem(**gcs_kwargs)
elif scheme == "file":
return PyArrowLocalFileSystem()
else:
raise ValueError(f"Unrecognized filesystem type in URI: {scheme}")

def new_input(self, location: str) -> PyArrowFile:
"""Get a PyArrowFile instance to read bytes from the file at the given location.
Expand All @@ -399,9 +417,10 @@ def new_input(self, location: str) -> PyArrowFile:
Returns:
PyArrowFile: A PyArrowFile instance for the given location.
"""
scheme, netloc, path = self.parse_location(location)
scheme, _, path = self.parse_location(location)
fs = self.get_fs(scheme)
return PyArrowFile(
fs=self.fs_by_scheme(scheme, netloc),
fs=fs,
location=location,
path=path,
buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)),
Expand All @@ -416,9 +435,10 @@ def new_output(self, location: str) -> PyArrowFile:
Returns:
PyArrowFile: A PyArrowFile instance for the given location.
"""
scheme, netloc, path = self.parse_location(location)
scheme, _, path = self.parse_location(location)
fs = self.get_fs(scheme)
return PyArrowFile(
fs=self.fs_by_scheme(scheme, netloc),
fs=fs,
location=location,
path=path,
buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)),
Expand All @@ -437,8 +457,8 @@ def delete(self, location: Union[str, InputFile, OutputFile]) -> None:
an AWS error code 15.
"""
str_location = location.location if isinstance(location, (InputFile, OutputFile)) else location
scheme, netloc, path = self.parse_location(str_location)
fs = self.fs_by_scheme(scheme, netloc)
scheme, _, path = self.parse_location(str_location)
fs = self.get_fs(scheme)

try:
fs.delete_file(path)
Expand All @@ -453,6 +473,12 @@ def delete(self, location: Union[str, InputFile, OutputFile]) -> None:
raise PermissionError(f"Cannot delete file, access denied: {location}") from e
raise # pragma: no cover - If some other kind of OSError, raise the raw error

def _get_fs(self, scheme: str) -> FileSystem:
"""Get a filesystem for a specific scheme."""
if scheme not in self._scheme_to_fs:
raise ValueError(f"No registered filesystem for scheme: {scheme}")
return self._scheme_to_fs[scheme](self.properties)


def schema_to_pyarrow(schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT) -> pa.schema:
return visit(schema, _ConvertToArrowSchema(metadata))
Expand Down Expand Up @@ -1048,9 +1074,9 @@ def project_table(
Raises:
ResolveError: When an incompatible query is done.
"""
scheme, netloc, _ = PyArrowFileIO.parse_location(table.location())
scheme, _, _ = PyArrowFileIO.parse_location(table.location())
if isinstance(table.io, PyArrowFileIO):
fs = table.io.fs_by_scheme(scheme, netloc)
fs = table.io.get_fs(scheme)
else:
try:
from pyiceberg.io.fsspec import FsspecFileIO
Expand Down
10 changes: 5 additions & 5 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,12 @@ def test_pyarrow_invalid_scheme() -> None:
with pytest.raises(ValueError) as exc_info:
PyArrowFileIO().new_input("foo://bar/baz.txt")

assert "Unrecognized filesystem type in URI" in str(exc_info.value)
assert "No registered filesystem for scheme" in str(exc_info.value)

with pytest.raises(ValueError) as exc_info:
PyArrowFileIO().new_output("foo://bar/baz.txt")

assert "Unrecognized filesystem type in URI" in str(exc_info.value)
assert "No registered filesystem for scheme" in str(exc_info.value)


def test_pyarrow_violating_input_stream_protocol() -> None:
Expand Down Expand Up @@ -301,7 +301,7 @@ def test_deleting_s3_file_no_permission() -> None:
s3fs_mock = MagicMock()
s3fs_mock.delete_file.side_effect = OSError("AWS Error [code 15]")

with patch.object(PyArrowFileIO, "_initialize_fs") as submocked:
with patch.object(PyArrowFileIO, "_get_fs") as submocked:
submocked.return_value = s3fs_mock

with pytest.raises(PermissionError) as exc_info:
Expand All @@ -316,7 +316,7 @@ def test_deleting_s3_file_not_found() -> None:
s3fs_mock = MagicMock()
s3fs_mock.delete_file.side_effect = OSError("Path does not exist")

with patch.object(PyArrowFileIO, "_initialize_fs") as submocked:
with patch.object(PyArrowFileIO, "_get_fs") as submocked:
submocked.return_value = s3fs_mock

with pytest.raises(FileNotFoundError) as exc_info:
Expand All @@ -331,7 +331,7 @@ def test_deleting_hdfs_file_not_found() -> None:
hdfs_mock = MagicMock()
hdfs_mock.delete_file.side_effect = OSError("Path does not exist")

with patch.object(PyArrowFileIO, "_initialize_fs") as submocked:
with patch.object(PyArrowFileIO, "_get_fs") as submocked:
submocked.return_value = hdfs_mock

with pytest.raises(FileNotFoundError) as exc_info:
Expand Down

0 comments on commit c85097e

Please sign in to comment.