diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 221a609e5c..26eecefd0f 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -932,6 +932,14 @@ def append(self, df: pa.Table) -> None: Args: df: The Arrow dataframe that will be appended to overwrite the table """ + try: + import pyarrow as pa + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + + if not isinstance(df, pa.Table): + raise ValueError(f"Expected PyArrow table, got: {df}") + if len(self.spec().fields) > 0: raise ValueError("Cannot write to partitioned tables") @@ -954,6 +962,14 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T overwrite_filter: ALWAYS_TRUE when you overwrite all the data, or a boolean expression in case of a partial overwrite """ + try: + import pyarrow as pa + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + + if not isinstance(df, pa.Table): + raise ValueError(f"Expected PyArrow table, got: {df}") + if overwrite_filter != AlwaysTrue(): raise NotImplementedError("Cannot overwrite a subset of a table") diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py index a095c13315..17dc997163 100644 --- a/tests/integration/test_writes.py +++ b/tests/integration/test_writes.py @@ -391,3 +391,21 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w assert [row.added_data_files_count for row in rows] == [1, 1, 0, 1, 1] assert [row.existing_data_files_count for row in rows] == [0, 0, 0, 0, 0] assert [row.deleted_data_files_count for row in rows] == [0, 0, 1, 0, 0] + + +@pytest.mark.integration +def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.arrow_data_files" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties={'format-version': '1'}) + + with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"): + tbl.overwrite("not a df") + + with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"): + tbl.append("not a df")