diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 6f79835db0..650d391807 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -146,6 +146,25 @@ catalog.create_table( ) ``` +To create a table using a pyarrow schema: + +```python +import pyarrow as pa + +schema = pa.schema( + [ + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + ] +) + +catalog.create_table( + identifier="docs_example.bids", + schema=schema, +) +``` + ## Load a table ### Catalog table diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index a39d0e915c..6e5dc2748f 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -24,6 +24,7 @@ from dataclasses import dataclass from enum import Enum from typing import ( + TYPE_CHECKING, Callable, Dict, List, @@ -56,6 +57,9 @@ ) from pyiceberg.utils.config import Config, merge_config +if TYPE_CHECKING: + import pyarrow as pa + logger = logging.getLogger(__name__) _ENV_CONFIG = Config() @@ -288,7 +292,7 @@ def _load_file_io(self, properties: Properties = EMPTY_DICT, location: Optional[ def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, @@ -512,6 +516,22 @@ def _check_for_overlap(removals: Optional[Set[str]], updates: Properties) -> Non if overlap: raise ValueError(f"Updates and deletes have an overlap: {overlap}") + @staticmethod + def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> Schema: + if isinstance(schema, Schema): + return schema + try: + import pyarrow as pa + + from pyiceberg.io.pyarrow import _ConvertToIcebergWithoutIDs, visit_pyarrow + + if isinstance(schema, pa.Schema): + schema: Schema = visit_pyarrow(schema, _ConvertToIcebergWithoutIDs()) # type: ignore + return schema + except ModuleNotFoundError: + pass + raise ValueError(f"{type(schema)=}, but it must be pyiceberg.schema.Schema or pyarrow.Schema") + def _resolve_table_location(self, location: Optional[str], database_name: str, table_name: str) -> str: if not location: return self._get_default_warehouse_location(database_name, table_name) diff --git a/pyiceberg/catalog/dynamodb.py b/pyiceberg/catalog/dynamodb.py index 6c3f931bd8..d5f3b5e14c 100644 --- a/pyiceberg/catalog/dynamodb.py +++ b/pyiceberg/catalog/dynamodb.py @@ -17,6 +17,7 @@ import uuid from time import time from typing import ( + TYPE_CHECKING, Any, Dict, List, @@ -57,6 +58,9 @@ from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT +if TYPE_CHECKING: + import pyarrow as pa + DYNAMODB_CLIENT = "dynamodb" DYNAMODB_COL_IDENTIFIER = "identifier" @@ -127,7 +131,7 @@ def _dynamodb_table_exists(self) -> bool: def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, @@ -152,6 +156,8 @@ def create_table( ValueError: If the identifier is invalid, or no path is given to store metadata. """ + schema: Schema = self._convert_schema_if_needed(schema) # type: ignore + database_name, table_name = self.identifier_to_database_and_table(identifier) location = self._resolve_table_location(location, database_name, table_name) diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py index 645568f80a..8f860fabba 100644 --- a/pyiceberg/catalog/glue.py +++ b/pyiceberg/catalog/glue.py @@ -17,6 +17,7 @@ from typing import ( + TYPE_CHECKING, Any, Dict, List, @@ -88,6 +89,9 @@ UUIDType, ) +if TYPE_CHECKING: + import pyarrow as pa + # If Glue should skip archiving an old table version when creating a new version in a commit. By # default, Glue archives all old table versions after an UpdateTable call, but Glue has a default # max number of archived table versions (can be increased). So for streaming use case with lots @@ -329,7 +333,7 @@ def _get_glue_table(self, database_name: str, table_name: str) -> TableTypeDef: def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, @@ -354,6 +358,8 @@ def create_table( ValueError: If the identifier is invalid, or no path is given to store metadata. """ + schema: Schema = self._convert_schema_if_needed(schema) # type: ignore + database_name, table_name = self.identifier_to_database_and_table(identifier) location = self._resolve_table_location(location, database_name, table_name) diff --git a/pyiceberg/catalog/hive.py b/pyiceberg/catalog/hive.py index 331b9ca80d..8069321095 100644 --- a/pyiceberg/catalog/hive.py +++ b/pyiceberg/catalog/hive.py @@ -18,6 +18,7 @@ import time from types import TracebackType from typing import ( + TYPE_CHECKING, Any, Dict, List, @@ -91,6 +92,10 @@ UUIDType, ) +if TYPE_CHECKING: + import pyarrow as pa + + # Replace by visitor hive_types = { BooleanType: "boolean", @@ -250,7 +255,7 @@ def _convert_hive_into_iceberg(self, table: HiveTable, io: FileIO) -> Table: def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, @@ -273,6 +278,8 @@ def create_table( AlreadyExistsError: If a table with the name already exists. ValueError: If the identifier is invalid. """ + schema: Schema = self._convert_schema_if_needed(schema) # type: ignore + properties = {**DEFAULT_PROPERTIES, **properties} database_name, table_name = self.identifier_to_database_and_table(identifier) current_time_millis = int(time.time() * 1000) diff --git a/pyiceberg/catalog/noop.py b/pyiceberg/catalog/noop.py index 083f851d1c..a8b7154621 100644 --- a/pyiceberg/catalog/noop.py +++ b/pyiceberg/catalog/noop.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. from typing import ( + TYPE_CHECKING, List, Optional, Set, @@ -33,12 +34,15 @@ from pyiceberg.table.sorting import UNSORTED_SORT_ORDER from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties +if TYPE_CHECKING: + import pyarrow as pa + class NoopCatalog(Catalog): def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py index de192a9e0b..34d75b5936 100644 --- a/pyiceberg/catalog/rest.py +++ b/pyiceberg/catalog/rest.py @@ -16,6 +16,7 @@ # under the License. from json import JSONDecodeError from typing import ( + TYPE_CHECKING, Any, Dict, List, @@ -68,6 +69,9 @@ from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT, UTF8, IcebergBaseModel +if TYPE_CHECKING: + import pyarrow as pa + ICEBERG_REST_SPEC_VERSION = "0.14.1" @@ -437,12 +441,14 @@ def _response_to_table(self, identifier_tuple: Tuple[str, ...], table_response: def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, ) -> Table: + schema: Schema = self._convert_schema_if_needed(schema) # type: ignore + namespace_and_table = self._split_identifier_for_path(identifier) request = CreateTableRequest( name=namespace_and_table["table"], diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py index 593c6b54a1..8a02b20dfc 100644 --- a/pyiceberg/catalog/sql.py +++ b/pyiceberg/catalog/sql.py @@ -16,6 +16,7 @@ # under the License. from typing import ( + TYPE_CHECKING, List, Optional, Set, @@ -65,6 +66,9 @@ from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT +if TYPE_CHECKING: + import pyarrow as pa + class SqlCatalogBaseTable(MappedAsDataclass, DeclarativeBase): pass @@ -140,7 +144,7 @@ def _convert_orm_to_iceberg(self, orm_table: IcebergTables) -> Table: def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, @@ -165,6 +169,8 @@ def create_table( ValueError: If the identifier is invalid, or no path is given to store metadata. """ + schema: Schema = self._convert_schema_if_needed(schema) # type: ignore + database_name, table_name = self.identifier_to_database_and_table(identifier) if not self._namespace_exists(database_name): raise NoSuchNamespaceError(f"Namespace does not exist: {database_name}") diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 1d7dcbef77..7a94ce4c7d 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -26,6 +26,7 @@ from __future__ import annotations import concurrent.futures +import itertools import logging import os import re @@ -34,7 +35,6 @@ from dataclasses import dataclass from enum import Enum from functools import lru_cache, singledispatch -from itertools import chain from typing import ( TYPE_CHECKING, Any, @@ -637,7 +637,7 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: if len(positional_deletes) == 1: all_chunks = positional_deletes[0] else: - all_chunks = pa.chunked_array(chain(*[arr.chunks for arr in positional_deletes])) + all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes])) return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False) @@ -912,6 +912,21 @@ def after_map_value(self, element: pa.Field) -> None: self._field_names.pop() +class _ConvertToIcebergWithoutIDs(_ConvertToIceberg): + """ + Converts PyArrowSchema to Iceberg Schema with all -1 ids. + + The schema generated through this visitor should always be + used in conjunction with `new_table_metadata` function to + assign new field ids in order. This is currently used only + when creating an Iceberg Schema from a PyArrow schema when + creating a new Iceberg table. + """ + + def _field_id(self, field: pa.Field) -> int: + return -1 + + def _task_to_table( fs: FileSystem, task: FileScanTask, @@ -999,7 +1014,7 @@ def _task_to_table( def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: deletes_per_file: Dict[str, List[ChunkedArray]] = {} - unique_deletes = set(chain.from_iterable([task.delete_files for task in tasks])) + unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks])) if len(unique_deletes) > 0: executor = ExecutorFactory.get_or_create() deletes_per_files: Iterator[Dict[str, ChunkedArray]] = executor.map( @@ -1421,7 +1436,7 @@ def schema(self, schema: Schema, struct_result: Callable[[], List[StatisticsColl def struct( self, struct: StructType, field_results: List[Callable[[], List[StatisticsCollector]]] ) -> List[StatisticsCollector]: - return list(chain(*[result() for result in field_results])) + return list(itertools.chain(*[result() for result in field_results])) def field(self, field: NestedField, field_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]: self._field_id = field.field_id @@ -1513,7 +1528,7 @@ def schema(self, schema: Schema, struct_result: Callable[[], List[ID2ParquetPath return struct_result() def struct(self, struct: StructType, field_results: List[Callable[[], List[ID2ParquetPath]]]) -> List[ID2ParquetPath]: - return list(chain(*[result() for result in field_results])) + return list(itertools.chain(*[result() for result in field_results])) def field(self, field: NestedField, field_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]: self._field_id = field.field_id diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index b61e4678b9..6dd174f325 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1221,50 +1221,57 @@ def assign_fresh_schema_ids(schema_or_type: Union[Schema, IcebergType], next_id: class _SetFreshIDs(PreOrderSchemaVisitor[IcebergType]): """Traverses the schema and assigns monotonically increasing ids.""" - reserved_ids: Dict[int, int] + old_id_to_new_id: Dict[int, int] def __init__(self, next_id_func: Optional[Callable[[], int]] = None) -> None: - self.reserved_ids = {} + self.old_id_to_new_id = {} counter = itertools.count(1) self.next_id_func = next_id_func if next_id_func is not None else lambda: next(counter) - def _get_and_increment(self) -> int: - return self.next_id_func() + def _get_and_increment(self, current_id: int) -> int: + new_id = self.next_id_func() + self.old_id_to_new_id[current_id] = new_id + return new_id def schema(self, schema: Schema, struct_result: Callable[[], StructType]) -> Schema: - # First we keep the original identifier_field_ids here, we remap afterwards - fields = struct_result().fields - return Schema(*fields, identifier_field_ids=[self.reserved_ids[field_id] for field_id in schema.identifier_field_ids]) + return Schema( + *struct_result().fields, + identifier_field_ids=[self.old_id_to_new_id[field_id] for field_id in schema.identifier_field_ids], + ) def struct(self, struct: StructType, field_results: List[Callable[[], IcebergType]]) -> StructType: - # assign IDs for this struct's fields first - self.reserved_ids.update({field.field_id: self._get_and_increment() for field in struct.fields}) - return StructType(*[field() for field in field_results]) + new_ids = [self._get_and_increment(field.field_id) for field in struct.fields] + new_fields = [] + for field_id, field, field_type in zip(new_ids, struct.fields, field_results): + new_fields.append( + NestedField( + field_id=field_id, + name=field.name, + field_type=field_type(), + required=field.required, + doc=field.doc, + ) + ) + return StructType(*new_fields) def field(self, field: NestedField, field_result: Callable[[], IcebergType]) -> IcebergType: - return NestedField( - field_id=self.reserved_ids[field.field_id], - name=field.name, - field_type=field_result(), - required=field.required, - doc=field.doc, - ) + return field_result() def list(self, list_type: ListType, element_result: Callable[[], IcebergType]) -> ListType: - self.reserved_ids[list_type.element_id] = self._get_and_increment() + element_id = self._get_and_increment(list_type.element_id) return ListType( - element_id=self.reserved_ids[list_type.element_id], + element_id=element_id, element=element_result(), element_required=list_type.element_required, ) def map(self, map_type: MapType, key_result: Callable[[], IcebergType], value_result: Callable[[], IcebergType]) -> MapType: - self.reserved_ids[map_type.key_id] = self._get_and_increment() - self.reserved_ids[map_type.value_id] = self._get_and_increment() + key_id = self._get_and_increment(map_type.key_id) + value_id = self._get_and_increment(map_type.value_id) return MapType( - key_id=self.reserved_ids[map_type.key_id], + key_id=key_id, key_type=key_result(), - value_id=self.reserved_ids[map_type.value_id], + value_id=value_id, value_type=value_result(), value_required=map_type.value_required, ) diff --git a/pyproject.toml b/pyproject.toml index e7f18b5551..d1bc82dc62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -311,7 +311,7 @@ select = [ "I", # isort "UP", # pyupgrade ] -ignore = ["E501","E203","B024","B028"] +ignore = ["E501","E203","B024","B028","UP037"] # Allow autofix for all enabled rules (when `--fix`) is provided. fixable = ["ALL"] diff --git a/tests/catalog/test_base.py b/tests/catalog/test_base.py index 911c06b27a..d15c90fee3 100644 --- a/tests/catalog/test_base.py +++ b/tests/catalog/test_base.py @@ -24,7 +24,9 @@ Union, ) +import pyarrow as pa import pytest +from pytest_lazyfixture import lazy_fixture from pyiceberg.catalog import ( Catalog, @@ -72,12 +74,14 @@ def __init__(self, name: str, **properties: str) -> None: def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, ) -> Table: + schema: Schema = self._convert_schema_if_needed(schema) # type: ignore + identifier = Catalog.identifier_to_tuple(identifier) namespace = Catalog.namespace_from(identifier) @@ -330,6 +334,34 @@ def test_create_table(catalog: InMemoryCatalog) -> None: assert catalog.load_table(TEST_TABLE_IDENTIFIER) == table +@pytest.mark.parametrize( + "schema,expected", + [ + (lazy_fixture("pyarrow_schema_simple_without_ids"), lazy_fixture("iceberg_schema_simple_no_ids")), + (lazy_fixture("iceberg_schema_simple"), lazy_fixture("iceberg_schema_simple")), + (lazy_fixture("iceberg_schema_nested"), lazy_fixture("iceberg_schema_nested")), + (lazy_fixture("pyarrow_schema_nested_without_ids"), lazy_fixture("iceberg_schema_nested_no_ids")), + ], +) +def test_convert_schema_if_needed( + schema: Union[Schema, pa.Schema], + expected: Schema, + catalog: InMemoryCatalog, +) -> None: + assert expected == catalog._convert_schema_if_needed(schema) + + +def test_create_table_pyarrow_schema(catalog: InMemoryCatalog, pyarrow_schema_simple_without_ids: pa.Schema) -> None: + table = catalog.create_table( + identifier=TEST_TABLE_IDENTIFIER, + schema=pyarrow_schema_simple_without_ids, + location=TEST_TABLE_LOCATION, + partition_spec=TEST_TABLE_PARTITION_SPEC, + properties=TEST_TABLE_PROPERTIES, + ) + assert catalog.load_table(TEST_TABLE_IDENTIFIER) == table + + def test_create_table_raises_error_when_table_already_exists(catalog: InMemoryCatalog) -> None: # Given given_catalog_has_a_table(catalog) diff --git a/tests/catalog/test_dynamodb.py b/tests/catalog/test_dynamodb.py index 5af89ef3be..bc801463c5 100644 --- a/tests/catalog/test_dynamodb.py +++ b/tests/catalog/test_dynamodb.py @@ -18,6 +18,7 @@ from unittest import mock import boto3 +import pyarrow as pa import pytest from moto import mock_dynamodb @@ -71,6 +72,23 @@ def test_create_table_with_database_location( assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location) +@mock_dynamodb +def test_create_table_with_pyarrow_schema( + _bucket_initialize: None, + moto_endpoint_url: str, + pyarrow_schema_simple_without_ids: pa.Schema, + database_name: str, + table_name: str, +) -> None: + catalog_name = "test_ddb_catalog" + identifier = (database_name, table_name) + test_catalog = DynamoDbCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url}) + test_catalog.create_namespace(namespace=database_name, properties={"location": f"s3://{BUCKET_NAME}/{database_name}.db"}) + table = test_catalog.create_table(identifier, pyarrow_schema_simple_without_ids) + assert table.identifier == (catalog_name,) + identifier + assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location) + + @mock_dynamodb def test_create_table_with_default_warehouse( _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py index b1f1371a04..63a213f94f 100644 --- a/tests/catalog/test_glue.py +++ b/tests/catalog/test_glue.py @@ -18,6 +18,7 @@ from unittest import mock import boto3 +import pyarrow as pa import pytest from moto import mock_glue @@ -101,6 +102,28 @@ def test_create_table_with_given_location( assert test_catalog._parse_metadata_version(table.metadata_location) == 0 +@mock_glue +def test_create_table_with_pyarrow_schema( + _bucket_initialize: None, + moto_endpoint_url: str, + pyarrow_schema_simple_without_ids: pa.Schema, + database_name: str, + table_name: str, +) -> None: + catalog_name = "glue" + identifier = (database_name, table_name) + test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url}) + test_catalog.create_namespace(namespace=database_name) + table = test_catalog.create_table( + identifier=identifier, + schema=pyarrow_schema_simple_without_ids, + location=f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}", + ) + assert table.identifier == (catalog_name,) + identifier + assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location) + assert test_catalog._parse_metadata_version(table.metadata_location) == 0 + + @mock_glue def test_create_table_with_no_location( _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 9dbcf8f84e..1ca8fd16d2 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -158,6 +158,26 @@ def test_create_table_default_sort_order(catalog: SqlCatalog, table_schema_neste catalog.drop_table(random_identifier) +@pytest.mark.parametrize( + 'catalog', + [ + lazy_fixture('catalog_memory'), + lazy_fixture('catalog_sqlite'), + ], +) +def test_create_table_with_pyarrow_schema( + catalog: SqlCatalog, + pyarrow_schema_simple_without_ids: pa.Schema, + iceberg_table_schema_simple: Schema, + random_identifier: Identifier, +) -> None: + database_name, _table_name = random_identifier + catalog.create_namespace(database_name) + table = catalog.create_table(random_identifier, pyarrow_schema_simple_without_ids) + assert table.schema() == iceberg_table_schema_simple + catalog.drop_table(random_identifier) + + @pytest.mark.parametrize( 'catalog', [ diff --git a/tests/conftest.py b/tests/conftest.py index 9c53301776..d9a8dfdf07 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,6 +45,7 @@ from urllib.parse import urlparse import boto3 +import pyarrow as pa import pytest from moto import mock_dynamodb, mock_glue from moto.server import ThreadedMotoServer # type: ignore @@ -267,6 +268,178 @@ def table_schema_nested_with_struct_key_map() -> Schema: ) +@pytest.fixture(scope="session") +def pyarrow_schema_simple_without_ids() -> pa.Schema: + return pa.schema([ + pa.field('foo', pa.string(), nullable=True), + pa.field('bar', pa.int32(), nullable=False), + pa.field('baz', pa.bool_(), nullable=True), + ]) + + +@pytest.fixture(scope="session") +def pyarrow_schema_nested_without_ids() -> pa.Schema: + return pa.schema([ + pa.field('foo', pa.string(), nullable=False), + pa.field('bar', pa.int32(), nullable=False), + pa.field('baz', pa.bool_(), nullable=True), + pa.field('qux', pa.list_(pa.string()), nullable=False), + pa.field( + 'quux', + pa.map_( + pa.string(), + pa.map_(pa.string(), pa.int32()), + ), + nullable=False, + ), + pa.field( + 'location', + pa.list_( + pa.struct([ + pa.field('latitude', pa.float32(), nullable=False), + pa.field('longitude', pa.float32(), nullable=False), + ]), + ), + nullable=False, + ), + pa.field( + 'person', + pa.struct([ + pa.field('name', pa.string(), nullable=True), + pa.field('age', pa.int32(), nullable=False), + ]), + nullable=True, + ), + ]) + + +@pytest.fixture(scope="session") +def iceberg_schema_simple() -> Schema: + return Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + ) + + +@pytest.fixture(scope="session") +def iceberg_schema_simple_no_ids() -> Schema: + return Schema( + NestedField(field_id=-1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=-1, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=-1, name="baz", field_type=BooleanType(), required=False), + ) + + +@pytest.fixture(scope="session") +def iceberg_table_schema_simple() -> Schema: + return Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=0, + identifier_field_ids=[], + ) + + +@pytest.fixture(scope="session") +def iceberg_schema_nested() -> Schema: + return Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + NestedField( + field_id=4, + name="qux", + field_type=ListType(element_id=5, element_type=StringType(), element_required=False), + required=True, + ), + NestedField( + field_id=6, + name="quux", + field_type=MapType( + key_id=7, + key_type=StringType(), + value_id=8, + value_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=False), + value_required=False, + ), + required=True, + ), + NestedField( + field_id=11, + name="location", + field_type=ListType( + element_id=12, + element_type=StructType( + NestedField(field_id=13, name="latitude", field_type=FloatType(), required=True), + NestedField(field_id=14, name="longitude", field_type=FloatType(), required=True), + ), + element_required=False, + ), + required=True, + ), + NestedField( + field_id=15, + name="person", + field_type=StructType( + NestedField(field_id=16, name="name", field_type=StringType(), required=False), + NestedField(field_id=17, name="age", field_type=IntegerType(), required=True), + ), + required=False, + ), + ) + + +@pytest.fixture(scope="session") +def iceberg_schema_nested_no_ids() -> Schema: + return Schema( + NestedField(field_id=-1, name="foo", field_type=StringType(), required=True), + NestedField(field_id=-1, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=-1, name="baz", field_type=BooleanType(), required=False), + NestedField( + field_id=-1, + name="qux", + field_type=ListType(element_id=-1, element_type=StringType(), element_required=False), + required=True, + ), + NestedField( + field_id=-1, + name="quux", + field_type=MapType( + key_id=-1, + key_type=StringType(), + value_id=-1, + value_type=MapType(key_id=-1, key_type=StringType(), value_id=-1, value_type=IntegerType(), value_required=False), + value_required=False, + ), + required=True, + ), + NestedField( + field_id=-1, + name="location", + field_type=ListType( + element_id=-1, + element_type=StructType( + NestedField(field_id=-1, name="latitude", field_type=FloatType(), required=True), + NestedField(field_id=-1, name="longitude", field_type=FloatType(), required=True), + ), + element_required=False, + ), + required=True, + ), + NestedField( + field_id=-1, + name="person", + field_type=StructType( + NestedField(field_id=-1, name="name", field_type=StringType(), required=False), + NestedField(field_id=-1, name="age", field_type=IntegerType(), required=True), + ), + required=False, + ), + ) + + @pytest.fixture(scope="session") def all_avro_types() -> Dict[str, Any]: return { diff --git a/tests/io/test_pyarrow_visitor.py b/tests/io/test_pyarrow_visitor.py index 0986eac409..c7f364b920 100644 --- a/tests/io/test_pyarrow_visitor.py +++ b/tests/io/test_pyarrow_visitor.py @@ -23,6 +23,7 @@ from pyiceberg.io.pyarrow import ( _ConvertToArrowSchema, _ConvertToIceberg, + _ConvertToIcebergWithoutIDs, _HasIds, pyarrow_to_schema, schema_to_pyarrow, @@ -51,104 +52,6 @@ ) -@pytest.fixture(scope="module") -def pyarrow_schema_simple_without_ids() -> pa.Schema: - return pa.schema([pa.field('some_int', pa.int32(), nullable=True), pa.field('some_string', pa.string(), nullable=False)]) - - -@pytest.fixture(scope="module") -def pyarrow_schema_nested_without_ids() -> pa.Schema: - return pa.schema([ - pa.field('foo', pa.string(), nullable=False), - pa.field('bar', pa.int32(), nullable=False), - pa.field('baz', pa.bool_(), nullable=True), - pa.field('qux', pa.list_(pa.string()), nullable=False), - pa.field( - 'quux', - pa.map_( - pa.string(), - pa.map_(pa.string(), pa.int32()), - ), - nullable=False, - ), - pa.field( - 'location', - pa.list_( - pa.struct([ - pa.field('latitude', pa.float32(), nullable=False), - pa.field('longitude', pa.float32(), nullable=False), - ]), - ), - nullable=False, - ), - pa.field( - 'person', - pa.struct([ - pa.field('name', pa.string(), nullable=True), - pa.field('age', pa.int32(), nullable=False), - ]), - nullable=True, - ), - ]) - - -@pytest.fixture(scope="module") -def iceberg_schema_simple() -> Schema: - return Schema( - NestedField(field_id=1, name="some_int", field_type=IntegerType(), required=False), - NestedField(field_id=2, name="some_string", field_type=StringType(), required=True), - ) - - -@pytest.fixture(scope="module") -def iceberg_schema_nested() -> Schema: - return Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=True), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - NestedField( - field_id=4, - name="qux", - field_type=ListType(element_id=5, element_type=StringType(), element_required=False), - required=True, - ), - NestedField( - field_id=6, - name="quux", - field_type=MapType( - key_id=7, - key_type=StringType(), - value_id=8, - value_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=False), - value_required=False, - ), - required=True, - ), - NestedField( - field_id=11, - name="location", - field_type=ListType( - element_id=12, - element_type=StructType( - NestedField(field_id=13, name="latitude", field_type=FloatType(), required=True), - NestedField(field_id=14, name="longitude", field_type=FloatType(), required=True), - ), - element_required=False, - ), - required=True, - ), - NestedField( - field_id=15, - name="person", - field_type=StructType( - NestedField(field_id=16, name="name", field_type=StringType(), required=False), - NestedField(field_id=17, name="age", field_type=IntegerType(), required=True), - ), - required=False, - ), - ) - - def test_pyarrow_binary_to_iceberg() -> None: length = 23 pyarrow_type = pa.binary(length) @@ -468,8 +371,9 @@ def test_simple_pyarrow_schema_to_schema_missing_ids_using_name_mapping( ) -> None: schema = pyarrow_schema_simple_without_ids name_mapping = NameMapping([ - MappedField(field_id=1, names=['some_int']), - MappedField(field_id=2, names=['some_string']), + MappedField(field_id=1, names=['foo']), + MappedField(field_id=2, names=['bar']), + MappedField(field_id=3, names=['baz']), ]) assert pyarrow_to_schema(schema, name_mapping) == iceberg_schema_simple @@ -480,11 +384,11 @@ def test_simple_pyarrow_schema_to_schema_missing_ids_using_name_mapping_partial_ ) -> None: schema = pyarrow_schema_simple_without_ids name_mapping = NameMapping([ - MappedField(field_id=1, names=['some_string']), + MappedField(field_id=1, names=['foo']), ]) with pytest.raises(ValueError) as exc_info: _ = pyarrow_to_schema(schema, name_mapping) - assert "Could not find field with name: some_int" in str(exc_info.value) + assert "Could not find field with name: bar" in str(exc_info.value) def test_nested_pyarrow_schema_to_schema_missing_ids_using_name_mapping( @@ -572,3 +476,15 @@ def test_pyarrow_schema_to_schema_missing_ids_using_name_mapping_nested_missing_ with pytest.raises(ValueError) as exc_info: _ = pyarrow_to_schema(schema, name_mapping) assert "Could not find field with name: quux.value.key" in str(exc_info.value) + + +def test_pyarrow_schema_to_schema_fresh_ids_simple_schema( + pyarrow_schema_simple_without_ids: pa.Schema, iceberg_schema_simple_no_ids: Schema +) -> None: + assert visit_pyarrow(pyarrow_schema_simple_without_ids, _ConvertToIcebergWithoutIDs()) == iceberg_schema_simple_no_ids + + +def test_pyarrow_schema_to_schema_fresh_ids_nested_schema( + pyarrow_schema_nested_without_ids: pa.Schema, iceberg_schema_nested_no_ids: Schema +) -> None: + assert visit_pyarrow(pyarrow_schema_nested_without_ids, _ConvertToIcebergWithoutIDs()) == iceberg_schema_nested_no_ids