diff --git a/internal/storage/v2/clickhouse/tracestore/driver_test.go b/internal/storage/v2/clickhouse/tracestore/driver_test.go index 388a30dbdc3..91db1199050 100644 --- a/internal/storage/v2/clickhouse/tracestore/driver_test.go +++ b/internal/storage/v2/clickhouse/tracestore/driver_test.go @@ -70,19 +70,35 @@ func (*testBatch) Close() error { return nil } +type testQueryResponse struct { + rows driver.Rows + err error +} + +type testBatchResponse struct { + batch *testBatch + err error +} + type testDriver struct { driver.Conn - t *testing.T - rows driver.Rows - err error - batch *testBatch - recordedQuery string + t *testing.T + queryResponses map[string]*testQueryResponse + batchResponses map[string]*testBatchResponse + recordedQueries []string } func (t *testDriver) Query(_ context.Context, query string, _ ...any) (driver.Rows, error) { - t.recordedQuery = query - return t.rows, t.err + t.recordedQueries = append(t.recordedQueries, query) + + for querySubstring, response := range t.queryResponses { + if strings.Contains(query, querySubstring) { + return response.rows, response.err + } + } + + return nil, nil } type testRows[T any] struct { @@ -138,9 +154,13 @@ func (t *testDriver) PrepareBatch( query string, _ ...driver.PrepareBatchOption, ) (driver.Batch, error) { - t.recordedQuery = query - if t.err != nil { - return nil, t.err + t.recordedQueries = append(t.recordedQueries, query) + + for querySubstring, response := range t.batchResponses { + if strings.Contains(query, querySubstring) { + return response.batch, response.err + } } - return t.batch, nil + + return nil, nil } diff --git a/internal/storage/v2/clickhouse/tracestore/reader_test.go b/internal/storage/v2/clickhouse/tracestore/reader_test.go index 62b8a8eafb6..55d37bf59d1 100644 --- a/internal/storage/v2/clickhouse/tracestore/reader_test.go +++ b/internal/storage/v2/clickhouse/tracestore/reader_test.go @@ -18,6 +18,7 @@ import ( "github.com/jaegertracing/jaeger/internal/jiter" "github.com/jaegertracing/jaeger/internal/storage/v2/api/tracestore" + "github.com/jaegertracing/jaeger/internal/storage/v2/clickhouse/sql" "github.com/jaegertracing/jaeger/internal/storage/v2/clickhouse/tracestore/dbmodel" ) @@ -181,9 +182,14 @@ func TestGetTraces_Success(t *testing.T) { t.Run(tt.name, func(t *testing.T) { conn := &testDriver{ t: t, - rows: &testRows[*dbmodel.SpanRow]{ - data: tt.data, - scanFn: scanSpanRowFn(), + queryResponses: map[string]*testQueryResponse{ + sql.SelectSpansByTraceID: { + rows: &testRows[*dbmodel.SpanRow]{ + data: tt.data, + scanFn: scanSpanRowFn(), + }, + err: nil, + }, }, } @@ -194,7 +200,8 @@ func TestGetTraces_Success(t *testing.T) { traces, err := jiter.FlattenWithErrors(getTracesIter) require.NoError(t, err) - verifyQuerySnapshot(t, conn.recordedQuery) + require.Len(t, conn.recordedQueries, 1) + verifyQuerySnapshot(t, conn.recordedQueries[0]) requireTracesEqual(t, tt.data, traces) }) } @@ -209,8 +216,13 @@ func TestGetTraces_ErrorCases(t *testing.T) { { name: "QueryError", driver: &testDriver{ - t: t, - err: assert.AnError, + t: t, + queryResponses: map[string]*testQueryResponse{ + sql.SelectSpansByTraceID: { + rows: nil, + err: assert.AnError, + }, + }, }, expectedErr: "failed to query trace", }, @@ -218,9 +230,14 @@ func TestGetTraces_ErrorCases(t *testing.T) { name: "ScanError", driver: &testDriver{ t: t, - rows: &testRows[*dbmodel.SpanRow]{ - data: singleSpan, - scanErr: assert.AnError, + queryResponses: map[string]*testQueryResponse{ + sql.SelectSpansByTraceID: { + rows: &testRows[*dbmodel.SpanRow]{ + data: singleSpan, + scanErr: assert.AnError, + }, + err: nil, + }, }, }, expectedErr: "failed to scan span row", @@ -229,10 +246,15 @@ func TestGetTraces_ErrorCases(t *testing.T) { name: "CloseError", driver: &testDriver{ t: t, - rows: &testRows[*dbmodel.SpanRow]{ - data: singleSpan, - scanFn: scanSpanRowFn(), - closeErr: assert.AnError, + queryResponses: map[string]*testQueryResponse{ + sql.SelectSpansByTraceID: { + rows: &testRows[*dbmodel.SpanRow]{ + data: singleSpan, + scanFn: scanSpanRowFn(), + closeErr: assert.AnError, + }, + err: nil, + }, }, }, expectedErr: "failed to close rows", @@ -264,9 +286,14 @@ func TestGetTraces_ScanErrorContinues(t *testing.T) { conn := &testDriver{ t: t, - rows: &testRows[*dbmodel.SpanRow]{ - data: multipleSpans, - scanFn: scanFn, + queryResponses: map[string]*testQueryResponse{ + sql.SelectSpansByTraceID: { + rows: &testRows[*dbmodel.SpanRow]{ + data: multipleSpans, + scanFn: scanFn, + }, + err: nil, + }, }, } @@ -288,9 +315,14 @@ func TestGetTraces_ScanErrorContinues(t *testing.T) { func TestGetTraces_YieldFalseOnSuccessStopsIteration(t *testing.T) { conn := &testDriver{ t: t, - rows: &testRows[*dbmodel.SpanRow]{ - data: multipleSpans, - scanFn: scanSpanRowFn(), + queryResponses: map[string]*testQueryResponse{ + sql.SelectSpansByTraceID: { + rows: &testRows[*dbmodel.SpanRow]{ + data: multipleSpans, + scanFn: scanSpanRowFn(), + }, + err: nil, + }, }, } @@ -321,19 +353,24 @@ func TestGetServices(t *testing.T) { name: "successfully returns services", conn: &testDriver{ t: t, - rows: &testRows[dbmodel.Service]{ - data: []dbmodel.Service{ - {Name: "serviceA"}, - {Name: "serviceB"}, - {Name: "serviceC"}, - }, - scanFn: func(dest any, src dbmodel.Service) error { - svc, ok := dest.(*dbmodel.Service) - if !ok { - return errors.New("dest is not *dbmodel.Service") - } - *svc = src - return nil + queryResponses: map[string]*testQueryResponse{ + sql.SelectServices: { + rows: &testRows[dbmodel.Service]{ + data: []dbmodel.Service{ + {Name: "serviceA"}, + {Name: "serviceB"}, + {Name: "serviceC"}, + }, + scanFn: func(dest any, src dbmodel.Service) error { + svc, ok := dest.(*dbmodel.Service) + if !ok { + return errors.New("dest is not *dbmodel.Service") + } + *svc = src + return nil + }, + }, + err: nil, }, }, }, @@ -342,8 +379,13 @@ func TestGetServices(t *testing.T) { { name: "query error", conn: &testDriver{ - t: t, - err: assert.AnError, + t: t, + queryResponses: map[string]*testQueryResponse{ + sql.SelectServices: { + rows: nil, + err: assert.AnError, + }, + }, }, expectError: "failed to query services", }, @@ -351,21 +393,26 @@ func TestGetServices(t *testing.T) { name: "scan error", conn: &testDriver{ t: t, - rows: &testRows[dbmodel.Service]{ - data: []dbmodel.Service{ - {Name: "serviceA"}, - {Name: "serviceB"}, - {Name: "serviceC"}, - }, - scanFn: func(dest any, src dbmodel.Service) error { - svc, ok := dest.(*dbmodel.Service) - if !ok { - return errors.New("dest is not *dbmodel.Service") - } - *svc = src - return nil + queryResponses: map[string]*testQueryResponse{ + sql.SelectServices: { + rows: &testRows[dbmodel.Service]{ + data: []dbmodel.Service{ + {Name: "serviceA"}, + {Name: "serviceB"}, + {Name: "serviceC"}, + }, + scanFn: func(dest any, src dbmodel.Service) error { + svc, ok := dest.(*dbmodel.Service) + if !ok { + return errors.New("dest is not *dbmodel.Service") + } + *svc = src + return nil + }, + scanErr: assert.AnError, + }, + err: nil, }, - scanErr: assert.AnError, }, }, expectError: "failed to scan row", @@ -382,7 +429,8 @@ func TestGetServices(t *testing.T) { require.ErrorContains(t, err, test.expectError) } else { require.NoError(t, err) - verifyQuerySnapshot(t, test.conn.recordedQuery) + require.Len(t, test.conn.recordedQueries, 1) + verifyQuerySnapshot(t, test.conn.recordedQueries[0]) require.Equal(t, test.expected, result) } }) @@ -401,19 +449,24 @@ func TestGetOperations(t *testing.T) { name: "successfully returns operations for all kinds", conn: &testDriver{ t: t, - rows: &testRows[dbmodel.Operation]{ - data: []dbmodel.Operation{ - {Name: "operationA"}, - {Name: "operationB"}, - {Name: "operationC"}, - }, - scanFn: func(dest any, src dbmodel.Operation) error { - svc, ok := dest.(*dbmodel.Operation) - if !ok { - return errors.New("dest is not *dbmodel.Operation") - } - *svc = src - return nil + queryResponses: map[string]*testQueryResponse{ + sql.SelectOperationsAllKinds: { + rows: &testRows[dbmodel.Operation]{ + data: []dbmodel.Operation{ + {Name: "operationA"}, + {Name: "operationB"}, + {Name: "operationC"}, + }, + scanFn: func(dest any, src dbmodel.Operation) error { + svc, ok := dest.(*dbmodel.Operation) + if !ok { + return errors.New("dest is not *dbmodel.Operation") + } + *svc = src + return nil + }, + }, + err: nil, }, }, }, @@ -436,19 +489,24 @@ func TestGetOperations(t *testing.T) { name: "successfully returns operations by kind", conn: &testDriver{ t: t, - rows: &testRows[dbmodel.Operation]{ - data: []dbmodel.Operation{ - {Name: "operationA", SpanKind: "server"}, - {Name: "operationB", SpanKind: "server"}, - {Name: "operationC", SpanKind: "server"}, - }, - scanFn: func(dest any, src dbmodel.Operation) error { - svc, ok := dest.(*dbmodel.Operation) - if !ok { - return errors.New("dest is not *dbmodel.Operation") - } - *svc = src - return nil + queryResponses: map[string]*testQueryResponse{ + sql.SelectOperationsByKind: { + rows: &testRows[dbmodel.Operation]{ + data: []dbmodel.Operation{ + {Name: "operationA", SpanKind: "server"}, + {Name: "operationB", SpanKind: "server"}, + {Name: "operationC", SpanKind: "server"}, + }, + scanFn: func(dest any, src dbmodel.Operation) error { + svc, ok := dest.(*dbmodel.Operation) + if !ok { + return errors.New("dest is not *dbmodel.Operation") + } + *svc = src + return nil + }, + }, + err: nil, }, }, }, @@ -474,8 +532,13 @@ func TestGetOperations(t *testing.T) { { name: "query error", conn: &testDriver{ - t: t, - err: assert.AnError, + t: t, + queryResponses: map[string]*testQueryResponse{ + sql.SelectOperationsAllKinds: { + rows: nil, + err: assert.AnError, + }, + }, }, expectError: "failed to query operations", }, @@ -483,21 +546,26 @@ func TestGetOperations(t *testing.T) { name: "scan error", conn: &testDriver{ t: t, - rows: &testRows[dbmodel.Operation]{ - data: []dbmodel.Operation{ - {Name: "operationA"}, - {Name: "operationB"}, - {Name: "operationC"}, - }, - scanFn: func(dest any, src dbmodel.Operation) error { - svc, ok := dest.(*dbmodel.Operation) - if !ok { - return errors.New("dest is not *dbmodel.Operation") - } - *svc = src - return nil + queryResponses: map[string]*testQueryResponse{ + sql.SelectOperationsAllKinds: { + rows: &testRows[dbmodel.Operation]{ + data: []dbmodel.Operation{ + {Name: "operationA"}, + {Name: "operationB"}, + {Name: "operationC"}, + }, + scanFn: func(dest any, src dbmodel.Operation) error { + svc, ok := dest.(*dbmodel.Operation) + if !ok { + return errors.New("dest is not *dbmodel.Operation") + } + *svc = src + return nil + }, + scanErr: assert.AnError, + }, + err: nil, }, - scanErr: assert.AnError, }, }, expectError: "failed to scan row", @@ -514,7 +582,8 @@ func TestGetOperations(t *testing.T) { require.ErrorContains(t, err, test.expectError) } else { require.NoError(t, err) - verifyQuerySnapshot(t, test.conn.recordedQuery) + require.Len(t, test.conn.recordedQueries, 1) + verifyQuerySnapshot(t, test.conn.recordedQueries[0]) require.Equal(t, test.expected, result) } }) @@ -540,9 +609,14 @@ func TestFindTraces_Success(t *testing.T) { t.Run(tt.name, func(t *testing.T) { conn := &testDriver{ t: t, - rows: &testRows[*dbmodel.SpanRow]{ - data: tt.data, - scanFn: scanSpanRowFn(), + queryResponses: map[string]*testQueryResponse{ + sql.SelectSpansQuery: { + rows: &testRows[*dbmodel.SpanRow]{ + data: tt.data, + scanFn: scanSpanRowFn(), + }, + err: nil, + }, }, } @@ -553,7 +627,8 @@ func TestFindTraces_Success(t *testing.T) { traces, err := jiter.FlattenWithErrors(findTracesIter) require.NoError(t, err) - verifyQuerySnapshot(t, conn.recordedQuery) + require.Len(t, conn.recordedQueries, 1) + verifyQuerySnapshot(t, conn.recordedQueries[0]) requireTracesEqual(t, tt.data, traces) }) } @@ -562,9 +637,14 @@ func TestFindTraces_Success(t *testing.T) { func TestFindTraces_WithFilters(t *testing.T) { conn := &testDriver{ t: t, - rows: &testRows[*dbmodel.SpanRow]{ - data: multipleSpans, - scanFn: scanSpanRowFn(), + queryResponses: map[string]*testQueryResponse{ + sql.SelectSpansQuery: { + rows: &testRows[*dbmodel.SpanRow]{ + data: multipleSpans, + scanFn: scanSpanRowFn(), + }, + err: nil, + }, }, } @@ -594,16 +674,22 @@ func TestFindTraces_WithFilters(t *testing.T) { }) traces, err := jiter.FlattenWithErrors(iter) require.NoError(t, err) - verifyQuerySnapshot(t, conn.recordedQuery) + require.Len(t, conn.recordedQueries, 1) + verifyQuerySnapshot(t, conn.recordedQueries[0]) requireTracesEqual(t, multipleSpans, traces) } func TestFindTraces_SearchDepthExceedsMax(t *testing.T) { driver := &testDriver{ t: t, - rows: &testRows[*dbmodel.SpanRow]{ - data: singleSpan, - scanFn: scanSpanRowFn(), + queryResponses: map[string]*testQueryResponse{ + sql.SelectSpansQuery: { + rows: &testRows[*dbmodel.SpanRow]{ + data: singleSpan, + scanFn: scanSpanRowFn(), + }, + err: nil, + }, }, } reader := NewReader(driver, testReaderConfig) @@ -618,9 +704,14 @@ func TestFindTraces_SearchDepthExceedsMax(t *testing.T) { func TestFindTraces_YieldFalseOnSuccessStopsIteration(t *testing.T) { conn := &testDriver{ t: t, - rows: &testRows[*dbmodel.SpanRow]{ - data: multipleSpans, - scanFn: scanSpanRowFn(), + queryResponses: map[string]*testQueryResponse{ + sql.SelectSpansQuery: { + rows: &testRows[*dbmodel.SpanRow]{ + data: multipleSpans, + scanFn: scanSpanRowFn(), + }, + err: nil, + }, }, } @@ -653,9 +744,14 @@ func TestFindTraces_ScanErrorContinues(t *testing.T) { conn := &testDriver{ t: t, - rows: &testRows[*dbmodel.SpanRow]{ - data: multipleSpans, - scanFn: scanFn, + queryResponses: map[string]*testQueryResponse{ + sql.SelectSpansQuery: { + rows: &testRows[*dbmodel.SpanRow]{ + data: multipleSpans, + scanFn: scanFn, + }, + err: nil, + }, }, } @@ -683,8 +779,13 @@ func TestFindTraces_ErrorCases(t *testing.T) { { name: "QueryError", driver: &testDriver{ - t: t, - err: assert.AnError, + t: t, + queryResponses: map[string]*testQueryResponse{ + sql.SelectSpansQuery: { + rows: nil, + err: assert.AnError, + }, + }, }, expectedErr: "failed to query traces", }, @@ -692,9 +793,14 @@ func TestFindTraces_ErrorCases(t *testing.T) { name: "ScanError", driver: &testDriver{ t: t, - rows: &testRows[*dbmodel.SpanRow]{ - data: singleSpan, - scanErr: assert.AnError, + queryResponses: map[string]*testQueryResponse{ + sql.SelectSpansQuery: { + rows: &testRows[*dbmodel.SpanRow]{ + data: singleSpan, + scanErr: assert.AnError, + }, + err: nil, + }, }, }, expectedErr: "failed to scan span row", @@ -736,9 +842,14 @@ func TestFindTraces_BuildQueryError(t *testing.T) { func TestFindTraceIDs(t *testing.T) { driver := &testDriver{ t: t, - rows: &testRows[[]any]{ - data: testTraceIDsData, - scanFn: scanTraceIDFn(), + queryResponses: map[string]*testQueryResponse{ + sql.SearchTraceIDs: { + rows: &testRows[[]any]{ + data: testTraceIDsData, + scanFn: scanTraceIDFn(), + }, + err: nil, + }, }, } reader := NewReader(driver, testReaderConfig) @@ -767,7 +878,8 @@ func TestFindTraceIDs(t *testing.T) { }) ids, err := jiter.FlattenWithErrors(iter) require.NoError(t, err) - verifyQuerySnapshot(t, driver.recordedQuery) + require.Len(t, driver.recordedQueries, 1) + verifyQuerySnapshot(t, driver.recordedQueries[0]) require.Equal(t, []tracestore.FoundTraceID{ { TraceID: pcommon.TraceID([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}), @@ -783,20 +895,25 @@ func TestFindTraceIDs(t *testing.T) { func TestFindTraceIDs_SearchDepthExceedsMax(t *testing.T) { driver := &testDriver{ t: t, - rows: &testRows[[]any]{ - data: [][]any{ - { - "00000000000000000000000000000001", - time.Now().Add(-1 * time.Hour), - time.Now().Add(-1 * time.Minute), - }, - { - "00000000000000000000000000000002", - time.Now().Add(-2 * time.Hour), - time.Now().Add(-2 * time.Minute), + queryResponses: map[string]*testQueryResponse{ + sql.SearchTraceIDs: { + rows: &testRows[[]any]{ + data: [][]any{ + { + "00000000000000000000000000000001", + time.Now().Add(-1 * time.Hour), + time.Now().Add(-1 * time.Minute), + }, + { + "00000000000000000000000000000002", + time.Now().Add(-2 * time.Hour), + time.Now().Add(-2 * time.Minute), + }, + }, + scanFn: scanTraceIDFn(), }, + err: nil, }, - scanFn: scanTraceIDFn(), }, } reader := NewReader(driver, testReaderConfig) @@ -810,9 +927,14 @@ func TestFindTraceIDs_SearchDepthExceedsMax(t *testing.T) { func TestFindTraceIDs_YieldFalseOnSuccessStopsIteration(t *testing.T) { conn := &testDriver{ t: t, - rows: &testRows[[]any]{ - data: testTraceIDsData, - scanFn: scanTraceIDFn(), + queryResponses: map[string]*testQueryResponse{ + sql.SearchTraceIDs: { + rows: &testRows[[]any]{ + data: testTraceIDsData, + scanFn: scanTraceIDFn(), + }, + err: nil, + }, }, } @@ -851,9 +973,14 @@ func TestFindTraceIDs_ScanErrorContinues(t *testing.T) { conn := &testDriver{ t: t, - rows: &testRows[[]any]{ - data: testTraceIDsData, - scanFn: scanFn, + queryResponses: map[string]*testQueryResponse{ + sql.SearchTraceIDs: { + rows: &testRows[[]any]{ + data: testTraceIDsData, + scanFn: scanFn, + }, + err: nil, + }, }, } @@ -880,22 +1007,27 @@ func TestFindTraceIDs_ScanErrorContinues(t *testing.T) { func TestFindTraceIDs_DecodeErrorContinues(t *testing.T) { conn := &testDriver{ t: t, - rows: &testRows[[]any]{ - data: [][]any{ - testTraceIDsData[0], - { - "0x", - time.Now().Add(-2 * time.Hour), - time.Now().Add(-2 * time.Minute), - }, - { - "invalid", - time.Now().Add(-3 * time.Hour), - time.Now().Add(-3 * time.Minute), + queryResponses: map[string]*testQueryResponse{ + sql.SearchTraceIDs: { + rows: &testRows[[]any]{ + data: [][]any{ + testTraceIDsData[0], + { + "0x", + time.Now().Add(-2 * time.Hour), + time.Now().Add(-2 * time.Minute), + }, + { + "invalid", + time.Now().Add(-3 * time.Hour), + time.Now().Add(-3 * time.Minute), + }, + testTraceIDsData[1], + }, + scanFn: scanTraceIDFn(), }, - testTraceIDsData[1], + err: nil, }, - scanFn: scanTraceIDFn(), }, } @@ -940,8 +1072,13 @@ func TestFindTraceIDs_ErrorCases(t *testing.T) { { name: "QueryError", driver: &testDriver{ - t: t, - err: assert.AnError, + t: t, + queryResponses: map[string]*testQueryResponse{ + sql.SearchTraceIDs: { + rows: nil, + err: assert.AnError, + }, + }, }, expectedErr: "failed to query trace IDs", }, @@ -949,9 +1086,14 @@ func TestFindTraceIDs_ErrorCases(t *testing.T) { name: "ScanError", driver: &testDriver{ t: t, - rows: &testRows[[]any]{ - data: testTraceIDsData, - scanErr: assert.AnError, + queryResponses: map[string]*testQueryResponse{ + sql.SearchTraceIDs: { + rows: &testRows[[]any]{ + data: testTraceIDsData, + scanErr: assert.AnError, + }, + err: nil, + }, }, }, expectedErr: "failed to scan row", @@ -960,15 +1102,20 @@ func TestFindTraceIDs_ErrorCases(t *testing.T) { name: "DecodeError", driver: &testDriver{ t: t, - rows: &testRows[[]any]{ - data: [][]any{ - { - "0x", - time.Now().Add(-1 * time.Hour), - time.Now().Add(-1 * time.Minute), + queryResponses: map[string]*testQueryResponse{ + sql.SearchTraceIDs: { + rows: &testRows[[]any]{ + data: [][]any{ + { + "0x", + time.Now().Add(-1 * time.Hour), + time.Now().Add(-1 * time.Minute), + }, + }, + scanFn: scanTraceIDFn(), }, + err: nil, }, - scanFn: scanTraceIDFn(), }, }, expectedErr: "failed to decode trace ID", diff --git a/internal/storage/v2/clickhouse/tracestore/writer_test.go b/internal/storage/v2/clickhouse/tracestore/writer_test.go index 381b5860d89..04ad46dd3c6 100644 --- a/internal/storage/v2/clickhouse/tracestore/writer_test.go +++ b/internal/storage/v2/clickhouse/tracestore/writer_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "go.opentelemetry.io/collector/pdata/ptrace" + "github.com/jaegertracing/jaeger/internal/storage/v2/clickhouse/sql" "github.com/jaegertracing/jaeger/internal/storage/v2/clickhouse/tracestore/dbmodel" ) @@ -29,9 +30,14 @@ func tracesFromSpanRows(rows []*dbmodel.SpanRow) ptrace.Traces { } func TestWriter_Success(t *testing.T) { + b := &testBatch{t: t} conn := &testDriver{ - t: t, - batch: &testBatch{t: t}, + t: t, + batchResponses: map[string]*testBatchResponse{ + sql.InsertSpan: { + batch: b, + }, + }, } w := NewWriter(conn) @@ -40,12 +46,13 @@ func TestWriter_Success(t *testing.T) { err := w.WriteTraces(context.Background(), td) require.NoError(t, err) - verifyQuerySnapshot(t, conn.recordedQuery) - require.True(t, conn.batch.sendCalled) - require.Len(t, conn.batch.appended, len(multipleSpans)) + require.Len(t, conn.recordedQueries, 1) + verifyQuerySnapshot(t, conn.recordedQueries[0]) + require.True(t, b.sendCalled) + require.Len(t, b.appended, len(multipleSpans)) for i, expected := range multipleSpans { - row := conn.batch.appended[i] + row := b.appended[i] require.Equal(t, expected.ID, row[0]) // SpanID require.Equal(t, expected.TraceID, row[1]) // TraceID @@ -140,39 +147,52 @@ func TestWriter_Success(t *testing.T) { func TestWriter_PrepareBatchError(t *testing.T) { conn := &testDriver{ - t: t, - err: assert.AnError, - batch: &testBatch{t: t}, + t: t, + batchResponses: map[string]*testBatchResponse{ + sql.InsertSpan: { + batch: nil, + err: assert.AnError, + }, + }, } w := NewWriter(conn) err := w.WriteTraces(context.Background(), tracesFromSpanRows(multipleSpans)) require.ErrorContains(t, err, "failed to prepare batch") require.ErrorIs(t, err, assert.AnError) - require.False(t, conn.batch.sendCalled) } func TestWriter_AppendBatchError(t *testing.T) { + b := &testBatch{t: t, appendErr: assert.AnError} conn := &testDriver{ - t: t, - batch: &testBatch{t: t, appendErr: assert.AnError}, + t: t, + batchResponses: map[string]*testBatchResponse{ + sql.InsertSpan: { + batch: b, + }, + }, } w := NewWriter(conn) err := w.WriteTraces(context.Background(), tracesFromSpanRows(multipleSpans)) require.ErrorContains(t, err, "failed to append span to batch") require.ErrorIs(t, err, assert.AnError) - require.False(t, conn.batch.sendCalled) + require.False(t, b.sendCalled) } func TestWriter_SendError(t *testing.T) { + b := &testBatch{t: t, sendErr: assert.AnError} conn := &testDriver{ - t: t, - batch: &testBatch{t: t, sendErr: assert.AnError}, + t: t, + batchResponses: map[string]*testBatchResponse{ + sql.InsertSpan: { + batch: b, + }, + }, } w := NewWriter(conn) err := w.WriteTraces(context.Background(), tracesFromSpanRows(multipleSpans)) require.ErrorContains(t, err, "failed to send batch") require.ErrorIs(t, err, assert.AnError) - require.False(t, conn.batch.sendCalled) + require.False(t, b.sendCalled) } func TestToTuple(t *testing.T) {