Skip to content

Commit

Permalink
Reference Table DML Join Fix (#17414)
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal authored Jan 7, 2025
1 parent 68861b1 commit 1ce7550
Show file tree
Hide file tree
Showing 18 changed files with 353 additions and 93 deletions.
12 changes: 12 additions & 0 deletions go/test/endtoend/utils/cmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,18 @@ func (mcmp *MySQLCompare) Exec(query string) *sqltypes.Result {
return vtQr
}

// ExecVitessAndMySQL executes Vitess and MySQL with the queries provided.
func (mcmp *MySQLCompare) ExecVitessAndMySQL(vtQ, mQ string) *sqltypes.Result {
mcmp.t.Helper()
vtQr, err := mcmp.VtConn.ExecuteFetch(vtQ, 1000, true)
require.NoError(mcmp.t, err, "[Vitess Error] for query: "+vtQ)

mysqlQr, err := mcmp.MySQLConn.ExecuteFetch(mQ, 1000, true)
require.NoError(mcmp.t, err, "[MySQL Error] for query: "+mQ)
compareVitessAndMySQLResults(mcmp.t, vtQ, mcmp.VtConn, vtQr, mysqlQr, CompareOptions{})
return vtQr
}

// ExecAssert is the same as Exec, but it only does assertions, it won't FailNow
func (mcmp *MySQLCompare) ExecAssert(query string) *sqltypes.Result {
mcmp.t.Helper()
Expand Down
5 changes: 3 additions & 2 deletions go/test/endtoend/vtgate/plan_tests/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql"
Expand Down Expand Up @@ -86,7 +87,7 @@ func TestMain(m *testing.M) {
// TODO: (@GuptaManan100/@systay): Also run the tests with normalizer on.
clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs,
"--normalize_queries=false",
"--schema_change_signal=false",
"--schema_change_signal=true",
)

// Start vtgate
Expand Down Expand Up @@ -178,7 +179,7 @@ func verifyTestExpectations(t *testing.T, pd engine.PrimitiveDescription, test p
// 1. Verify that the Join primitive sees atleast 1 row on the left side.
engine.WalkPrimitiveDescription(pd, func(description engine.PrimitiveDescription) {
if description.OperatorType == "Join" {
require.NotZero(t, description.Inputs[0].RowsReceived[0])
assert.NotZero(t, description.Inputs[0].RowsReceived[0])
}
})

Expand Down
19 changes: 17 additions & 2 deletions go/test/endtoend/vtgate/plan_tests/plan_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,22 @@ package plan_tests
import (
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/endtoend/utils"
"vitess.io/vitess/go/vt/sqlparser"
)

func TestE2ECases(t *testing.T) {
e2eTestCaseFiles := []string{"select_cases.json", "filter_cases.json", "dml_cases.json"}
err := utils.WaitForAuthoritative(t, "main", "source_of_ref", clusterInstance.VtgateProcess.ReadVSchema)
require.NoError(t, err)

e2eTestCaseFiles := []string{
"select_cases.json",
"filter_cases.json",
"dml_cases.json",
"reference_cases.json",
}
mcmp, closer := start(t)
defer closer()
loadSampleData(t, mcmp)
Expand All @@ -34,7 +45,11 @@ func TestE2ECases(t *testing.T) {
if test.SkipE2E {
mcmp.AsT().Skip(test.Query)
}
mcmp.Exec(test.Query)
stmt, err := sqlparser.NewTestParser().Parse(test.Query)
require.NoError(mcmp.AsT(), err)
sqlparser.RemoveKeyspaceIgnoreSysSchema(stmt)

mcmp.ExecVitessAndMySQL(test.Query, sqlparser.String(stmt))
pd := utils.ExecTrace(mcmp.AsT(), mcmp.VtConn, test.Query)
verifyTestExpectations(mcmp.AsT(), pd, test)
if mcmp.VtConn.IsClosed() {
Expand Down
19 changes: 19 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2440,6 +2440,25 @@ func RemoveKeyspace(in SQLNode) {
})
}

// RemoveKeyspaceIgnoreSysSchema removes the Qualifier.Qualifier on all ColNames and Qualifier on all TableNames in the AST
// except for the system schema.
func RemoveKeyspaceIgnoreSysSchema(in SQLNode) {
Rewrite(in, nil, func(cursor *Cursor) bool {
switch expr := cursor.Node().(type) {
case *ColName:
if expr.Qualifier.Qualifier.NotEmpty() && !SystemSchema(expr.Qualifier.Qualifier.String()) {
expr.Qualifier.Qualifier = NewIdentifierCS("")
}
case TableName:
if expr.Qualifier.NotEmpty() && !SystemSchema(expr.Qualifier.String()) {
expr.Qualifier = NewIdentifierCS("")
cursor.Replace(expr)
}
}
return true
})
}

func convertStringToInt(integer string) int {
val, _ := strconv.Atoi(integer)
return val
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/cte_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func tryMergeRecurse(ctx *plancontext.PlanningContext, in *RecurseCTE) (Operator
}

func tryMergeCTE(ctx *plancontext.PlanningContext, seed, term Operator, in *RecurseCTE) *Route {
seedRoute, termRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(seed, term)
seedRoute, termRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(ctx, seed, term)
if seedRoute == nil {
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func createDeleteWithInputOp(ctx *plancontext.PlanningContext, del *sqlparser.De
}

var delOps []dmlOp
for _, target := range ctx.SemTable.Targets.Constituents() {
for _, target := range ctx.SemTable.DMLTargets.Constituents() {
op := createDeleteOpWithTarget(ctx, target, del.Ignore)
delOps = append(delOps, op)
}
Expand Down Expand Up @@ -322,7 +322,7 @@ func updateQueryGraphWithSource(ctx *plancontext.PlanningContext, input Operator
return op, NoRewrite
}
if len(qg.Tables) > 1 {
panic(vterrors.VT12001("DELETE on reference table with join"))
panic(vterrors.VT12001("DML on reference table with join"))
}
for _, tbl := range qg.Tables {
if tbl.ID != tblID {
Expand Down
16 changes: 10 additions & 6 deletions go/vt/vtgate/planbuilder/operators/join_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
// If they can be merged, a new operator with the merged routing is returned
// If they cannot be merged, nil is returned.
func (jm *joinMerger) mergeJoinInputs(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredicates []sqlparser.Expr) *Route {
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(lhs, rhs)
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(ctx, lhs, rhs)
if lhsRoute == nil {
return nil
}
Expand Down Expand Up @@ -102,13 +102,13 @@ func mergeAnyShardRoutings(ctx *plancontext.PlanningContext, a, b *AnyShardRouti
}
}

func prepareInputRoutes(lhs Operator, rhs Operator) (*Route, *Route, Routing, Routing, routingType, routingType, bool) {
func prepareInputRoutes(ctx *plancontext.PlanningContext, lhs Operator, rhs Operator) (*Route, *Route, Routing, Routing, routingType, routingType, bool) {
lhsRoute, rhsRoute := operatorsToRoutes(lhs, rhs)
if lhsRoute == nil || rhsRoute == nil {
return nil, nil, nil, nil, 0, 0, false
}

lhsRoute, rhsRoute, routingA, routingB, sameKeyspace := getRoutesOrAlternates(lhsRoute, rhsRoute)
lhsRoute, rhsRoute, routingA, routingB, sameKeyspace := getRoutesOrAlternates(ctx, lhsRoute, rhsRoute)

a, b := getRoutingType(routingA), getRoutingType(routingB)
return lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace
Expand Down Expand Up @@ -159,7 +159,7 @@ func (rt routingType) String() string {

// getRoutesOrAlternates gets the Routings from each Route. If they are from different keyspaces,
// we check if this is a table with alternates in other keyspaces that we can use
func getRoutesOrAlternates(lhsRoute, rhsRoute *Route) (*Route, *Route, Routing, Routing, bool) {
func getRoutesOrAlternates(ctx *plancontext.PlanningContext, lhsRoute, rhsRoute *Route) (*Route, *Route, Routing, Routing, bool) {
routingA := lhsRoute.Routing
routingB := rhsRoute.Routing
sameKeyspace := routingA.Keyspace() == routingB.Keyspace()
Expand All @@ -171,13 +171,17 @@ func getRoutesOrAlternates(lhsRoute, rhsRoute *Route) (*Route, *Route, Routing,
return lhsRoute, rhsRoute, routingA, routingB, sameKeyspace
}

if refA, ok := routingA.(*AnyShardRouting); ok {
// If we have a reference route, we will try to find an alternate route in same keyspace as other routing keyspace.
// If the reference route is part of DML table update target, alternate keyspace route cannot be considered.
if refA, ok := routingA.(*AnyShardRouting); ok &&
!TableID(lhsRoute).IsOverlapping(ctx.SemTable.DMLTargets) {
if altARoute := refA.AlternateInKeyspace(routingB.Keyspace()); altARoute != nil {
return altARoute, rhsRoute, altARoute.Routing, routingB, true
}
}

if refB, ok := routingB.(*AnyShardRouting); ok {
if refB, ok := routingB.(*AnyShardRouting); ok &&
!TableID(rhsRoute).IsOverlapping(ctx.SemTable.DMLTargets) {
if altBRoute := refB.AlternateInKeyspace(routingA.Keyspace()); altBRoute != nil {
return lhsRoute, altBRoute, routingA, altBRoute.Routing, true
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/subquery_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ func mergeSubqueryInputs(ctx *plancontext.PlanningContext, in, out Operator, joi
return nil
}

inRoute, outRoute, inRouting, outRouting, sameKeyspace := getRoutesOrAlternates(inRoute, outRoute)
inRoute, outRoute, inRouting, outRouting, sameKeyspace := getRoutesOrAlternates(ctx, inRoute, outRoute)
inner, outer := getRoutingType(inRouting), getRoutingType(outRouting)

switch {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/union_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func mergeUnionInputs(
lhsExprs, rhsExprs sqlparser.SelectExprs,
distinct bool,
) (Operator, sqlparser.SelectExprs) {
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(lhs, rhs)
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(ctx, lhs, rhs)
if lhsRoute == nil {
return nil, nil
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func createUpdateWithInputOp(ctx *plancontext.PlanningContext, upd *sqlparser.Up
ueMap := prepareUpdateExpressionList(ctx, upd)

var updOps []dmlOp
for _, target := range ctx.SemTable.Targets.Constituents() {
for _, target := range ctx.SemTable.DMLTargets.Constituents() {
op := createUpdateOpWithTarget(ctx, upd, target, ueMap[target])
updOps = append(updOps, op)
}
Expand Down Expand Up @@ -308,7 +308,7 @@ func errIfUpdateNotSupported(ctx *plancontext.PlanningContext, stmt *sqlparser.U
}
}

// Now we check if any of the foreign key columns that are being udpated have dependencies on other updated columns.
// Now we check if any of the foreign key columns that are being updated have dependencies on other updated columns.
// This is unsafe, and we currently don't support this in Vitess.
if err := ctx.SemTable.ErrIfFkDependentColumnUpdated(stmt.Exprs); err != nil {
panic(err)
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/planbuilder/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func (s *planTestSuite) TestPlan() {
s.addPKsProvided(vschema, "user", []string{"user_extra"}, []string{"id", "user_id"})
s.addPKsProvided(vschema, "ordering", []string{"order"}, []string{"oid", "region_id"})
s.addPKsProvided(vschema, "ordering", []string{"order_event"}, []string{"oid", "ename"})
s.addPKsProvided(vschema, "main", []string{"source_of_ref"}, []string{"id"})

// You will notice that some tests expect user.Id instead of user.id.
// This is because we now pre-create vindex columns in the symbol
Expand Down Expand Up @@ -305,6 +306,7 @@ func (s *planTestSuite) TestOne() {
s.addPKsProvided(vschema, "user", []string{"user_extra"}, []string{"id", "user_id"})
s.addPKsProvided(vschema, "ordering", []string{"order"}, []string{"oid", "region_id"})
s.addPKsProvided(vschema, "ordering", []string{"order_event"}, []string{"oid", "ename"})
s.addPKsProvided(vschema, "main", []string{"source_of_ref"}, []string{"id"})

s.testFile("onecase.json", vw, false)
}
Expand Down Expand Up @@ -666,7 +668,7 @@ func (s *planTestSuite) testFile(filename string, vschema *vschemawrapper.VSchem
current := PlanTest{
Comment: tcase.Comment,
Query: tcase.Query,
SkipE2E: true,
SkipE2E: tcase.SkipE2E,
}
vschema.Version = Gen4
out := getPlanOutput(tcase, vschema, render)
Expand Down
Loading

0 comments on commit 1ce7550

Please sign in to comment.