From 551f524170b12900cfaa3fef1ec8a0f9f437ee4c Mon Sep 17 00:00:00 2001 From: Jiakai Li <50531391+jiakai-li@users.noreply.github.com> Date: Tue, 7 Jan 2025 03:47:43 +1300 Subject: [PATCH] Fix read from multiple s3 regions (#1453) * Take netloc into account for s3 filesystem when calling `_initialize_fs` * Fix unit test for s3 fileystem * Update ArrowScan to use different FileSystem per file * Add unit test for `PyArrorFileIO.fs_by_scheme` cache behavior * Add error handling * Update tests/io/test_pyarrow.py Co-authored-by: Kevin Liu * Update `s3.region` document and a test case * Add test case for `PyArrowFileIO.new_input` multi region * Shuffle code location for better maintainability * Comment for future integration test * Typo fix * Document wording * Add warning when the bucket region for a file cannot be resolved (for `pyarrow.S3FileSystem`) * Fix code linting * Update mkdocs/docs/configuration.md Co-authored-by: Kevin Liu * Code refactoring * Unit test * Code refactoring * Test cases * Code format * Code tidy-up * Update pyiceberg/io/pyarrow.py Co-authored-by: Kevin Liu --------- Co-authored-by: Kevin Liu --- mkdocs/docs/configuration.md | 30 ++--- pyiceberg/io/pyarrow.py | 212 +++++++++++++++++++++----------- tests/integration/test_reads.py | 29 +++++ tests/io/test_pyarrow.py | 96 ++++++++++++++- 4 files changed, 273 insertions(+), 94 deletions(-) diff --git a/mkdocs/docs/configuration.md b/mkdocs/docs/configuration.md index 621b313613..06eaac1bed 100644 --- a/mkdocs/docs/configuration.md +++ b/mkdocs/docs/configuration.md @@ -102,21 +102,21 @@ For the FileIO there are several configuration options available: -| Key | Example | Description | -|----------------------|----------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| s3.endpoint | | Configure an alternative endpoint of the S3 service for the FileIO to access. This could be used to use S3FileIO with any s3-compatible object storage service that has a different endpoint, or access a private S3 endpoint in a virtual private cloud. | -| s3.access-key-id | admin | Configure the static access key id used to access the FileIO. | -| s3.secret-access-key | password | Configure the static secret access key used to access the FileIO. | -| s3.session-token | AQoDYXdzEJr... | Configure the static session token used to access the FileIO. | -| s3.role-session-name | session | An optional identifier for the assumed role session. | -| s3.role-arn | arn:aws:... | AWS Role ARN. If provided instead of access_key and secret_key, temporary credentials will be fetched by assuming this role. | -| s3.signer | bearer | Configure the signature version of the FileIO. | -| s3.signer.uri | | Configure the remote signing uri if it differs from the catalog uri. Remote signing is only implemented for `FsspecFileIO`. The final request is sent to `/`. | -| s3.signer.endpoint | v1/main/s3-sign | Configure the remote signing endpoint. Remote signing is only implemented for `FsspecFileIO`. The final request is sent to `/`. (default : v1/aws/s3/sign). | -| s3.region | us-west-2 | Sets the region of the bucket | -| s3.proxy-uri | | Configure the proxy server to be used by the FileIO. | -| s3.connect-timeout | 60.0 | Configure socket connection timeout, in seconds. | -| s3.force-virtual-addressing | False | Whether to use virtual addressing of buckets. If true, then virtual addressing is always enabled. If false, then virtual addressing is only enabled if endpoint_override is empty. This can be used for non-AWS backends that only support virtual hosted-style access. | +| Key | Example | Description | +|----------------------|----------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| s3.endpoint | | Configure an alternative endpoint of the S3 service for the FileIO to access. This could be used to use S3FileIO with any s3-compatible object storage service that has a different endpoint, or access a private S3 endpoint in a virtual private cloud. | +| s3.access-key-id | admin | Configure the static access key id used to access the FileIO. | +| s3.secret-access-key | password | Configure the static secret access key used to access the FileIO. | +| s3.session-token | AQoDYXdzEJr... | Configure the static session token used to access the FileIO. | +| s3.role-session-name | session | An optional identifier for the assumed role session. | +| s3.role-arn | arn:aws:... | AWS Role ARN. If provided instead of access_key and secret_key, temporary credentials will be fetched by assuming this role. | +| s3.signer | bearer | Configure the signature version of the FileIO. | +| s3.signer.uri | | Configure the remote signing uri if it differs from the catalog uri. Remote signing is only implemented for `FsspecFileIO`. The final request is sent to `/`. | +| s3.signer.endpoint | v1/main/s3-sign | Configure the remote signing endpoint. Remote signing is only implemented for `FsspecFileIO`. The final request is sent to `/`. (default : v1/aws/s3/sign). | +| s3.region | us-west-2 | Configure the default region used to initialize an `S3FileSystem`. `PyArrowFileIO` attempts to automatically resolve the region for each S3 bucket, falling back to this value if resolution fails. | +| s3.proxy-uri | | Configure the proxy server to be used by the FileIO. | +| s3.connect-timeout | 60.0 | Configure socket connection timeout, in seconds. | +| s3.force-virtual-addressing | False | Whether to use virtual addressing of buckets. If true, then virtual addressing is always enabled. If false, then virtual addressing is only enabled if endpoint_override is empty. This can be used for non-AWS backends that only support virtual hosted-style access. | diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index dc41a7d6a1..ad7e4f4f85 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -351,77 +351,141 @@ def parse_location(location: str) -> Tuple[str, str, str]: return uri.scheme, uri.netloc, f"{uri.netloc}{uri.path}" def _initialize_fs(self, scheme: str, netloc: Optional[str] = None) -> FileSystem: - if scheme in {"s3", "s3a", "s3n", "oss"}: - from pyarrow.fs import S3FileSystem - - client_kwargs: Dict[str, Any] = { - "endpoint_override": self.properties.get(S3_ENDPOINT), - "access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), - "secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), - "session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), - "region": get_first_property_value(self.properties, S3_REGION, AWS_REGION), - } - - if proxy_uri := self.properties.get(S3_PROXY_URI): - client_kwargs["proxy_options"] = proxy_uri - - if connect_timeout := self.properties.get(S3_CONNECT_TIMEOUT): - client_kwargs["connect_timeout"] = float(connect_timeout) - - if role_arn := get_first_property_value(self.properties, S3_ROLE_ARN, AWS_ROLE_ARN): - client_kwargs["role_arn"] = role_arn - - if session_name := get_first_property_value(self.properties, S3_ROLE_SESSION_NAME, AWS_ROLE_SESSION_NAME): - client_kwargs["session_name"] = session_name - - if force_virtual_addressing := self.properties.get(S3_FORCE_VIRTUAL_ADDRESSING): - client_kwargs["force_virtual_addressing"] = property_as_bool(self.properties, force_virtual_addressing, False) - - return S3FileSystem(**client_kwargs) - elif scheme in ("hdfs", "viewfs"): - from pyarrow.fs import HadoopFileSystem - - hdfs_kwargs: Dict[str, Any] = {} - if netloc: - return HadoopFileSystem.from_uri(f"{scheme}://{netloc}") - if host := self.properties.get(HDFS_HOST): - hdfs_kwargs["host"] = host - if port := self.properties.get(HDFS_PORT): - # port should be an integer type - hdfs_kwargs["port"] = int(port) - if user := self.properties.get(HDFS_USER): - hdfs_kwargs["user"] = user - if kerb_ticket := self.properties.get(HDFS_KERB_TICKET): - hdfs_kwargs["kerb_ticket"] = kerb_ticket - - return HadoopFileSystem(**hdfs_kwargs) + """Initialize FileSystem for different scheme.""" + if scheme in {"oss"}: + return self._initialize_oss_fs() + + elif scheme in {"s3", "s3a", "s3n"}: + return self._initialize_s3_fs(netloc) + + elif scheme in {"hdfs", "viewfs"}: + return self._initialize_hdfs_fs(scheme, netloc) + elif scheme in {"gs", "gcs"}: - from pyarrow.fs import GcsFileSystem - - gcs_kwargs: Dict[str, Any] = {} - if access_token := self.properties.get(GCS_TOKEN): - gcs_kwargs["access_token"] = access_token - if expiration := self.properties.get(GCS_TOKEN_EXPIRES_AT_MS): - gcs_kwargs["credential_token_expiration"] = millis_to_datetime(int(expiration)) - if bucket_location := self.properties.get(GCS_DEFAULT_LOCATION): - gcs_kwargs["default_bucket_location"] = bucket_location - if endpoint := get_first_property_value(self.properties, GCS_SERVICE_HOST, GCS_ENDPOINT): - if self.properties.get(GCS_ENDPOINT): - deprecation_message( - deprecated_in="0.8.0", - removed_in="0.9.0", - help_message=f"The property {GCS_ENDPOINT} is deprecated, please use {GCS_SERVICE_HOST} instead", - ) - url_parts = urlparse(endpoint) - gcs_kwargs["scheme"] = url_parts.scheme - gcs_kwargs["endpoint_override"] = url_parts.netloc + return self._initialize_gcs_fs() + + elif scheme in {"file"}: + return self._initialize_local_fs() - return GcsFileSystem(**gcs_kwargs) - elif scheme == "file": - return PyArrowLocalFileSystem() else: raise ValueError(f"Unrecognized filesystem type in URI: {scheme}") + def _initialize_oss_fs(self) -> FileSystem: + from pyarrow.fs import S3FileSystem + + client_kwargs: Dict[str, Any] = { + "endpoint_override": self.properties.get(S3_ENDPOINT), + "access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), + "secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), + "session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), + "region": get_first_property_value(self.properties, S3_REGION, AWS_REGION), + } + + if proxy_uri := self.properties.get(S3_PROXY_URI): + client_kwargs["proxy_options"] = proxy_uri + + if connect_timeout := self.properties.get(S3_CONNECT_TIMEOUT): + client_kwargs["connect_timeout"] = float(connect_timeout) + + if role_arn := get_first_property_value(self.properties, S3_ROLE_ARN, AWS_ROLE_ARN): + client_kwargs["role_arn"] = role_arn + + if session_name := get_first_property_value(self.properties, S3_ROLE_SESSION_NAME, AWS_ROLE_SESSION_NAME): + client_kwargs["session_name"] = session_name + + if force_virtual_addressing := self.properties.get(S3_FORCE_VIRTUAL_ADDRESSING): + client_kwargs["force_virtual_addressing"] = property_as_bool(self.properties, force_virtual_addressing, False) + + return S3FileSystem(**client_kwargs) + + def _initialize_s3_fs(self, netloc: Optional[str]) -> FileSystem: + from pyarrow.fs import S3FileSystem, resolve_s3_region + + # Resolve region from netloc(bucket), fallback to user-provided region + provided_region = get_first_property_value(self.properties, S3_REGION, AWS_REGION) + + try: + bucket_region = resolve_s3_region(bucket=netloc) + except (OSError, TypeError): + bucket_region = None + logger.warning(f"Unable to resolve region for bucket {netloc}, using default region {provided_region}") + + bucket_region = bucket_region or provided_region + if bucket_region != provided_region: + logger.warning( + f"PyArrow FileIO overriding S3 bucket region for bucket {netloc}: " + f"provided region {provided_region}, actual region {bucket_region}" + ) + + client_kwargs: Dict[str, Any] = { + "endpoint_override": self.properties.get(S3_ENDPOINT), + "access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID), + "secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY), + "session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN), + "region": bucket_region, + } + + if proxy_uri := self.properties.get(S3_PROXY_URI): + client_kwargs["proxy_options"] = proxy_uri + + if connect_timeout := self.properties.get(S3_CONNECT_TIMEOUT): + client_kwargs["connect_timeout"] = float(connect_timeout) + + if role_arn := get_first_property_value(self.properties, S3_ROLE_ARN, AWS_ROLE_ARN): + client_kwargs["role_arn"] = role_arn + + if session_name := get_first_property_value(self.properties, S3_ROLE_SESSION_NAME, AWS_ROLE_SESSION_NAME): + client_kwargs["session_name"] = session_name + + if force_virtual_addressing := self.properties.get(S3_FORCE_VIRTUAL_ADDRESSING): + client_kwargs["force_virtual_addressing"] = property_as_bool(self.properties, force_virtual_addressing, False) + + return S3FileSystem(**client_kwargs) + + def _initialize_hdfs_fs(self, scheme: str, netloc: Optional[str]) -> FileSystem: + from pyarrow.fs import HadoopFileSystem + + hdfs_kwargs: Dict[str, Any] = {} + if netloc: + return HadoopFileSystem.from_uri(f"{scheme}://{netloc}") + if host := self.properties.get(HDFS_HOST): + hdfs_kwargs["host"] = host + if port := self.properties.get(HDFS_PORT): + # port should be an integer type + hdfs_kwargs["port"] = int(port) + if user := self.properties.get(HDFS_USER): + hdfs_kwargs["user"] = user + if kerb_ticket := self.properties.get(HDFS_KERB_TICKET): + hdfs_kwargs["kerb_ticket"] = kerb_ticket + + return HadoopFileSystem(**hdfs_kwargs) + + def _initialize_gcs_fs(self) -> FileSystem: + from pyarrow.fs import GcsFileSystem + + gcs_kwargs: Dict[str, Any] = {} + if access_token := self.properties.get(GCS_TOKEN): + gcs_kwargs["access_token"] = access_token + if expiration := self.properties.get(GCS_TOKEN_EXPIRES_AT_MS): + gcs_kwargs["credential_token_expiration"] = millis_to_datetime(int(expiration)) + if bucket_location := self.properties.get(GCS_DEFAULT_LOCATION): + gcs_kwargs["default_bucket_location"] = bucket_location + if endpoint := get_first_property_value(self.properties, GCS_SERVICE_HOST, GCS_ENDPOINT): + if self.properties.get(GCS_ENDPOINT): + deprecation_message( + deprecated_in="0.8.0", + removed_in="0.9.0", + help_message=f"The property {GCS_ENDPOINT} is deprecated, please use {GCS_SERVICE_HOST} instead", + ) + url_parts = urlparse(endpoint) + gcs_kwargs["scheme"] = url_parts.scheme + gcs_kwargs["endpoint_override"] = url_parts.netloc + + return GcsFileSystem(**gcs_kwargs) + + def _initialize_local_fs(self) -> FileSystem: + return PyArrowLocalFileSystem() + def new_input(self, location: str) -> PyArrowFile: """Get a PyArrowFile instance to read bytes from the file at the given location. @@ -1326,13 +1390,14 @@ def _task_to_table( return None -def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: +def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: deletes_per_file: Dict[str, List[ChunkedArray]] = {} unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks])) if len(unique_deletes) > 0: executor = ExecutorFactory.get_or_create() deletes_per_files: Iterator[Dict[str, ChunkedArray]] = executor.map( - lambda args: _read_deletes(*args), [(fs, delete) for delete in unique_deletes] + lambda args: _read_deletes(*args), + [(_fs_from_file_path(io, delete_file.file_path), delete_file) for delete_file in unique_deletes], ) for delete in deletes_per_files: for file, arr in delete.items(): @@ -1344,7 +1409,7 @@ def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dic return deletes_per_file -def _fs_from_file_path(file_path: str, io: FileIO) -> FileSystem: +def _fs_from_file_path(io: FileIO, file_path: str) -> FileSystem: scheme, netloc, _ = _parse_location(file_path) if isinstance(io, PyArrowFileIO): return io.fs_by_scheme(scheme, netloc) @@ -1366,7 +1431,6 @@ def _fs_from_file_path(file_path: str, io: FileIO) -> FileSystem: class ArrowScan: _table_metadata: TableMetadata _io: FileIO - _fs: FileSystem _projected_schema: Schema _bound_row_filter: BooleanExpression _case_sensitive: bool @@ -1376,7 +1440,6 @@ class ArrowScan: Attributes: _table_metadata: Current table metadata of the Iceberg table _io: PyIceberg FileIO implementation from which to fetch the io properties - _fs: PyArrow FileSystem to use to read the files _projected_schema: Iceberg Schema to project onto the data files _bound_row_filter: Schema bound row expression to filter the data with _case_sensitive: Case sensitivity when looking up column names @@ -1394,7 +1457,6 @@ def __init__( ) -> None: self._table_metadata = table_metadata self._io = io - self._fs = _fs_from_file_path(table_metadata.location, io) # TODO: use different FileSystem per file self._projected_schema = projected_schema self._bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive) self._case_sensitive = case_sensitive @@ -1434,7 +1496,7 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: ResolveError: When a required field cannot be found in the file ValueError: When a field type in the file cannot be projected to the schema type """ - deletes_per_file = _read_all_delete_files(self._fs, tasks) + deletes_per_file = _read_all_delete_files(self._io, tasks) executor = ExecutorFactory.get_or_create() def _table_from_scan_task(task: FileScanTask) -> pa.Table: @@ -1497,7 +1559,7 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record ResolveError: When a required field cannot be found in the file ValueError: When a field type in the file cannot be projected to the schema type """ - deletes_per_file = _read_all_delete_files(self._fs, tasks) + deletes_per_file = _read_all_delete_files(self._io, tasks) return self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file) def _record_batches_from_scan_tasks_and_deletes( @@ -1508,7 +1570,7 @@ def _record_batches_from_scan_tasks_and_deletes( if self._limit is not None and total_row_count >= self._limit: break batches = _task_to_record_batches( - self._fs, + _fs_from_file_path(self._io, task.file.file_path), task, self._bound_row_filter, self._projected_schema, diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index 8d13724087..f2e79bae60 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -19,6 +19,7 @@ import math import time import uuid +from pathlib import PosixPath from urllib.parse import urlparse import pyarrow as pa @@ -921,3 +922,31 @@ def test_table_scan_empty_table(catalog: Catalog) -> None: result_table = tbl.scan().to_arrow() assert len(result_table) == 0 + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_read_from_s3_and_local_fs(catalog: Catalog, tmp_path: PosixPath) -> None: + identifier = "default.test_read_from_s3_and_local_fs" + schema = pa.schema([pa.field("colA", pa.string())]) + arrow_table = pa.Table.from_arrays([pa.array(["one"])], schema=schema) + + tmp_dir = tmp_path / "data" + tmp_dir.mkdir() + local_file = tmp_dir / "local_file.parquet" + + try: + catalog.drop_table(identifier) + except NoSuchTableError: + pass + tbl = catalog.create_table(identifier, schema=schema) + + # Append table to s3 endpoint + tbl.append(arrow_table) + + # Append a local file + pq.write_table(arrow_table, local_file) + tbl.add_files([str(local_file)]) + + result_table = tbl.scan().to_arrow() + assert result_table["colA"].to_pylist() == ["one", "one"] diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 8bb97e150a..8beb750f49 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=protected-access,unused-argument,redefined-outer-name - +import logging import os import tempfile import uuid @@ -27,7 +27,7 @@ import pyarrow as pa import pyarrow.parquet as pq import pytest -from pyarrow.fs import FileType, LocalFileSystem +from pyarrow.fs import FileType, LocalFileSystem, S3FileSystem from pyiceberg.exceptions import ResolveError from pyiceberg.expressions import ( @@ -360,10 +360,12 @@ def test_pyarrow_s3_session_properties() -> None: **UNIFIED_AWS_SESSION_PROPERTIES, } - with patch("pyarrow.fs.S3FileSystem") as mock_s3fs: + with patch("pyarrow.fs.S3FileSystem") as mock_s3fs, patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: s3_fileio = PyArrowFileIO(properties=session_properties) filename = str(uuid.uuid4()) + # Mock `resolve_s3_region` to prevent from the location used resolving to a different s3 region + mock_s3_region_resolver.side_effect = OSError("S3 bucket is not found") s3_fileio.new_input(location=f"s3://warehouse/{filename}") mock_s3fs.assert_called_with( @@ -381,10 +383,11 @@ def test_pyarrow_unified_session_properties() -> None: **UNIFIED_AWS_SESSION_PROPERTIES, } - with patch("pyarrow.fs.S3FileSystem") as mock_s3fs: + with patch("pyarrow.fs.S3FileSystem") as mock_s3fs, patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: s3_fileio = PyArrowFileIO(properties=session_properties) filename = str(uuid.uuid4()) + mock_s3_region_resolver.return_value = "client.region" s3_fileio.new_input(location=f"s3://warehouse/{filename}") mock_s3fs.assert_called_with( @@ -2096,3 +2099,88 @@ def test__to_requested_schema_timestamps_without_downcast_raises_exception( _to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False) assert "Unsupported schema projection from timestamp[ns] to timestamp[us]" in str(exc_info.value) + + +def test_pyarrow_file_io_fs_by_scheme_cache() -> None: + # It's better to set up multi-region minio servers for an integration test once `endpoint_url` argument becomes available for `resolve_s3_region` + # Refer to: https://github.com/apache/arrow/issues/43713 + + pyarrow_file_io = PyArrowFileIO() + us_east_1_region = "us-east-1" + ap_southeast_2_region = "ap-southeast-2" + + with patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: + # Call with new argument resolves region automatically + mock_s3_region_resolver.return_value = us_east_1_region + filesystem_us = pyarrow_file_io.fs_by_scheme("s3", "us-east-1-bucket") + assert filesystem_us.region == us_east_1_region + assert pyarrow_file_io.fs_by_scheme.cache_info().misses == 1 # type: ignore + assert pyarrow_file_io.fs_by_scheme.cache_info().currsize == 1 # type: ignore + + # Call with different argument also resolves region automatically + mock_s3_region_resolver.return_value = ap_southeast_2_region + filesystem_ap_southeast_2 = pyarrow_file_io.fs_by_scheme("s3", "ap-southeast-2-bucket") + assert filesystem_ap_southeast_2.region == ap_southeast_2_region + assert pyarrow_file_io.fs_by_scheme.cache_info().misses == 2 # type: ignore + assert pyarrow_file_io.fs_by_scheme.cache_info().currsize == 2 # type: ignore + + # Call with same argument hits cache + filesystem_us_cached = pyarrow_file_io.fs_by_scheme("s3", "us-east-1-bucket") + assert filesystem_us_cached.region == us_east_1_region + assert pyarrow_file_io.fs_by_scheme.cache_info().hits == 1 # type: ignore + + # Call with same argument hits cache + filesystem_ap_southeast_2_cached = pyarrow_file_io.fs_by_scheme("s3", "ap-southeast-2-bucket") + assert filesystem_ap_southeast_2_cached.region == ap_southeast_2_region + assert pyarrow_file_io.fs_by_scheme.cache_info().hits == 2 # type: ignore + + +def test_pyarrow_io_new_input_multi_region(caplog: Any) -> None: + # It's better to set up multi-region minio servers for an integration test once `endpoint_url` argument becomes available for `resolve_s3_region` + # Refer to: https://github.com/apache/arrow/issues/43713 + user_provided_region = "ap-southeast-1" + bucket_regions = [ + ("us-east-2-bucket", "us-east-2"), + ("ap-southeast-2-bucket", "ap-southeast-2"), + ] + + def _s3_region_map(bucket: str) -> str: + for bucket_region in bucket_regions: + if bucket_region[0] == bucket: + return bucket_region[1] + raise OSError("Unknown bucket") + + # For a pyarrow io instance with configured default s3 region + pyarrow_file_io = PyArrowFileIO({"s3.region": user_provided_region}) + with patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: + mock_s3_region_resolver.side_effect = _s3_region_map + + # The region is set to provided region if bucket region cannot be resolved + with caplog.at_level(logging.WARNING): + assert pyarrow_file_io.new_input("s3://non-exist-bucket/path/to/file")._filesystem.region == user_provided_region + assert f"Unable to resolve region for bucket non-exist-bucket, using default region {user_provided_region}" in caplog.text + + for bucket_region in bucket_regions: + # For s3 scheme, region is overwritten by resolved bucket region if different from user provided region + with caplog.at_level(logging.WARNING): + assert pyarrow_file_io.new_input(f"s3://{bucket_region[0]}/path/to/file")._filesystem.region == bucket_region[1] + assert ( + f"PyArrow FileIO overriding S3 bucket region for bucket {bucket_region[0]}: " + f"provided region {user_provided_region}, actual region {bucket_region[1]}" in caplog.text + ) + + # For oss scheme, user provided region is used instead + assert pyarrow_file_io.new_input(f"oss://{bucket_region[0]}/path/to/file")._filesystem.region == user_provided_region + + +def test_pyarrow_io_multi_fs() -> None: + pyarrow_file_io = PyArrowFileIO({"s3.region": "ap-southeast-1"}) + + with patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver: + mock_s3_region_resolver.return_value = None + + # The PyArrowFileIO instance resolves s3 file input to S3FileSystem + assert isinstance(pyarrow_file_io.new_input("s3://bucket/path/to/file")._filesystem, S3FileSystem) + + # Same PyArrowFileIO instance resolves local file input to LocalFileSystem + assert isinstance(pyarrow_file_io.new_input("file:///path/to/file")._filesystem, LocalFileSystem)