Skip to content

Commit 0519448

Browse files
committed
Merge branch 'main' into manifest_compaction
# Conflicts: # pyiceberg/table/__init__.py
2 parents 7582bd2 + b31922f commit 0519448

File tree

6 files changed

+144
-40
lines changed

6 files changed

+144
-40
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,10 @@ def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = N
655655
return visit_pyarrow(schema, visitor)
656656

657657

658+
def _pyarrow_to_schema_without_ids(schema: pa.Schema) -> Schema:
659+
return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs())
660+
661+
658662
@singledispatch
659663
def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: PyArrowSchemaVisitor[T]) -> T:
660664
"""Apply a pyarrow schema visitor to any point within a schema.

pyiceberg/schema.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from dataclasses import dataclass
2323
from functools import cached_property, partial, singledispatch
2424
from typing import (
25+
TYPE_CHECKING,
2526
Any,
2627
Callable,
2728
Dict,
@@ -62,6 +63,11 @@
6263
UUIDType,
6364
)
6465

66+
if TYPE_CHECKING:
67+
from pyiceberg.table.name_mapping import (
68+
NameMapping,
69+
)
70+
6571
T = TypeVar("T")
6672
P = TypeVar("P")
6773

@@ -221,6 +227,12 @@ def find_type(self, name_or_id: Union[str, int], case_sensitive: bool = True) ->
221227
def highest_field_id(self) -> int:
222228
return max(self._lazy_id_to_name.keys(), default=0)
223229

230+
@cached_property
231+
def name_mapping(self) -> NameMapping:
232+
from pyiceberg.table.name_mapping import create_mapping_from_schema
233+
234+
return create_mapping_from_schema(self)
235+
224236
def find_column_name(self, column_id: int) -> Optional[str]:
225237
"""Find a column name given a column ID.
226238

pyiceberg/table/__init__.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,11 @@
8888
TableMetadata,
8989
TableMetadataUtil,
9090
)
91-
from pyiceberg.table.name_mapping import NameMapping, parse_mapping_from_json, update_mapping
91+
from pyiceberg.table.name_mapping import (
92+
NameMapping,
93+
parse_mapping_from_json,
94+
update_mapping,
95+
)
9296
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
9397
from pyiceberg.table.snapshots import (
9498
Operation,
@@ -134,6 +138,41 @@
134138
_JAVA_LONG_MAX = 9223372036854775807
135139

136140

141+
def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None:
142+
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema
143+
144+
name_mapping = table_schema.name_mapping
145+
try:
146+
task_schema = pyarrow_to_schema(other_schema, name_mapping=name_mapping)
147+
except ValueError as e:
148+
other_schema = _pyarrow_to_schema_without_ids(other_schema)
149+
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
150+
raise ValueError(
151+
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
152+
) from e
153+
154+
if table_schema.as_struct() != task_schema.as_struct():
155+
from rich.console import Console
156+
from rich.table import Table as RichTable
157+
158+
console = Console(record=True)
159+
160+
rich_table = RichTable(show_header=True, header_style="bold")
161+
rich_table.add_column("")
162+
rich_table.add_column("Table field")
163+
rich_table.add_column("Dataframe field")
164+
165+
for lhs in table_schema.fields:
166+
try:
167+
rhs = task_schema.find_field(lhs.field_id)
168+
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
169+
except ValueError:
170+
rich_table.add_row("❌", str(lhs), "Missing")
171+
172+
console.print(rich_table)
173+
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
174+
175+
137176
class TableProperties:
138177
PARQUET_ROW_GROUP_SIZE_BYTES = "write.parquet.row-group-size-bytes"
139178
PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT = 128 * 1024 * 1024 # 128 MB
@@ -1035,6 +1074,8 @@ def append(self, df: pa.Table) -> None:
10351074
if len(self.spec().fields) > 0:
10361075
raise ValueError("Cannot write to partitioned tables")
10371076

1077+
_check_schema(self.schema(), other_schema=df.schema)
1078+
10381079
with self.update_snapshot().merge_append() as update_snapshot:
10391080
# skip writing data files if the dataframe is empty
10401081
if df.shape[0] > 0:
@@ -1065,6 +1106,8 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
10651106
if len(self.spec().fields) > 0:
10661107
raise ValueError("Cannot write to partitioned tables")
10671108

1109+
_check_schema(self.schema(), other_schema=df.schema)
1110+
10681111
with self.update_snapshot().overwrite() as update_snapshot:
10691112
# skip writing data files if the dataframe is empty
10701113
if df.shape[0] > 0:

pyiceberg/table/name_mapping.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from abc import ABC, abstractmethod
2727
from collections import ChainMap
2828
from functools import cached_property, singledispatch
29-
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
29+
from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union
3030

3131
from pydantic import Field, conlist, field_validator, model_serializer
3232

@@ -97,6 +97,10 @@ def __len__(self) -> int:
9797
"""Return the number of mappings."""
9898
return len(self.root)
9999

100+
def __iter__(self) -> Iterator[MappedField]:
101+
"""Iterate over the mapped fields."""
102+
return iter(self.root)
103+
100104
def __str__(self) -> str:
101105
"""Convert the name-mapping into a nicely formatted string."""
102106
if len(self.root) == 0:

tests/conftest.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
List,
4343
Optional,
4444
)
45-
from urllib.parse import urlparse
4645

4746
import boto3
4847
import pytest
@@ -57,8 +56,6 @@
5756
GCS_PROJECT_ID,
5857
GCS_TOKEN,
5958
GCS_TOKEN_EXPIRES_AT_MS,
60-
OutputFile,
61-
OutputStream,
6259
fsspec,
6360
load_file_io,
6461
)
@@ -88,7 +85,7 @@
8885
import pyarrow as pa
8986
from moto.server import ThreadedMotoServer # type: ignore
9087

91-
from pyiceberg.io.pyarrow import PyArrowFile, PyArrowFileIO
88+
from pyiceberg.io.pyarrow import PyArrowFileIO
9289

9390

9491
def pytest_collection_modifyitems(items: List[pytest.Item]) -> None:
@@ -1456,40 +1453,6 @@ def simple_map() -> MapType:
14561453
return MapType(key_id=19, key_type=StringType(), value_id=25, value_type=DoubleType(), value_required=False)
14571454

14581455

1459-
class LocalOutputFile(OutputFile):
1460-
"""An OutputFile implementation for local files (for test use only)."""
1461-
1462-
def __init__(self, location: str) -> None:
1463-
parsed_location = urlparse(location) # Create a ParseResult from the uri
1464-
if (
1465-
parsed_location.scheme and parsed_location.scheme != "file"
1466-
): # Validate that an uri is provided with a scheme of `file`
1467-
raise ValueError("LocalOutputFile location must have a scheme of `file`")
1468-
elif parsed_location.netloc:
1469-
raise ValueError(f"Network location is not allowed for LocalOutputFile: {parsed_location.netloc}")
1470-
1471-
super().__init__(location=location)
1472-
self._path = parsed_location.path
1473-
1474-
def __len__(self) -> int:
1475-
"""Return the length of an instance of the LocalOutputFile class."""
1476-
return os.path.getsize(self._path)
1477-
1478-
def exists(self) -> bool:
1479-
return os.path.exists(self._path)
1480-
1481-
def to_input_file(self) -> "PyArrowFile":
1482-
from pyiceberg.io.pyarrow import PyArrowFileIO
1483-
1484-
return PyArrowFileIO().new_input(location=self.location)
1485-
1486-
def create(self, overwrite: bool = False) -> OutputStream:
1487-
output_file = open(self._path, "wb" if overwrite else "xb")
1488-
if not issubclass(type(output_file), OutputStream):
1489-
raise TypeError("Object returned from LocalOutputFile.create(...) does not match the OutputStream protocol.")
1490-
return output_file
1491-
1492-
14931456
@pytest.fixture(scope="session")
14941457
def generated_manifest_entry_file(avro_schema_manifest_entry: Dict[str, Any]) -> Generator[str, None, None]:
14951458
from fastavro import parse_schema, writer

tests/table/test_init.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from copy import copy
2020
from typing import Dict
2121

22+
import pyarrow as pa
2223
import pytest
2324
from sortedcontainers import SortedList
2425

@@ -58,6 +59,7 @@
5859
Table,
5960
UpdateSchema,
6061
_apply_table_update,
62+
_check_schema,
6163
_generate_snapshot_id,
6264
_match_deletes_to_data_file,
6365
_TableMetadataUpdateContext,
@@ -982,3 +984,79 @@ def test_correct_schema() -> None:
982984
_ = t.scan(snapshot_id=-1).projection()
983985

984986
assert "Snapshot not found: -1" in str(exc_info.value)
987+
988+
989+
def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
990+
other_schema = pa.schema((
991+
pa.field("foo", pa.string(), nullable=True),
992+
pa.field("bar", pa.decimal128(18, 6), nullable=False),
993+
pa.field("baz", pa.bool_(), nullable=True),
994+
))
995+
996+
expected = r"""Mismatch in fields:
997+
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
998+
┃ ┃ Table field ┃ Dataframe field ┃
999+
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
1000+
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
1001+
│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │
1002+
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
1003+
└────┴──────────────────────────┴─────────────────────────────────┘
1004+
"""
1005+
1006+
with pytest.raises(ValueError, match=expected):
1007+
_check_schema(table_schema_simple, other_schema)
1008+
1009+
1010+
def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
1011+
other_schema = pa.schema((
1012+
pa.field("foo", pa.string(), nullable=True),
1013+
pa.field("bar", pa.int32(), nullable=True),
1014+
pa.field("baz", pa.bool_(), nullable=True),
1015+
))
1016+
1017+
expected = """Mismatch in fields:
1018+
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
1019+
┃ ┃ Table field ┃ Dataframe field ┃
1020+
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
1021+
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
1022+
│ ❌ │ 2: bar: required int │ 2: bar: optional int │
1023+
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
1024+
└────┴──────────────────────────┴──────────────────────────┘
1025+
"""
1026+
1027+
with pytest.raises(ValueError, match=expected):
1028+
_check_schema(table_schema_simple, other_schema)
1029+
1030+
1031+
def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
1032+
other_schema = pa.schema((
1033+
pa.field("foo", pa.string(), nullable=True),
1034+
pa.field("baz", pa.bool_(), nullable=True),
1035+
))
1036+
1037+
expected = """Mismatch in fields:
1038+
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
1039+
┃ ┃ Table field ┃ Dataframe field ┃
1040+
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
1041+
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
1042+
│ ❌ │ 2: bar: required int │ Missing │
1043+
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
1044+
└────┴──────────────────────────┴──────────────────────────┘
1045+
"""
1046+
1047+
with pytest.raises(ValueError, match=expected):
1048+
_check_schema(table_schema_simple, other_schema)
1049+
1050+
1051+
def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
1052+
other_schema = pa.schema((
1053+
pa.field("foo", pa.string(), nullable=True),
1054+
pa.field("bar", pa.int32(), nullable=True),
1055+
pa.field("baz", pa.bool_(), nullable=True),
1056+
pa.field("new_field", pa.date32(), nullable=True),
1057+
))
1058+
1059+
expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)."
1060+
1061+
with pytest.raises(ValueError, match=expected):
1062+
_check_schema(table_schema_simple, other_schema)

0 commit comments

Comments
 (0)