Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CreateTableTransaction for SqlCatalog #684

Merged
merged 5 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 64 additions & 40 deletions pyiceberg/catalog/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.serializers import FromInputFile
from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, update_table_metadata
from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table
from pyiceberg.table.metadata import new_table_metadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties
Expand Down Expand Up @@ -402,59 +402,83 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons
identifier_tuple = self.identifier_to_tuple_without_catalog(
tuple(table_request.identifier.namespace.root + [table_request.identifier.name])
)
current_table = self.load_table(identifier_tuple)
namespace_tuple = Catalog.namespace_from(identifier_tuple)
namespace = Catalog.namespace_to_string(namespace_tuple)
table_name = Catalog.table_name_from(identifier_tuple)
base_metadata = current_table.metadata
for requirement in table_request.requirements:
requirement.validate(base_metadata)

updated_metadata = update_table_metadata(base_metadata, table_request.updates)
if updated_metadata == base_metadata:
# no changes, do nothing
return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location)
current_table: Optional[Table]
try:
current_table = self.load_table(identifier_tuple)
except NoSuchTableError:
current_table = None

# write new metadata
new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1
new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version)
self._write_metadata(updated_metadata, current_table.io, new_metadata_location)
updated_staged_table = self._update_and_stage_table(current_table, table_request)
if current_table and updated_staged_table.metadata == current_table.metadata:
# no changes, do nothing
return CommitTableResponse(metadata=current_table.metadata, metadata_location=current_table.metadata_location)
self._write_metadata(
metadata=updated_staged_table.metadata,
io=updated_staged_table.io,
metadata_path=updated_staged_table.metadata_location,
)

with Session(self.engine) as session:
if self.engine.dialect.supports_sane_rowcount:
stmt = (
update(IcebergTables)
.where(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == namespace,
IcebergTables.table_name == table_name,
IcebergTables.metadata_location == current_table.metadata_location,
)
.values(metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location)
)
result = session.execute(stmt)
if result.rowcount < 1:
raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}")
else:
try:
tbl = (
session.query(IcebergTables)
.with_for_update(of=IcebergTables)
.filter(
if current_table:
# table exists, update it
if self.engine.dialect.supports_sane_rowcount:
stmt = (
update(IcebergTables)
.where(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == namespace,
IcebergTables.table_name == table_name,
IcebergTables.metadata_location == current_table.metadata_location,
)
.one()
.values(
metadata_location=updated_staged_table.metadata_location,
previous_metadata_location=current_table.metadata_location,
)
)
tbl.metadata_location = new_metadata_location
tbl.previous_metadata_location = current_table.metadata_location
except NoResultFound as e:
raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}") from e
session.commit()
result = session.execute(stmt)
if result.rowcount < 1:
raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}")
else:
try:
tbl = (
session.query(IcebergTables)
.with_for_update(of=IcebergTables)
.filter(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == namespace,
IcebergTables.table_name == table_name,
IcebergTables.metadata_location == current_table.metadata_location,
)
.one()
)
tbl.metadata_location = updated_staged_table.metadata_location
tbl.previous_metadata_location = current_table.metadata_location
except NoResultFound as e:
raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}") from e
session.commit()
else:
# table does not exist, create it
try:
session.add(
IcebergTables(
catalog_name=self.name,
table_namespace=namespace,
table_name=table_name,
metadata_location=updated_staged_table.metadata_location,
previous_metadata_location=None,
)
)
session.commit()
except IntegrityError as e:
raise TableAlreadyExistsError(f"Table {namespace}.{table_name} already exists") from e

return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location)
return CommitTableResponse(
metadata=updated_staged_table.metadata, metadata_location=updated_staged_table.metadata_location
)

def _namespace_exists(self, identifier: Union[str, Identifier]) -> bool:
namespace_tuple = Catalog.identifier_to_tuple(identifier)
Expand Down
60 changes: 60 additions & 0 deletions tests/catalog/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,66 @@ def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None:
snapshot_update.append_data_file(data_file)


@pytest.mark.parametrize(
"catalog",
[
lazy_fixture("catalog_memory"),
lazy_fixture("catalog_sqlite"),
lazy_fixture("catalog_sqlite_without_rowcount"),
],
)
@pytest.mark.parametrize("format_version", [1, 2])
def test_create_table_transaction(catalog: SqlCatalog, format_version: int) -> None:
identifier = f"default.arrow_create_table_transaction_{catalog.name}_{format_version}"
try:
catalog.create_namespace("default")
except NamespaceAlreadyExistsError:
pass

try:
catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

pa_table = pa.Table.from_pydict(
{
"foo": ["a", None, "z"],
},
schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
)

pa_table_with_column = pa.Table.from_pydict(
{
"foo": ["a", None, "z"],
"bar": [19, None, 25],
},
schema=pa.schema([
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
]),
)

with catalog.create_table_transaction(
identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)}
) as txn:
with txn.update_snapshot().fast_append() as snapshot_update:
for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table, io=txn._table.io):
snapshot_update.append_data_file(data_file)

with txn.update_schema() as schema_txn:
schema_txn.union_by_name(pa_table_with_column.schema)

with txn.update_snapshot().fast_append() as snapshot_update:
for data_file in _dataframe_to_data_files(
table_metadata=txn.table_metadata, df=pa_table_with_column, io=txn._table.io
):
snapshot_update.append_data_file(data_file)

tbl = catalog.load_table(identifier=identifier)
assert tbl.format_version == format_version
assert len(tbl.scan().to_arrow()) == 6


@pytest.mark.parametrize(
"catalog",
[
Expand Down