diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 7978fdc9b4..91411344a2 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -1088,6 +1088,36 @@ with table.transaction() as transaction: # ... Update properties etc ``` +### Overwrite schema + +To overwrite the entire schema of a table, use the `overwrite` method: + +```python +from pyiceberg.catalog import load_catalog +from pyiceberg.schema import Schema +from pyiceberg.types import NestedField, StringType, DoubleType + +catalog = load_catalog() + +initial_schema = Schema( + NestedField(1, "city_name", StringType(), required=False), + NestedField(2, "latitude", DoubleType(), required=False), + NestedField(3, "longitude", DoubleType(), required=False), +) + +table = catalog.create_table("default.locations", initial_schema) + +new_schema = Schema( + NestedField(1, "city", StringType(), required=False), + NestedField(2, "lat", DoubleType(), required=False), + NestedField(3, "long", DoubleType(), required=False), + NestedField(4, "population", LongType(), required=False), +) + +with table.update_schema() as update: + update.overwrite(new_schema) +``` + ### Union by Name Using `.union_by_name()` you can merge another schema into an existing schema without having to worry about field-IDs: diff --git a/pyiceberg/table/update/schema.py b/pyiceberg/table/update/schema.py index 8ee3b43c24..fb940c758a 100644 --- a/pyiceberg/table/update/schema.py +++ b/pyiceberg/table/update/schema.py @@ -140,6 +140,30 @@ def case_sensitive(self, case_sensitive: bool) -> UpdateSchema: self._case_sensitive = case_sensitive return self + def overwrite(self, new_schema: Union[Schema, "pa.Schema"]) -> UpdateSchema: + """Overwrite the schema with a new schema. + + Args: + new_schema: The new schema to overwrite with. + + Returns: + This for method chaining. + """ + from pyiceberg.catalog import Catalog + + new_schema = Catalog._convert_schema_if_needed(new_schema) + + if self._schema == new_schema: + return self + + for field in self._schema.fields: + self.delete_column(field.name) + + for field in new_schema.fields: + self.add_column(field.name, field.field_type, field.doc, field.required) + + return self + def union_by_name(self, new_schema: Union[Schema, "pa.Schema"]) -> UpdateSchema: from pyiceberg.catalog import Catalog diff --git a/tests/test_schema.py b/tests/test_schema.py index daa46dee1f..a89b55199a 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1653,3 +1653,57 @@ def test_arrow_schema() -> None: ) assert base_schema.as_arrow() == expected_schema + + +def test_overwrite_schema() -> None: + base_schema = Schema(NestedField(field_id=1, name="old", field_type=StringType(), required=True)) + + schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=False), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + ) + + new_schema = UpdateSchema(transaction=None, schema=base_schema).overwrite(schema)._apply() # type: ignore + + expected_schema = Schema( + NestedField(field_id=2, name="foo", field_type=StringType(), required=False), + NestedField(field_id=3, name="bar", field_type=IntegerType(), required=False), + NestedField(field_id=4, name="baz", field_type=BooleanType(), required=False), + ) + + assert new_schema == expected_schema + + +def test_overwrite_with_pa_schema() -> None: + base_schema = Schema(NestedField(field_id=1, name="old", field_type=StringType(), required=True)) + + pa_schema = pa.schema( + [ + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=True), + pa.field("baz", pa.bool_(), nullable=True), + ] + ) + + new_schema = UpdateSchema(transaction=None, schema=base_schema).overwrite(pa_schema)._apply() # type: ignore + + expected_schema = Schema( + NestedField(field_id=2, name="foo", field_type=StringType(), required=False), + NestedField(field_id=3, name="bar", field_type=IntegerType(), required=False), + NestedField(field_id=4, name="baz", field_type=BooleanType(), required=False), + ) + + assert new_schema == expected_schema + + +def test_overwrite_schema_no_changes() -> None: + base_schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=False), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + ) + + new_schema = UpdateSchema(transaction=None, schema=base_schema).overwrite(base_schema)._apply() # type: ignore + + assert new_schema == base_schema