Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
methane committed Dec 14, 2024
1 parent 962608a commit 59d0d57
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
19 changes: 13 additions & 6 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,8 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
t.Fatalf("connecting %q: %s", dsn, err)
}
defer db.Close()

cleanup := func() {
db.Exec("DROP TABLE IF EXISTS test")
if err = db.Ping(); err != nil {
t.Fatalf("connecting %q: %s", dsn, err)
}

dsn2 := dsn + "&interpolateParams=true"
Expand All @@ -173,23 +172,31 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
}
defer db3.Close()

cleanupSql := "DROP TABLE IF EXISTS test"

for _, test := range tests {
test := test
t.Run("default", func(t *testing.T) {
dbt := &DBTest{t, db}
t.Cleanup(cleanup)
t.Cleanup(func() {
db.Exec(cleanupSql)
})
test(dbt)
})
if db2 != nil {
t.Run("interpolateParams", func(t *testing.T) {
dbt2 := &DBTest{t, db2}
t.Cleanup(cleanup)
t.Cleanup(func() {
db2.Exec(cleanupSql)
})
test(dbt2)
})
}
t.Run("compress", func(t *testing.T) {
dbt3 := &DBTest{t, db3}
t.Cleanup(cleanup)
t.Cleanup(func() {
db3.Exec(cleanupSql)
})
test(dbt3)
})
}
Expand Down
26 changes: 12 additions & 14 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
// Read packet to buffer 'data'
func (mc *mysqlConn) readPacket() ([]byte, error) {
var prevData []byte
invalid := false
invalidSequence := false

readNext := mc.buf.readNext
if mc.compress {
Expand Down Expand Up @@ -67,8 +67,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
mc.close()
return nil, ErrPktSyncMul
}
// TODO(methane): report error when the packet is not an error packet.
invalid = true
invalidSequence = true
}
mc.sequence++
}
Expand Down Expand Up @@ -99,19 +98,18 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// return data if this was the last packet
if pktLen < maxPacketSize {
// zero allocations for non-split packets
if prevData == nil {
if invalid {
mc.close()
// return sync error only for regular packet.
// error packets may have wrong sequence number.
if data[0] != iERR {
return nil, ErrPktSync
}
if prevData != nil {
data = append(prevData, data...)
}
if invalidSequence {
mc.close()
// return sync error only for regular packet.
// error packets may have wrong sequence number.
if data[0] != iERR {
return nil, ErrPktSync
}
return data, nil
}

return append(prevData, data...), nil
return data, nil
}

prevData = append(prevData, data...)
Expand Down

0 comments on commit 59d0d57

Please sign in to comment.