Skip to content

Commit 6449392

Browse files
committed
materialize-snowflake: transient tables + split insert and merge
1 parent 4bf727c commit 6449392

7 files changed

Lines changed: 687 additions & 126 deletions

File tree

materialize-snowflake/snowflake.go

Lines changed: 210 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
sql "github.com/estuary/connectors/materialize-sql"
2020
pf "github.com/estuary/flow/go/protocols/flow"
2121
pm "github.com/estuary/flow/go/protocols/materialize"
22+
"github.com/google/uuid"
2223
"github.com/jmoiron/sqlx"
2324
log "github.com/sirupsen/logrus"
2425
sf "github.com/snowflakedb/gosnowflake"
@@ -326,10 +327,13 @@ type binding struct {
326327
}
327328
// Variables accessed by Prepare, Store, and Commit.
328329
store struct {
329-
stage *stagedFile
330+
insertStage *stagedFile // For rows where !it.Exists (new rows)
331+
mergeStage *stagedFile // For rows where it.Exists (existing rows)
332+
hasInserts bool // Track if we have insert operations
333+
hasMerges bool // Track if we have merge operations
330334
mergeInto string
331335
copyInto string
332-
mustMerge bool
336+
mustMerge bool // Deprecated: use hasMerges instead
333337
mergeBounds *sql.MergeBoundsBuilder
334338
}
335339
}
@@ -365,7 +369,8 @@ func (d *transactor) addBinding(ctx context.Context, target sql.Table, streaming
365369
}
366370

367371
b.load.stage = newStagedFile(os.TempDir())
368-
b.store.stage = newStagedFile(os.TempDir())
372+
b.store.insertStage = newStagedFile(os.TempDir())
373+
b.store.mergeStage = newStagedFile(os.TempDir())
369374

370375
if b.target.DeltaUpdates && d.cfg.Credentials.AuthType == snowflake_auth.JWT {
371376
var keyBegin = fmt.Sprintf("%08x", d._range.KeyBegin)
@@ -557,8 +562,14 @@ func (d *transactor) pipeExists(ctx context.Context, pipeName string) (bool, err
557562

558563
type checkpointItem struct {
559564
Table string
560-
Query string
561-
StagedDir string
565+
Query string // Deprecated: use InsertQuery/MergeQuery for new code
566+
InsertQuery string // Query for insert operations (from insertStage)
567+
MergeQuery string // Query for merge operations (from mergeStage)
568+
StagedDir string // Deprecated: use InsertStagedDir/MergeStagedDir
569+
InsertStagedDir string // Staged directory for inserts
570+
MergeStagedDir string // Staged directory for merges
571+
InsertStagingTable string // Transient table name for inserts (for cleanup)
572+
MergeStagingTable string // Transient table name for merges (for cleanup)
562573
StreamBlobs []*blobMetadata
563574
PipeName string
564575
PipeFiles []fileRecord
@@ -579,12 +590,21 @@ func (d *transactor) Store(it *m.StoreIterator) (m.StartCommitFunc, error) {
579590
continue
580591
}
581592

593+
// Route to appropriate stage based on whether key was loaded
594+
var targetStage *stagedFile
582595
if it.Exists {
583-
b.store.mustMerge = true
596+
// Key was loaded, so it exists - this is an UPDATE or DELETE
597+
targetStage = b.store.mergeStage
598+
b.store.hasMerges = true
599+
b.store.mustMerge = true // Keep for backward compatibility
600+
} else {
601+
// Key was not loaded, so it doesn't exist - this is an INSERT
602+
targetStage = b.store.insertStage
603+
b.store.hasInserts = true
584604
}
585605

586606
if !b.streaming {
587-
if err := b.store.stage.start(ctx, d.db); err != nil {
607+
if err := targetStage.start(ctx, d.db); err != nil {
588608
return nil, err
589609
}
590610
}
@@ -594,10 +614,13 @@ func (d *transactor) Store(it *m.StoreIterator) (m.StartCommitFunc, error) {
594614
if err := d.streamManager.writeRow(ctx, it.Binding, converted); err != nil {
595615
return nil, fmt.Errorf("encoding Store to stream for resource %s: %w", b.target.Path, err)
596616
}
597-
} else if err = b.store.stage.writeRow(append(converted, flowDelete)); err != nil {
617+
} else if err = targetStage.writeRow(append(converted, flowDelete)); err != nil {
598618
return nil, fmt.Errorf("writing Store to scratch file: %w", err)
599619
} else {
600-
b.store.mergeBounds.NextKey(converted[:len(b.target.Keys)])
620+
// Only track merge bounds for rows that will be merged
621+
if it.Exists {
622+
b.store.mergeBounds.NextKey(converted[:len(b.target.Keys)])
623+
}
601624
}
602625
}
603626
if it.Err() != nil {
@@ -638,37 +661,84 @@ func (d *transactor) buildDriverCheckpoint(ctx context.Context, runtimeCheckpoin
638661
continue
639662
}
640663

641-
if !b.store.stage.started {
664+
// Skip if neither stage has data
665+
if !b.store.insertStage.started && !b.store.mergeStage.started {
642666
continue
643667
}
644668

645-
dir, err := b.store.stage.flush()
646-
if err != nil {
647-
return nil, err
669+
var cpItem = &checkpointItem{
670+
Table: b.target.Identifier,
671+
Version: d.version,
648672
}
649673

650-
if b.store.mustMerge {
651-
mergeIntoQuery, err := renderBoundedQueryTemplate(d.templates.mergeInto, b.target, dir, b.store.mergeBounds.Build())
674+
// Handle INSERT operations (new rows)
675+
if b.store.hasInserts && b.store.insertStage.started {
676+
insertDir, err := b.store.insertStage.flush()
652677
if err != nil {
653-
return nil, fmt.Errorf("mergeInto template: %w", err)
678+
return nil, fmt.Errorf("flushing insert stage: %w", err)
654679
}
655-
d.cp[b.target.StateKey] = &checkpointItem{
656-
Table: b.target.Identifier,
657-
Query: mergeIntoQuery,
658-
StagedDir: dir,
659-
Version: d.version,
680+
681+
if b.pipeName != "" {
682+
// Snowpipe for inserts
683+
cpItem.InsertStagedDir = insertDir
684+
cpItem.PipeFiles = b.store.insertStage.uploaded
685+
cpItem.PipeName = b.pipeName
686+
} else {
687+
// Direct COPY INTO for inserts (no transient table needed)
688+
insertQuery, err := renderTableAndFileTemplate(b.target, insertDir, d.templates.copyInto)
689+
if err != nil {
690+
return nil, fmt.Errorf("copyInto template for inserts: %w", err)
691+
}
692+
cpItem.InsertQuery = insertQuery
693+
cpItem.InsertStagedDir = insertDir
660694
}
661-
// Reset for next round.
662-
b.store.mustMerge = false
663-
} else if b.pipeName != "" {
695+
}
696+
697+
// Handle MERGE operations (existing rows - updates/deletes)
698+
if b.store.hasMerges && b.store.mergeStage.started {
699+
mergeDir, err := b.store.mergeStage.flush()
700+
if err != nil {
701+
return nil, fmt.Errorf("flushing merge stage: %w", err)
702+
}
703+
704+
// Generate unique transient table name for merges
705+
// Replace hyphens with underscores to avoid identifier quoting issues
706+
mergeStagingTable := fmt.Sprintf("%s.FLOW_STAGING_MERGE_%d_%s",
707+
d.cfg.Schema, idx, strings.ReplaceAll(uuid.New().String(), "-", "_"))
708+
709+
// Build compound query: CREATE + COPY + MERGE + DROP
710+
createSQL, err := renderCreateTransientTable(b.target, mergeStagingTable, d.templates.createTransientTable)
711+
if err != nil {
712+
return nil, fmt.Errorf("createTransientTable template for merges: %w", err)
713+
}
714+
715+
copySQL, err := renderCopyIntoTransient(b.target, mergeStagingTable, mergeDir, d.templates.copyIntoTransient)
716+
if err != nil {
717+
return nil, fmt.Errorf("copyIntoTransient template for merges: %w", err)
718+
}
719+
720+
mergeSQL, err := renderMergeFromTransient(b.target, mergeStagingTable, b.store.mergeBounds.Build(), d.templates.mergeFromTransient)
721+
if err != nil {
722+
return nil, fmt.Errorf("mergeFromTransient template: %w", err)
723+
}
724+
725+
dropSQL := fmt.Sprintf("DROP TABLE IF EXISTS %s;", mergeStagingTable)
726+
727+
cpItem.MergeQuery = fmt.Sprintf("%s\n%s\n%s\n%s",
728+
createSQL, copySQL, mergeSQL, dropSQL)
729+
cpItem.MergeStagedDir = mergeDir
730+
cpItem.MergeStagingTable = mergeStagingTable
731+
}
732+
733+
// Handle Snowpipe creation if needed
734+
if b.pipeName != "" && (b.store.hasInserts || b.store.hasMerges) {
664735
// Check to see if a pipe for this version already exists
665736
exists, err := d.pipeExists(ctx, b.pipeName)
666737
if err != nil {
667738
return nil, err
668739
}
669740

670-
// Only create the pipe if it doesn't exist. Since the pipe name is versioned by the spec
671-
// it means if the spec has been updated, we will end up creating a new pipe
741+
// Only create the pipe if it doesn't exist
672742
if !exists {
673743
log.WithField("name", b.pipeName).Info("store: creating pipe")
674744
if createPipe, err := renderTablePipeTemplate(b.target, b.pipeName, d.templates.createPipe); err != nil {
@@ -678,8 +748,7 @@ func (d *transactor) buildDriverCheckpoint(ctx context.Context, runtimeCheckpoin
678748
}
679749
}
680750

681-
// Our understanding is that CREATE PIPE is _eventually consistent_, and so we
682-
// wait until we can make sure the pipe exists before continuing
751+
// Wait for pipe to be eventually consistent
683752
for !exists {
684753
exists, err = d.pipeExists(ctx, b.pipeName)
685754
if err != nil {
@@ -689,25 +758,15 @@ func (d *transactor) buildDriverCheckpoint(ctx context.Context, runtimeCheckpoin
689758
time.Sleep(5 * time.Second)
690759
}
691760
}
761+
}
692762

693-
d.cp[b.target.StateKey] = &checkpointItem{
694-
Table: b.target.Identifier,
695-
StagedDir: dir,
696-
PipeFiles: b.store.stage.uploaded,
697-
PipeName: b.pipeName,
698-
Version: d.version,
699-
}
700-
701-
} else {
702-
if copyIntoQuery, err := renderTableAndFileTemplate(b.target, dir, d.templates.copyInto); err != nil {
703-
return nil, fmt.Errorf("copyInto template: %w", err)
704-
} else {
705-
d.cp[b.target.StateKey] = &checkpointItem{
706-
Table: b.target.Identifier,
707-
Query: copyIntoQuery,
708-
StagedDir: dir,
709-
}
710-
}
763+
// Store checkpoint item if we have any operations
764+
if cpItem.InsertQuery != "" || cpItem.MergeQuery != "" || cpItem.PipeName != "" {
765+
d.cp[b.target.StateKey] = cpItem
766+
// Reset flags for next transaction
767+
b.store.hasInserts = false
768+
b.store.hasMerges = false
769+
b.store.mustMerge = false
711770
}
712771
}
713772

@@ -808,25 +867,112 @@ func (d *transactor) Acknowledge(ctx context.Context) (*pf.ConnectorState, error
808867
continue
809868
}
810869

870+
// Execute INSERT query if present
871+
if len(item.InsertQuery) > 0 {
872+
item := item
873+
group.Go(func() error {
874+
d.be.StartedResourceCommit(path)
875+
876+
queryStart := time.Now()
877+
result, err := d.db.ExecContext(ctx, item.InsertQuery)
878+
queryDuration := time.Since(queryStart)
879+
880+
if err != nil {
881+
// Best-effort cleanup of transient table on failure
882+
d.cleanupTransientTable(ctx, item.InsertStagingTable)
883+
return fmt.Errorf("insert query failed: %w", err)
884+
}
885+
886+
// Log performance metrics for performance test tables
887+
if strings.HasPrefix(item.Table, "perf_") {
888+
rowsAffected, _ := result.RowsAffected()
889+
log.WithFields(log.Fields{
890+
"perf_query_duration_ms": queryDuration.Milliseconds(),
891+
"perf_rows_affected": rowsAffected,
892+
"perf_table": item.Table,
893+
"perf_query_type": "INSERT",
894+
}).Info("SNOWFLAKE_PERF_METRIC")
895+
}
896+
897+
d.be.FinishedResourceCommit(path)
898+
if err := d.deleteFiles(groupCtx, []string{item.InsertStagedDir}); err != nil {
899+
return fmt.Errorf("cleaning up insert files: %w", err)
900+
}
901+
902+
return nil
903+
})
904+
}
905+
906+
// Execute MERGE query if present
907+
if len(item.MergeQuery) > 0 {
908+
item := item
909+
group.Go(func() error {
910+
d.be.StartedResourceCommit(path)
911+
912+
queryStart := time.Now()
913+
result, err := d.db.ExecContext(ctx, item.MergeQuery)
914+
queryDuration := time.Since(queryStart)
915+
916+
if err != nil {
917+
// Best-effort cleanup of transient table on failure
918+
d.cleanupTransientTable(ctx, item.MergeStagingTable)
919+
return fmt.Errorf("merge query failed: %w", err)
920+
}
921+
922+
// Log performance metrics for performance test tables
923+
if strings.HasPrefix(item.Table, "perf_") {
924+
rowsAffected, _ := result.RowsAffected()
925+
log.WithFields(log.Fields{
926+
"perf_query_duration_ms": queryDuration.Milliseconds(),
927+
"perf_rows_affected": rowsAffected,
928+
"perf_table": item.Table,
929+
"perf_query_type": "MERGE",
930+
}).Info("SNOWFLAKE_PERF_METRIC")
931+
}
932+
933+
d.be.FinishedResourceCommit(path)
934+
if err := d.deleteFiles(groupCtx, []string{item.MergeStagedDir}); err != nil {
935+
return fmt.Errorf("cleaning up merge files: %w", err)
936+
}
937+
938+
return nil
939+
})
940+
}
941+
942+
// Legacy: Handle old Query field for backward compatibility
811943
if len(item.Query) > 0 {
812944
item := item
813945
group.Go(func() error {
814946
d.be.StartedResourceCommit(path)
815-
// NB: Not using groupTx here since the Go Snowflake driver
816-
// retains contexts internally, and groupCtx is cancelled after
817-
// group.Wait() returns.
818-
if _, err := d.db.ExecContext(ctx, item.Query); err != nil {
947+
948+
queryStart := time.Now()
949+
result, err := d.db.ExecContext(ctx, item.Query)
950+
queryDuration := time.Since(queryStart)
951+
952+
if err != nil {
819953
return fmt.Errorf("query %q failed: %w", item.Query, err)
820954
}
821955

956+
// Log performance metrics for performance test tables
957+
if strings.HasPrefix(item.Table, "perf_") {
958+
rowsAffected, _ := result.RowsAffected()
959+
log.WithFields(log.Fields{
960+
"perf_query_duration_ms": queryDuration.Milliseconds(),
961+
"perf_rows_affected": rowsAffected,
962+
"perf_table": item.Table,
963+
}).Info("SNOWFLAKE_PERF_METRIC")
964+
}
965+
822966
d.be.FinishedResourceCommit(path)
823967
if err := d.deleteFiles(groupCtx, []string{item.StagedDir}); err != nil {
824968
return fmt.Errorf("cleaning up files: %w", err)
825969
}
826970

827971
return nil
828972
})
829-
} else if len(item.StreamBlobs) > 0 {
973+
}
974+
975+
if len(item.StreamBlobs) > 0 {
830976
group.Go(func() error {
831977
d.be.StartedResourceCommit(path)
832978
if err := d.streamManager.write(groupCtx, item.StreamBlobs); err != nil {
@@ -1048,6 +1194,19 @@ func (d *transactor) Acknowledge(ctx context.Context) (*pf.ConnectorState, error
10481194
return &pf.ConnectorState{UpdatedJson: json.RawMessage(checkpointJSON), MergePatch: true}, nil
10491195
}
10501196

1197+
// Best-effort cleanup of transient staging tables on error
1198+
func (d *transactor) cleanupTransientTable(ctx context.Context, tableName string) {
1199+
if tableName == "" {
1200+
return
1201+
}
1202+
dropSQL := fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)
1203+
if _, err := d.db.ExecContext(ctx, dropSQL); err != nil {
1204+
log.WithError(err).WithField("table", tableName).Warn("failed to cleanup transient staging table")
1205+
} else {
1206+
log.WithField("table", tableName).Debug("cleaned up transient staging table")
1207+
}
1208+
}
1209+
10511210
func (d *transactor) pathForStateKey(stateKey string) []string {
10521211
for _, b := range d.bindings {
10531212
if b.target.StateKey == stateKey {

0 commit comments

Comments
 (0)