diff --git a/go/test/endtoend/utils/cmp.go b/go/test/endtoend/utils/cmp.go index 8d0b56ac6b3..678f4499f45 100644 --- a/go/test/endtoend/utils/cmp.go +++ b/go/test/endtoend/utils/cmp.go @@ -29,12 +29,17 @@ import ( "vitess.io/vitess/go/sqltypes" ) +type TestingT interface { + require.TestingT + Helper() +} + type MySQLCompare struct { - t *testing.T + t TestingT MySQLConn, VtConn *mysql.Conn } -func NewMySQLCompare(t *testing.T, vtParams, mysqlParams mysql.ConnParams) (MySQLCompare, error) { +func NewMySQLCompare(t TestingT, vtParams, mysqlParams mysql.ConnParams) (MySQLCompare, error) { ctx := context.Background() vtConn, err := mysql.Connect(ctx, &vtParams) if err != nil { @@ -53,6 +58,10 @@ func NewMySQLCompare(t *testing.T, vtParams, mysqlParams mysql.ConnParams) (MySQ }, nil } +func (mcmp *MySQLCompare) AsT() *testing.T { + return mcmp.t.(*testing.T) +} + func (mcmp *MySQLCompare) Close() { mcmp.VtConn.Close() mcmp.MySQLConn.Close() @@ -73,7 +82,7 @@ func (mcmp *MySQLCompare) AssertMatches(query, expected string) { // SkipIfBinaryIsBelowVersion should be used instead of using utils.SkipIfBinaryIsBelowVersion(t, // This is because we might be inside a Run block that has a different `t` variable func (mcmp *MySQLCompare) SkipIfBinaryIsBelowVersion(majorVersion int, binary string) { - SkipIfBinaryIsBelowVersion(mcmp.t, majorVersion, binary) + SkipIfBinaryIsBelowVersion(mcmp.t.(*testing.T), majorVersion, binary) } // AssertMatchesAny ensures the given query produces any one of the expected results. @@ -264,7 +273,7 @@ func (mcmp *MySQLCompare) ExecAndIgnore(query string) (*sqltypes.Result, error) } func (mcmp *MySQLCompare) Run(query string, f func(mcmp *MySQLCompare)) { - mcmp.t.Run(query, func(t *testing.T) { + mcmp.AsT().Run(query, func(t *testing.T) { inner := &MySQLCompare{ t: t, MySQLConn: mcmp.MySQLConn, diff --git a/go/test/endtoend/utils/mysql.go b/go/test/endtoend/utils/mysql.go index 41a70e2dfa4..790e1fc4ba1 100644 --- a/go/test/endtoend/utils/mysql.go +++ b/go/test/endtoend/utils/mysql.go @@ -22,7 +22,6 @@ import ( "fmt" "os" "path" - "testing" "time" "github.com/stretchr/testify/assert" @@ -169,18 +168,18 @@ func prepareMySQLWithSchema(params mysql.ConnParams, sql string) error { return nil } -func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumnNames bool) error { +func compareVitessAndMySQLResults(t TestingT, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumnNames bool) error { t.Helper() if vtQr == nil && mysqlQr == nil { return nil } if vtQr == nil { - t.Error("Vitess result is 'nil' while MySQL's is not.") + t.Errorf("Vitess result is 'nil' while MySQL's is not.") return errors.New("Vitess result is 'nil' while MySQL's is not.\n") } if mysqlQr == nil { - t.Error("MySQL result is 'nil' while Vitess' is not.") + t.Errorf("MySQL result is 'nil' while Vitess' is not.") return errors.New("MySQL result is 'nil' while Vitess' is not.\n") } @@ -209,7 +208,7 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn stmt, err := sqlparser.NewTestParser().Parse(query) if err != nil { - t.Error(err) + t.Errorf(err.Error()) return err } orderBy := false @@ -237,11 +236,11 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn errStr += fmt.Sprintf("query plan: \n%s\n", qr.Rows[0][0].ToString()) } } - t.Error(errStr) + t.Errorf(errStr) return errors.New(errStr) } -func checkFields(t *testing.T, columnName string, vtField, myField *querypb.Field) { +func checkFields(t TestingT, columnName string, vtField, myField *querypb.Field) { t.Helper() if vtField.Type != myField.Type { t.Errorf("for column %s field types do not match\nNot equal: \nMySQL: %v\nVitess: %v\n", columnName, myField.Type.String(), vtField.Type.String()) @@ -255,10 +254,9 @@ func checkFields(t *testing.T, columnName string, vtField, myField *querypb.Fiel } } -func compareVitessAndMySQLErrors(t *testing.T, vtErr, mysqlErr error) { +func compareVitessAndMySQLErrors(t TestingT, vtErr, mysqlErr error) { if vtErr != nil && mysqlErr != nil || vtErr == nil && mysqlErr == nil { return } - out := fmt.Sprintf("Vitess and MySQL are not erroring the same way.\nVitess error: %v\nMySQL error: %v", vtErr, mysqlErr) - t.Error(out) + t.Errorf("Vitess and MySQL are not erroring the same way.\nVitess error: %v\nMySQL error: %v", vtErr, mysqlErr) } diff --git a/go/test/endtoend/utils/utils.go b/go/test/endtoend/utils/utils.go index d94da27377e..6098fc63eb6 100644 --- a/go/test/endtoend/utils/utils.go +++ b/go/test/endtoend/utils/utils.go @@ -187,7 +187,7 @@ func ExecCompareMySQL(t *testing.T, vtConn, mysqlConn *mysql.Conn, query string) // ExecAllowError executes the given query without failing the test if it produces // an error. The error is returned to the client, along with the result set. -func ExecAllowError(t testing.TB, conn *mysql.Conn, query string) (*sqltypes.Result, error) { +func ExecAllowError(t TestingT, conn *mysql.Conn, query string) (*sqltypes.Result, error) { t.Helper() return conn.ExecuteFetch(query, 1000, true) }