Skip to content

Commit

Permalink
Handle panics during parallel execution (#15450)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
Signed-off-by: Harshit Gangal <[email protected]>
Co-authored-by: Harshit Gangal <[email protected]>
  • Loading branch information
systay and harshit-gangal authored Mar 27, 2024
1 parent 308f1fc commit 7aec15f
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 0 deletions.
21 changes: 21 additions & 0 deletions go/vt/vtgate/scatter_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package vtgate
import (
"context"
"io"
"runtime/debug"
"sync"
"sync/atomic"
"time"

"vitess.io/vitess/go/mysql/sqlerror"
Expand Down Expand Up @@ -603,6 +605,12 @@ func (stc *ScatterConn) multiGo(
return allErrors
}

// panicData is used to capture panics during parallel execution.
type panicData struct {
p any
trace []byte
}

// multiGoTransaction performs the requested 'action' on the specified
// ResolvedShards in parallel. For each shard, if the requested
// session is in a transaction, it opens a new transactions on the connection,
Expand Down Expand Up @@ -660,15 +668,28 @@ func (stc *ScatterConn) multiGoTransaction(
oneShard(rs, i)
}
} else {
var panicRecord atomic.Value
var wg sync.WaitGroup
for i, rs := range rss {
wg.Add(1)
go func(rs *srvtopo.ResolvedShard, i int) {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
panicRecord.Store(&panicData{
p: r,
trace: debug.Stack(),
})
}
}()
oneShard(rs, i)
}(rs, i)
}
wg.Wait()
if pr, ok := panicRecord.Load().(*panicData); ok {
log.Errorf("caught a panic during parallel execution:\n%s", string(pr.trace))
panic(pr.p) // rethrow the captured panic in the main thread
}
}

if session.MustRollback() {
Expand Down
82 changes: 82 additions & 0 deletions go/vt/vtgate/scatter_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ limitations under the License.
package vtgate

import (
"fmt"
"testing"

"vitess.io/vitess/go/vt/log"

"vitess.io/vitess/go/mysql/sqlerror"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"

Expand Down Expand Up @@ -105,6 +108,85 @@ func TestExecuteFailOnAutocommit(t *testing.T) {
utils.MustMatch(t, []*querypb.BoundQuery{queries[1]}, sbc1.Queries, "")
}

func TestExecutePanic(t *testing.T) {
ctx := utils.LeakCheckContext(t)

createSandbox("TestExecutePanic")
hc := discovery.NewFakeHealthCheck(nil)
sc := newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa")
sbc0 := hc.AddTestTablet("aa", "0", 1, "TestExecutePanic", "0", topodatapb.TabletType_PRIMARY, true, 1, nil)
sbc1 := hc.AddTestTablet("aa", "1", 1, "TestExecutePanic", "1", topodatapb.TabletType_PRIMARY, true, 1, nil)
sbc0.SetPanic(42)
sbc1.SetPanic(42)
rss := []*srvtopo.ResolvedShard{
{
Target: &querypb.Target{
Keyspace: "TestExecutePanic",
Shard: "0",
TabletType: topodatapb.TabletType_PRIMARY,
},
Gateway: sbc0,
},
{
Target: &querypb.Target{
Keyspace: "TestExecutePanic",
Shard: "1",
TabletType: topodatapb.TabletType_PRIMARY,
},
Gateway: sbc1,
},
}
queries := []*querypb.BoundQuery{
{
// This will fail to go to shard. It will be rejected at vtgate.
Sql: "query1",
BindVariables: map[string]*querypb.BindVariable{
"bv0": sqltypes.Int64BindVariable(0),
},
},
{
// This will go to shard.
Sql: "query2",
BindVariables: map[string]*querypb.BindVariable{
"bv1": sqltypes.Int64BindVariable(1),
},
},
}
// shard 0 - has transaction
// shard 1 - does not have transaction.
session := &vtgatepb.Session{
InTransaction: true,
ShardSessions: []*vtgatepb.Session_ShardSession{
{
Target: &querypb.Target{Keyspace: "TestExecutePanic", Shard: "0", TabletType: topodatapb.TabletType_PRIMARY, Cell: "aa"},
TransactionId: 123,
TabletAlias: nil,
},
},
Autocommit: false,
}

original := log.Errorf
defer func() {
log.Errorf = original
}()

var logMessage string
log.Errorf = func(format string, args ...any) {
logMessage = fmt.Sprintf(format, args...)
}

defer func() {
r := recover()
require.NotNil(t, r, "The code did not panic")
// assert we are seeing the stack trace
require.Contains(t, logMessage, "(*ScatterConn).multiGoTransaction")
}()

_, _ = sc.ExecuteMultiShard(ctx, nil, rss, queries, NewSafeSession(session), true /*autocommit*/, false)

}

func TestReservedOnMultiReplica(t *testing.T) {
ctx := utils.LeakCheckContext(t)

Expand Down
28 changes: 28 additions & 0 deletions go/vt/vttablet/sandboxconn/sandboxconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ type SandboxConn struct {
// this error will only happen once
EphemeralShardErr error

// if this is not nil, any calls will panic the tablet
panicThis interface{}

NotServing bool

getSchemaResult []map[string]string
Expand Down Expand Up @@ -206,6 +209,7 @@ func (sbc *SandboxConn) SetSchemaResult(r []map[string]string) {

// Execute is part of the QueryService interface.
func (sbc *SandboxConn) Execute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (*sqltypes.Result, error) {
sbc.panicIfNeeded()
sbc.execMu.Lock()
defer sbc.execMu.Unlock()
sbc.ExecCount.Add(1)
Expand Down Expand Up @@ -238,6 +242,7 @@ func (sbc *SandboxConn) Execute(ctx context.Context, target *querypb.Target, que

// StreamExecute is part of the QueryService interface.
func (sbc *SandboxConn) StreamExecute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error {
sbc.panicIfNeeded()
sbc.sExecMu.Lock()
sbc.ExecCount.Add(1)
bv := make(map[string]*querypb.BindVariable)
Expand Down Expand Up @@ -278,6 +283,7 @@ func (sbc *SandboxConn) StreamExecute(ctx context.Context, target *querypb.Targe

// Begin is part of the QueryService interface.
func (sbc *SandboxConn) Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (queryservice.TransactionState, error) {
sbc.panicIfNeeded()
return sbc.begin(ctx, target, nil, 0, options)
}

Expand All @@ -303,6 +309,7 @@ func (sbc *SandboxConn) begin(ctx context.Context, target *querypb.Target, preQu

// Commit is part of the QueryService interface.
func (sbc *SandboxConn) Commit(ctx context.Context, target *querypb.Target, transactionID int64) (int64, error) {
sbc.panicIfNeeded()
sbc.CommitCount.Add(1)
reservedID := sbc.getTxReservedID(transactionID)
if reservedID != 0 {
Expand All @@ -323,6 +330,7 @@ func (sbc *SandboxConn) Rollback(ctx context.Context, target *querypb.Target, tr

// Prepare prepares the specified transaction.
func (sbc *SandboxConn) Prepare(ctx context.Context, target *querypb.Target, transactionID int64, dtid string) (err error) {
sbc.panicIfNeeded()
sbc.PrepareCount.Add(1)
if sbc.MustFailPrepare > 0 {
sbc.MustFailPrepare--
Expand All @@ -333,6 +341,7 @@ func (sbc *SandboxConn) Prepare(ctx context.Context, target *querypb.Target, tra

// CommitPrepared commits the prepared transaction.
func (sbc *SandboxConn) CommitPrepared(ctx context.Context, target *querypb.Target, dtid string) (err error) {
sbc.panicIfNeeded()
sbc.CommitPreparedCount.Add(1)
if sbc.MustFailCommitPrepared > 0 {
sbc.MustFailCommitPrepared--
Expand All @@ -343,6 +352,7 @@ func (sbc *SandboxConn) CommitPrepared(ctx context.Context, target *querypb.Targ

// RollbackPrepared rolls back the prepared transaction.
func (sbc *SandboxConn) RollbackPrepared(ctx context.Context, target *querypb.Target, dtid string, originalID int64) (err error) {
sbc.panicIfNeeded()
sbc.RollbackPreparedCount.Add(1)
if sbc.MustFailRollbackPrepared > 0 {
sbc.MustFailRollbackPrepared--
Expand All @@ -364,6 +374,7 @@ func (sbc *SandboxConn) CreateTransaction(ctx context.Context, target *querypb.T
// StartCommit atomically commits the transaction along with the
// decision to commit the associated 2pc transaction.
func (sbc *SandboxConn) StartCommit(ctx context.Context, target *querypb.Target, transactionID int64, dtid string) (err error) {
sbc.panicIfNeeded()
sbc.StartCommitCount.Add(1)
if sbc.MustFailStartCommit > 0 {
sbc.MustFailStartCommit--
Expand All @@ -375,6 +386,7 @@ func (sbc *SandboxConn) StartCommit(ctx context.Context, target *querypb.Target,
// SetRollback transitions the 2pc transaction to the Rollback state.
// If a transaction id is provided, that transaction is also rolled back.
func (sbc *SandboxConn) SetRollback(ctx context.Context, target *querypb.Target, dtid string, transactionID int64) (err error) {
sbc.panicIfNeeded()
sbc.SetRollbackCount.Add(1)
if sbc.MustFailSetRollback > 0 {
sbc.MustFailSetRollback--
Expand Down Expand Up @@ -410,6 +422,7 @@ func (sbc *SandboxConn) ReadTransaction(ctx context.Context, target *querypb.Tar

// BeginExecute is part of the QueryService interface.
func (sbc *SandboxConn) BeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (queryservice.TransactionState, *sqltypes.Result, error) {
sbc.panicIfNeeded()
state, err := sbc.begin(ctx, target, preQueries, reservedID, options)
if state.TransactionID != 0 {
sbc.setTxReservedID(state.TransactionID, reservedID)
Expand All @@ -423,6 +436,7 @@ func (sbc *SandboxConn) BeginExecute(ctx context.Context, target *querypb.Target

// BeginStreamExecute is part of the QueryService interface.
func (sbc *SandboxConn) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.TransactionState, error) {
sbc.panicIfNeeded()
state, err := sbc.begin(ctx, target, preQueries, reservedID, options)
if state.TransactionID != 0 {
sbc.setTxReservedID(state.TransactionID, reservedID)
Expand Down Expand Up @@ -567,6 +581,7 @@ func (sbc *SandboxConn) HandlePanic(err *error) {

// ReserveBeginExecute implements the QueryService interface
func (sbc *SandboxConn) ReserveBeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (queryservice.ReservedTransactionState, *sqltypes.Result, error) {
sbc.panicIfNeeded()
reservedID := sbc.reserve(ctx, target, preQueries, bindVariables, 0, options)
state, result, err := sbc.BeginExecute(ctx, target, postBeginQueries, sql, bindVariables, reservedID, options)
if state.TransactionID != 0 {
Expand All @@ -581,6 +596,7 @@ func (sbc *SandboxConn) ReserveBeginExecute(ctx context.Context, target *querypb

// ReserveBeginStreamExecute is part of the QueryService interface.
func (sbc *SandboxConn) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.ReservedTransactionState, error) {
sbc.panicIfNeeded()
reservedID := sbc.reserve(ctx, target, preQueries, bindVariables, 0, options)
state, err := sbc.BeginStreamExecute(ctx, target, postBeginQueries, sql, bindVariables, reservedID, options, callback)
if state.TransactionID != 0 {
Expand All @@ -595,6 +611,7 @@ func (sbc *SandboxConn) ReserveBeginStreamExecute(ctx context.Context, target *q

// ReserveExecute implements the QueryService interface
func (sbc *SandboxConn) ReserveExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (queryservice.ReservedState, *sqltypes.Result, error) {
sbc.panicIfNeeded()
reservedID := sbc.reserve(ctx, target, preQueries, bindVariables, transactionID, options)
result, err := sbc.Execute(ctx, target, sql, bindVariables, transactionID, reservedID, options)
if transactionID != 0 {
Expand All @@ -608,6 +625,7 @@ func (sbc *SandboxConn) ReserveExecute(ctx context.Context, target *querypb.Targ

// ReserveStreamExecute is part of the QueryService interface.
func (sbc *SandboxConn) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.ReservedState, error) {
sbc.panicIfNeeded()
reservedID := sbc.reserve(ctx, target, preQueries, bindVariables, transactionID, options)
err := sbc.StreamExecute(ctx, target, sql, bindVariables, transactionID, reservedID, options, callback)
if transactionID != 0 {
Expand Down Expand Up @@ -769,3 +787,13 @@ var StreamRowResult = &sqltypes.Result{
sqltypes.NewVarChar("foo"),
}},
}

func (sbc *SandboxConn) SetPanic(i interface{}) {
sbc.panicThis = i
}

func (sbc *SandboxConn) panicIfNeeded() {
if sbc.panicThis != nil {
panic(sbc.panicThis)
}
}

0 comments on commit 7aec15f

Please sign in to comment.