Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push table concatenation to Arrow #116

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 1 addition & 35 deletions pyiceberg/avro/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@
)
from pyiceberg.exceptions import ResolveError
from pyiceberg.schema import (
PartnerAccessor,
PrimitiveWithPartnerVisitor,
Schema,
SchemaPartnerAccessor,
SchemaVisitorPerPrimitiveType,
promote,
visit,
Expand Down Expand Up @@ -472,37 +472,3 @@ def visit_fixed(self, fixed_type: FixedType, partner: Optional[IcebergType]) ->

def visit_binary(self, binary_type: BinaryType, partner: Optional[IcebergType]) -> Reader:
return BinaryReader()


class SchemaPartnerAccessor(PartnerAccessor[IcebergType]):
def schema_partner(self, partner: Optional[IcebergType]) -> Optional[IcebergType]:
if isinstance(partner, Schema):
return partner.as_struct()

raise ResolveError(f"File/read schema are not aligned for schema, got {partner}")

def field_partner(self, partner: Optional[IcebergType], field_id: int, field_name: str) -> Optional[IcebergType]:
if isinstance(partner, StructType):
field = partner.field(field_id)
else:
raise ResolveError(f"File/read schema are not aligned for struct, got {partner}")

return field.field_type if field else None

def list_element_partner(self, partner_list: Optional[IcebergType]) -> Optional[IcebergType]:
if isinstance(partner_list, ListType):
return partner_list.element_type

raise ResolveError(f"File/read schema are not aligned for list, got {partner_list}")

def map_key_partner(self, partner_map: Optional[IcebergType]) -> Optional[IcebergType]:
if isinstance(partner_map, MapType):
return partner_map.key_type

raise ResolveError(f"File/read schema are not aligned for map, got {partner_map}")

def map_value_partner(self, partner_map: Optional[IcebergType]) -> Optional[IcebergType]:
if isinstance(partner_map, MapType):
return partner_map.value_type

raise ResolveError(f"File/read schema are not aligned for map, got {partner_map}")
142 changes: 23 additions & 119 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
)
from sortedcontainers import SortedList

from pyiceberg.avro.resolver import ResolveError
from pyiceberg.conversions import to_bytes
from pyiceberg.expressions import (
AlwaysTrue,
Expand Down Expand Up @@ -105,17 +104,13 @@
)
from pyiceberg.manifest import DataFile, FileFormat
from pyiceberg.schema import (
PartnerAccessor,
PreOrderSchemaVisitor,
Schema,
SchemaVisitorPerPrimitiveType,
SchemaWithPartnerVisitor,
pre_order_visit,
promote,
prune_columns,
sanitize_column_names,
visit,
visit_with_partner,
)
from pyiceberg.transforms import TruncateTransform
from pyiceberg.typedef import EMPTY_DICT, Properties
Expand Down Expand Up @@ -809,6 +804,7 @@ def _task_to_table(
task: FileScanTask,
bound_row_filter: BooleanExpression,
projected_schema: Schema,
projected_arrow_schema: pa.schema,
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
Expand Down Expand Up @@ -841,13 +837,21 @@ def _task_to_table(
if file_schema is None:
raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}")

columns = {
# Projecting nested fields doesn't work...
projected_schema.find_column_name(col.field_id): pc.field(col.name).cast(
schema_to_pyarrow(col.field_type)
)
for col in file_project_schema.columns
}

fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
schema=physical_schema,
# This will push down the query to Arrow.
# But in case there are positional deletes, we have to apply them first
filter=pyarrow_filter if not positional_deletes else None,
columns=[col.name for col in file_project_schema.columns],
columns=columns,
)

if positional_deletes:
Expand Down Expand Up @@ -885,7 +889,8 @@ def _task_to_table(

row_counts.append(len(arrow_table))

return to_requested_schema(projected_schema, file_project_schema, arrow_table)
# arrow_table.select(projected_arrow_schema)
return arrow_table.cast(projected_arrow_schema)


def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
Expand Down Expand Up @@ -946,6 +951,10 @@ def project_table(

bound_row_filter = bind(table.schema(), row_filter, case_sensitive=case_sensitive)

# Will raise an exception
_ = table.schema().is_compatible(projected_schema)
projected_schema_arrow = schema_to_pyarrow(projected_schema)

projected_field_ids = {
id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType))
}.union(extract_field_ids(bound_row_filter))
Expand All @@ -960,6 +969,7 @@ def project_table(
task,
bound_row_filter,
projected_schema,
projected_schema_arrow,
projected_field_ids,
deletes_per_file.get(task.file.file_path),
case_sensitive,
Expand All @@ -985,124 +995,18 @@ def project_table(

tables = [f.result() for f in completed_futures if f.result()]

empty_table = pa.Table.from_batches([], schema=projected_schema_arrow)

if len(tables) < 1:
return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema))
return empty_table

result = pa.concat_tables(tables)
result = pa.concat_tables([empty_table] + tables, promote_options="permissive")

if limit is not None:
return result.slice(0, limit)

return result


def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table:
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))

arrays = []
fields = []
for pos, field in enumerate(requested_schema.fields):
array = struct_array.field(pos)
arrays.append(array)
fields.append(pa.field(field.name, array.type, field.optional))
return pa.Table.from_arrays(arrays, schema=pa.schema(fields))


class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
file_schema: Schema

def __init__(self, file_schema: Schema):
self.file_schema = file_schema

def cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self.file_schema.find_field(field.field_id)
if field.field_type.is_primitive and field.field_type != file_field.field_type:
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type)))
return values

def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[pa.Array]:
return struct_result

def struct(
self, struct: StructType, struct_array: Optional[pa.Array], field_results: List[Optional[pa.Array]]
) -> Optional[pa.Array]:
if struct_array is None:
return None
field_arrays: List[pa.Array] = []
fields: List[pa.Field] = []
for field, field_array in zip(struct.fields, field_results):
if field_array is not None:
array = self.cast_if_needed(field, field_array)
field_arrays.append(array)
fields.append(pa.field(field.name, array.type, field.optional))
elif field.optional:
arrow_type = schema_to_pyarrow(field.field_type)
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))
fields.append(pa.field(field.name, arrow_type, field.optional))
else:
raise ResolveError(f"Field is required, and could not be found in the file: {field}")

return pa.StructArray.from_arrays(arrays=field_arrays, fields=pa.struct(fields))

def field(self, field: NestedField, _: Optional[pa.Array], field_array: Optional[pa.Array]) -> Optional[pa.Array]:
return field_array

def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: Optional[pa.Array]) -> Optional[pa.Array]:
return (
pa.ListArray.from_arrays(list_array.offsets, self.cast_if_needed(list_type.element_field, value_array))
if isinstance(list_array, pa.ListArray)
else None
)

def map(
self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array], value_result: Optional[pa.Array]
) -> Optional[pa.Array]:
return (
pa.MapArray.from_arrays(
map_array.offsets,
self.cast_if_needed(map_type.key_field, key_result),
self.cast_if_needed(map_type.value_field, value_result),
)
if isinstance(map_array, pa.MapArray)
else None
)

def primitive(self, _: PrimitiveType, array: Optional[pa.Array]) -> Optional[pa.Array]:
return array


class ArrowAccessor(PartnerAccessor[pa.Array]):
file_schema: Schema

def __init__(self, file_schema: Schema):
self.file_schema = file_schema

def schema_partner(self, partner: Optional[pa.Array]) -> Optional[pa.Array]:
return partner

def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: str) -> Optional[pa.Array]:
if partner_struct:
# use the field name from the file schema
try:
name = self.file_schema.find_field(field_id).name
except ValueError:
return None

if isinstance(partner_struct, pa.StructArray):
return partner_struct.field(name)
elif isinstance(partner_struct, pa.Table):
return partner_struct.column(name).combine_chunks()

return None

def list_element_partner(self, partner_list: Optional[pa.Array]) -> Optional[pa.Array]:
return partner_list.values if isinstance(partner_list, pa.ListArray) else None

def map_key_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]:
return partner_map.keys if isinstance(partner_map, pa.MapArray) else None

def map_value_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]:
return partner_map.items if isinstance(partner_map, pa.MapArray) else None
# This cast is still needed for projecting away nested fields
return result.cast(projected_schema_arrow)


def _primitive_to_physical(iceberg_type: PrimitiveType) -> str:
Expand Down
77 changes: 77 additions & 0 deletions pyiceberg/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,17 @@ def _validate_identifier_field(self, field_id: int) -> None:
f"Cannot add field {field.name} as an identifier field: must not be nested in an optional field {parent}"
)

def is_compatible(self, promoted: Schema) -> Schema:
"""Promotes the schema.

Args:
promoted: the destination schema to promote to.

Returns:
The promoted schema.
"""
return visit_with_partner(promoted, self, _CheckSchemaCompatibility(), SchemaPartnerAccessor())


class SchemaVisitor(Generic[T], ABC):
def before_field(self, field: NestedField) -> None:
Expand Down Expand Up @@ -1589,3 +1600,69 @@ def _(file_type: FixedType, read_type: IcebergType) -> IcebergType:
return read_type
else:
raise ResolveError(f"Cannot promote {file_type} to {read_type}")


class SchemaPartnerAccessor(PartnerAccessor[IcebergType]):
def schema_partner(self, partner: Optional[IcebergType]) -> Optional[IcebergType]:
if isinstance(partner, Schema):
return partner.as_struct()

raise ResolveError(f"File/read schema are not aligned for schema, got {partner}")

def field_partner(self, partner: Optional[IcebergType], field_id: int, field_name: str) -> Optional[IcebergType]:
if isinstance(partner, StructType):
field = partner.field(field_id)
else:
raise ResolveError(f"File/read schema are not aligned for struct, got {partner}")

return field.field_type if field else None

def list_element_partner(self, partner_list: Optional[IcebergType]) -> Optional[IcebergType]:
if isinstance(partner_list, ListType):
return partner_list.element_type

raise ResolveError(f"File/read schema are not aligned for list, got {partner_list}")

def map_key_partner(self, partner_map: Optional[IcebergType]) -> Optional[IcebergType]:
if isinstance(partner_map, MapType):
return partner_map.key_type

raise ResolveError(f"File/read schema are not aligned for map, got {partner_map}")

def map_value_partner(self, partner_map: Optional[IcebergType]) -> Optional[IcebergType]:
if isinstance(partner_map, MapType):
return partner_map.value_type

raise ResolveError(f"File/read schema are not aligned for map, got {partner_map}")


class _CheckSchemaCompatibility(SchemaWithPartnerVisitor[IcebergType, IcebergType]):
def schema(self, schema: Schema, schema_partner: Optional[IcebergType], struct_result: IcebergType) -> IcebergType:
return schema

def struct(self, struct: StructType, struct_partner: Optional[IcebergType], field_results: List[IcebergType]) -> IcebergType:
"""Visit a struct type with a partner."""
return struct

def field(self, field: NestedField, field_partner: Optional[IcebergType], field_result: IcebergType) -> IcebergType:
"""Visit a nested field with a partner."""
if field_partner is None and field.required:
raise ResolveError(f"Field is required, and could not be found: {field}")

return field

def list(self, list_type: ListType, list_partner: Optional[IcebergType], element_result: IcebergType) -> IcebergType:
"""Visit a list type with a partner."""
return list_type

def map(
self, map_type: MapType, map_partner: Optional[IcebergType], key_result: IcebergType, value_result: IcebergType
) -> IcebergType:
"""Visit a map type with a partner."""
return map_type

def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[IcebergType]) -> IcebergType:
if primitive_partner is not None and primitive != primitive_partner:
return promote(primitive_partner, primitive)

return primitive
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ sortedcontainers = "2.4.0"
fsspec = ">=2023.1.0,<2024.1.0"
pyparsing = ">=3.1.0,<4.0.0"
zstandard = ">=0.13.0,<1.0.0"
pyarrow = { version = ">=9.0.0,<15.0.0", optional = true }
pyarrow = { version = ">=14.0.0", optional = true }
pandas = { version = ">=1.0.0,<3.0.0", optional = true }
duckdb = { version = ">=0.5.0,<1.0.0", optional = true }
ray = { version = ">=2.0.0,<3.0.0", optional = true }
Expand Down Expand Up @@ -123,9 +123,9 @@ markers = [
]

# Turns a warning into an error
filterwarnings = [
"error"
]
#filterwarnings = [
# "error"
#]

[tool.black]
line-length = 130
Expand Down
Loading