diff --git a/pkg/proxy/backend.go b/pkg/proxy/backend.go index 498f81424..2d6aa7148 100644 --- a/pkg/proxy/backend.go +++ b/pkg/proxy/backend.go @@ -301,6 +301,7 @@ func (bc *BackendConn) loopReader(tasks <-chan *Request, c *redis.Conn, round in } } } + r.respForRelease = resp bc.setResponse(r, resp, nil) } return nil diff --git a/pkg/proxy/redis/decoder.go b/pkg/proxy/redis/decoder.go index a2ca69586..52c38f0c5 100644 --- a/pkg/proxy/redis/decoder.go +++ b/pkg/proxy/redis/decoder.go @@ -92,17 +92,6 @@ func (d *Decoder) Decode() (*Resp, error) { return r, d.Err } -func (d *Decoder) DecodeMultiBulk() ([]*Resp, error) { - if d.Err != nil { - return nil, errors.Trace(ErrFailedDecoder) - } - m, err := d.decodeMultiBulk() - if err != nil { - d.Err = err - } - return m, err -} - func Decode(r io.Reader) (*Resp, error) { return NewDecoder(r).Decode() } @@ -111,16 +100,12 @@ func DecodeFromBytes(p []byte) (*Resp, error) { return NewDecoder(bytes.NewReader(p)).Decode() } -func DecodeMultiBulkFromBytes(p []byte) ([]*Resp, error) { - return NewDecoder(bytes.NewReader(p)).DecodeMultiBulk() -} - func (d *Decoder) decodeResp() (*Resp, error) { b, err := d.br.ReadByte() if err != nil { return nil, errors.Trace(err) } - r := &Resp{} + r := AcquireResp() r.Type = RespType(b) switch r.Type { default: @@ -130,7 +115,7 @@ func (d *Decoder) decodeResp() (*Resp, error) { case TypeBulkBytes: r.Value, err = d.decodeBulkBytes() case TypeArray: - r.Array, err = d.decodeArray() + err = d.decodeArray(r) } return r, err } @@ -182,84 +167,30 @@ func (d *Decoder) decodeBulkBytes() ([]byte, error) { return b[:n], nil } -func (d *Decoder) decodeArray() ([]*Resp, error) { +func (d *Decoder) decodeArray(r *Resp) error { n, err := d.decodeInt() if err != nil { - return nil, err + return err } switch { case n < -1: - return nil, errors.Trace(ErrBadArrayLen) + return errors.Trace(ErrBadArrayLen) case n > MaxArrayLen: - return nil, errors.Trace(ErrBadArrayLenTooLong) + return errors.Trace(ErrBadArrayLenTooLong) case n == -1: - return nil, nil - } - array := make([]*Resp, n) - for i := range array { - r, err := d.decodeResp() - if err != nil { - return nil, err - } - array[i] = r - } - return array, nil -} - -func (d *Decoder) decodeSingleLineMultiBulk() ([]*Resp, error) { - b, err := d.decodeTextBytes() - if err != nil { - return nil, err - } - if len(b) == 0 { - return nil, nil - } - multi := make([]*Resp, 0, 8) - for l, r := 0, 0; r <= len(b); r++ { - if r == len(b) || b[r] == ' ' { - if l < r { - multi = append(multi, NewBulkBytes(b[l:r])) - } - l = r + 1 - } - } - if len(multi) == 0 { - return nil, errors.Trace(ErrBadMultiBulkLen) - } - return multi, nil -} - -func (d *Decoder) decodeMultiBulk() ([]*Resp, error) { - b, err := d.br.PeekByte() - if err != nil { - return nil, errors.Trace(err) - } - if RespType(b) != TypeArray { - return d.decodeSingleLineMultiBulk() + return nil } - if _, err := d.br.ReadByte(); err != nil { - return nil, errors.Trace(err) - } - n, err := d.decodeInt() - if err != nil { - return nil, errors.Trace(err) - } - switch { - case n <= 0: - return nil, errors.Trace(ErrBadArrayLen) - case n > MaxArrayLen: - return nil, errors.Trace(ErrBadArrayLenTooLong) + if r.Array == nil { + r.Array = make([]*Resp, 0, n+2) + } else { + r.Array = r.Array[:0] } - multi := make([]*Resp, n) - for i := range multi { - r, err := d.decodeResp() + for i := int64(0); i < n; i++ { + sub, err := d.decodeResp() if err != nil { - return nil, err - } - if r.Type != TypeBulkBytes { - return nil, errors.Trace(ErrBadMultiBulkContent) + return err } - multi[i] = r + r.Array = append(r.Array, sub) } - return multi, nil + return nil } diff --git a/pkg/proxy/redis/decoder_test.go b/pkg/proxy/redis/decoder_test.go index 0ede12c34..3db8d191c 100644 --- a/pkg/proxy/redis/decoder_test.go +++ b/pkg/proxy/redis/decoder_test.go @@ -53,23 +53,6 @@ func TestDecodeSimpleRequest1(t *testing.T) { } func TestDecodeSimpleRequest2(t *testing.T) { - test := []string{ - "hello world\r\n", - "hello world \r\n", - " hello world \r\n", - " hello world\r\n", - " hello world \r\n", - } - for _, s := range test { - a, err := DecodeMultiBulkFromBytes([]byte(s)) - assert.MustNoError(err) - assert.Must(len(a) == 2) - assert.Must(bytes.Equal(a[0].Value, []byte("hello"))) - assert.Must(bytes.Equal(a[1].Value, []byte("world"))) - } -} - -func TestDecodeSimpleRequest3(t *testing.T) { test := []string{"\r", "\n", " \n"} for _, s := range test { _, err := DecodeFromBytes([]byte(s)) @@ -139,9 +122,9 @@ func newBenchmarkDecoder(n int) *Decoder { func benchmarkDecode(b *testing.B, n int) { d := newBenchmarkDecoder(n) for i := 0; i < b.N; i++ { - multi, err := d.DecodeMultiBulk() + resp, err := d.Decode() assert.MustNoError(err) - assert.Must(len(multi) == 1 && len(multi[0].Value) == n) + assert.Must(len(resp.Array) == 1 && len(resp.Array[0].Value) == n) } } diff --git a/pkg/proxy/redis/resp.go b/pkg/proxy/redis/resp.go index c14435bb1..8482e59ba 100644 --- a/pkg/proxy/redis/resp.go +++ b/pkg/proxy/redis/resp.go @@ -3,7 +3,10 @@ package redis -import "fmt" +import ( + "fmt" + "sync" +) type RespType byte @@ -39,6 +42,31 @@ type Resp struct { Array []*Resp } +var respPool = &sync.Pool{ + New: func() interface{} { + return &Resp{Array: make([]*Resp, 0, 8)} + }, +} + +func AcquireResp() *Resp { + return respPool.Get().(*Resp) +} + +func ReleaseResp(r *Resp) { + r.Type = 0 + r.Value = nil + if r.Array == nil { + r.Array = make([]*Resp, 0, 8) + } else { + for i := 0; i < len(r.Array); i++ { + ReleaseResp(r.Array[i]) + r.Array[i] = nil + } + r.Array = r.Array[:0] + } + respPool.Put(r) +} + func (r *Resp) IsString() bool { return r.Type == TypeString } diff --git a/pkg/proxy/request.go b/pkg/proxy/request.go index a2f1ced2f..1b11d17a0 100644 --- a/pkg/proxy/request.go +++ b/pkg/proxy/request.go @@ -12,6 +12,9 @@ import ( ) type Request struct { + reqForRelease *redis.Resp + respForRelease *redis.Resp + Multi []*redis.Resp Batch *sync.WaitGroup Group *sync.WaitGroup @@ -30,6 +33,33 @@ type Request struct { Coalesce func() error } +var requestPool = &sync.Pool{ + New: func() interface{} { + return &Request{ + Batch: &sync.WaitGroup{}, + } + }, +} + +func AcquireRequest() *Request { + return requestPool.Get().(*Request) +} + +func ReleaseRequest(r *Request) { + r.reqForRelease = nil + r.respForRelease = nil + r.Multi = nil + r.Group = nil + r.OpStr = "" + r.OpFlag = OpFlag(0) + r.Database = 0 + r.UnixNano = 0 + r.Resp = nil + r.Err = nil + r.Coalesce = nil + requestPool.Put(r) +} + func (r *Request) IsBroken() bool { return r.Broken != nil && r.Broken.IsTrue() } diff --git a/pkg/proxy/session.go b/pkg/proxy/session.go index 616e057b2..b84c2bd38 100644 --- a/pkg/proxy/session.go +++ b/pkg/proxy/session.go @@ -160,11 +160,12 @@ func (s *Session) loopReader(tasks *RequestChan, d *Router) (err error) { ) for !s.quit { - multi, err := s.Conn.DecodeMultiBulk() + req, err := s.Conn.Decode() if err != nil { return err } - if len(multi) == 0 { + if len(req.Array) == 0 { + redis.ReleaseResp(req) continue } s.incrOpTotal() @@ -177,9 +178,9 @@ func (s *Session) loopReader(tasks *RequestChan, d *Router) (err error) { s.LastOpUnix = start.Unix() s.Ops++ - r := &Request{} - r.Multi = multi - r.Batch = &sync.WaitGroup{} + r := AcquireRequest() + r.reqForRelease = req + r.Multi = req.Array r.Database = s.database r.UnixNano = start.UnixNano() @@ -215,6 +216,14 @@ func (s *Session) loopWriter(tasks *RequestChan) (err error) { p.MaxBuffered = maxPipelineLen / 2 return tasks.PopFrontAll(func(r *Request) error { + defer func() { + redis.ReleaseResp(r.reqForRelease) + if r.respForRelease != nil { + redis.ReleaseResp(r.respForRelease) + } + ReleaseRequest(r) + }() + resp, err := s.handleResponse(r) if err != nil { resp = redis.NewErrorf("ERR handle response, %s", err)