Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Time Travel in InspectTable.entries #599

Merged
merged 5 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mkdocs/docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,18 @@ table.append(df)

To explore the table metadata, tables can be inspected.

<!-- prettier-ignore-start -->

!!! tip "Time Travel"
To inspect a tables's metadata with the time travel feature, call the inspect table method with the `snapshot_id` argument.
Time travel is supported on all metadata tables except `snapshots` and `refs`.

```python
table.inspect.entries(snapshot_id=805611270568163028)
```

<!-- prettier-ignore-end -->

### Snapshots

Inspect the snapshots of the table:
Expand Down
128 changes: 70 additions & 58 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3253,6 +3253,18 @@ def __init__(self, tbl: Table) -> None:
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For metadata operations PyArrow needs to be installed") from e

def _get_snapshot(self, snapshot_id: Optional[int] = None) -> Snapshot:
if snapshot_id is not None:
if snapshot := self.tbl.metadata.snapshot_by_id(snapshot_id):
return snapshot
else:
raise ValueError(f"Cannot find snapshot with ID {snapshot_id}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if the return value is Optional[Snapshot], maybe the function should not raise ValueError

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kevinjqliu thanks for the review.

I thought about this, and I stand by this behavior / type annotation. This is my rationale:

  1. If the user passes in snapshot_id - that's an indication that the user wants to look up a specific snapshot_id (instead of using the current one). Then we should look up snapshot_by_id and see if we can find the corresponding Snapshot and if we can't find the Snapshot, we should raise.
  2. The output arg type is still Optional[Snapshot] even if we raise in the above behavior, because tbl.metadata.current_snapshot() returns Optional[Snapshot].
    def current_snapshot(self) -> Optional[Snapshot]:

Copy link
Collaborator Author

@sungwy sungwy Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I believe the above is because a newly created table isn't required to have a snapshot

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 make sense, thanks for the explanation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem, and thank you again for the review!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should raise an error when the snapshot cannot be found. What do you tink of updating the signature to Snapshot, and also raise an exception when there is no current snapshot?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fokko Sure - I thought it would be more correct to return an empty metadata table (entries, partitions, etc) if there's no snapshot in the table than raising an Exception, but this way I think we avoid extra if statements in each of the metadata table generating methods.


if snapshot := self.tbl.metadata.current_snapshot():
return snapshot
else:
raise ValueError("Cannot get a snapshot as the table does not have any.")

def snapshots(self) -> "pa.Table":
import pyarrow as pa

Expand Down Expand Up @@ -3287,7 +3299,7 @@ def snapshots(self) -> "pa.Table":
schema=snapshots_schema,
)

def entries(self) -> "pa.Table":
def entries(self, snapshot_id: Optional[int] = None) -> "pa.Table":
import pyarrow as pa

from pyiceberg.io.pyarrow import schema_to_pyarrow
Expand Down Expand Up @@ -3346,64 +3358,64 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType:
])

entries = []
if snapshot := self.tbl.metadata.current_snapshot():
for manifest in snapshot.manifests(self.tbl.io):
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
column_sizes = entry.data_file.column_sizes or {}
value_counts = entry.data_file.value_counts or {}
null_value_counts = entry.data_file.null_value_counts or {}
nan_value_counts = entry.data_file.nan_value_counts or {}
lower_bounds = entry.data_file.lower_bounds or {}
upper_bounds = entry.data_file.upper_bounds or {}
readable_metrics = {
schema.find_column_name(field.field_id): {
"column_size": column_sizes.get(field.field_id),
"value_count": value_counts.get(field.field_id),
"null_value_count": null_value_counts.get(field.field_id),
"nan_value_count": nan_value_counts.get(field.field_id),
# Makes them readable
"lower_bound": from_bytes(field.field_type, lower_bound)
if (lower_bound := lower_bounds.get(field.field_id))
else None,
"upper_bound": from_bytes(field.field_type, upper_bound)
if (upper_bound := upper_bounds.get(field.field_id))
else None,
}
for field in self.tbl.metadata.schema().fields
}

partition = entry.data_file.partition
partition_record_dict = {
field.name: partition[pos]
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
snapshot = self._get_snapshot(snapshot_id)
for manifest in snapshot.manifests(self.tbl.io):
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
column_sizes = entry.data_file.column_sizes or {}
value_counts = entry.data_file.value_counts or {}
null_value_counts = entry.data_file.null_value_counts or {}
nan_value_counts = entry.data_file.nan_value_counts or {}
lower_bounds = entry.data_file.lower_bounds or {}
upper_bounds = entry.data_file.upper_bounds or {}
readable_metrics = {
schema.find_column_name(field.field_id): {
"column_size": column_sizes.get(field.field_id),
"value_count": value_counts.get(field.field_id),
"null_value_count": null_value_counts.get(field.field_id),
"nan_value_count": nan_value_counts.get(field.field_id),
# Makes them readable
"lower_bound": from_bytes(field.field_type, lower_bound)
if (lower_bound := lower_bounds.get(field.field_id))
else None,
"upper_bound": from_bytes(field.field_type, upper_bound)
if (upper_bound := upper_bounds.get(field.field_id))
else None,
}

entries.append({
'status': entry.status.value,
'snapshot_id': entry.snapshot_id,
'sequence_number': entry.data_sequence_number,
'file_sequence_number': entry.file_sequence_number,
'data_file': {
"content": entry.data_file.content,
"file_path": entry.data_file.file_path,
"file_format": entry.data_file.file_format,
"partition": partition_record_dict,
"record_count": entry.data_file.record_count,
"file_size_in_bytes": entry.data_file.file_size_in_bytes,
"column_sizes": dict(entry.data_file.column_sizes),
"value_counts": dict(entry.data_file.value_counts),
"null_value_counts": dict(entry.data_file.null_value_counts),
"nan_value_counts": entry.data_file.nan_value_counts,
"lower_bounds": entry.data_file.lower_bounds,
"upper_bounds": entry.data_file.upper_bounds,
"key_metadata": entry.data_file.key_metadata,
"split_offsets": entry.data_file.split_offsets,
"equality_ids": entry.data_file.equality_ids,
"sort_order_id": entry.data_file.sort_order_id,
"spec_id": entry.data_file.spec_id,
},
'readable_metrics': readable_metrics,
})
for field in self.tbl.metadata.schema().fields
}

partition = entry.data_file.partition
partition_record_dict = {
field.name: partition[pos]
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
}

entries.append({
'status': entry.status.value,
'snapshot_id': entry.snapshot_id,
'sequence_number': entry.data_sequence_number,
'file_sequence_number': entry.file_sequence_number,
'data_file': {
"content": entry.data_file.content,
"file_path": entry.data_file.file_path,
"file_format": entry.data_file.file_format,
"partition": partition_record_dict,
"record_count": entry.data_file.record_count,
"file_size_in_bytes": entry.data_file.file_size_in_bytes,
"column_sizes": dict(entry.data_file.column_sizes),
"value_counts": dict(entry.data_file.value_counts),
"null_value_counts": dict(entry.data_file.null_value_counts),
"nan_value_counts": entry.data_file.nan_value_counts,
"lower_bounds": entry.data_file.lower_bounds,
"upper_bounds": entry.data_file.upper_bounds,
"key_metadata": entry.data_file.key_metadata,
"split_offsets": entry.data_file.split_offsets,
"equality_ids": entry.data_file.equality_ids,
"sort_order_id": entry.data_file.sort_order_id,
"spec_id": entry.data_file.spec_id,
},
'readable_metrics': readable_metrics,
})

return pa.Table.from_pylist(
entries,
Expand Down
156 changes: 80 additions & 76 deletions tests/integration/test_inspect_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import pyarrow as pa
import pytest
import pytz
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame, SparkSession

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
Expand Down Expand Up @@ -148,81 +148,85 @@ def test_inspect_entries(
# Write some data
tbl.append(arrow_table_with_null)

df = tbl.inspect.entries()

assert df.column_names == [
'status',
'snapshot_id',
'sequence_number',
'file_sequence_number',
'data_file',
'readable_metrics',
]

# Make sure that they are filled properly
for int_column in ['status', 'snapshot_id', 'sequence_number', 'file_sequence_number']:
for value in df[int_column]:
assert isinstance(value.as_py(), int)

for snapshot_id in df['snapshot_id']:
assert isinstance(snapshot_id.as_py(), int)

lhs = df.to_pandas()
rhs = spark.table(f"{identifier}.entries").toPandas()
for column in df.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
if column == 'data_file':
right = right.asDict(recursive=True)
for df_column in left.keys():
if df_column == 'partition':
# Spark leaves out the partition if the table is unpartitioned
continue

df_lhs = left[df_column]
df_rhs = right[df_column]
if isinstance(df_rhs, dict):
# Arrow turns dicts into lists of tuple
df_lhs = dict(df_lhs)

assert df_lhs == df_rhs, f"Difference in data_file column {df_column}: {df_lhs} != {df_rhs}"
elif column == 'readable_metrics':
right = right.asDict(recursive=True)

assert list(left.keys()) == [
'bool',
'string',
'string_long',
'int',
'long',
'float',
'double',
'timestamp',
'timestamptz',
'date',
'binary',
'fixed',
]

assert left.keys() == right.keys()

for rm_column in left.keys():
rm_lhs = left[rm_column]
rm_rhs = right[rm_column]

assert rm_lhs['column_size'] == rm_rhs['column_size']
assert rm_lhs['value_count'] == rm_rhs['value_count']
assert rm_lhs['null_value_count'] == rm_rhs['null_value_count']
assert rm_lhs['nan_value_count'] == rm_rhs['nan_value_count']

if rm_column == 'timestamptz':
# PySpark does not correctly set the timstamptz
rm_rhs['lower_bound'] = rm_rhs['lower_bound'].replace(tzinfo=pytz.utc)
rm_rhs['upper_bound'] = rm_rhs['upper_bound'].replace(tzinfo=pytz.utc)

assert rm_lhs['lower_bound'] == rm_rhs['lower_bound']
assert rm_lhs['upper_bound'] == rm_rhs['upper_bound']
else:
assert left == right, f"Difference in column {column}: {left} != {right}"
def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> None:
assert df.column_names == [
'status',
'snapshot_id',
'sequence_number',
'file_sequence_number',
'data_file',
'readable_metrics',
]

# Make sure that they are filled properly
for int_column in ['status', 'snapshot_id', 'sequence_number', 'file_sequence_number']:
for value in df[int_column]:
assert isinstance(value.as_py(), int)

for snapshot_id in df['snapshot_id']:
assert isinstance(snapshot_id.as_py(), int)

lhs = df.to_pandas()
rhs = spark_df.toPandas()
for column in df.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
if column == 'data_file':
right = right.asDict(recursive=True)
for df_column in left.keys():
if df_column == 'partition':
# Spark leaves out the partition if the table is unpartitioned
continue

df_lhs = left[df_column]
df_rhs = right[df_column]
if isinstance(df_rhs, dict):
# Arrow turns dicts into lists of tuple
df_lhs = dict(df_lhs)

assert df_lhs == df_rhs, f"Difference in data_file column {df_column}: {df_lhs} != {df_rhs}"
elif column == 'readable_metrics':
right = right.asDict(recursive=True)

assert list(left.keys()) == [
'bool',
'string',
'string_long',
'int',
'long',
'float',
'double',
'timestamp',
'timestamptz',
'date',
'binary',
'fixed',
]

assert left.keys() == right.keys()

for rm_column in left.keys():
rm_lhs = left[rm_column]
rm_rhs = right[rm_column]

assert rm_lhs['column_size'] == rm_rhs['column_size']
assert rm_lhs['value_count'] == rm_rhs['value_count']
assert rm_lhs['null_value_count'] == rm_rhs['null_value_count']
assert rm_lhs['nan_value_count'] == rm_rhs['nan_value_count']

if rm_column == 'timestamptz':
# PySpark does not correctly set the timstamptz
rm_rhs['lower_bound'] = rm_rhs['lower_bound'].replace(tzinfo=pytz.utc)
rm_rhs['upper_bound'] = rm_rhs['upper_bound'].replace(tzinfo=pytz.utc)

assert rm_lhs['lower_bound'] == rm_rhs['lower_bound']
assert rm_lhs['upper_bound'] == rm_rhs['upper_bound']
else:
assert left == right, f"Difference in column {column}: {left} != {right}"

for snapshot in tbl.metadata.snapshots:
df = tbl.inspect.entries(snapshot_id=snapshot.snapshot_id)
spark_df = spark.sql(f"SELECT * FROM {identifier}.entries VERSION AS OF {snapshot.snapshot_id}")
check_pyiceberg_df_equals_spark_df(df, spark_df)


@pytest.mark.integration
Expand Down