Skip to content

Commit

Permalink
removing whole file reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
MarquisC committed Oct 27, 2023
1 parent 523ff61 commit b8c2ae3
Showing 1 changed file with 51 additions and 66 deletions.
117 changes: 51 additions & 66 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,16 +590,14 @@ def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.Fi
raise ValueError(f"Unsupported file format: {file_format}")


def _construct_fragment(fs: FileSystem, data_file: DataFile,
file_format_kwargs: Dict[str, Any] = EMPTY_DICT) -> ds.Fragment:
def _construct_fragment(fs: FileSystem, data_file: DataFile, file_format_kwargs: Dict[str, Any] = EMPTY_DICT) -> ds.Fragment:
_, _, path = PyArrowFileIO.parse_location(data_file.file_path)
return _get_file_format(data_file.file_format, **file_format_kwargs).make_fragment(path, fs)


def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedArray]:
delete_fragment = _construct_fragment(
fs, data_file,
file_format_kwargs={"dictionary_columns": ("file_path",), "pre_buffer": True, "buffer_size": ONE_MEGABYTE}
fs, data_file, file_format_kwargs={"dictionary_columns": ("file_path",), "pre_buffer": True, "buffer_size": ONE_MEGABYTE}
)
table = ds.Scanner.from_fragment(fragment=delete_fragment).to_table()
table = table.unify_dictionaries()
Expand Down Expand Up @@ -731,8 +729,7 @@ def _get_field_doc(field: pa.Field) -> Optional[str]:


class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[
NestedField]:
def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[NestedField]:
fields = []
for i, field in enumerate(arrow_fields):
field_id = _get_field_id(field)
Expand All @@ -756,7 +753,7 @@ def list(self, list_type: pa.ListType, element_result: Optional[IcebergType]) ->
return None

def map(
self, map_type: pa.MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
self, map_type: pa.MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
) -> Optional[IcebergType]:
key_field = map_type.key_field
key_id = _get_field_id(key_field)
Expand Down Expand Up @@ -825,15 +822,15 @@ def _hack_names(column_name_list: list[str], enabled: bool):
return column_name_list

def _task_to_table(
fs: FileSystem,
task: FileScanTask,
bound_row_filter: BooleanExpression,
projected_schema: Schema,
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
row_counts: List[int],
limit: Optional[int] = None,
fs: FileSystem,
task: FileScanTask,
bound_row_filter: BooleanExpression,
projected_schema: Schema,
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
row_counts: List[int],
limit: Optional[int] = None,
) -> Optional[pa.Table]:
if limit and sum(row_counts) >= limit:
return None
Expand All @@ -848,17 +845,15 @@ def _task_to_table(
schema_raw = metadata.get(ICEBERG_SCHEMA)
# TODO: if field_ids are not present, Name Mapping should be implemented to look them up in the table schema,
# see https://github.com/apache/iceberg/issues/7451
file_schema = Schema.model_validate_json(schema_raw) if schema_raw is not None else pyarrow_to_schema(
physical_schema)
file_schema = Schema.model_validate_json(schema_raw) if schema_raw is not None else pyarrow_to_schema(physical_schema)

pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
pyarrow_filter = expression_to_pyarrow(bound_file_filter)

file_project_schema = sanitize_column_names(
prune_columns(file_schema, projected_field_ids, select_full_types=False))
file_project_schema = sanitize_column_names(prune_columns(file_schema, projected_field_ids, select_full_types=False))

if file_schema is None:
raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}")
Expand Down Expand Up @@ -929,12 +924,12 @@ def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dic


def project_table(
tasks: Iterable[FileScanTask],
table: Table,
row_filter: BooleanExpression,
projected_schema: Schema,
case_sensitive: bool = True,
limit: Optional[int] = None,
tasks: Iterable[FileScanTask],
table: Table,
row_filter: BooleanExpression,
projected_schema: Schema,
case_sensitive: bool = True,
limit: Optional[int] = None,
) -> pa.Table:
"""Resolve the right columns based on the identifier.
Expand Down Expand Up @@ -1019,8 +1014,7 @@ def project_table(


def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table:
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema),
ArrowAccessor(file_schema))
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))

arrays = []
fields = []
Expand All @@ -1043,12 +1037,11 @@ def cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type)))
return values

def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[
pa.Array]:
def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[pa.Array]:
return struct_result

def struct(
self, struct: StructType, struct_array: Optional[pa.Array], field_results: List[Optional[pa.Array]]
self, struct: StructType, struct_array: Optional[pa.Array], field_results: List[Optional[pa.Array]]
) -> Optional[pa.Array]:
if struct_array is None:
return None
Expand All @@ -1071,17 +1064,15 @@ def struct(
def field(self, field: NestedField, _: Optional[pa.Array], field_array: Optional[pa.Array]) -> Optional[pa.Array]:
return field_array

def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: Optional[pa.Array]) -> Optional[
pa.Array]:
def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: Optional[pa.Array]) -> Optional[pa.Array]:
return (
pa.ListArray.from_arrays(list_array.offsets, self.cast_if_needed(list_type.element_field, value_array))
if isinstance(list_array, pa.ListArray)
else None
)

def map(
self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array],
value_result: Optional[pa.Array]
self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array], value_result: Optional[pa.Array]
) -> Optional[pa.Array]:
return (
pa.MapArray.from_arrays(
Expand Down Expand Up @@ -1202,8 +1193,7 @@ class StatsAggregator:
current_max: Any
trunc_length: Optional[int]

def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str,
trunc_length: Optional[int] = None) -> None:
def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str, trunc_length: Optional[int] = None) -> None:
self.current_min = None
self.current_max = None
self.trunc_length = trunc_length
Expand Down Expand Up @@ -1316,30 +1306,27 @@ def __init__(self, schema: Schema, properties: Dict[str, str]):
self._properties = properties
self._default_mode = self._properties.get(DEFAULT_METRICS_MODE_KEY)

def schema(self, schema: Schema, struct_result: Callable[[], List[StatisticsCollector]]) -> List[
StatisticsCollector]:
def schema(self, schema: Schema, struct_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]:
return struct_result()

def struct(
self, struct: StructType, field_results: List[Callable[[], List[StatisticsCollector]]]
self, struct: StructType, field_results: List[Callable[[], List[StatisticsCollector]]]
) -> List[StatisticsCollector]:
return list(chain(*[result() for result in field_results]))

def field(self, field: NestedField, field_result: Callable[[], List[StatisticsCollector]]) -> List[
StatisticsCollector]:
def field(self, field: NestedField, field_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]:
self._field_id = field.field_id
return field_result()

def list(self, list_type: ListType, element_result: Callable[[], List[StatisticsCollector]]) -> List[
StatisticsCollector]:
def list(self, list_type: ListType, element_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]:
self._field_id = list_type.element_id
return element_result()

def map(
self,
map_type: MapType,
key_result: Callable[[], List[StatisticsCollector]],
value_result: Callable[[], List[StatisticsCollector]],
self,
map_type: MapType,
key_result: Callable[[], List[StatisticsCollector]],
value_result: Callable[[], List[StatisticsCollector]],
) -> List[StatisticsCollector]:
self._field_id = map_type.key_id
k = key_result()
Expand All @@ -1362,8 +1349,8 @@ def primitive(self, primitive: PrimitiveType) -> List[StatisticsCollector]:
metrics_mode = match_metrics_mode(col_mode)

if (
not (isinstance(primitive, StringType) or isinstance(primitive, BinaryType))
and metrics_mode.type == MetricModeTypes.TRUNCATE
not (isinstance(primitive, StringType) or isinstance(primitive, BinaryType))
and metrics_mode.type == MetricModeTypes.TRUNCATE
):
metrics_mode = MetricsMode(MetricModeTypes.FULL)

Expand All @@ -1372,13 +1359,12 @@ def primitive(self, primitive: PrimitiveType) -> List[StatisticsCollector]:
if is_nested and metrics_mode.type in [MetricModeTypes.TRUNCATE, MetricModeTypes.FULL]:
metrics_mode = MetricsMode(MetricModeTypes.COUNTS)

return [StatisticsCollector(field_id=self._field_id, iceberg_type=primitive, mode=metrics_mode,
column_name=column_name)]
return [StatisticsCollector(field_id=self._field_id, iceberg_type=primitive, mode=metrics_mode, column_name=column_name)]


def compute_statistics_plan(
schema: Schema,
table_properties: Dict[str, str],
schema: Schema,
table_properties: Dict[str, str],
) -> Dict[int, StatisticsCollector]:
"""
Compute the statistics plan for all columns.
Expand Down Expand Up @@ -1417,8 +1403,7 @@ def __init__(self) -> None:
def schema(self, schema: Schema, struct_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]:
return struct_result()

def struct(self, struct: StructType, field_results: List[Callable[[], List[ID2ParquetPath]]]) -> List[
ID2ParquetPath]:
def struct(self, struct: StructType, field_results: List[Callable[[], List[ID2ParquetPath]]]) -> List[ID2ParquetPath]:
return list(chain(*[result() for result in field_results]))

def field(self, field: NestedField, field_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]:
Expand All @@ -1436,10 +1421,10 @@ def list(self, list_type: ListType, element_result: Callable[[], List[ID2Parquet
return result

def map(
self,
map_type: MapType,
key_result: Callable[[], List[ID2ParquetPath]],
value_result: Callable[[], List[ID2ParquetPath]],
self,
map_type: MapType,
key_result: Callable[[], List[ID2ParquetPath]],
value_result: Callable[[], List[ID2ParquetPath]],
) -> List[ID2ParquetPath]:
self._field_id = map_type.key_id
self._path.append("key_value.key")
Expand All @@ -1456,7 +1441,7 @@ def primitive(self, primitive: PrimitiveType) -> List[ID2ParquetPath]:


def parquet_path_to_id_mapping(
schema: Schema,
schema: Schema,
) -> Dict[str, int]:
"""
Compute the mapping of parquet column path to Iceberg ID.
Expand All @@ -1475,11 +1460,11 @@ def parquet_path_to_id_mapping(


def fill_parquet_file_metadata(
df: DataFile,
parquet_metadata: pq.FileMetaData,
file_size: int,
stats_columns: Dict[int, StatisticsCollector],
parquet_column_mapping: Dict[str, int],
df: DataFile,
parquet_metadata: pq.FileMetaData,
file_size: int,
stats_columns: Dict[int, StatisticsCollector],
parquet_column_mapping: Dict[str, int],
) -> None:
"""
Compute and fill the following fields of the DataFile object.
Expand Down

0 comments on commit b8c2ae3

Please sign in to comment.