Skip to content

Commit

Permalink
support CreateTableTransaction for SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
HonahX committed Apr 24, 2024
1 parent 3910e5e commit 12b9458
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 35 deletions.
104 changes: 69 additions & 35 deletions pyiceberg/catalog/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,55 +376,89 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons
identifier_tuple = self.identifier_to_tuple_without_catalog(
tuple(table_request.identifier.namespace.root + [table_request.identifier.name])
)
current_table = self.load_table(identifier_tuple)
database_name, table_name = self.identifier_to_database_and_table(identifier_tuple, NoSuchTableError)
base_metadata = current_table.metadata

current_table: Optional[Table]
try:
current_table = self.load_table(identifier_tuple)
except NoSuchTableError:
current_table = None

for requirement in table_request.requirements:
requirement.validate(base_metadata)
requirement.validate(current_table.metadata if current_table else None)

updated_metadata = update_table_metadata(base_metadata, table_request.updates)
if updated_metadata == base_metadata:
updated_metadata = update_table_metadata(
base_metadata=current_table.metadata if current_table else self._empty_table_metadata(),
updates=table_request.updates,
enforce_validation=current_table is None,
)
if current_table and updated_metadata == current_table.metadata:
# no changes, do nothing
return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location)
return CommitTableResponse(metadata=current_table.metadata, metadata_location=current_table.metadata_location)

# write new metadata
new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1
new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version)
self._write_metadata(updated_metadata, current_table.io, new_metadata_location)
new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1 if current_table else 0
new_metadata_location = self._get_metadata_location(updated_metadata.location, new_metadata_version)
self._write_metadata(
metadata=updated_metadata,
io=self._load_file_io(updated_metadata.properties, new_metadata_location),
metadata_path=new_metadata_location,
)

with Session(self.engine) as session:
if self.engine.dialect.supports_sane_rowcount:
stmt = (
update(IcebergTables)
.where(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == database_name,
IcebergTables.table_name == table_name,
IcebergTables.metadata_location == current_table.metadata_location,
)
.values(metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location)
)
result = session.execute(stmt)
if result.rowcount < 1:
raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}")
else:
try:
tbl = (
session.query(IcebergTables)
.with_for_update(of=IcebergTables)
.filter(
if current_table:
# table exists, update it
if self.engine.dialect.supports_sane_rowcount:
stmt = (
update(IcebergTables)
.where(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == database_name,
IcebergTables.table_name == table_name,
IcebergTables.metadata_location == current_table.metadata_location,
)
.one()
.values(
metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location
)
)
tbl.metadata_location = new_metadata_location
tbl.previous_metadata_location = current_table.metadata_location
except NoResultFound as e:
raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}") from e
session.commit()
result = session.execute(stmt)
if result.rowcount < 1:
raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}")
else:
try:
tbl = (
session.query(IcebergTables)
.with_for_update(of=IcebergTables)
.filter(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == database_name,
IcebergTables.table_name == table_name,
IcebergTables.metadata_location == current_table.metadata_location,
)
.one()
)
tbl.metadata_location = new_metadata_location
tbl.previous_metadata_location = current_table.metadata_location
except NoResultFound as e:
raise CommitFailedException(
f"Table has been updated by another process: {database_name}.{table_name}"
) from e
session.commit()
else:
# table does not exist, create it
try:
session.add(
IcebergTables(
catalog_name=self.name,
table_namespace=database_name,
table_name=table_name,
metadata_location=new_metadata_location,
previous_metadata_location=None,
)
)
session.commit()
except IntegrityError as e:
raise TableAlreadyExistsError(f"Table {database_name}.{table_name} already exists") from e

return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location)

Expand Down
60 changes: 60 additions & 0 deletions tests/catalog/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,66 @@ def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None:
snapshot_update.append_data_file(data_file)


@pytest.mark.parametrize(
'catalog',
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
@pytest.mark.parametrize("format_version", [1, 2])
def test_create_table_transaction(catalog: SqlCatalog, format_version: int) -> None:
identifier = f"default.arrow_create_table_transaction_{catalog.name}_{format_version}"
try:
catalog.create_namespace("default")
except NamespaceAlreadyExistsError:
pass

try:
catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

pa_table = pa.Table.from_pydict(
{
'foo': ['a', None, 'z'],
},
schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
)

pa_table_with_column = pa.Table.from_pydict(
{
'foo': ['a', None, 'z'],
'bar': [19, None, 25],
},
schema=pa.schema([
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
]),
)

with catalog.create_table_transaction(
identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)}
) as txn:
with txn.update_snapshot().fast_append() as snapshot_update:
for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table, io=txn._table.io):
snapshot_update.append_data_file(data_file)

with txn.update_schema() as schema_txn:
schema_txn.union_by_name(pa_table_with_column.schema)

with txn.update_snapshot().fast_append() as snapshot_update:
for data_file in _dataframe_to_data_files(
table_metadata=txn.table_metadata, df=pa_table_with_column, io=txn._table.io
):
snapshot_update.append_data_file(data_file)

tbl = catalog.load_table(identifier=identifier)
assert tbl.format_version == format_version
assert len(tbl.scan().to_arrow()) == 6


@pytest.mark.parametrize(
'catalog',
[
Expand Down

0 comments on commit 12b9458

Please sign in to comment.