From 5cce906db89fa1edbb57bb423ba371598ce50acb Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 13 Aug 2024 14:34:24 +0200 Subject: [PATCH] Use `VisitorWithPartner` for name-mapping (#1014) * Use `VisitorWithPartner` for name-mapping This will correctly handle fields with `.` in the name. * Fix versions in deprecation Co-authored-by: Sung Yun <107272191+sungwy@users.noreply.github.com> * Use full path in error --------- Co-authored-by: Sung Yun <107272191+sungwy@users.noreply.github.com> --- pyiceberg/io/pyarrow.py | 16 ++-- pyiceberg/table/name_mapping.py | 134 ++++++++++++++++++++++++++++++- tests/table/test_name_mapping.py | 52 +++++++++++- 3 files changed, 189 insertions(+), 13 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 719d289717..b2cb167adb 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -130,7 +130,7 @@ visit_with_partner, ) from pyiceberg.table.metadata import TableMetadata -from pyiceberg.table.name_mapping import NameMapping +from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping from pyiceberg.transforms import TruncateTransform from pyiceberg.typedef import EMPTY_DICT, Properties, Record from pyiceberg.types import ( @@ -818,14 +818,14 @@ def pyarrow_to_schema( ) -> Schema: has_ids = visit_pyarrow(schema, _HasIds()) if has_ids: - visitor = _ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) + return visit_pyarrow(schema, _ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)) elif name_mapping is not None: - visitor = _ConvertToIceberg(name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) + schema_without_ids = _pyarrow_to_schema_without_ids(schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) + return apply_name_mapping(schema_without_ids, name_mapping) else: raise ValueError( "Parquet file does not have field-ids and the Iceberg table does not have 'schema.name-mapping.default' defined" ) - return visit_pyarrow(schema, visitor) def _pyarrow_to_schema_without_ids(schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> Schema: @@ -1002,17 +1002,13 @@ class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]): """Converts PyArrowSchema to Iceberg Schema. Applies the IDs from name_mapping if provided.""" _field_names: List[str] - _name_mapping: Optional[NameMapping] - def __init__(self, name_mapping: Optional[NameMapping] = None, downcast_ns_timestamp_to_us: bool = False) -> None: + def __init__(self, downcast_ns_timestamp_to_us: bool = False) -> None: self._field_names = [] - self._name_mapping = name_mapping self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us def _field_id(self, field: pa.Field) -> int: - if self._name_mapping: - return self._name_mapping.find(*self._field_names).field_id - elif (field_id := _get_field_id(field)) is not None: + if (field_id := _get_field_id(field)) is not None: return field_id else: raise ValueError(f"Cannot convert {field} to Iceberg Field as field_id is empty.") diff --git a/pyiceberg/table/name_mapping.py b/pyiceberg/table/name_mapping.py index cb9f72bf97..eaf5fc855d 100644 --- a/pyiceberg/table/name_mapping.py +++ b/pyiceberg/table/name_mapping.py @@ -30,9 +30,10 @@ from pydantic import Field, conlist, field_validator, model_serializer -from pyiceberg.schema import Schema, SchemaVisitor, visit +from pyiceberg.schema import P, PartnerAccessor, Schema, SchemaVisitor, SchemaWithPartnerVisitor, visit, visit_with_partner from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel -from pyiceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType +from pyiceberg.types import IcebergType, ListType, MapType, NestedField, PrimitiveType, StructType +from pyiceberg.utils.deprecated import deprecated class MappedField(IcebergBaseModel): @@ -74,6 +75,11 @@ class NameMapping(IcebergRootModel[List[MappedField]]): def _field_by_name(self) -> Dict[str, MappedField]: return visit_name_mapping(self, _IndexByName()) + @deprecated( + deprecated_in="0.8.0", + removed_in="0.9.0", + help_message="Please use `apply_name_mapping` instead", + ) def find(self, *names: str) -> MappedField: name = ".".join(names) try: @@ -248,3 +254,127 @@ def create_mapping_from_schema(schema: Schema) -> NameMapping: def update_mapping(mapping: NameMapping, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]) -> NameMapping: return NameMapping(visit_name_mapping(mapping, _UpdateMapping(updates, adds))) + + +class NameMappingAccessor(PartnerAccessor[MappedField]): + def schema_partner(self, partner: Optional[MappedField]) -> Optional[MappedField]: + return partner + + def field_partner( + self, partner_struct: Optional[Union[List[MappedField], MappedField]], _: int, field_name: str + ) -> Optional[MappedField]: + if partner_struct is not None: + if isinstance(partner_struct, MappedField): + partner_struct = partner_struct.fields + + for field in partner_struct: + if field_name in field.names: + return field + + return None + + def list_element_partner(self, partner_list: Optional[MappedField]) -> Optional[MappedField]: + if partner_list is not None: + for field in partner_list.fields: + if "element" in field.names: + return field + return None + + def map_key_partner(self, partner_map: Optional[MappedField]) -> Optional[MappedField]: + if partner_map is not None: + for field in partner_map.fields: + if "key" in field.names: + return field + return None + + def map_value_partner(self, partner_map: Optional[MappedField]) -> Optional[MappedField]: + if partner_map is not None: + for field in partner_map.fields: + if "value" in field.names: + return field + return None + + +class NameMappingProjectionVisitor(SchemaWithPartnerVisitor[MappedField, IcebergType]): + current_path: List[str] + + def __init__(self) -> None: + # For keeping track where we are in case when a field cannot be found + self.current_path = [] + + def before_field(self, field: NestedField, field_partner: Optional[P]) -> None: + self.current_path.append(field.name) + + def after_field(self, field: NestedField, field_partner: Optional[P]) -> None: + self.current_path.pop() + + def before_list_element(self, element: NestedField, element_partner: Optional[P]) -> None: + self.current_path.append("element") + + def after_list_element(self, element: NestedField, element_partner: Optional[P]) -> None: + self.current_path.pop() + + def before_map_key(self, key: NestedField, key_partner: Optional[P]) -> None: + self.current_path.append("key") + + def after_map_key(self, key: NestedField, key_partner: Optional[P]) -> None: + self.current_path.pop() + + def before_map_value(self, value: NestedField, value_partner: Optional[P]) -> None: + self.current_path.append("value") + + def after_map_value(self, value: NestedField, value_partner: Optional[P]) -> None: + self.current_path.pop() + + def schema(self, schema: Schema, schema_partner: Optional[MappedField], struct_result: StructType) -> IcebergType: + return Schema(*struct_result.fields, schema_id=schema.schema_id) + + def struct(self, struct: StructType, struct_partner: Optional[MappedField], field_results: List[NestedField]) -> IcebergType: + return StructType(*field_results) + + def field(self, field: NestedField, field_partner: Optional[MappedField], field_result: IcebergType) -> IcebergType: + if field_partner is None: + raise ValueError(f"Field missing from NameMapping: {'.'.join(self.current_path)}") + + return NestedField( + field_id=field_partner.field_id, + name=field.name, + field_type=field_result, + required=field.required, + doc=field.doc, + initial_default=field.initial_default, + initial_write=field.write_default, + ) + + def list(self, list_type: ListType, list_partner: Optional[MappedField], element_result: IcebergType) -> IcebergType: + if list_partner is None: + raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}") + + element_id = next(field for field in list_partner.fields if "element" in field.names).field_id + return ListType(element_id=element_id, element=element_result, element_required=list_type.element_required) + + def map( + self, map_type: MapType, map_partner: Optional[MappedField], key_result: IcebergType, value_result: IcebergType + ) -> IcebergType: + if map_partner is None: + raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}") + + key_id = next(field for field in map_partner.fields if "key" in field.names).field_id + value_id = next(field for field in map_partner.fields if "value" in field.names).field_id + return MapType( + key_id=key_id, + key_type=key_result, + value_id=value_id, + value_type=value_result, + value_required=map_type.value_required, + ) + + def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[MappedField]) -> PrimitiveType: + if primitive_partner is None: + raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}") + + return primitive + + +def apply_name_mapping(schema_without_ids: Schema, name_mapping: NameMapping) -> Schema: + return visit_with_partner(schema_without_ids, name_mapping, NameMappingProjectionVisitor(), NameMappingAccessor()) # type: ignore diff --git a/tests/table/test_name_mapping.py b/tests/table/test_name_mapping.py index 3c50a24e5e..647644fa98 100644 --- a/tests/table/test_name_mapping.py +++ b/tests/table/test_name_mapping.py @@ -20,11 +20,12 @@ from pyiceberg.table.name_mapping import ( MappedField, NameMapping, + apply_name_mapping, create_mapping_from_schema, parse_mapping_from_json, update_mapping, ) -from pyiceberg.types import NestedField, StringType +from pyiceberg.types import BooleanType, FloatType, IntegerType, ListType, MapType, NestedField, StringType, StructType @pytest.fixture(scope="session") @@ -321,3 +322,52 @@ def test_update_mapping(table_name_mapping_nested: NameMapping) -> None: MappedField(field_id=18, names=["add_18"]), ]) assert update_mapping(table_name_mapping_nested, updates, adds) == expected + + +def test_mapping_using_by_visitor(table_schema_nested: Schema, table_name_mapping_nested: NameMapping) -> None: + schema_without_ids = Schema( + NestedField(field_id=0, name="foo", field_type=StringType(), required=False), + NestedField(field_id=0, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=0, name="baz", field_type=BooleanType(), required=False), + NestedField( + field_id=0, + name="qux", + field_type=ListType(element_id=0, element_type=StringType(), element_required=True), + required=True, + ), + NestedField( + field_id=0, + name="quux", + field_type=MapType( + key_id=0, + key_type=StringType(), + value_id=0, + value_type=MapType(key_id=0, key_type=StringType(), value_id=0, value_type=IntegerType(), value_required=True), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=0, + name="location", + field_type=ListType( + element_id=0, + element_type=StructType( + NestedField(field_id=0, name="latitude", field_type=FloatType(), required=False), + NestedField(field_id=0, name="longitude", field_type=FloatType(), required=False), + ), + element_required=True, + ), + required=True, + ), + NestedField( + field_id=0, + name="person", + field_type=StructType( + NestedField(field_id=0, name="name", field_type=StringType(), required=False), + NestedField(field_id=0, name="age", field_type=IntegerType(), required=True), + ), + required=False, + ), + ) + assert apply_name_mapping(schema_without_ids, table_name_mapping_nested).fields == table_schema_nested.fields