Skip to content

Commit 339ba53

Browse files
FokkoHonahX
andauthored
Improve error message in case of a mismatch (#352)
* Nice error * Simplify a bit * Property Co-authored-by: Honah J. <[email protected]> * Property Co-authored-by: Honah J. <[email protected]> * Whitespace --------- Co-authored-by: Honah J. <[email protected]>
1 parent 44948cd commit 339ba53

File tree

5 files changed

+143
-2
lines changed

5 files changed

+143
-2
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,10 @@ def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = N
655655
return visit_pyarrow(schema, visitor)
656656

657657

658+
def _pyarrow_to_schema_without_ids(schema: pa.Schema) -> Schema:
659+
return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs())
660+
661+
658662
@singledispatch
659663
def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: PyArrowSchemaVisitor[T]) -> T:
660664
"""Apply a pyarrow schema visitor to any point within a schema.

pyiceberg/schema.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from dataclasses import dataclass
2323
from functools import cached_property, partial, singledispatch
2424
from typing import (
25+
TYPE_CHECKING,
2526
Any,
2627
Callable,
2728
Dict,
@@ -62,6 +63,11 @@
6263
UUIDType,
6364
)
6465

66+
if TYPE_CHECKING:
67+
from pyiceberg.table.name_mapping import (
68+
NameMapping,
69+
)
70+
6571
T = TypeVar("T")
6672
P = TypeVar("P")
6773

@@ -221,6 +227,12 @@ def find_type(self, name_or_id: Union[str, int], case_sensitive: bool = True) ->
221227
def highest_field_id(self) -> int:
222228
return max(self._lazy_id_to_name.keys(), default=0)
223229

230+
@cached_property
231+
def name_mapping(self) -> NameMapping:
232+
from pyiceberg.table.name_mapping import create_mapping_from_schema
233+
234+
return create_mapping_from_schema(self)
235+
224236
def find_column_name(self, column_id: int) -> Optional[str]:
225237
"""Find a column name given a column ID.
226238

pyiceberg/table/__init__.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@
8484
TableMetadata,
8585
TableMetadataUtil,
8686
)
87-
from pyiceberg.table.name_mapping import NameMapping, parse_mapping_from_json, update_mapping
87+
from pyiceberg.table.name_mapping import (
88+
NameMapping,
89+
parse_mapping_from_json,
90+
update_mapping,
91+
)
8892
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
8993
from pyiceberg.table.snapshots import (
9094
Operation,
@@ -129,6 +133,41 @@
129133
_JAVA_LONG_MAX = 9223372036854775807
130134

131135

136+
def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None:
137+
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema
138+
139+
name_mapping = table_schema.name_mapping
140+
try:
141+
task_schema = pyarrow_to_schema(other_schema, name_mapping=name_mapping)
142+
except ValueError as e:
143+
other_schema = _pyarrow_to_schema_without_ids(other_schema)
144+
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
145+
raise ValueError(
146+
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
147+
) from e
148+
149+
if table_schema.as_struct() != task_schema.as_struct():
150+
from rich.console import Console
151+
from rich.table import Table as RichTable
152+
153+
console = Console(record=True)
154+
155+
rich_table = RichTable(show_header=True, header_style="bold")
156+
rich_table.add_column("")
157+
rich_table.add_column("Table field")
158+
rich_table.add_column("Dataframe field")
159+
160+
for lhs in table_schema.fields:
161+
try:
162+
rhs = task_schema.find_field(lhs.field_id)
163+
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
164+
except ValueError:
165+
rich_table.add_row("❌", str(lhs), "Missing")
166+
167+
console.print(rich_table)
168+
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
169+
170+
132171
class TableProperties:
133172
PARQUET_ROW_GROUP_SIZE_BYTES = "write.parquet.row-group-size-bytes"
134173
PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT = 128 * 1024 * 1024 # 128 MB
@@ -1015,6 +1054,8 @@ def append(self, df: pa.Table) -> None:
10151054
if len(self.spec().fields) > 0:
10161055
raise ValueError("Cannot write to partitioned tables")
10171056

1057+
_check_schema(self.schema(), other_schema=df.schema)
1058+
10181059
with self.update_snapshot().fast_append() as update_snapshot:
10191060
# skip writing data files if the dataframe is empty
10201061
if df.shape[0] > 0:
@@ -1045,6 +1086,8 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
10451086
if len(self.spec().fields) > 0:
10461087
raise ValueError("Cannot write to partitioned tables")
10471088

1089+
_check_schema(self.schema(), other_schema=df.schema)
1090+
10481091
with self.update_snapshot().overwrite() as update_snapshot:
10491092
# skip writing data files if the dataframe is empty
10501093
if df.shape[0] > 0:

pyiceberg/table/name_mapping.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from abc import ABC, abstractmethod
2727
from collections import ChainMap
2828
from functools import cached_property, singledispatch
29-
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
29+
from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union
3030

3131
from pydantic import Field, conlist, field_validator, model_serializer
3232

@@ -97,6 +97,10 @@ def __len__(self) -> int:
9797
"""Return the number of mappings."""
9898
return len(self.root)
9999

100+
def __iter__(self) -> Iterator[MappedField]:
101+
"""Iterate over the mapped fields."""
102+
return iter(self.root)
103+
100104
def __str__(self) -> str:
101105
"""Convert the name-mapping into a nicely formatted string."""
102106
if len(self.root) == 0:

tests/table/test_init.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from copy import copy
2020
from typing import Dict
2121

22+
import pyarrow as pa
2223
import pytest
2324
from sortedcontainers import SortedList
2425

@@ -58,6 +59,7 @@
5859
Table,
5960
UpdateSchema,
6061
_apply_table_update,
62+
_check_schema,
6163
_generate_snapshot_id,
6264
_match_deletes_to_data_file,
6365
_TableMetadataUpdateContext,
@@ -982,3 +984,79 @@ def test_correct_schema() -> None:
982984
_ = t.scan(snapshot_id=-1).projection()
983985

984986
assert "Snapshot not found: -1" in str(exc_info.value)
987+
988+
989+
def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
990+
other_schema = pa.schema((
991+
pa.field("foo", pa.string(), nullable=True),
992+
pa.field("bar", pa.decimal128(18, 6), nullable=False),
993+
pa.field("baz", pa.bool_(), nullable=True),
994+
))
995+
996+
expected = r"""Mismatch in fields:
997+
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
998+
┃ ┃ Table field ┃ Dataframe field ┃
999+
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
1000+
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
1001+
│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │
1002+
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
1003+
└────┴──────────────────────────┴─────────────────────────────────┘
1004+
"""
1005+
1006+
with pytest.raises(ValueError, match=expected):
1007+
_check_schema(table_schema_simple, other_schema)
1008+
1009+
1010+
def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
1011+
other_schema = pa.schema((
1012+
pa.field("foo", pa.string(), nullable=True),
1013+
pa.field("bar", pa.int32(), nullable=True),
1014+
pa.field("baz", pa.bool_(), nullable=True),
1015+
))
1016+
1017+
expected = """Mismatch in fields:
1018+
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
1019+
┃ ┃ Table field ┃ Dataframe field ┃
1020+
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
1021+
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
1022+
│ ❌ │ 2: bar: required int │ 2: bar: optional int │
1023+
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
1024+
└────┴──────────────────────────┴──────────────────────────┘
1025+
"""
1026+
1027+
with pytest.raises(ValueError, match=expected):
1028+
_check_schema(table_schema_simple, other_schema)
1029+
1030+
1031+
def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
1032+
other_schema = pa.schema((
1033+
pa.field("foo", pa.string(), nullable=True),
1034+
pa.field("baz", pa.bool_(), nullable=True),
1035+
))
1036+
1037+
expected = """Mismatch in fields:
1038+
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
1039+
┃ ┃ Table field ┃ Dataframe field ┃
1040+
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
1041+
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
1042+
│ ❌ │ 2: bar: required int │ Missing │
1043+
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
1044+
└────┴──────────────────────────┴──────────────────────────┘
1045+
"""
1046+
1047+
with pytest.raises(ValueError, match=expected):
1048+
_check_schema(table_schema_simple, other_schema)
1049+
1050+
1051+
def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
1052+
other_schema = pa.schema((
1053+
pa.field("foo", pa.string(), nullable=True),
1054+
pa.field("bar", pa.int32(), nullable=True),
1055+
pa.field("baz", pa.bool_(), nullable=True),
1056+
pa.field("new_field", pa.date32(), nullable=True),
1057+
))
1058+
1059+
expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)."
1060+
1061+
with pytest.raises(ValueError, match=expected):
1062+
_check_schema(table_schema_simple, other_schema)

0 commit comments

Comments
 (0)