Skip to content

Commit 523ff61

Browse files
committed
trying out naive hack
1 parent 7ec1c04 commit 523ff61

File tree

1 file changed

+89
-52
lines changed

1 file changed

+89
-52
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 89 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -590,14 +590,16 @@ def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.Fi
590590
raise ValueError(f"Unsupported file format: {file_format}")
591591

592592

593-
def _construct_fragment(fs: FileSystem, data_file: DataFile, file_format_kwargs: Dict[str, Any] = EMPTY_DICT) -> ds.Fragment:
593+
def _construct_fragment(fs: FileSystem, data_file: DataFile,
594+
file_format_kwargs: Dict[str, Any] = EMPTY_DICT) -> ds.Fragment:
594595
_, _, path = PyArrowFileIO.parse_location(data_file.file_path)
595596
return _get_file_format(data_file.file_format, **file_format_kwargs).make_fragment(path, fs)
596597

597598

598599
def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedArray]:
599600
delete_fragment = _construct_fragment(
600-
fs, data_file, file_format_kwargs={"dictionary_columns": ("file_path",), "pre_buffer": True, "buffer_size": ONE_MEGABYTE}
601+
fs, data_file,
602+
file_format_kwargs={"dictionary_columns": ("file_path",), "pre_buffer": True, "buffer_size": ONE_MEGABYTE}
601603
)
602604
table = ds.Scanner.from_fragment(fragment=delete_fragment).to_table()
603605
table = table.unify_dictionaries()
@@ -729,7 +731,8 @@ def _get_field_doc(field: pa.Field) -> Optional[str]:
729731

730732

731733
class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
732-
def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[NestedField]:
734+
def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[
735+
NestedField]:
733736
fields = []
734737
for i, field in enumerate(arrow_fields):
735738
field_id = _get_field_id(field)
@@ -753,7 +756,7 @@ def list(self, list_type: pa.ListType, element_result: Optional[IcebergType]) ->
753756
return None
754757

755758
def map(
756-
self, map_type: pa.MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
759+
self, map_type: pa.MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
757760
) -> Optional[IcebergType]:
758761
key_field = map_type.key_field
759762
key_id = _get_field_id(key_field)
@@ -798,17 +801,39 @@ def primitive(self, primitive: pa.DataType) -> IcebergType:
798801

799802
raise TypeError(f"Unsupported type: {primitive}")
800803

804+
# ToDo get guidance on where this should be and if we can find an exhaustive list of magic
805+
parquet_magic_columns = {
806+
"""
807+
Apache Iceberg -> Parquet converts column names like
808+
"foo:bar" to "foo_x3A" within the parquet file itself
809+
"""
810+
":": "_x3A"
811+
}
812+
813+
# ToDo get guidance on where this should be, and how we want to flag it
814+
def _hack_names(column_name_list: list[str], enabled: bool):
815+
if enabled:
816+
o = []
817+
# ToDo fix time and space complexity
818+
for key in parquet_magic_columns.keys():
819+
for column_name in column_name_list:
820+
if key in column_name:
821+
o.append(column_name.replace(key, parquet_magic_columns[key]))
822+
else:
823+
o.append(column_name)
824+
return o
825+
return column_name_list
801826

802827
def _task_to_table(
803-
fs: FileSystem,
804-
task: FileScanTask,
805-
bound_row_filter: BooleanExpression,
806-
projected_schema: Schema,
807-
projected_field_ids: Set[int],
808-
positional_deletes: Optional[List[ChunkedArray]],
809-
case_sensitive: bool,
810-
row_counts: List[int],
811-
limit: Optional[int] = None,
828+
fs: FileSystem,
829+
task: FileScanTask,
830+
bound_row_filter: BooleanExpression,
831+
projected_schema: Schema,
832+
projected_field_ids: Set[int],
833+
positional_deletes: Optional[List[ChunkedArray]],
834+
case_sensitive: bool,
835+
row_counts: List[int],
836+
limit: Optional[int] = None,
812837
) -> Optional[pa.Table]:
813838
if limit and sum(row_counts) >= limit:
814839
return None
@@ -823,15 +848,17 @@ def _task_to_table(
823848
schema_raw = metadata.get(ICEBERG_SCHEMA)
824849
# TODO: if field_ids are not present, Name Mapping should be implemented to look them up in the table schema,
825850
# see https://github.com/apache/iceberg/issues/7451
826-
file_schema = Schema.model_validate_json(schema_raw) if schema_raw is not None else pyarrow_to_schema(physical_schema)
851+
file_schema = Schema.model_validate_json(schema_raw) if schema_raw is not None else pyarrow_to_schema(
852+
physical_schema)
827853

828854
pyarrow_filter = None
829855
if bound_row_filter is not AlwaysTrue():
830856
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
831857
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
832858
pyarrow_filter = expression_to_pyarrow(bound_file_filter)
833859

834-
file_project_schema = sanitize_column_names(prune_columns(file_schema, projected_field_ids, select_full_types=False))
860+
file_project_schema = sanitize_column_names(
861+
prune_columns(file_schema, projected_field_ids, select_full_types=False))
835862

836863
if file_schema is None:
837864
raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}")
@@ -842,7 +869,7 @@ def _task_to_table(
842869
# This will push down the query to Arrow.
843870
# But in case there are positional deletes, we have to apply them first
844871
filter=pyarrow_filter if not positional_deletes else None,
845-
columns=[col.name for col in file_project_schema.columns],
872+
columns=_hack_names([col.name for col in file_project_schema.columns], True),
846873
)
847874

848875
if positional_deletes:
@@ -902,12 +929,12 @@ def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dic
902929

903930

904931
def project_table(
905-
tasks: Iterable[FileScanTask],
906-
table: Table,
907-
row_filter: BooleanExpression,
908-
projected_schema: Schema,
909-
case_sensitive: bool = True,
910-
limit: Optional[int] = None,
932+
tasks: Iterable[FileScanTask],
933+
table: Table,
934+
row_filter: BooleanExpression,
935+
projected_schema: Schema,
936+
case_sensitive: bool = True,
937+
limit: Optional[int] = None,
911938
) -> pa.Table:
912939
"""Resolve the right columns based on the identifier.
913940
@@ -992,7 +1019,8 @@ def project_table(
9921019

9931020

9941021
def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table:
995-
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
1022+
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema),
1023+
ArrowAccessor(file_schema))
9961024

9971025
arrays = []
9981026
fields = []
@@ -1015,11 +1043,12 @@ def cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
10151043
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type)))
10161044
return values
10171045

1018-
def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[pa.Array]:
1046+
def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[
1047+
pa.Array]:
10191048
return struct_result
10201049

10211050
def struct(
1022-
self, struct: StructType, struct_array: Optional[pa.Array], field_results: List[Optional[pa.Array]]
1051+
self, struct: StructType, struct_array: Optional[pa.Array], field_results: List[Optional[pa.Array]]
10231052
) -> Optional[pa.Array]:
10241053
if struct_array is None:
10251054
return None
@@ -1042,15 +1071,17 @@ def struct(
10421071
def field(self, field: NestedField, _: Optional[pa.Array], field_array: Optional[pa.Array]) -> Optional[pa.Array]:
10431072
return field_array
10441073

1045-
def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: Optional[pa.Array]) -> Optional[pa.Array]:
1074+
def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: Optional[pa.Array]) -> Optional[
1075+
pa.Array]:
10461076
return (
10471077
pa.ListArray.from_arrays(list_array.offsets, self.cast_if_needed(list_type.element_field, value_array))
10481078
if isinstance(list_array, pa.ListArray)
10491079
else None
10501080
)
10511081

10521082
def map(
1053-
self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array], value_result: Optional[pa.Array]
1083+
self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array],
1084+
value_result: Optional[pa.Array]
10541085
) -> Optional[pa.Array]:
10551086
return (
10561087
pa.MapArray.from_arrays(
@@ -1171,7 +1202,8 @@ class StatsAggregator:
11711202
current_max: Any
11721203
trunc_length: Optional[int]
11731204

1174-
def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str, trunc_length: Optional[int] = None) -> None:
1205+
def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str,
1206+
trunc_length: Optional[int] = None) -> None:
11751207
self.current_min = None
11761208
self.current_max = None
11771209
self.trunc_length = trunc_length
@@ -1284,27 +1316,30 @@ def __init__(self, schema: Schema, properties: Dict[str, str]):
12841316
self._properties = properties
12851317
self._default_mode = self._properties.get(DEFAULT_METRICS_MODE_KEY)
12861318

1287-
def schema(self, schema: Schema, struct_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]:
1319+
def schema(self, schema: Schema, struct_result: Callable[[], List[StatisticsCollector]]) -> List[
1320+
StatisticsCollector]:
12881321
return struct_result()
12891322

12901323
def struct(
1291-
self, struct: StructType, field_results: List[Callable[[], List[StatisticsCollector]]]
1324+
self, struct: StructType, field_results: List[Callable[[], List[StatisticsCollector]]]
12921325
) -> List[StatisticsCollector]:
12931326
return list(chain(*[result() for result in field_results]))
12941327

1295-
def field(self, field: NestedField, field_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]:
1328+
def field(self, field: NestedField, field_result: Callable[[], List[StatisticsCollector]]) -> List[
1329+
StatisticsCollector]:
12961330
self._field_id = field.field_id
12971331
return field_result()
12981332

1299-
def list(self, list_type: ListType, element_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]:
1333+
def list(self, list_type: ListType, element_result: Callable[[], List[StatisticsCollector]]) -> List[
1334+
StatisticsCollector]:
13001335
self._field_id = list_type.element_id
13011336
return element_result()
13021337

13031338
def map(
1304-
self,
1305-
map_type: MapType,
1306-
key_result: Callable[[], List[StatisticsCollector]],
1307-
value_result: Callable[[], List[StatisticsCollector]],
1339+
self,
1340+
map_type: MapType,
1341+
key_result: Callable[[], List[StatisticsCollector]],
1342+
value_result: Callable[[], List[StatisticsCollector]],
13081343
) -> List[StatisticsCollector]:
13091344
self._field_id = map_type.key_id
13101345
k = key_result()
@@ -1327,8 +1362,8 @@ def primitive(self, primitive: PrimitiveType) -> List[StatisticsCollector]:
13271362
metrics_mode = match_metrics_mode(col_mode)
13281363

13291364
if (
1330-
not (isinstance(primitive, StringType) or isinstance(primitive, BinaryType))
1331-
and metrics_mode.type == MetricModeTypes.TRUNCATE
1365+
not (isinstance(primitive, StringType) or isinstance(primitive, BinaryType))
1366+
and metrics_mode.type == MetricModeTypes.TRUNCATE
13321367
):
13331368
metrics_mode = MetricsMode(MetricModeTypes.FULL)
13341369

@@ -1337,12 +1372,13 @@ def primitive(self, primitive: PrimitiveType) -> List[StatisticsCollector]:
13371372
if is_nested and metrics_mode.type in [MetricModeTypes.TRUNCATE, MetricModeTypes.FULL]:
13381373
metrics_mode = MetricsMode(MetricModeTypes.COUNTS)
13391374

1340-
return [StatisticsCollector(field_id=self._field_id, iceberg_type=primitive, mode=metrics_mode, column_name=column_name)]
1375+
return [StatisticsCollector(field_id=self._field_id, iceberg_type=primitive, mode=metrics_mode,
1376+
column_name=column_name)]
13411377

13421378

13431379
def compute_statistics_plan(
1344-
schema: Schema,
1345-
table_properties: Dict[str, str],
1380+
schema: Schema,
1381+
table_properties: Dict[str, str],
13461382
) -> Dict[int, StatisticsCollector]:
13471383
"""
13481384
Compute the statistics plan for all columns.
@@ -1381,7 +1417,8 @@ def __init__(self) -> None:
13811417
def schema(self, schema: Schema, struct_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]:
13821418
return struct_result()
13831419

1384-
def struct(self, struct: StructType, field_results: List[Callable[[], List[ID2ParquetPath]]]) -> List[ID2ParquetPath]:
1420+
def struct(self, struct: StructType, field_results: List[Callable[[], List[ID2ParquetPath]]]) -> List[
1421+
ID2ParquetPath]:
13851422
return list(chain(*[result() for result in field_results]))
13861423

13871424
def field(self, field: NestedField, field_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]:
@@ -1399,10 +1436,10 @@ def list(self, list_type: ListType, element_result: Callable[[], List[ID2Parquet
13991436
return result
14001437

14011438
def map(
1402-
self,
1403-
map_type: MapType,
1404-
key_result: Callable[[], List[ID2ParquetPath]],
1405-
value_result: Callable[[], List[ID2ParquetPath]],
1439+
self,
1440+
map_type: MapType,
1441+
key_result: Callable[[], List[ID2ParquetPath]],
1442+
value_result: Callable[[], List[ID2ParquetPath]],
14061443
) -> List[ID2ParquetPath]:
14071444
self._field_id = map_type.key_id
14081445
self._path.append("key_value.key")
@@ -1419,7 +1456,7 @@ def primitive(self, primitive: PrimitiveType) -> List[ID2ParquetPath]:
14191456

14201457

14211458
def parquet_path_to_id_mapping(
1422-
schema: Schema,
1459+
schema: Schema,
14231460
) -> Dict[str, int]:
14241461
"""
14251462
Compute the mapping of parquet column path to Iceberg ID.
@@ -1438,11 +1475,11 @@ def parquet_path_to_id_mapping(
14381475

14391476

14401477
def fill_parquet_file_metadata(
1441-
df: DataFile,
1442-
parquet_metadata: pq.FileMetaData,
1443-
file_size: int,
1444-
stats_columns: Dict[int, StatisticsCollector],
1445-
parquet_column_mapping: Dict[str, int],
1478+
df: DataFile,
1479+
parquet_metadata: pq.FileMetaData,
1480+
file_size: int,
1481+
stats_columns: Dict[int, StatisticsCollector],
1482+
parquet_column_mapping: Dict[str, int],
14461483
) -> None:
14471484
"""
14481485
Compute and fill the following fields of the DataFile object.

0 commit comments

Comments
 (0)