Skip to content

Commit

Permalink
Add include_field_ids flag in schema_to_pyarrow (#789)
Browse files Browse the repository at this point in the history
* include_field_ids flag

* include_field_ids flag
  • Loading branch information
sungwy authored Jun 3, 2024
1 parent 31c6c23 commit e61ef57
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 37 deletions.
25 changes: 16 additions & 9 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,15 +469,18 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
self.fs_by_scheme = lru_cache(self._initialize_fs)


def schema_to_pyarrow(schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT) -> pa.schema:
return visit(schema, _ConvertToArrowSchema(metadata))
def schema_to_pyarrow(
schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True
) -> pa.schema:
return visit(schema, _ConvertToArrowSchema(metadata, include_field_ids))


class _ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]):
_metadata: Dict[bytes, bytes]

def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT) -> None:
def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True) -> None:
self._metadata = metadata
self._include_field_ids = include_field_ids

def schema(self, _: Schema, struct_result: pa.StructType) -> pa.schema:
return pa.schema(list(struct_result), metadata=self._metadata)
Expand All @@ -486,13 +489,17 @@ def struct(self, _: StructType, field_results: List[pa.DataType]) -> pa.DataType
return pa.struct(field_results)

def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field:
metadata = {}
if field.doc:
metadata[PYARROW_FIELD_DOC_KEY] = field.doc
if self._include_field_ids:
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)

return pa.field(
name=field.name,
type=field_result,
nullable=field.optional,
metadata={PYARROW_FIELD_DOC_KEY: field.doc, PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)}
if field.doc
else {PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)},
metadata=metadata,
)

def list(self, list_type: ListType, element_result: pa.DataType) -> pa.DataType:
Expand Down Expand Up @@ -1130,7 +1137,7 @@ def project_table(
tables = [f.result() for f in completed_futures if f.result()]

if len(tables) < 1:
return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema))
return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema, include_field_ids=False))

result = pa.concat_tables(tables)

Expand Down Expand Up @@ -1161,7 +1168,7 @@ def __init__(self, file_schema: Schema):
def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self.file_schema.find_field(field.field_id)
if field.field_type.is_primitive and field.field_type != file_field.field_type:
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type)))
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False))
return values

def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
Expand All @@ -1188,7 +1195,7 @@ def struct(
field_arrays.append(array)
fields.append(self._construct_field(field, array.type))
elif field.optional:
arrow_type = schema_to_pyarrow(field.field_type)
arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=False)
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))
fields.append(self._construct_field(field, arrow_type))
else:
Expand Down
57 changes: 29 additions & 28 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def test_deleting_hdfs_file_not_found() -> None:
assert "Cannot delete file, does not exist:" in str(exc_info.value)


def test_schema_to_pyarrow_schema(table_schema_nested: Schema) -> None:
def test_schema_to_pyarrow_schema_include_field_ids(table_schema_nested: Schema) -> None:
actual = schema_to_pyarrow(table_schema_nested)
expected = """foo: string
-- field metadata --
Expand Down Expand Up @@ -402,6 +402,30 @@ def test_schema_to_pyarrow_schema(table_schema_nested: Schema) -> None:
assert repr(actual) == expected


def test_schema_to_pyarrow_schema_exclude_field_ids(table_schema_nested: Schema) -> None:
actual = schema_to_pyarrow(table_schema_nested, include_field_ids=False)
expected = """foo: string
bar: int32 not null
baz: bool
qux: list<element: string not null> not null
child 0, element: string not null
quux: map<string, map<string, int32>> not null
child 0, entries: struct<key: string not null, value: map<string, int32> not null> not null
child 0, key: string not null
child 1, value: map<string, int32> not null
child 0, entries: struct<key: string not null, value: int32 not null> not null
child 0, key: string not null
child 1, value: int32 not null
location: list<element: struct<latitude: float, longitude: float> not null> not null
child 0, element: struct<latitude: float, longitude: float> not null
child 0, latitude: float
child 1, longitude: float
person: struct<name: string, age: int32 not null>
child 0, name: string
child 1, age: int32 not null"""
assert repr(actual) == expected


def test_fixed_type_to_pyarrow() -> None:
length = 22
iceberg_type = FixedType(length)
Expand Down Expand Up @@ -945,23 +969,13 @@ def test_projection_add_column(file_int: str) -> None:
== """id: int32
list: list<element: int32>
child 0, element: int32
-- field metadata --
PARQUET:field_id: '21'
map: map<int32, string>
child 0, entries: struct<key: int32 not null, value: string> not null
child 0, key: int32 not null
-- field metadata --
PARQUET:field_id: '31'
child 1, value: string
-- field metadata --
PARQUET:field_id: '32'
location: struct<lat: double, lon: double>
child 0, lat: double
-- field metadata --
PARQUET:field_id: '41'
child 1, lon: double
-- field metadata --
PARQUET:field_id: '42'"""
child 1, lon: double"""
)


Expand Down Expand Up @@ -1014,11 +1028,7 @@ def test_projection_add_column_struct(schema_int: Schema, file_int: str) -> None
== """id: map<int32, string>
child 0, entries: struct<key: int32 not null, value: string> not null
child 0, key: int32 not null
-- field metadata --
PARQUET:field_id: '3'
child 1, value: string
-- field metadata --
PARQUET:field_id: '4'"""
child 1, value: string"""
)


Expand Down Expand Up @@ -1062,12 +1072,7 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None:
def test_projection_filter(schema_int: Schema, file_int: str) -> None:
result_table = project(schema_int, [file_int], GreaterThan("id", 4))
assert len(result_table.columns[0]) == 0
assert (
repr(result_table.schema)
== """id: int32
-- field metadata --
PARQUET:field_id: '1'"""
)
assert repr(result_table.schema) == """id: int32"""


def test_projection_filter_renamed_column(file_int: str) -> None:
Expand Down Expand Up @@ -1304,11 +1309,7 @@ def test_projection_nested_struct_different_parent_id(file_struct: str) -> None:
repr(result_table.schema)
== """location: struct<lat: double, long: double>
child 0, lat: double
-- field metadata --
PARQUET:field_id: '41'
child 1, long: double
-- field metadata --
PARQUET:field_id: '42'"""
child 1, long: double"""
)


Expand Down

0 comments on commit e61ef57

Please sign in to comment.