diff --git a/README.md b/README.md index 451aeb8ac..d6cb24164 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,7 @@ the table, the following tracks the current write support: | Append Data Files | X | | Rewrite Files | | | Rewrite manifests | | -| Overwrite Files | | +| Overwrite Files | X | | Write Pos Delete | | | Write Eq Delete | | | Row Delta | | diff --git a/internal/recipe/validation.py b/internal/recipe/validation.py index 1ac9421a6..7c1b2ecbd 100644 --- a/internal/recipe/validation.py +++ b/internal/recipe/validation.py @@ -26,16 +26,40 @@ def testSetProperties(): def testAddedFile(): - spark.sql("SELECT COUNT(*) FROM default.test_partitioned_by_days").show(truncate=False) + spark.sql("SELECT COUNT(*) FROM default.test_partitioned_by_days").show( + truncate=False + ) def testReadDifferentDataTypes(): - spark.sql("DESCRIBE TABLE EXTENDED default.go_test_different_data_types").show(truncate=False) + spark.sql("DESCRIBE TABLE EXTENDED default.go_test_different_data_types").show( + truncate=False + ) spark.sql("SELECT * FROM default.go_test_different_data_types").show(truncate=False) def testReadSpecUpdate(): - spark.sql("DESCRIBE TABLE EXTENDED default.go_test_update_spec").show(truncate=False) + spark.sql("DESCRIBE TABLE EXTENDED default.go_test_update_spec").show( + truncate=False + ) + + +def testOverwriteBasic(): + spark.sql("SELECT COUNT(*) FROM default.go_test_overwrite_basic").show( + truncate=False + ) + spark.sql("SELECT * FROM default.go_test_overwrite_basic ORDER BY baz").show( + truncate=False + ) + + +def testOverwriteWithFilter(): + spark.sql("SELECT COUNT(*) FROM default.go_test_overwrite_filter").show( + truncate=False + ) + spark.sql("SELECT * FROM default.go_test_overwrite_filter ORDER BY baz").show( + truncate=False + ) if __name__ == "__main__": @@ -54,3 +78,9 @@ def testReadSpecUpdate(): if args.test == "TestReadSpecUpdate": testReadSpecUpdate() + + if args.test == "TestOverwriteBasic": + testOverwriteBasic() + + if args.test == "TestOverwriteWithFilter": + testOverwriteWithFilter() diff --git a/table/table.go b/table/table.go index 8abf6fb8c..5654d077c 100644 --- a/table/table.go +++ b/table/table.go @@ -129,6 +129,26 @@ func (t Table) Append(ctx context.Context, rdr array.RecordReader, snapshotProps return txn.Commit(ctx) } +// OverwriteTable is a shortcut for NewTransaction().OverwriteTable() and then committing the transaction +func (t Table) OverwriteTable(ctx context.Context, tbl arrow.Table, batchSize int64, filter iceberg.BooleanExpression, caseSensitive bool, snapshotProps iceberg.Properties) (*Table, error) { + txn := t.NewTransaction() + if err := txn.OverwriteTable(ctx, tbl, batchSize, filter, caseSensitive, snapshotProps); err != nil { + return nil, err + } + + return txn.Commit(ctx) +} + +// Overwrite is a shortcut for NewTransaction().Overwrite() and then committing the transaction +func (t Table) Overwrite(ctx context.Context, rdr array.RecordReader, filter iceberg.BooleanExpression, caseSensitive bool, snapshotProps iceberg.Properties) (*Table, error) { + txn := t.NewTransaction() + if err := txn.Overwrite(ctx, rdr, filter, caseSensitive, snapshotProps); err != nil { + return nil, err + } + + return txn.Commit(ctx) +} + func (t Table) AllManifests(ctx context.Context) iter.Seq2[iceberg.ManifestFile, error] { fs, err := t.fsF(ctx) if err != nil { diff --git a/table/table_test.go b/table/table_test.go index 437c461fc..fe26b7c03 100644 --- a/table/table_test.go +++ b/table/table_test.go @@ -368,7 +368,7 @@ func (t *TableWritingTestSuite) createTable(identifier table.Identifier, formatV func(ctx context.Context) (iceio.IO, error) { return iceio.LocalFS{}, nil }, - nil, + &mockedCatalog{meta}, ) } @@ -1281,6 +1281,50 @@ func (t *TableWritingTestSuite) TestMergeManifests() { t.True(array.TableEqual(resultB, resultC), "expected:\n %s\ngot:\n %s", resultB, resultC) } +// TestOverwriteTable verifies that Table.OverwriteTable properly delegates to Transaction.OverwriteTable +func (t *TableWritingTestSuite) TestOverwriteTable() { + ident := table.Identifier{"default", "overwrite_table_wrapper_v" + strconv.Itoa(t.formatVersion)} + tbl := t.createTable(ident, t.formatVersion, *iceberg.UnpartitionedSpec, t.tableSchema) + newTable, err := array.TableFromJSON(memory.DefaultAllocator, t.arrSchema, []string{ + `[{"foo": false, "bar": "wrapper_test", "baz": 123, "qux": "2024-01-01"}]`, + }) + t.Require().NoError(err) + defer newTable.Release() + resultTbl, err := tbl.OverwriteTable(t.ctx, newTable, 1, nil, true, nil) + t.Require().NoError(err) + t.NotNil(resultTbl) + + snapshot := resultTbl.CurrentSnapshot() + t.NotNil(snapshot) + t.Equal(table.OpAppend, snapshot.Summary.Operation) // Empty table overwrite becomes append +} + +// TestOverwriteRecord verifies that Table.Overwrite properly delegates to Transaction.Overwrite +func (t *TableWritingTestSuite) TestOverwriteRecord() { + ident := table.Identifier{"default", "overwrite_record_wrapper_v" + strconv.Itoa(t.formatVersion)} + tbl := t.createTable(ident, t.formatVersion, *iceberg.UnpartitionedSpec, t.tableSchema) + + // Create test data as RecordReader + testTable, err := array.TableFromJSON(memory.DefaultAllocator, t.arrSchema, []string{ + `[{"foo": true, "bar": "record_test", "baz": 456, "qux": "2024-01-02"}]`, + }) + t.Require().NoError(err) + defer testTable.Release() + + rdr := array.NewTableReader(testTable, 1) + defer rdr.Release() + + // Test that Table.Overwrite works (delegates to transaction) + resultTbl, err := tbl.Overwrite(t.ctx, rdr, nil, true, nil) + t.Require().NoError(err) + t.NotNil(resultTbl) + + // Verify the operation worked + snapshot := resultTbl.CurrentSnapshot() + t.NotNil(snapshot) + t.Equal(table.OpAppend, snapshot.Summary.Operation) // Empty table overwrite becomes append +} + func TestTableWriting(t *testing.T) { suite.Run(t, &TableWritingTestSuite{formatVersion: 1}) suite.Run(t, &TableWritingTestSuite{formatVersion: 2}) diff --git a/table/transaction.go b/table/transaction.go index d949df3af..3d16b915a 100644 --- a/table/transaction.go +++ b/table/transaction.go @@ -474,6 +474,280 @@ func (t *Transaction) AddFiles(ctx context.Context, files []string, snapshotProp return t.apply(updates, reqs) } +// OverwriteTable overwrites the table data using an Arrow Table, optionally with a filter. +// If filter is nil or AlwaysTrue, all existing data will be replaced. +// If filter is provided, only data matching the filter will be replaced. +func (t *Transaction) OverwriteTable(ctx context.Context, tbl arrow.Table, batchSize int64, filter iceberg.BooleanExpression, caseSensitive bool, snapshotProps iceberg.Properties) error { + rdr := array.NewTableReader(tbl, batchSize) + defer rdr.Release() + + return t.Overwrite(ctx, rdr, filter, caseSensitive, snapshotProps) +} + +// Overwrite overwrites the table data using a RecordReader, optionally with a filter. +// If filter is nil or AlwaysTrue, all existing data will be replaced. +// If filter is provided, only data matching the filter will be replaced. +func (t *Transaction) Overwrite(ctx context.Context, rdr array.RecordReader, filter iceberg.BooleanExpression, caseSensitive bool, snapshotProps iceberg.Properties) error { + fs, err := t.tbl.fsF(ctx) + if err != nil { + return err + } + + if t.meta.NameMapping() == nil { + nameMapping := t.meta.CurrentSchema().NameMapping() + mappingJson, err := json.Marshal(nameMapping) + if err != nil { + return err + } + err = t.SetProperties(iceberg.Properties{DefaultNameMappingKey: string(mappingJson)}) + if err != nil { + return err + } + } + + commitUUID := uuid.New() + updater := t.updateSnapshot(fs, snapshotProps).mergeOverwrite(&commitUUID) + + filesToDelete, filesToRewrite, err := t.classifyFilesForOverwrite(ctx, fs, filter) + if err != nil { + return err + } + + for _, df := range filesToDelete { + updater.deleteDataFile(df) + } + + if len(filesToRewrite) > 0 { + if err := t.rewriteFilesWithFilter(ctx, fs, updater, filesToRewrite, filter); err != nil { + return err + } + } + + itr := recordsToDataFiles(ctx, t.tbl.Location(), t.meta, recordWritingArgs{ + sc: rdr.Schema(), + itr: array.IterFromReader(rdr), + fs: fs.(io.WriteFileIO), + writeUUID: &updater.commitUuid, + }) + + for df, err := range itr { + if err != nil { + return err + } + updater.appendDataFile(df) + } + + updates, reqs, err := updater.commit() + if err != nil { + return err + } + + return t.apply(updates, reqs) +} + +// classifyFilesForOverwrite classifies existing data files based on the provided filter. +// Returns files to delete completely, files to rewrite partially, and any error. +func (t *Transaction) classifyFilesForOverwrite(ctx context.Context, fs io.IO, filter iceberg.BooleanExpression) (filesToDelete, filesToRewrite []iceberg.DataFile, err error) { + s := t.meta.currentSnapshot() + if s == nil { + return nil, nil, nil + } + + if filter == nil || filter.Equals(iceberg.AlwaysTrue{}) { + for df, err := range s.dataFiles(fs, nil) { + if err != nil { + return nil, nil, err + } + if df.ContentType() == iceberg.EntryContentData { + filesToDelete = append(filesToDelete, df) + } + } + + return filesToDelete, filesToRewrite, nil + } + + return t.classifyFilesForFilteredOverwrite(ctx, fs, filter) +} + +// classifyFilesForFilteredOverwrite classifies files for filtered overwrite operations. +// Returns files to delete completely, files to rewrite partially, and any error. +func (t *Transaction) classifyFilesForFilteredOverwrite(ctx context.Context, fs io.IO, filter iceberg.BooleanExpression) (filesToDelete, filesToRewrite []iceberg.DataFile, err error) { + schema := t.meta.CurrentSchema() + + inclusiveEvaluator, err := newInclusiveMetricsEvaluator(schema, filter, true, false) + if err != nil { + return nil, nil, fmt.Errorf("failed to create inclusive metrics evaluator: %w", err) + } + + strictEvaluator, err := newStrictMetricsEvaluator(schema, filter, true, false) + if err != nil { + return nil, nil, fmt.Errorf("failed to create strict metrics evaluator: %w", err) + } + + var manifestEval func(iceberg.ManifestFile) (bool, error) + meta, err := t.meta.Build() + if err != nil { + return nil, nil, fmt.Errorf("failed to build metadata: %w", err) + } + spec := meta.PartitionSpec() + if !spec.IsUnpartitioned() { + manifestEval, err = newManifestEvaluator(spec, schema, filter, true) + if err != nil { + return nil, nil, fmt.Errorf("failed to create manifest evaluator: %w", err) + } + } + + s := t.meta.currentSnapshot() + manifests, err := s.Manifests(fs) + if err != nil { + return nil, nil, fmt.Errorf("failed to get manifests: %w", err) + } + + for _, manifest := range manifests { + if manifestEval != nil { + match, err := manifestEval(manifest) + if err != nil { + return nil, nil, fmt.Errorf("failed to evaluate manifest: %w", err) + } + if !match { + continue + } + } + + entries, err := manifest.FetchEntries(fs, false) + if err != nil { + return nil, nil, fmt.Errorf("failed to fetch manifest entries: %w", err) + } + + for _, entry := range entries { + if entry.Status() == iceberg.EntryStatusDELETED { + continue + } + + df := entry.DataFile() + if df.ContentType() != iceberg.EntryContentData { + continue + } + + inclusive, err := inclusiveEvaluator(df) + if err != nil { + return nil, nil, fmt.Errorf("failed to evaluate data file %s with inclusive evaluator: %w", df.FilePath(), err) + } + + if !inclusive { + continue + } + + strict, err := strictEvaluator(df) + if err != nil { + return nil, nil, fmt.Errorf("failed to evaluate data file %s with strict evaluator: %w", df.FilePath(), err) + } + + if strict { + filesToDelete = append(filesToDelete, df) + } else { + filesToRewrite = append(filesToRewrite, df) + } + } + } + + return filesToDelete, filesToRewrite, nil +} + +// rewriteFilesWithFilter rewrites data files by preserving only rows that do NOT match the filter +func (t *Transaction) rewriteFilesWithFilter(ctx context.Context, fs io.IO, updater *snapshotProducer, files []iceberg.DataFile, filter iceberg.BooleanExpression) error { + complementFilter := iceberg.NewNot(filter) + + for _, originalFile := range files { + rewrittenFiles, err := t.rewriteSingleFile(ctx, fs, originalFile, complementFilter, updater.commitUuid) + if err != nil { + return fmt.Errorf("failed to rewrite file %s: %w", originalFile.FilePath(), err) + } + + updater.deleteDataFile(originalFile) + for _, rewrittenFile := range rewrittenFiles { + updater.appendDataFile(rewrittenFile) + } + } + + return nil +} + +// rewriteSingleFile reads a single data file, applies the filter, and writes new files with filtered data +func (t *Transaction) rewriteSingleFile(ctx context.Context, fs io.IO, originalFile iceberg.DataFile, filter iceberg.BooleanExpression, commitUUID uuid.UUID) ([]iceberg.DataFile, error) { + scanTask := &FileScanTask{ + File: originalFile, + Start: 0, + Length: originalFile.FileSizeBytes(), + } + + boundFilter, err := iceberg.BindExpr(t.meta.CurrentSchema(), filter, true) + if err != nil { + return nil, fmt.Errorf("failed to bind filter: %w", err) + } + + meta, err := t.meta.Build() + if err != nil { + return nil, fmt.Errorf("failed to build metadata: %w", err) + } + + scanner := &arrowScan{ + metadata: meta, + fs: fs, + projectedSchema: t.meta.CurrentSchema(), + boundRowFilter: boundFilter, + caseSensitive: true, + rowLimit: -1, // No limit + concurrency: 1, + } + + _, recordIter, err := scanner.GetRecords(ctx, []FileScanTask{*scanTask}) + if err != nil { + return nil, fmt.Errorf("failed to get records from original file: %w", err) + } + + var records []arrow.RecordBatch + for record, err := range recordIter { + if err != nil { + return nil, fmt.Errorf("failed to read record: %w", err) + } + records = append(records, record) + } + + // If no records remain after filtering, don't create any new files + // we shouldn't hit this case given that we only run this after determining this is a file to rewrite + if len(records) == 0 { + return nil, nil + } + + arrowSchema, err := SchemaToArrowSchema(t.meta.CurrentSchema(), nil, false, false) + if err != nil { + return nil, fmt.Errorf("failed to convert schema to arrow: %w", err) + } + table := array.NewTableFromRecords(arrowSchema, records) + defer table.Release() + + rdr := array.NewTableReader(table, table.NumRows()) + defer rdr.Release() + + var result []iceberg.DataFile + itr := recordsToDataFiles(ctx, t.tbl.Location(), t.meta, recordWritingArgs{ + sc: rdr.Schema(), + itr: array.IterFromReader(rdr), + fs: fs.(io.WriteFileIO), + writeUUID: &commitUUID, + }) + + for df, err := range itr { + if err != nil { + return nil, err + } + result = append(result, df) + } + + return result, nil +} + func (t *Transaction) Scan(opts ...ScanOption) (*Scan, error) { updatedMeta, err := t.meta.Build() if err != nil { diff --git a/table/transaction_test.go b/table/transaction_test.go index a0269cb3f..d5912af3d 100644 --- a/table/transaction_test.go +++ b/table/transaction_test.go @@ -363,6 +363,139 @@ func (s *SparkIntegrationTestSuite) TestUpdateSpec() { ) } +func (s *SparkIntegrationTestSuite) TestOverwriteBasic() { + icebergSchema := iceberg.NewSchema(0, + iceberg.NestedField{ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.Bool}, + iceberg.NestedField{ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.String}, + iceberg.NestedField{ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Int32}, + ) + + tbl, err := s.cat.CreateTable(s.ctx, catalog.ToIdentifier("default", "go_test_overwrite_basic"), icebergSchema) + s.Require().NoError(err) + + // Create initial data + arrowSchema, err := table.SchemaToArrowSchema(icebergSchema, nil, true, false) + s.Require().NoError(err) + + initialTable, err := array.TableFromJSON(memory.DefaultAllocator, arrowSchema, []string{ + `[ + {"foo": true, "bar": "initial", "baz": 100}, + {"foo": false, "bar": "old_data", "baz": 200} + ]`, + }) + s.Require().NoError(err) + defer initialTable.Release() + + tx := tbl.NewTransaction() + err = tx.AppendTable(s.ctx, initialTable, 2, nil) + s.Require().NoError(err) + tbl, err = tx.Commit(s.ctx) + s.Require().NoError(err) + + overwriteTable, err := array.TableFromJSON(memory.DefaultAllocator, arrowSchema, []string{ + `[ + {"foo": false, "bar": "overwritten", "baz": 300}, + {"foo": true, "bar": "new_data", "baz": 400} + ]`, + }) + s.Require().NoError(err) + defer overwriteTable.Release() + + tx = tbl.NewTransaction() + err = tx.OverwriteTable(s.ctx, overwriteTable, 2, nil, true, nil) + s.Require().NoError(err) + _, err = tx.Commit(s.ctx) + s.Require().NoError(err) + + expectedOutput := ` ++--------+ +|count(1)| ++--------+ +|2 | ++--------+ + ++-----+-----------+----+ +|foo |bar |baz | ++-----+-----------+----+ +|false|overwritten|300 | +|true |new_data |400 | ++-----+-----------+----+ +` + + output, err := recipe.ExecuteSpark(s.T(), "./validation.py", "--test", "TestOverwriteBasic") + s.Require().NoError(err) + s.Require().True( + strings.HasSuffix(strings.TrimSpace(output), strings.TrimSpace(expectedOutput)), + "result does not contain expected output: %s", expectedOutput, + ) +} + +func (s *SparkIntegrationTestSuite) TestOverwriteWithFilter() { + icebergSchema := iceberg.NewSchema(0, + iceberg.NestedField{ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.Bool}, + iceberg.NestedField{ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.String}, + iceberg.NestedField{ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Int32}, + ) + + tbl, err := s.cat.CreateTable(s.ctx, catalog.ToIdentifier("default", "go_test_overwrite_filter"), icebergSchema) + s.Require().NoError(err) + + arrowSchema, err := table.SchemaToArrowSchema(icebergSchema, nil, true, false) + s.Require().NoError(err) + + initialTable, err := array.TableFromJSON(memory.DefaultAllocator, arrowSchema, []string{ + `[ + {"foo": true, "bar": "should_be_replaced", "baz": 100}, + {"foo": false, "bar": "should_remain", "baz": 200}, + {"foo": true, "bar": "also_replaced", "baz": 300} + ]`, + }) + s.Require().NoError(err) + defer initialTable.Release() + + tx := tbl.NewTransaction() + err = tx.AppendTable(s.ctx, initialTable, 3, nil) + s.Require().NoError(err) + tbl, err = tx.Commit(s.ctx) + s.Require().NoError(err) + + overwriteTable, err := array.TableFromJSON(memory.DefaultAllocator, arrowSchema, []string{ + `[ + {"foo": true, "bar": "new_replacement", "baz": 999} + ]`, + }) + s.Require().NoError(err) + defer overwriteTable.Release() + + filter := iceberg.EqualTo(iceberg.Reference("foo"), true) + tx = tbl.NewTransaction() + err = tx.OverwriteTable(s.ctx, overwriteTable, 1, filter, true, nil) + s.Require().NoError(err) + _, err = tx.Commit(s.ctx) + s.Require().NoError(err) + + expectedOutput := ` ++--------+ +|count(1)| ++--------+ +|1 | ++--------+ + ++-----+---------------+----+ +|foo |bar |baz | ++-----+---------------+----+ +|true |new_replacement|999 | ++-----+---------------+----+ +` + + output, err := recipe.ExecuteSpark(s.T(), "./validation.py", "--test", "TestOverwriteWithFilter") + s.Require().NoError(err) + s.Require().True( + strings.HasSuffix(strings.TrimSpace(output), strings.TrimSpace(expectedOutput)), + "result does not contain expected output: %s", expectedOutput, + ) +} + func TestSparkIntegration(t *testing.T) { suite.Run(t, new(SparkIntegrationTestSuite)) }