Skip to content

Commit

Permalink
Implement __getstate__ and __setstate__ so that FileIO instances can …
Browse files Browse the repository at this point in the history
…be pickled
  • Loading branch information
amogh-jahagirdar committed Mar 24, 2024
1 parent 6989b92 commit 63f0f5f
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
12 changes: 12 additions & 0 deletions pyiceberg/io/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import json
import logging
import os
from copy import copy
from functools import lru_cache, partial
from typing import (
Any,
Expand Down Expand Up @@ -338,3 +339,14 @@ def _get_fs(self, scheme: str) -> AbstractFileSystem:
if scheme not in self._scheme_to_fs:
raise ValueError(f"No registered filesystem for scheme: {scheme}")
return self._scheme_to_fs[scheme](self.properties)

def __getstate__(self) -> Dict[str, Any]:
"""Create a dictionary of the FsSpecFileIO fields used when pickling."""
fileio_copy = copy(self.__dict__)
fileio_copy["fs_by_scheme"] = None
return fileio_copy

def __setstate__(self, state: Dict[str, Any]) -> None:
"""Deserialize the state into a FsSpecFileIO instance."""
self.__dict__ = state
self.fs_by_scheme = lru_cache(self._get_fs)
12 changes: 12 additions & 0 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import re
from abc import ABC, abstractmethod
from concurrent.futures import Future
from copy import copy
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache, singledispatch
Expand Down Expand Up @@ -456,6 +457,17 @@ def delete(self, location: Union[str, InputFile, OutputFile]) -> None:
raise PermissionError(f"Cannot delete file, access denied: {location}") from e
raise # pragma: no cover - If some other kind of OSError, raise the raw error

def __getstate__(self) -> Dict[str, Any]:
"""Create a dictionary of the PyArrowFileIO fields used when pickling."""
fileio_copy = copy(self.__dict__)
fileio_copy["fs_by_scheme"] = None
return fileio_copy

def __setstate__(self, state: Dict[str, Any]) -> None:
"""Deserialize the state into a PyArrowFileIO instance."""
self.__dict__ = state
self.fs_by_scheme = lru_cache(self._initialize_fs)


def schema_to_pyarrow(schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT) -> pa.schema:
return visit(schema, _ConvertToArrowSchema(metadata))
Expand Down
9 changes: 9 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=protected-access,unused-argument,redefined-outer-name

import os
import pickle
import tempfile
from datetime import date
from typing import Any, List, Optional
Expand Down Expand Up @@ -256,6 +257,14 @@ def test_raise_on_opening_a_local_file_not_found() -> None:
assert "[Errno 2] Failed to open local file" in str(exc_info.value)


def test_pickle_pyarrow_file_io() -> None:
f = PyArrowFileIO()
serialized = pickle.dumps(f)
assert serialized is not None
deserialized = pickle.loads(serialized)
assert deserialized is not None


def test_raise_on_opening_an_s3_file_no_permission() -> None:
"""Test that opening a PyArrowFile raises a PermissionError when the pyarrow error includes 'AWS Error [code 15]'"""

Expand Down

0 comments on commit 63f0f5f

Please sign in to comment.