Skip to content

Commit

Permalink
support time travel
Browse files Browse the repository at this point in the history
  • Loading branch information
sungwy committed Apr 13, 2024
1 parent 5f13f4f commit 9a26b7b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
22 changes: 11 additions & 11 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3423,7 +3423,7 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType:
schema=entries_schema,
)

def partitions(self) -> "pa.Table":
def partitions(self, snapshot_id: Optional[int] = None) -> "pa.Table":
import pyarrow as pa

from pyiceberg.io.pyarrow import schema_to_pyarrow
Expand Down Expand Up @@ -3495,16 +3495,16 @@ def update_partitions_map(
raise ValueError(f"Unknown DataFileContent ({file.content})")

partitions_map: Dict[Tuple[str, Any], Any] = {}
if snapshot := self.tbl.metadata.current_snapshot():
for manifest in snapshot.manifests(self.tbl.io):
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
partition = entry.data_file.partition
partition_record_dict = {
field.name: partition[pos]
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
}
entry_snapshot = self.tbl.snapshot_by_id(entry.snapshot_id) if entry.snapshot_id is not None else None
update_partitions_map(partitions_map, entry.data_file, partition_record_dict, entry_snapshot)
snapshot = self._get_snapshot(snapshot_id)
for manifest in snapshot.manifests(self.tbl.io):
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
partition = entry.data_file.partition
partition_record_dict = {
field.name: partition[pos]
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
}
entry_snapshot = self.tbl.snapshot_by_id(entry.snapshot_id) if entry.snapshot_id is not None else None
update_partitions_map(partitions_map, entry.data_file, partition_record_dict, entry_snapshot)

return pa.Table.from_pylist(
partitions_map.values(),
Expand Down
24 changes: 13 additions & 11 deletions tests/integration/test_inspect_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,17 @@ def test_inspect_partitions_partitioned(spark: SparkSession, session_catalog: Ca
"""
)

df = session_catalog.load_table(identifier).inspect.partitions()

lhs = df.to_pandas()
rhs = spark.table(f"{identifier}.partitions").toPandas()
def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> None:
lhs = df.to_pandas().sort_values('spec_id')
rhs = spark_df.toPandas().sort_values('spec_id')
for column in df.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
if column == "partition":
right = right.asDict()
assert left == right, f"Difference in column {column}: {left} != {right}"

lhs.sort_values('spec_id', inplace=True)
rhs.sort_values('spec_id', inplace=True)
for column in df.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
if column == "partition":
right = right.asDict()
assert left == right, f"Difference in column {column}: {left} != {right}"
tbl = session_catalog.load_table(identifier)
for snapshot in tbl.metadata.snapshots:
df = tbl.inspect.partitions(snapshot_id=snapshot.snapshot_id)
spark_df = spark.sql(f"SELECT * FROM {identifier}.partitions VERSION AS OF {snapshot.snapshot_id}")
check_pyiceberg_df_equals_spark_df(df, spark_df)

0 comments on commit 9a26b7b

Please sign in to comment.