diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b1c1f2b3..2e07fea9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -83,7 +83,7 @@ jobs: my-cnf: | innodb_log_file_size=256MB innodb_buffer_pool_size=512MB - max_allowed_packet=16MB + max_allowed_packet=48MB ; TestConcurrent fails if max_connections is too large max_connections=50 local_infile=1 diff --git a/AUTHORS b/AUTHORS index 989f85d6..c5293e0f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -20,7 +20,9 @@ Andrew Reid Animesh Ray Arne Hormann Ariel Mashraki +Artur Melanchyk Asta Xie +B Lamarche Brian Hendriks Bulat Gaifullin Caine Jette @@ -62,6 +64,7 @@ Jennifer Purevsuren Jerome Meyer Jiajia Zhong Jian Zhen +Joe Mann Joshua Prunier Julien Lefevre Julien Schmidt @@ -92,6 +95,7 @@ Paul Bonser Paulius Lozys Peter Schultz Phil Porada +Minh Quang Rebecca Chin Reed Allman Richard Wilkes diff --git a/README.md b/README.md index e9d9222b..da4593cc 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support * Optional `time.Time` parsing * Optional placeholder interpolation + * Supports zlib compression. ## Requirements @@ -267,6 +268,16 @@ SELECT u.id FROM users as u will return `u.id` instead of just `id` if `columnsWithAlias=true`. +##### `compress` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +Toggles zlib compression. false by default. + ##### `interpolateParams` ``` diff --git a/benchmark_test.go b/benchmark_test.go index a4ecc0a6..5c9a046b 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -46,9 +46,13 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { return stmt } -func initDB(b *testing.B, queries ...string) *sql.DB { +func initDB(b *testing.B, useCompression bool, queries ...string) *sql.DB { tb := (*TB)(b) - db := tb.checkDB(sql.Open(driverNameTest, dsn)) + comprStr := "" + if useCompression { + comprStr = "&compress=1" + } + db := tb.checkDB(sql.Open(driverNameTest, dsn+comprStr)) for _, query := range queries { if _, err := db.Exec(query); err != nil { b.Fatalf("error on %q: %v", query, err) @@ -60,10 +64,18 @@ func initDB(b *testing.B, queries ...string) *sql.DB { const concurrencyLevel = 10 func BenchmarkQuery(b *testing.B) { + benchmarkQueryHelper(b, false) +} + +func BenchmarkQueryCompression(b *testing.B) { + benchmarkQueryHelper(b, true) +} + +func benchmarkQueryHelper(b *testing.B, compr bool) { tb := (*TB)(b) b.StopTimer() b.ReportAllocs() - db := initDB(b, + db := initDB(b, compr, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -222,7 +234,7 @@ func BenchmarkInterpolation(b *testing.B) { }, maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, - buf: newBuffer(nil), + buf: newBuffer(), } args := []driver.Value{ @@ -269,7 +281,7 @@ func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) { } func BenchmarkQueryContext(b *testing.B) { - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -305,7 +317,7 @@ func benchmarkExecContext(b *testing.B, db *sql.DB, p int) { } func BenchmarkExecContext(b *testing.B) { - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -323,7 +335,7 @@ func BenchmarkExecContext(b *testing.B) { // "size=" means size of each blobs. func BenchmarkQueryRawBytes(b *testing.B) { var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000} - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS bench_rawbytes", "CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)", ) @@ -376,7 +388,7 @@ func BenchmarkQueryRawBytes(b *testing.B) { // BenchmarkReceiveMassiveRows measures performance of receiving large number of rows. func BenchmarkReceiveMassiveRows(b *testing.B) { // Setup -- prepare 10000 rows. - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val TEXT)") defer db.Close() diff --git a/buffer.go b/buffer.go index 0774c5c8..a6532431 100644 --- a/buffer.go +++ b/buffer.go @@ -10,54 +10,42 @@ package mysql import ( "io" - "net" - "time" ) const defaultBufSize = 4096 const maxCachedBufSize = 256 * 1024 +// readerFunc is a function that compatible with io.Reader. +// We use this function type instead of io.Reader because we want to +// just pass mc.readWithTimeout. +type readerFunc func([]byte) (int, error) + // A buffer which is used for both reading and writing. // This is possible since communication on each connection is synchronous. // In other words, we can't write and read simultaneously on the same connection. // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. -// This buffer is backed by two byte slices in a double-buffering scheme type buffer struct { - buf []byte // buf is a byte buffer who's length and capacity are equal. - nc net.Conn - idx int - length int - timeout time.Duration - dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer - flipcnt uint // flipccnt is the current buffer counter for double-buffering + buf []byte // read buffer. + cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize. } // newBuffer allocates and returns a new buffer. -func newBuffer(nc net.Conn) buffer { - fg := make([]byte, defaultBufSize) +func newBuffer() buffer { return buffer{ - buf: fg, - nc: nc, - dbuf: [2][]byte{fg, nil}, + cachedBuf: make([]byte, defaultBufSize), } } -// flip replaces the active buffer with the background buffer -// this is a delayed flip that simply increases the buffer counter; -// the actual flip will be performed the next time we call `buffer.fill` -func (b *buffer) flip() { - b.flipcnt += 1 +// busy returns true if the read buffer is not empty. +func (b *buffer) busy() bool { + return len(b.buf) > 0 } -// fill reads into the buffer until at least _need_ bytes are in it -func (b *buffer) fill(need int) error { - n := b.length - // fill data into its double-buffering target: if we've called - // flip on this buffer, we'll be copying to the background buffer, - // and then filling it with network data; otherwise we'll just move - // the contents of the current buffer to the front before filling it - dest := b.dbuf[b.flipcnt&1] +// fill reads into the read buffer until at least _need_ bytes are in it. +func (b *buffer) fill(need int, r readerFunc) error { + // we'll move the contents of the current buffer to dest before filling it. + dest := b.cachedBuf // grow buffer if necessary to fit the whole packet. if need > len(dest) { @@ -67,64 +55,48 @@ func (b *buffer) fill(need int) error { // if the allocated buffer is not too large, move it to backing storage // to prevent extra allocations on applications that perform large reads if len(dest) <= maxCachedBufSize { - b.dbuf[b.flipcnt&1] = dest + b.cachedBuf = dest } } - // if we're filling the fg buffer, move the existing data to the start of it. - // if we're filling the bg buffer, copy over the data - if n > 0 { - copy(dest[:n], b.buf[b.idx:]) - } - - b.buf = dest - b.idx = 0 + // move the existing data to the start of the buffer. + n := len(b.buf) + copy(dest[:n], b.buf) for { - if b.timeout > 0 { - if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { - return err - } + nn, err := r(dest[n:]) + n += nn + + if err == nil && n < need { + continue } - nn, err := b.nc.Read(b.buf[n:]) - n += nn + b.buf = dest[:n] - switch err { - case nil: + if err == io.EOF { if n < need { - continue + err = io.ErrUnexpectedEOF + } else { + err = nil } - b.length = n - return nil - - case io.EOF: - if n >= need { - b.length = n - return nil - } - return io.ErrUnexpectedEOF - - default: - return err } + return err } } // returns next N bytes from buffer. // The returned slice is only guaranteed to be valid until the next read -func (b *buffer) readNext(need int) ([]byte, error) { - if b.length < need { +func (b *buffer) readNext(need int, r readerFunc) ([]byte, error) { + if len(b.buf) < need { // refill - if err := b.fill(need); err != nil { + if err := b.fill(need, r); err != nil { return nil, err } } - offset := b.idx - b.idx += need - b.length -= need - return b.buf[offset:b.idx], nil + data := b.buf[:need] + b.buf = b.buf[need:] + return data, nil } // takeBuffer returns a buffer with the requested size. @@ -132,18 +104,18 @@ func (b *buffer) readNext(need int) ([]byte, error) { // Otherwise a bigger buffer is made. // Only one buffer (total) can be used at a time. func (b *buffer) takeBuffer(length int) ([]byte, error) { - if b.length > 0 { + if b.busy() { return nil, ErrBusyBuffer } // test (cheap) general case first - if length <= cap(b.buf) { - return b.buf[:length], nil + if length <= len(b.cachedBuf) { + return b.cachedBuf[:length], nil } - if length < maxPacketSize { - b.buf = make([]byte, length) - return b.buf, nil + if length < maxCachedBufSize { + b.cachedBuf = make([]byte, length) + return b.cachedBuf, nil } // buffer is larger than we want to store. @@ -154,10 +126,10 @@ func (b *buffer) takeBuffer(length int) ([]byte, error) { // known to be smaller than defaultBufSize. // Only one buffer (total) can be used at a time. func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { - if b.length > 0 { + if b.busy() { return nil, ErrBusyBuffer } - return b.buf[:length], nil + return b.cachedBuf[:length], nil } // takeCompleteBuffer returns the complete existing buffer. @@ -165,18 +137,15 @@ func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { // cap and len of the returned buffer will be equal. // Only one buffer (total) can be used at a time. func (b *buffer) takeCompleteBuffer() ([]byte, error) { - if b.length > 0 { + if b.busy() { return nil, ErrBusyBuffer } - return b.buf, nil + return b.cachedBuf, nil } // store stores buf, an updated buffer, if its suitable to do so. -func (b *buffer) store(buf []byte) error { - if b.length > 0 { - return ErrBusyBuffer - } else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) { - b.buf = buf[:cap(buf)] +func (b *buffer) store(buf []byte) { + if cap(buf) <= maxCachedBufSize && cap(buf) > cap(b.cachedBuf) { + b.cachedBuf = buf[:cap(buf)] } - return nil } diff --git a/compress.go b/compress.go new file mode 100644 index 00000000..fa42772a --- /dev/null +++ b/compress.go @@ -0,0 +1,214 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "compress/zlib" + "fmt" + "io" + "sync" +) + +var ( + zrPool *sync.Pool // Do not use directly. Use zDecompress() instead. + zwPool *sync.Pool // Do not use directly. Use zCompress() instead. +) + +func init() { + zrPool = &sync.Pool{ + New: func() any { return nil }, + } + zwPool = &sync.Pool{ + New: func() any { + zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2) + if err != nil { + panic(err) // compress/zlib return non-nil error only if level is invalid + } + return zw + }, + } +} + +func zDecompress(src []byte, dst *bytes.Buffer) (int, error) { + br := bytes.NewReader(src) + var zr io.ReadCloser + var err error + + if a := zrPool.Get(); a == nil { + if zr, err = zlib.NewReader(br); err != nil { + return 0, err + } + } else { + zr = a.(io.ReadCloser) + if err := zr.(zlib.Resetter).Reset(br, nil); err != nil { + return 0, err + } + } + + n, _ := dst.ReadFrom(zr) // ignore err because zr.Close() will return it again. + err = zr.Close() // zr.Close() may return chuecksum error. + zrPool.Put(zr) + return int(n), err +} + +func zCompress(src []byte, dst io.Writer) error { + zw := zwPool.Get().(*zlib.Writer) + zw.Reset(dst) + if _, err := zw.Write(src); err != nil { + return err + } + err := zw.Close() + zwPool.Put(zw) + return err +} + +type compIO struct { + mc *mysqlConn + buff bytes.Buffer +} + +func newCompIO(mc *mysqlConn) *compIO { + return &compIO{ + mc: mc, + } +} + +func (c *compIO) reset() { + c.buff.Reset() +} + +func (c *compIO) readNext(need int, r readerFunc) ([]byte, error) { + for c.buff.Len() < need { + if err := c.readCompressedPacket(r); err != nil { + return nil, err + } + } + data := c.buff.Next(need) + return data[:need:need], nil // prevent caller writes into c.buff +} + +func (c *compIO) readCompressedPacket(r readerFunc) error { + header, err := c.mc.buf.readNext(7, r) // size of compressed header + if err != nil { + return err + } + _ = header[6] // bounds check hint to compiler; guaranteed by readNext + + // compressed header structure + comprLength := getUint24(header[0:3]) + compressionSequence := uint8(header[3]) + uncompressedLength := getUint24(header[4:7]) + if debug { + fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n", + comprLength, uncompressedLength, compressionSequence, c.mc.sequence) + } + // Do not return ErrPktSync here. + // Server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes) + // before receiving all packets from client. In this case, seqnr is younger than expected. + // NOTE: Both of mariadbclient and mysqlclient do not check seqnr. Only server checks it. + if debug && compressionSequence != c.mc.sequence { + fmt.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v", + c.mc.sequence, compressionSequence) + } + c.mc.sequence = compressionSequence + 1 + c.mc.compressSequence = c.mc.sequence + + comprData, err := c.mc.buf.readNext(comprLength, r) + if err != nil { + return err + } + + // if payload is uncompressed, its length will be specified as zero, and its + // true length is contained in comprLength + if uncompressedLength == 0 { + c.buff.Write(comprData) + return nil + } + + // use existing capacity in bytesBuf if possible + c.buff.Grow(uncompressedLength) + nread, err := zDecompress(comprData, &c.buff) + if err != nil { + return err + } + if nread != uncompressedLength { + return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d", + uncompressedLength, nread) + } + return nil +} + +const minCompressLength = 150 +const maxPayloadLen = maxPacketSize - 4 + +// writePackets sends one or some packets with compression. +// Use this instead of mc.netConn.Write() when mc.compress is true. +func (c *compIO) writePackets(packets []byte) (int, error) { + totalBytes := len(packets) + blankHeader := make([]byte, 7) + buf := &c.buff + + for len(packets) > 0 { + payloadLen := min(maxPayloadLen, len(packets)) + payload := packets[:payloadLen] + uncompressedLen := payloadLen + + buf.Reset() + buf.Write(blankHeader) // Buffer.Write() never returns error + + // If payload is less than minCompressLength, don't compress. + if uncompressedLen < minCompressLength { + buf.Write(payload) + uncompressedLen = 0 + } else { + err := zCompress(payload, buf) + if debug && err != nil { + fmt.Printf("zCompress error: %v", err) + } + // do not compress if compressed data is larger than uncompressed data + // I intentionally miss 7 byte header in the buf; zCompress must compress more than 7 bytes. + if err != nil || buf.Len() >= uncompressedLen { + buf.Reset() + buf.Write(blankHeader) + buf.Write(payload) + uncompressedLen = 0 + } + } + + if n, err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { + // To allow returning ErrBadConn when sending really 0 bytes, we sum + // up compressed bytes that is returned by underlying Write(). + return totalBytes - len(packets) + n, err + } + packets = packets[payloadLen:] + } + + return totalBytes, nil +} + +// writeCompressedPacket writes a compressed packet with header. +// data should start with 7 size space for header followed by payload. +func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, error) { + mc := c.mc + comprLength := len(data) - 7 + if debug { + fmt.Printf( + "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", + comprLength, uncompressedLen, mc.compressSequence) + } + + // compression header + putUint24(data[0:3], comprLength) + data[3] = mc.compressSequence + putUint24(data[4:7], uncompressedLen) + + mc.compressSequence++ + return mc.writeWithTimeout(data) +} diff --git a/compress_test.go b/compress_test.go new file mode 100644 index 00000000..030deaef --- /dev/null +++ b/compress_test.go @@ -0,0 +1,119 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "crypto/rand" + "io" + "testing" +) + +func makeRandByteSlice(size int) []byte { + randBytes := make([]byte, size) + rand.Read(randBytes) + return randBytes +} + +// compressHelper compresses uncompressedPacket and checks state variables +func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { + conn := new(mockConn) + mc.netConn = conn + + err := mc.writePacket(append(make([]byte, 4), uncompressedPacket...)) + if err != nil { + t.Fatal(err) + } + + return conn.written +} + +// uncompressHelper uncompresses compressedPacket and checks state variables +func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte) []byte { + // mocking out buf variable + conn := new(mockConn) + conn.data = compressedPacket + mc.netConn = conn + + uncompressedPacket, err := mc.readPacket() + if err != nil { + if err != io.EOF { + t.Fatalf("non-nil/non-EOF error when reading contents: %s", err.Error()) + } + } + return uncompressedPacket +} + +// roundtripHelper compresses then uncompresses uncompressedPacket and checks state variables +func roundtripHelper(t *testing.T, cSend *mysqlConn, cReceive *mysqlConn, uncompressedPacket []byte) []byte { + compressed := compressHelper(t, cSend, uncompressedPacket) + return uncompressHelper(t, cReceive, compressed) +} + +// TestRoundtrip tests two connections, where one is reading and the other is writing +func TestRoundtrip(t *testing.T) { + tests := []struct { + uncompressed []byte + desc string + }{ + {uncompressed: []byte("a"), + desc: "a"}, + {uncompressed: []byte("hello world"), + desc: "hello world"}, + {uncompressed: make([]byte, 100), + desc: "100 bytes"}, + {uncompressed: make([]byte, 32768), + desc: "32768 bytes"}, + {uncompressed: make([]byte, 330000), + desc: "33000 bytes"}, + {uncompressed: makeRandByteSlice(10), + desc: "10 rand bytes", + }, + {uncompressed: makeRandByteSlice(100), + desc: "100 rand bytes", + }, + {uncompressed: makeRandByteSlice(32768), + desc: "32768 rand bytes", + }, + {uncompressed: bytes.Repeat(makeRandByteSlice(100), 10000), + desc: "100 rand * 10000 repeat bytes", + }, + } + + _, cSend := newRWMockConn(0) + cSend.compress = true + cSend.compIO = newCompIO(cSend) + _, cReceive := newRWMockConn(0) + cReceive.compress = true + cReceive.compIO = newCompIO(cReceive) + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + cSend.resetSequence() + cReceive.resetSequence() + + uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) + if len(uncompressed) != len(test.uncompressed) { + t.Errorf("uncompressed size is unexpected. expected %d but got %d", + len(test.uncompressed), len(uncompressed)) + } + if !bytes.Equal(uncompressed, test.uncompressed) { + t.Errorf("roundtrip failed") + } + if cSend.sequence != cReceive.sequence { + t.Errorf("inconsistent sequence number: send=%v recv=%v", + cSend.sequence, cReceive.sequence) + } + if cSend.compressSequence != cReceive.compressSequence { + t.Errorf("inconsistent compress sequence number: send=%v recv=%v", + cSend.compressSequence, cReceive.compressSequence) + } + }) + } +} diff --git a/connection.go b/connection.go index ef6fc9e4..3e455a3f 100644 --- a/connection.go +++ b/connection.go @@ -28,15 +28,17 @@ type mysqlConn struct { netConn net.Conn rawConn net.Conn // underlying connection when netConn is TLS connection. result mysqlResult // managed by clearResult() and handleOkPacket(). + compIO *compIO cfg *Config connector *connector maxAllowedPacket int maxWriteSize int - writeTimeout time.Duration flags clientFlag status statusFlag sequence uint8 + compressSequence uint8 parseTime bool + compress bool // for context support (Go 1.8+) watching bool @@ -62,6 +64,43 @@ func (mc *mysqlConn) log(v ...any) { mc.cfg.Logger.Print(v...) } +func (mc *mysqlConn) readWithTimeout(b []byte) (int, error) { + to := mc.cfg.ReadTimeout + if to > 0 { + if err := mc.netConn.SetReadDeadline(time.Now().Add(to)); err != nil { + return 0, err + } + } + return mc.netConn.Read(b) +} + +func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) { + to := mc.cfg.WriteTimeout + if to > 0 { + if err := mc.netConn.SetWriteDeadline(time.Now().Add(to)); err != nil { + return 0, err + } + } + return mc.netConn.Write(b) +} + +func (mc *mysqlConn) resetSequence() { + mc.sequence = 0 + mc.compressSequence = 0 +} + +// syncSequence must be called when finished writing some packet and before start reading. +func (mc *mysqlConn) syncSequence() { + // Syncs compressionSequence to sequence. + // This is not documented but done in `net_flush()` in MySQL and MariaDB. + // https://github.com/mariadb-corporation/mariadb-connector-c/blob/8228164f850b12353da24df1b93a1e53cc5e85e9/libmariadb/ma_net.c#L170-L171 + // https://github.com/mysql/mysql-server/blob/824e2b4064053f7daf17d7f3f84b7a3ed92e5fb4/sql-common/net_serv.cc#L293 + if mc.compress { + mc.sequence = mc.compressSequence + mc.compIO.reset() + } +} + // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { var cmdSet strings.Builder @@ -121,10 +160,14 @@ func (mc *mysqlConn) Close() (err error) { if !mc.closed.Load() { err = mc.writeCommandPacket(comQuit) } + mc.close() + return +} +// close closes the network connection and clear results without sending COM_QUIT. +func (mc *mysqlConn) close() { mc.cleanup() mc.clearResult() - return } // Closes the network connection and unsets internal variables. Do not call this @@ -143,7 +186,7 @@ func (mc *mysqlConn) cleanup() { return } if err := conn.Close(); err != nil { - mc.log(err) + mc.log("closing connection:", err) } // This function can be called from multiple goroutines. // So we can not mc.clearResult() here. @@ -431,7 +474,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { return nil, err } -// finish is called when the query has canceled. +// cancel is called when the query has canceled. func (mc *mysqlConn) cancel(err error) { mc.canceled.Set(err) mc.cleanup() @@ -637,7 +680,7 @@ func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { // ResetSession implements driver.SessionResetter. // (From Go 1.10) func (mc *mysqlConn) ResetSession(ctx context.Context) error { - if mc.closed.Load() { + if mc.closed.Load() || mc.buf.busy() { return driver.ErrBadConn } @@ -671,7 +714,7 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { // IsValid implements driver.Validator interface // (From Go 1.15) func (mc *mysqlConn) IsValid() bool { - return !mc.closed.Load() + return !mc.closed.Load() && !mc.buf.busy() } var _ driver.SessionResetter = &mysqlConn{} diff --git a/connection_test.go b/connection_test.go index 6f8d2a6d..f7740898 100644 --- a/connection_test.go +++ b/connection_test.go @@ -19,7 +19,7 @@ import ( func TestInterpolateParams(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -39,7 +39,7 @@ func TestInterpolateParams(t *testing.T) { func TestInterpolateParamsJSONRawMessage(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -66,7 +66,7 @@ func TestInterpolateParamsJSONRawMessage(t *testing.T) { func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -83,7 +83,7 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { // https://github.com/go-sql-driver/mysql/pull/490 func TestInterpolateParamsPlaceholderInString(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -99,7 +99,7 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) { func TestInterpolateParamsUint64(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -161,7 +161,7 @@ func TestPingMarkBadConnection(t *testing.T) { nc := badConnection{err: errors.New("boom")} mc := &mysqlConn{ netConn: nc, - buf: newBuffer(nc), + buf: newBuffer(), maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), @@ -178,7 +178,7 @@ func TestPingErrInvalidConn(t *testing.T) { nc := badConnection{err: errors.New("failed to write"), n: 10} mc := &mysqlConn{ netConn: nc, - buf: newBuffer(nc), + buf: newBuffer(), maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), diff --git a/connector.go b/connector.go index 769b3adc..a4f3655e 100644 --- a/connector.go +++ b/connector.go @@ -127,11 +127,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } defer mc.finish() - mc.buf = newBuffer(mc.netConn) - - // Set I/O timeouts - mc.buf.timeout = mc.cfg.ReadTimeout - mc.writeTimeout = mc.cfg.WriteTimeout + mc.buf = newBuffer() // Reading Handshake Initialization Packet authData, plugin, err := mc.readHandshakePacket() @@ -170,6 +166,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } + if mc.cfg.compress && mc.flags&clientCompress == clientCompress { + mc.compress = true + mc.compIO = newCompIO(mc) + } if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket } else { diff --git a/const.go b/const.go index 0cee9b2e..4aadcd64 100644 --- a/const.go +++ b/const.go @@ -11,6 +11,8 @@ package mysql import "runtime" const ( + debug = false // for debugging. Set true only in development. + defaultAuthPlugin = "mysql_native_password" defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 minProtocolVersion = 10 diff --git a/driver_test.go b/driver_test.go index 24d73c34..58b3cb38 100644 --- a/driver_test.go +++ b/driver_test.go @@ -147,12 +147,11 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { db, err := sql.Open(driverNameTest, dsn) if err != nil { - t.Fatalf("error connecting: %s", err.Error()) + t.Fatalf("connecting %q: %s", dsn, err) } defer db.Close() - - cleanup := func() { - db.Exec("DROP TABLE IF EXISTS test") + if err = db.Ping(); err != nil { + t.Fatalf("connecting %q: %s", dsn, err) } dsn2 := dsn + "&interpolateParams=true" @@ -160,25 +159,46 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { db2, err = sql.Open(driverNameTest, dsn2) if err != nil { - t.Fatalf("error connecting: %s", err.Error()) + t.Fatalf("connecting %q: %s", dsn2, err) } defer db2.Close() } + dsn3 := dsn + "&compress=true" + var db3 *sql.DB + db3, err = sql.Open(driverNameTest, dsn3) + if err != nil { + t.Fatalf("connecting %q: %s", dsn3, err) + } + defer db3.Close() + + cleanupSql := "DROP TABLE IF EXISTS test" + for _, test := range tests { test := test t.Run("default", func(t *testing.T) { dbt := &DBTest{t, db} - t.Cleanup(cleanup) + t.Cleanup(func() { + db.Exec(cleanupSql) + }) test(dbt) }) if db2 != nil { t.Run("interpolateParams", func(t *testing.T) { dbt2 := &DBTest{t, db2} - t.Cleanup(cleanup) + t.Cleanup(func() { + db2.Exec(cleanupSql) + }) test(dbt2) }) } + t.Run("compress", func(t *testing.T) { + dbt3 := &DBTest{t, db3} + t.Cleanup(func() { + db3.Exec(cleanupSql) + }) + test(dbt3) + }) } } @@ -958,12 +978,16 @@ func TestDateTime(t *testing.T) { var err error rows, err = dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`) if err == nil { - rows.Scan(µsecsSupported) + if rows.Next() { + rows.Scan(µsecsSupported) + } rows.Close() } rows, err = dbt.db.Query(`SELECT cast("0000-00-00" as DATE) = "0000-00-00"`) if err == nil { - rows.Scan(&zeroDateSupported) + if rows.Next() { + rows.Scan(&zeroDateSupported) + } rows.Close() } for _, setups := range testcases { @@ -1265,8 +1289,7 @@ func TestLongData(t *testing.T) { var rows *sql.Rows // Long text data - const nonDataQueryLen = 28 // length query w/o value - inS := in[:maxAllowedPacketSize-nonDataQueryLen] + inS := in[:maxAllowedPacketSize-100] dbt.mustExec("INSERT INTO test VALUES('" + inS + "')") rows = dbt.mustQuery("SELECT value FROM test") defer rows.Close() diff --git a/dsn.go b/dsn.go index f391a8fc..9b560b73 100644 --- a/dsn.go +++ b/dsn.go @@ -73,7 +73,10 @@ type Config struct { ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections - // unexported fields. new options should be come here + // unexported fields. new options should be come here. + // boolean first. alphabetical order. + + compress bool // Enable zlib compression beforeConnect func(context.Context, *Config) error // Invoked before a connection is established pubKey *rsa.PublicKey // Server public key @@ -93,7 +96,6 @@ func NewConfig() *Config { AllowNativePasswords: true, CheckConnLiveness: true, } - return cfg } @@ -125,6 +127,14 @@ func BeforeConnect(fn func(context.Context, *Config) error) Option { } } +// EnableCompress sets the compression mode. +func EnableCompression(yes bool) Option { + return func(cfg *Config) error { + cfg.compress = yes + return nil + } +} + func (cfg *Config) Clone() *Config { cp := *cfg if cp.TLS != nil { @@ -297,6 +307,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true") } + if cfg.compress { + writeDSNParam(&buf, &hasParam, "compress", "true") + } + if cfg.InterpolateParams { writeDSNParam(&buf, &hasParam, "interpolateParams", "true") } @@ -525,7 +539,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Compression case "compress": - return errors.New("compression not implemented yet") + var isBool bool + cfg.compress, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } // Enable client side placeholder substitution case "interpolateParams": diff --git a/infile.go b/infile.go index cf892bea..453ae091 100644 --- a/infile.go +++ b/infile.go @@ -17,7 +17,7 @@ import ( ) var ( - fileRegister map[string]bool + fileRegister map[string]struct{} fileRegisterLock sync.RWMutex readerRegister map[string]func() io.Reader readerRegisterLock sync.RWMutex @@ -37,10 +37,10 @@ func RegisterLocalFile(filePath string) { fileRegisterLock.Lock() // lazy map init if fileRegister == nil { - fileRegister = make(map[string]bool) + fileRegister = make(map[string]struct{}) } - fileRegister[strings.Trim(filePath, `"`)] = true + fileRegister[strings.Trim(filePath, `"`)] = struct{}{} fileRegisterLock.Unlock() } @@ -123,9 +123,9 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) { } else { // File name = strings.Trim(name, `"`) fileRegisterLock.RLock() - fr := fileRegister[name] + _, exists := fileRegister[name] fileRegisterLock.RUnlock() - if mc.cfg.AllowAllFiles || fr { + if mc.cfg.AllowAllFiles || exists { var file *os.File var fi os.FileInfo @@ -172,6 +172,7 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) { if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil { return ioErr } + mc.conn().syncSequence() // read OK packet if err == nil { diff --git a/packets.go b/packets.go index eb4e0cef..f3860c5f 100644 --- a/packets.go +++ b/packets.go @@ -28,30 +28,49 @@ import ( // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte + invalidSequence := false + + readNext := mc.buf.readNext + if mc.compress { + readNext = mc.compIO.readNext + } + for { // read packet header - data, err := mc.buf.readNext(4) + data, err := readNext(4, mc.readWithTimeout) if err != nil { + mc.close() if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } mc.log(err) - mc.Close() return nil, ErrInvalidConn } // packet length [24 bit] - pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) - - // check packet sync [8 bit] - if data[3] != mc.sequence { - mc.Close() - if data[3] > mc.sequence { - return nil, ErrPktSyncMul + pktLen := getUint24(data[:3]) + seq := data[3] + + if mc.compress { + // MySQL and MariaDB doesn't check packet nr in compressed packet. + if debug && seq != mc.compressSequence { + fmt.Printf("[debug] mismatched compression sequence nr: expected: %v, got %v", + mc.compressSequence, seq) } - return nil, ErrPktSync + mc.compressSequence = seq + 1 + } else { + // check packet sync [8 bit] + if seq != mc.sequence { + mc.log(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seq)) + // For large packets, we stop reading as soon as sync error. + if len(prevData) > 0 { + mc.close() + return nil, ErrPktSyncMul + } + invalidSequence = true + } + mc.sequence++ } - mc.sequence++ // packets with length 0 terminate a previous packet which is a // multiple of (2^24)-1 bytes long @@ -59,32 +78,38 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // there was no previous packet if prevData == nil { mc.log(ErrMalformPkt) - mc.Close() + mc.close() return nil, ErrInvalidConn } - return prevData, nil } // read packet body [pktLen bytes] - data, err = mc.buf.readNext(pktLen) + data, err = readNext(pktLen, mc.readWithTimeout) if err != nil { + mc.close() if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } mc.log(err) - mc.Close() return nil, ErrInvalidConn } // return data if this was the last packet if pktLen < maxPacketSize { // zero allocations for non-split packets - if prevData == nil { - return data, nil + if prevData != nil { + data = append(prevData, data...) } - - return append(prevData, data...), nil + if invalidSequence { + mc.close() + // return sync error only for regular packet. + // error packets may have wrong sequence number. + if data[0] != iERR { + return nil, ErrPktSync + } + } + return data, nil } prevData = append(prevData, data...) @@ -94,41 +119,31 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // Write packet buffer 'data' func (mc *mysqlConn) writePacket(data []byte) error { pktLen := len(data) - 4 - if pktLen > mc.maxAllowedPacket { return ErrPktTooLarge } + writeFunc := mc.writeWithTimeout + if mc.compress { + writeFunc = mc.compIO.writePackets + } + for { - var size int - if pktLen >= maxPacketSize { - data[0] = 0xff - data[1] = 0xff - data[2] = 0xff - size = maxPacketSize - } else { - data[0] = byte(pktLen) - data[1] = byte(pktLen >> 8) - data[2] = byte(pktLen >> 16) - size = pktLen - } + size := min(maxPacketSize, pktLen) + putUint24(data[:3], size) data[3] = mc.sequence // Write packet - if mc.writeTimeout > 0 { - if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { - mc.cleanup() - mc.log(err) - return err - } + if debug { + fmt.Printf("writePacket: size=%v seq=%v", size, mc.sequence) } - n, err := mc.netConn.Write(data[:4+size]) + n, err := writeFunc(data[:4+size]) if err != nil { + mc.cleanup() if cerr := mc.canceled.Value(); cerr != nil { return cerr } - mc.cleanup() if n == 0 && pktLen == len(data)-4 { // only for the first loop iteration when nothing was written yet mc.log(err) @@ -162,11 +177,6 @@ func (mc *mysqlConn) writePacket(data []byte) error { func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { data, err = mc.readPacket() if err != nil { - // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since - // in connection initialization we don't risk retrying non-idempotent actions. - if err == ErrInvalidConn { - return nil, "", driver.ErrBadConn - } return } @@ -210,10 +220,13 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro if len(data) > pos { // character set [1 byte] // status flags [2 bytes] + pos += 3 // capability flags (upper 2 bytes) [2 bytes] + mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 + pos += 2 // length of auth-plugin-data [1 byte] // reserved (all [00]) [10 bytes] - pos += 1 + 2 + 2 + 1 + 10 + pos += 11 // second part of the password cipher [minimum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -261,13 +274,17 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientLocalFiles | clientPluginAuth | clientMultiResults | - clientConnectAttrs | + mc.flags&clientConnectAttrs | mc.flags&clientLongFlag + sendConnectAttrs := mc.flags&clientConnectAttrs != 0 + if mc.cfg.ClientFoundRows { clientFlags |= clientFoundRows } - + if mc.cfg.compress && mc.flags&clientCompress == clientCompress { + clientFlags |= clientCompress + } // To enable TLS / SSL if mc.cfg.TLS != nil { clientFlags |= clientSSL @@ -296,30 +313,26 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } // encode length of the connection attributes - var connAttrsLEIBuf [9]byte - connAttrsLen := len(mc.connector.encodedAttributes) - connAttrsLEI := appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen)) - pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes) + var connAttrsLEI []byte + if sendConnectAttrs { + var connAttrsLEIBuf [9]byte + connAttrsLen := len(mc.connector.encodedAttributes) + connAttrsLEI = appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen)) + pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes) + } // Calculate packet length and get buffer with that size data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // ClientFlags [32 bit] - data[4] = byte(clientFlags) - data[5] = byte(clientFlags >> 8) - data[6] = byte(clientFlags >> 16) - data[7] = byte(clientFlags >> 24) + binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags)) // MaxPacketSize [32 bit] (none) - data[8] = 0x00 - data[9] = 0x00 - data[10] = 0x00 - data[11] = 0x00 + binary.LittleEndian.PutUint32(data[8:], 0) // Collation ID [1 byte] data[12] = defaultCollationID @@ -356,7 +369,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string return err } mc.netConn = tlsConn - mc.buf.nc = tlsConn } // User [null terminated string] @@ -382,8 +394,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pos++ // Connection Attributes - pos += copy(data[pos:], connAttrsLEI) - pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) + if sendConnectAttrs { + pos += copy(data[pos:], connAttrsLEI) + pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) + } // Send Auth packet return mc.writePacket(data[:pos]) @@ -394,9 +408,8 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { pktLen := 4 + len(authData) data, err := mc.buf.takeBuffer(pktLen) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // Add the auth data [EOF] @@ -410,13 +423,11 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence - mc.sequence = 0 + mc.resetSequence() data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // Add command byte @@ -428,14 +439,12 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Reset Packet Sequence - mc.sequence = 0 + mc.resetSequence() pktLen := 1 + len(arg) data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // Add command byte @@ -445,28 +454,25 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { copy(data[5:], arg) // Send CMD packet - return mc.writePacket(data) + err = mc.writePacket(data) + mc.syncSequence() + return err } func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence - mc.sequence = 0 + mc.resetSequence() data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // Add command byte data[4] = command // Add arg [32 bit] - data[5] = byte(arg) - data[6] = byte(arg >> 8) - data[7] = byte(arg >> 16) - data[8] = byte(arg >> 24) + binary.LittleEndian.PutUint32(data[5:], arg) // Send CMD packet return mc.writePacket(data) @@ -935,19 +941,15 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { pktLen = dataOffset + argLen } - stmt.mc.sequence = 0 + stmt.mc.resetSequence() // Add command byte [1 byte] data[4] = comStmtSendLongData // Add stmtID [32 bit] - data[5] = byte(stmt.id) - data[6] = byte(stmt.id >> 8) - data[7] = byte(stmt.id >> 16) - data[8] = byte(stmt.id >> 24) + binary.LittleEndian.PutUint32(data[5:], stmt.id) // Add paramID [16 bit] - data[9] = byte(paramID) - data[10] = byte(paramID >> 8) + binary.LittleEndian.PutUint16(data[9:], uint16(paramID)) // Send CMD packet err := stmt.mc.writePacket(data[:4+pktLen]) @@ -956,11 +958,10 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { continue } return err - } // Reset Packet Sequence - stmt.mc.sequence = 0 + stmt.mc.resetSequence() return nil } @@ -985,7 +986,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } // Reset packet-sequence - mc.sequence = 0 + mc.resetSequence() var data []byte var err error @@ -997,28 +998,20 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In this case the len(data) == cap(data) which is used to optimise the flow below. } if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // command [1 byte] data[4] = comStmtExecute // statement_id [4 bytes] - data[5] = byte(stmt.id) - data[6] = byte(stmt.id >> 8) - data[7] = byte(stmt.id >> 16) - data[8] = byte(stmt.id >> 24) + binary.LittleEndian.PutUint32(data[5:], stmt.id) // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] data[9] = 0x00 // iteration_count (uint32(1)) [4 bytes] - data[10] = 0x01 - data[11] = 0x00 - data[12] = 0x00 - data[13] = 0x00 + binary.LittleEndian.PutUint32(data[10:], 1) if len(args) > 0 { pos := minPktLen @@ -1072,50 +1065,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { case int64: paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i+1] = 0x00 - - if cap(paramValues)-len(paramValues)-8 >= 0 { - paramValues = paramValues[:len(paramValues)+8] - binary.LittleEndian.PutUint64( - paramValues[len(paramValues)-8:], - uint64(v), - ) - } else { - paramValues = append(paramValues, - uint64ToBytes(uint64(v))..., - ) - } + paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v)) case uint64: paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i+1] = 0x80 // type is unsigned - - if cap(paramValues)-len(paramValues)-8 >= 0 { - paramValues = paramValues[:len(paramValues)+8] - binary.LittleEndian.PutUint64( - paramValues[len(paramValues)-8:], - uint64(v), - ) - } else { - paramValues = append(paramValues, - uint64ToBytes(uint64(v))..., - ) - } + paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v)) case float64: paramTypes[i+i] = byte(fieldTypeDouble) paramTypes[i+i+1] = 0x00 - - if cap(paramValues)-len(paramValues)-8 >= 0 { - paramValues = paramValues[:len(paramValues)+8] - binary.LittleEndian.PutUint64( - paramValues[len(paramValues)-8:], - math.Float64bits(v), - ) - } else { - paramValues = append(paramValues, - uint64ToBytes(math.Float64bits(v))..., - ) - } + paramValues = binary.LittleEndian.AppendUint64(paramValues, math.Float64bits(v)) case bool: paramTypes[i+i] = byte(fieldTypeTiny) @@ -1196,17 +1156,16 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) - if err = mc.buf.store(data); err != nil { - mc.log(err) - return errBadConnNoWrite - } + mc.buf.store(data) // allow this buffer to be reused } pos += len(paramValues) data = data[:pos] } - return mc.writePacket(data) + err = mc.writePacket(data) + mc.syncSequence() + return err } // For each remaining resultset in the stream, discards its rows and updates diff --git a/packets_test.go b/packets_test.go index fa4683ea..694b0564 100644 --- a/packets_test.go +++ b/packets_test.go @@ -98,7 +98,7 @@ func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { conn := new(mockConn) connector := newConnector(NewConfig()) mc := &mysqlConn{ - buf: newBuffer(conn), + buf: newBuffer(), cfg: connector.cfg, connector: connector, netConn: conn, @@ -112,7 +112,9 @@ func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { func TestReadPacketSingleByte(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), + cfg: NewConfig(), } conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} @@ -143,12 +145,12 @@ func TestReadPacketWrongSequenceID(t *testing.T) { { ClientSequenceID: 0, ServerSequenceID: 0x42, - ExpectedErr: ErrPktSyncMul, + ExpectedErr: ErrPktSync, }, } { conn, mc := newRWMockConn(testCase.ClientSequenceID) - conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0xff} + conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0x22} _, err := mc.readPacket() if err != testCase.ExpectedErr { t.Errorf("expected %v, got %v", testCase.ExpectedErr, err) @@ -164,7 +166,9 @@ func TestReadPacketWrongSequenceID(t *testing.T) { func TestReadPacketSplit(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), + cfg: NewConfig(), } data := make([]byte, maxPacketSize*2+4*3) @@ -269,7 +273,8 @@ func TestReadPacketSplit(t *testing.T) { func TestReadPacketFail(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), closech: make(chan struct{}), cfg: NewConfig(), } @@ -285,7 +290,7 @@ func TestReadPacketFail(t *testing.T) { // reset conn.reads = 0 mc.sequence = 0 - mc.buf = newBuffer(conn) + mc.buf = newBuffer() // fail to read header conn.closed = true @@ -298,7 +303,7 @@ func TestReadPacketFail(t *testing.T) { conn.closed = false conn.reads = 0 mc.sequence = 0 - mc.buf = newBuffer(conn) + mc.buf = newBuffer() // fail to read body conn.maxReads = 1 @@ -313,7 +318,8 @@ func TestReadPacketFail(t *testing.T) { func TestRegression801(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), cfg: new(Config), sequence: 42, closech: make(chan struct{}), diff --git a/rows.go b/rows.go index 81fa6062..df98417b 100644 --- a/rows.go +++ b/rows.go @@ -111,13 +111,6 @@ func (rows *mysqlRows) Close() (err error) { return err } - // flip the buffer for this connection if we need to drain it. - // note that for a successful query (i.e. one where rows.next() - // has been called until it returns false), `rows.mc` will be nil - // by the time the user calls `(*Rows).Close`, so we won't reach this - // see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47 - mc.buf.flip() - // Remove unread packets from stream if !rows.rs.done { err = mc.readUntilEOF() diff --git a/statement.go b/statement.go index 35b02bbe..35df8545 100644 --- a/statement.go +++ b/statement.go @@ -24,11 +24,12 @@ type mysqlStmt struct { func (stmt *mysqlStmt) Close() error { if stmt.mc == nil || stmt.mc.closed.Load() { - // driver.Stmt.Close can be called more than once, thus this function - // has to be idempotent. - // See also Issue #450 and golang/go#16019. - //errLog.Print(ErrInvalidConn) - return driver.ErrBadConn + // driver.Stmt.Close could be called more than once, thus this function + // had to be idempotent. See also Issue #450 and golang/go#16019. + // This bug has been fixed in Go 1.8. + // https://github.com/golang/go/commit/90b8a0ca2d0b565c7c7199ffcf77b15ea6b6db3a + // But we keep this function idempotent because it is safer. + return nil } err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) diff --git a/utils.go b/utils.go index cda24fe7..44f43ef7 100644 --- a/utils.go +++ b/utils.go @@ -490,17 +490,16 @@ func formatBinaryTime(src []byte, length uint8) (driver.Value, error) { * Convert from and to bytes * ******************************************************************************/ -func uint64ToBytes(n uint64) []byte { - return []byte{ - byte(n), - byte(n >> 8), - byte(n >> 16), - byte(n >> 24), - byte(n >> 32), - byte(n >> 40), - byte(n >> 48), - byte(n >> 56), - } +// 24bit integer: used for packet headers. + +func putUint24(data []byte, n int) { + data[2] = byte(n >> 16) + data[1] = byte(n >> 8) + data[0] = byte(n) +} + +func getUint24(data []byte) int { + return int(data[2])<<16 | int(data[1])<<8 | int(data[0]) } func uint64ToString(n uint64) []byte { @@ -586,18 +585,15 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { // 252: value of following 2 case 0xfc: - return uint64(b[1]) | uint64(b[2])<<8, false, 3 + return uint64(binary.LittleEndian.Uint16(b[1:])), false, 3 // 253: value of following 3 case 0xfd: - return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 + return uint64(getUint24(b[1:])), false, 4 // 254: value of following 8 case 0xfe: - return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | - uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | - uint64(b[7])<<48 | uint64(b[8])<<56, - false, 9 + return uint64(binary.LittleEndian.Uint64(b[1:])), false, 9 } // 0-250: value of first byte @@ -611,13 +607,14 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte { return append(b, byte(n)) case n <= 0xffff: - return append(b, 0xfc, byte(n), byte(n>>8)) + b = append(b, 0xfc) + return binary.LittleEndian.AppendUint16(b, uint16(n)) case n <= 0xffffff: return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) } - return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), - byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) + b = append(b, 0xfe) + return binary.LittleEndian.AppendUint64(b, n) } func appendLengthEncodedString(b []byte, s string) []byte {