diff --git a/session.go b/session.go index e93317e..29d7e20 100644 --- a/session.go +++ b/session.go @@ -8,6 +8,7 @@ import ( "time" "github.com/pkg/errors" + "container/heap" ) const ( @@ -21,8 +22,10 @@ const ( ) type writeRequest struct { - frame Frame - result chan writeResult + niceness uint8 + sequence uint64 // Used to keep the heap ordered by time + frame Frame + result chan writeResult } type writeResult struct { @@ -30,6 +33,31 @@ type writeResult struct { err error } +type writeHeap []writeRequest + +func (h writeHeap) Len() int { return len(h) } +func (h writeHeap) Less(i, j int) bool { + if h[i].niceness == h[j].niceness { + return h[i].sequence < h[j].sequence + } + return h[i].niceness < h[j].niceness +} +func (h writeHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *writeHeap) Push(x interface{}) { + // Push and Pop use pointer receivers because they modify the slice's length, + // not just its contents. + *h = append(*h, x.(writeRequest)) +} + +func (h *writeHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + // Session defines a multiplexed connection for streams type Session struct { conn io.ReadWriteCloser @@ -54,7 +82,10 @@ type Session struct { deadline atomic.Value - writes chan writeRequest + writeTicket chan struct{} + writesLock sync.Mutex + writes writeHeap + writeSequenceNum uint64 } func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { @@ -66,7 +97,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { s.chAccepts = make(chan *Stream, defaultAcceptBacklog) s.bucket = int32(config.MaxReceiveBuffer) s.bucketNotify = make(chan struct{}, 1) - s.writes = make(chan writeRequest) + s.writeTicket = make(chan struct{}) if client { s.nextStreamID = 1 @@ -79,8 +110,12 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { return s } -// OpenStream is used to create a new stream func (s *Session) OpenStream() (*Stream, error) { + return s.OpenStreamOpt(100) +} + +// OpenStream is used to create a new stream +func (s *Session) OpenStreamOpt(niceness uint8) (*Stream, error) { if s.IsClosed() { return nil, errors.New(errBrokenPipe) } @@ -101,9 +136,9 @@ func (s *Session) OpenStream() (*Stream, error) { } s.nextStreamIDLock.Unlock() - stream := newStream(sid, s.config.MaxFrameSize, s) + stream := newStream(sid, niceness, s.config.MaxFrameSize, s) - if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil { + if _, err := s.writeFrame(0, newFrame(cmdSYN, sid)); err != nil { return nil, errors.Wrap(err, "writeFrame") } @@ -113,9 +148,13 @@ func (s *Session) OpenStream() (*Stream, error) { return stream, nil } +func (s *Session) AcceptStream() (*Stream, error) { + return s.AcceptStreamOpt(100) +} + // AcceptStream is used to block until the next available stream // is ready to be accepted. -func (s *Session) AcceptStream() (*Stream, error) { +func (s *Session) AcceptStreamOpt(niceness uint8) (*Stream, error) { var deadline <-chan time.Time if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() { timer := time.NewTimer(time.Until(d)) @@ -124,6 +163,7 @@ func (s *Session) AcceptStream() (*Stream, error) { } select { case stream := <-s.chAccepts: + stream.niceness = niceness return stream, nil case <-deadline: return nil, errTimeout @@ -247,7 +287,7 @@ func (s *Session) recvLoop() { case cmdSYN: s.streamLock.Lock() if _, ok := s.streams[f.sid]; !ok { - stream := newStream(f.sid, s.config.MaxFrameSize, s) + stream := newStream(f.sid, 255, s.config.MaxFrameSize, s) s.streams[f.sid] = stream select { case s.chAccepts <- stream: @@ -289,7 +329,7 @@ func (s *Session) keepalive() { for { select { case <-tickerPing.C: - s.writeFrame(newFrame(cmdNOP, 0)) + s.writeFrame(0, newFrame(cmdNOP, 0)) s.notifyBucket() // force a signal to the recvLoop case <-tickerTimeout.C: if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { @@ -308,7 +348,11 @@ func (s *Session) sendLoop() { select { case <-s.die: return - case request := <-s.writes: + case <-s.writeTicket: + s.writesLock.Lock() + request := heap.Pop(&s.writes).(writeRequest) + s.writesLock.Unlock() + buf[0] = request.frame.ver buf[1] = request.frame.cmd binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data))) @@ -334,15 +378,21 @@ func (s *Session) sendLoop() { // writeFrame writes the frame to the underlying connection // and returns the number of bytes written if successful -func (s *Session) writeFrame(f Frame) (n int, err error) { +func (s *Session) writeFrame(niceness uint8, f Frame) (n int, err error) { req := writeRequest{ - frame: f, - result: make(chan writeResult, 1), + niceness: niceness, + sequence: atomic.AddUint64(&s.writeSequenceNum, 1), + frame: f, + result: make(chan writeResult, 1), } + + s.writesLock.Lock() + heap.Push(&s.writes, req) + s.writesLock.Unlock() select { case <-s.die: return 0, errors.New(errBrokenPipe) - case s.writes <- req: + case s.writeTicket <- struct{}{}: } result := <-req.result diff --git a/session_test.go b/session_test.go index 760642d..03e28db 100644 --- a/session_test.go +++ b/session_test.go @@ -11,6 +11,8 @@ import ( "sync" "testing" "time" + "container/heap" + "github.com/stretchr/testify/assert" ) // setupServer starts new server listening on a random localhost port and @@ -58,6 +60,19 @@ func handleConnection(conn net.Conn) { } } +func TestWriteHeap(t *testing.T) { + var reqs writeHeap + req1 := writeRequest{niceness: 1} + heap.Push(&reqs, req1) + req3 := writeRequest{niceness: 3} + heap.Push(&reqs, req3) + req2 := writeRequest{niceness: 2} + heap.Push(&reqs, req2) + assert.Equal(t, heap.Pop(&reqs), req1) + assert.Equal(t, heap.Pop(&reqs), req2) + assert.Equal(t, heap.Pop(&reqs), req3) +} + func TestEcho(t *testing.T) { _, stop, cli, err := setupServer(t) if err != nil { @@ -461,7 +476,7 @@ func TestRandomFrame(t *testing.T) { session, _ = Client(cli, nil) for i := 0; i < 100; i++ { f := newFrame(cmdSYN, 1000) - session.writeFrame(f) + session.writeFrame(0, f) } cli.Close() @@ -474,7 +489,7 @@ func TestRandomFrame(t *testing.T) { session, _ = Client(cli, nil) for i := 0; i < 100; i++ { f := newFrame(allcmds[rand.Int()%len(allcmds)], rand.Uint32()) - session.writeFrame(f) + session.writeFrame(0, f) } cli.Close() @@ -486,7 +501,7 @@ func TestRandomFrame(t *testing.T) { session, _ = Client(cli, nil) for i := 0; i < 100; i++ { f := newFrame(byte(rand.Uint32()), rand.Uint32()) - session.writeFrame(f) + session.writeFrame(0, f) } cli.Close() @@ -499,7 +514,7 @@ func TestRandomFrame(t *testing.T) { for i := 0; i < 100; i++ { f := newFrame(byte(rand.Uint32()), rand.Uint32()) f.ver = byte(rand.Uint32()) - session.writeFrame(f) + session.writeFrame(0, f) } cli.Close() diff --git a/stream.go b/stream.go index 57a0bc6..234291a 100644 --- a/stream.go +++ b/stream.go @@ -9,11 +9,13 @@ import ( "time" "github.com/pkg/errors" + "container/heap" ) // Stream implements net.Conn type Stream struct { id uint32 + niceness uint8 rstflag int32 sess *Session buffer bytes.Buffer @@ -27,8 +29,9 @@ type Stream struct { } // newStream initiates a Stream struct -func newStream(id uint32, frameSize int, sess *Session) *Stream { +func newStream(id uint32, niceness uint8, frameSize int, sess *Session) *Stream { s := new(Stream) + s.niceness = niceness s.id = id s.chReadEvent = make(chan struct{}, 1) s.frameSize = frameSize @@ -102,12 +105,18 @@ func (s *Stream) Write(b []byte) (n int, err error) { sent := 0 for k := range frames { req := writeRequest{ - frame: frames[k], - result: make(chan writeResult, 1), + niceness: s.niceness, + sequence: atomic.AddUint64(&s.sess.writeSequenceNum, 1), + frame: frames[k], + result: make(chan writeResult, 1), } + // TODO(jnewman): replace with session.writeFrame(..)? + s.sess.writesLock.Lock() + heap.Push(&s.sess.writes, req) + s.sess.writesLock.Unlock() select { - case s.sess.writes <- req: + case s.sess.writeTicket <- struct{}{}: case <-s.die: return sent, errors.New(errBrokenPipe) case <-deadline: @@ -141,7 +150,7 @@ func (s *Stream) Close() error { close(s.die) s.dieLock.Unlock() s.sess.streamClosed(s.id) - _, err := s.sess.writeFrame(newFrame(cmdFIN, s.id)) + _, err := s.sess.writeFrame(0, newFrame(cmdFIN, s.id)) return err } }