Skip to content

Commit

Permalink
Check the types when writing (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko authored Jan 28, 2024
1 parent 9e03949 commit acc934f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
16 changes: 16 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,14 @@ def append(self, df: pa.Table) -> None:
Args:
df: The Arrow dataframe that will be appended to overwrite the table
"""
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

Expand All @@ -954,6 +962,14 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
or a boolean expression in case of a partial overwrite
"""
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if overwrite_filter != AlwaysTrue():
raise NotImplementedError("Cannot overwrite a subset of a table")

Expand Down
18 changes: 18 additions & 0 deletions tests/integration/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,21 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w
assert [row.added_data_files_count for row in rows] == [1, 1, 0, 1, 1]
assert [row.existing_data_files_count for row in rows] == [0, 0, 0, 0, 0]
assert [row.deleted_data_files_count for row in rows] == [0, 0, 1, 0, 0]


@pytest.mark.integration
def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.arrow_data_files"

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties={'format-version': '1'})

with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"):
tbl.overwrite("not a df")

with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"):
tbl.append("not a df")

0 comments on commit acc934f

Please sign in to comment.