From e7215979ddf882ec0dc38956040ea4fe7e55589a Mon Sep 17 00:00:00 2001 From: "vitess-bot[bot]" <108069721+vitess-bot[bot]@users.noreply.github.com> Date: Mon, 30 Dec 2024 11:07:58 +0530 Subject: [PATCH] [release-20.0] Fix Data race in semi-join (#17417) (#17446) Signed-off-by: Manan Gupta Co-authored-by: Manan Gupta <35839558+GuptaManan100@users.noreply.github.com> Co-authored-by: Manan Gupta --- .../endtoend/vtgate/queries/misc/misc_test.go | 20 +++++ go/vt/vtgate/engine/fake_primitive_test.go | 7 +- go/vt/vtgate/engine/semi_join.go | 13 +-- go/vt/vtgate/engine/semi_join_test.go | 79 +++++++++++++++++++ 4 files changed, 112 insertions(+), 7 deletions(-) diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index bee98096fab..7280ea81d8e 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -480,3 +480,23 @@ func TestEnumSetVals(t *testing.T) { mcmp.AssertMatches("select id, enum_col, cast(enum_col as signed) from tbl_enum_set order by enum_col, id", `[[INT64(4) ENUM("xsmall") INT64(1)] [INT64(2) ENUM("small") INT64(2)] [INT64(1) ENUM("medium") INT64(3)] [INT64(5) ENUM("medium") INT64(3)] [INT64(3) ENUM("large") INT64(4)]]`) mcmp.AssertMatches("select id, set_col, cast(set_col as unsigned) from tbl_enum_set order by set_col, id", `[[INT64(4) SET("a,b") UINT64(3)] [INT64(3) SET("c") UINT64(4)] [INT64(5) SET("a,d") UINT64(9)] [INT64(1) SET("a,b,e") UINT64(19)] [INT64(2) SET("e,f,g") UINT64(112)]]`) } + +// TestSemiJoin tests that the semi join works as intended. +func TestSemiJoin(t *testing.T) { + mcmp, closer := start(t) + defer closer() + + for i := 1; i <= 1000; i++ { + mcmp.Exec(fmt.Sprintf("insert into t1(id1, id2) values (%d, %d)", i, 2*i)) + mcmp.Exec(fmt.Sprintf("insert into tbl(id, unq_col, nonunq_col) values (%d, %d, %d)", i, 2*i, 3*i)) + } + + // Test that the semi join works as intended + for _, mode := range []string{"oltp", "olap"} { + mcmp.Run(mode, func(mcmp *utils.MySQLCompare) { + utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode)) + + mcmp.Exec("select id1, id2 from t1 where exists (select id from tbl where nonunq_col = t1.id2) order by id1") + }) + } +} diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index 532d3ebb970..1a2b2f57120 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -40,7 +40,8 @@ type fakePrimitive struct { // sendErr is sent at the end of the stream if it's set. sendErr error - log []string + noLog bool + log []string allResultsInOneCall bool @@ -85,7 +86,9 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar } func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields)) + if !f.noLog { + f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields)) + } if f.results == nil { return f.sendErr } diff --git a/go/vt/vtgate/engine/semi_join.go b/go/vt/vtgate/engine/semi_join.go index de8eeef5a32..277afc863c0 100644 --- a/go/vt/vtgate/engine/semi_join.go +++ b/go/vt/vtgate/engine/semi_join.go @@ -18,6 +18,7 @@ package engine import ( "context" + "sync/atomic" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -62,24 +63,26 @@ func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma // TryStreamExecute performs a streaming exec. func (jn *SemiJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - joinVars := make(map[string]*querypb.BindVariable) err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error { + joinVars := make(map[string]*querypb.BindVariable) result := &sqltypes.Result{Fields: lresult.Fields} for _, lrow := range lresult.Rows { for k, col := range jn.Vars { joinVars[k] = sqltypes.ValueBindVariable(lrow[col]) } - rowAdded := false + var rowAdded atomic.Bool err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), false, func(rresult *sqltypes.Result) error { - if len(rresult.Rows) > 0 && !rowAdded { - result.Rows = append(result.Rows, lrow) - rowAdded = true + if len(rresult.Rows) > 0 { + rowAdded.Store(true) } return nil }) if err != nil { return err } + if rowAdded.Load() { + result.Rows = append(result.Rows, lrow) + } } return callback(result) }) diff --git a/go/vt/vtgate/engine/semi_join_test.go b/go/vt/vtgate/engine/semi_join_test.go index 8fee0490415..a103b0686b2 100644 --- a/go/vt/vtgate/engine/semi_join_test.go +++ b/go/vt/vtgate/engine/semi_join_test.go @@ -18,6 +18,7 @@ package engine import ( "context" + "sync" "testing" "vitess.io/vitess/go/test/utils" @@ -159,3 +160,81 @@ func TestSemiJoinStreamExecute(t *testing.T) { "4|d|dd", )) } + +// TestSemiJoinStreamExecuteParallelExecution tests SemiJoin stream execution with parallel execution +// to ensure we have no data races. +func TestSemiJoinStreamExecuteParallelExecution(t *testing.T) { + leftPrim := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "1|a|aa", + "2|b|bb", + ), sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "3|c|cc", + "4|d|dd", + ), + }, + async: true, + } + rightFields := sqltypes.MakeTestFields( + "col4|col5|col6", + "int64|varchar|varchar", + ) + rightPrim := &fakePrimitive{ + // we'll return non-empty results for rows 2 and 4 + results: sqltypes.MakeTestStreamingResults(rightFields, + "4|d|dd", + "---", + "---", + "5|e|ee", + "6|f|ff", + "7|g|gg", + ), + async: true, + noLog: true, + } + + jn := &SemiJoin{ + Left: leftPrim, + Right: rightPrim, + Vars: map[string]int{ + "bv": 1, + }, + } + var res *sqltypes.Result + var mu sync.Mutex + err := jn.TryStreamExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{}, true, func(result *sqltypes.Result) error { + mu.Lock() + defer mu.Unlock() + if res == nil { + res = result + } else { + res.Rows = append(res.Rows, result.Rows...) + } + return nil + }) + require.NoError(t, err) + leftPrim.ExpectLog(t, []string{ + `StreamExecute true`, + }) + // We'll get all the rows back in left primitive, since we're returning the same set of rows + // from the right primitive that makes them all qualify. + expectResultAnyOrder(t, res, sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "1|a|aa", + "2|b|bb", + "3|c|cc", + "4|d|dd", + )) +}