Skip to content

Commit

Permalink
Merge pull request #21 from libp2p/fix/keepalive-race
Browse files Browse the repository at this point in the history
fix: synchronize when resetting the keepalive timer
  • Loading branch information
Stebalien authored Mar 16, 2020
2 parents 51522d4 + 345f639 commit 97856b4
Show file tree
Hide file tree
Showing 5 changed files with 414 additions and 258 deletions.
4 changes: 3 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ go:

env:
global:
- GOTFLAGS="-race"
- BUILD_DEPTYPE=gomod
matrix:
- GOTFLAGS="-race"
- GOTFLAGS="-count 5"


# disable travis install
Expand Down
13 changes: 8 additions & 5 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,25 @@ func BenchmarkAccept(b *testing.B) {
func BenchmarkSendRecv(b *testing.B) {
client, server := testClientServer()
defer client.Close()
defer server.Close()

sendBuf := make([]byte, 512)
recvBuf := make([]byte, 512)

doneCh := make(chan struct{})
go func() {
defer close(doneCh)
defer server.Close()
stream, err := server.AcceptStream()
if err != nil {
return
}
defer stream.Close()
for i := 0; i < b.N; i++ {
if _, err := io.ReadFull(stream, recvBuf); err != nil {
b.Fatalf("err: %v", err)
b.Errorf("err: %v", err)
return
}
}
close(doneCh)
}()

stream, err := client.Open()
Expand Down Expand Up @@ -95,6 +96,8 @@ func BenchmarkSendRecvLarge(b *testing.B) {
recvDone := make(chan struct{})

go func() {
defer close(recvDone)
defer server.Close()
stream, err := server.AcceptStream()
if err != nil {
return
Expand All @@ -103,11 +106,11 @@ func BenchmarkSendRecvLarge(b *testing.B) {
for i := 0; i < b.N; i++ {
for j := 0; j < sendSize/recvSize; j++ {
if _, err := io.ReadFull(stream, recvBuf); err != nil {
b.Fatalf("err: %v", err)
b.Errorf("err: %v", err)
return
}
}
}
close(recvDone)
}()

stream, err := client.Open()
Expand Down
53 changes: 34 additions & 19 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,11 @@ type Session struct {

// keepaliveTimer is a periodic timer for keepalive messages. It's nil
// when keepalives are disabled.
keepaliveLock sync.Mutex
keepaliveTimer *time.Timer
keepaliveLock sync.Mutex
keepaliveTimer *time.Timer
keepaliveActive bool
}

const (
stageInitial uint32 = iota
stageFinal
)

// newSession is used to construct a new session
func newSession(config *Config, conn net.Conn, client bool, readBuf int) *Session {
var reader io.Reader = conn
Expand Down Expand Up @@ -327,23 +323,27 @@ func (s *Session) startKeepalive() {
defer s.keepaliveLock.Unlock()
s.keepaliveTimer = time.AfterFunc(s.config.KeepAliveInterval, func() {
s.keepaliveLock.Lock()

if s.keepaliveTimer == nil {
if s.keepaliveTimer == nil || s.keepaliveActive {
// keepalives have been stopped or a keepalive is active.
s.keepaliveLock.Unlock()
// keepalives have been stopped.
return
}
s.keepaliveActive = true
s.keepaliveLock.Unlock()

_, err := s.Ping()

s.keepaliveLock.Lock()
s.keepaliveActive = false
if s.keepaliveTimer != nil {
s.keepaliveTimer.Reset(s.config.KeepAliveInterval)
}
s.keepaliveLock.Unlock()

if err != nil {
// Make sure to unlock before exiting so we don't
// deadlock trying to shutdown keepalives.
s.keepaliveLock.Unlock()
s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
s.exitErr(ErrKeepAliveTimeout)
return
}
s.keepaliveTimer.Reset(s.config.KeepAliveInterval)
s.keepaliveLock.Unlock()
})
}

Expand All @@ -353,7 +353,24 @@ func (s *Session) stopKeepalive() {
defer s.keepaliveLock.Unlock()
if s.keepaliveTimer != nil {
s.keepaliveTimer.Stop()
s.keepaliveTimer = nil
}
}

func (s *Session) extendKeepalive() {
s.keepaliveLock.Lock()
if s.keepaliveTimer != nil && !s.keepaliveActive {
// Don't stop the timer and drain the channel. This is an
// AfterFunc, not a normal timer, and any attempts to drain the
// channel will block forever.
//
// Go will stop the timer for us internally anyways. The docs
// say one must stop the timer before calling reset but that's
// to ensure that the timer doesn't end up firing immediately
// after calling Reset.
s.keepaliveTimer.Reset(s.config.KeepAliveInterval)
}
s.keepaliveLock.Unlock()
}

// send sends the header and body.
Expand Down Expand Up @@ -512,9 +529,7 @@ func (s *Session) recvLoop() error {
// There's no reason to keepalive if we're active. Worse, if the
// peer is busy sending us stuff, the pong might get stuck
// behind a bunch of data.
if s.keepaliveTimer != nil {
s.keepaliveTimer.Reset(s.config.KeepAliveInterval)
}
s.extendKeepalive()

// Verify the version
if hdr.Version() != protoVersion {
Expand Down
163 changes: 163 additions & 0 deletions session_norace_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
//+build !race

package yamux

import (
"bytes"
"io"
"io/ioutil"
"sync"
"testing"
"time"
)

func TestSession_PingOfDeath(t *testing.T) {
client, server := testClientServerConfig(testConfNoKeepAlive())
defer client.Close()
defer server.Close()

count := 10000

var wg sync.WaitGroup
begin := make(chan struct{})
for i := 0; i < count; i++ {
wg.Add(2)
go func() {
defer wg.Done()
<-begin
if _, err := server.Ping(); err != nil {
t.Error(err)
}
}()
go func() {
defer wg.Done()
<-begin
if _, err := client.Ping(); err != nil {
t.Error(err)
}
}()
}
close(begin)
wg.Wait()
}

func TestSendData_VeryLarge(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()

var n int64 = 1 * 1024 * 1024 * 1024
var workers int = 16

wg := &sync.WaitGroup{}
wg.Add(workers * 2)

for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
stream, err := server.AcceptStream()
if err != nil {
t.Errorf("err: %v", err)
return
}
defer stream.Close()

buf := make([]byte, 4)
_, err = io.ReadFull(stream, buf)
if err != nil {
t.Errorf("err: %v", err)
return
}
if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
t.Errorf("bad header")
return
}

recv, err := io.Copy(ioutil.Discard, stream)
if err != nil {
t.Errorf("err: %v", err)
return
}
if recv != n {
t.Errorf("bad: %v", recv)
return
}
}()
}
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
stream, err := client.Open()
if err != nil {
t.Errorf("err: %v", err)
return
}
defer stream.Close()

_, err = stream.Write([]byte{0, 1, 2, 3})
if err != nil {
t.Errorf("err: %v", err)
return
}

unlimited := &UnlimitedReader{}
sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
if err != nil {
t.Errorf("err: %v", err)
return
}
if sent != n {
t.Errorf("bad: %v", sent)
return
}
}()
}

doneCh := make(chan struct{})
go func() {
wg.Wait()
close(doneCh)
}()
select {
case <-doneCh:
case <-time.After(20 * time.Second):
server.Close()
client.Close()
wg.Wait()
t.Fatal("timeout")
}
}

func TestLargeWindow(t *testing.T) {
conf := DefaultConfig()
conf.MaxStreamWindowSize *= 2

client, server := testClientServerConfig(conf)
defer client.Close()
defer server.Close()

stream, err := client.Open()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()

stream2, err := server.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream2.Close()

err = stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
if err != nil {
t.Fatal(err)
}
buf := make([]byte, conf.MaxStreamWindowSize)
n, err := stream.Write(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != len(buf) {
t.Fatalf("short write: %d", n)
}
}
Loading

0 comments on commit 97856b4

Please sign in to comment.