Unverified Commit 97856b4f authored by Steven Allen's avatar Steven Allen Committed by GitHub

Merge pull request #21 from libp2p/fix/keepalive-race

fix: synchronize when resetting the keepalive timer
parents 51522d41 345f6390
...@@ -8,8 +8,10 @@ go: ...@@ -8,8 +8,10 @@ go:
env: env:
global: global:
- GOTFLAGS="-race"
- BUILD_DEPTYPE=gomod - BUILD_DEPTYPE=gomod
matrix:
- GOTFLAGS="-race"
- GOTFLAGS="-count 5"
# disable travis install # disable travis install
......
...@@ -48,13 +48,14 @@ func BenchmarkAccept(b *testing.B) { ...@@ -48,13 +48,14 @@ func BenchmarkAccept(b *testing.B) {
func BenchmarkSendRecv(b *testing.B) { func BenchmarkSendRecv(b *testing.B) {
client, server := testClientServer() client, server := testClientServer()
defer client.Close() defer client.Close()
defer server.Close()
sendBuf := make([]byte, 512) sendBuf := make([]byte, 512)
recvBuf := make([]byte, 512) recvBuf := make([]byte, 512)
doneCh := make(chan struct{}) doneCh := make(chan struct{})
go func() { go func() {
defer close(doneCh)
defer server.Close()
stream, err := server.AcceptStream() stream, err := server.AcceptStream()
if err != nil { if err != nil {
return return
...@@ -62,10 +63,10 @@ func BenchmarkSendRecv(b *testing.B) { ...@@ -62,10 +63,10 @@ func BenchmarkSendRecv(b *testing.B) {
defer stream.Close() defer stream.Close()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if _, err := io.ReadFull(stream, recvBuf); err != nil { if _, err := io.ReadFull(stream, recvBuf); err != nil {
b.Fatalf("err: %v", err) b.Errorf("err: %v", err)
return
} }
} }
close(doneCh)
}() }()
stream, err := client.Open() stream, err := client.Open()
...@@ -95,6 +96,8 @@ func BenchmarkSendRecvLarge(b *testing.B) { ...@@ -95,6 +96,8 @@ func BenchmarkSendRecvLarge(b *testing.B) {
recvDone := make(chan struct{}) recvDone := make(chan struct{})
go func() { go func() {
defer close(recvDone)
defer server.Close()
stream, err := server.AcceptStream() stream, err := server.AcceptStream()
if err != nil { if err != nil {
return return
...@@ -103,11 +106,11 @@ func BenchmarkSendRecvLarge(b *testing.B) { ...@@ -103,11 +106,11 @@ func BenchmarkSendRecvLarge(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for j := 0; j < sendSize/recvSize; j++ { for j := 0; j < sendSize/recvSize; j++ {
if _, err := io.ReadFull(stream, recvBuf); err != nil { if _, err := io.ReadFull(stream, recvBuf); err != nil {
b.Fatalf("err: %v", err) b.Errorf("err: %v", err)
return
} }
} }
} }
close(recvDone)
}() }()
stream, err := client.Open() stream, err := client.Open()
......
...@@ -87,15 +87,11 @@ type Session struct { ...@@ -87,15 +87,11 @@ type Session struct {
// keepaliveTimer is a periodic timer for keepalive messages. It's nil // keepaliveTimer is a periodic timer for keepalive messages. It's nil
// when keepalives are disabled. // when keepalives are disabled.
keepaliveLock sync.Mutex keepaliveLock sync.Mutex
keepaliveTimer *time.Timer keepaliveTimer *time.Timer
keepaliveActive bool
} }
const (
stageInitial uint32 = iota
stageFinal
)
// newSession is used to construct a new session // newSession is used to construct a new session
func newSession(config *Config, conn net.Conn, client bool, readBuf int) *Session { func newSession(config *Config, conn net.Conn, client bool, readBuf int) *Session {
var reader io.Reader = conn var reader io.Reader = conn
...@@ -327,23 +323,27 @@ func (s *Session) startKeepalive() { ...@@ -327,23 +323,27 @@ func (s *Session) startKeepalive() {
defer s.keepaliveLock.Unlock() defer s.keepaliveLock.Unlock()
s.keepaliveTimer = time.AfterFunc(s.config.KeepAliveInterval, func() { s.keepaliveTimer = time.AfterFunc(s.config.KeepAliveInterval, func() {
s.keepaliveLock.Lock() s.keepaliveLock.Lock()
if s.keepaliveTimer == nil || s.keepaliveActive {
if s.keepaliveTimer == nil { // keepalives have been stopped or a keepalive is active.
s.keepaliveLock.Unlock() s.keepaliveLock.Unlock()
// keepalives have been stopped.
return return
} }
s.keepaliveActive = true
s.keepaliveLock.Unlock()
_, err := s.Ping() _, err := s.Ping()
s.keepaliveLock.Lock()
s.keepaliveActive = false
if s.keepaliveTimer != nil {
s.keepaliveTimer.Reset(s.config.KeepAliveInterval)
}
s.keepaliveLock.Unlock()
if err != nil { if err != nil {
// Make sure to unlock before exiting so we don't
// deadlock trying to shutdown keepalives.
s.keepaliveLock.Unlock()
s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
s.exitErr(ErrKeepAliveTimeout) s.exitErr(ErrKeepAliveTimeout)
return
} }
s.keepaliveTimer.Reset(s.config.KeepAliveInterval)
s.keepaliveLock.Unlock()
}) })
} }
...@@ -353,7 +353,24 @@ func (s *Session) stopKeepalive() { ...@@ -353,7 +353,24 @@ func (s *Session) stopKeepalive() {
defer s.keepaliveLock.Unlock() defer s.keepaliveLock.Unlock()
if s.keepaliveTimer != nil { if s.keepaliveTimer != nil {
s.keepaliveTimer.Stop() s.keepaliveTimer.Stop()
s.keepaliveTimer = nil
}
}
func (s *Session) extendKeepalive() {
s.keepaliveLock.Lock()
if s.keepaliveTimer != nil && !s.keepaliveActive {
// Don't stop the timer and drain the channel. This is an
// AfterFunc, not a normal timer, and any attempts to drain the
// channel will block forever.
//
// Go will stop the timer for us internally anyways. The docs
// say one must stop the timer before calling reset but that's
// to ensure that the timer doesn't end up firing immediately
// after calling Reset.
s.keepaliveTimer.Reset(s.config.KeepAliveInterval)
} }
s.keepaliveLock.Unlock()
} }
// send sends the header and body. // send sends the header and body.
...@@ -512,9 +529,7 @@ func (s *Session) recvLoop() error { ...@@ -512,9 +529,7 @@ func (s *Session) recvLoop() error {
// There's no reason to keepalive if we're active. Worse, if the // There's no reason to keepalive if we're active. Worse, if the
// peer is busy sending us stuff, the pong might get stuck // peer is busy sending us stuff, the pong might get stuck
// behind a bunch of data. // behind a bunch of data.
if s.keepaliveTimer != nil { s.extendKeepalive()
s.keepaliveTimer.Reset(s.config.KeepAliveInterval)
}
// Verify the version // Verify the version
if hdr.Version() != protoVersion { if hdr.Version() != protoVersion {
......
//+build !race
package yamux
import (
"bytes"
"io"
"io/ioutil"
"sync"
"testing"
"time"
)
func TestSession_PingOfDeath(t *testing.T) {
client, server := testClientServerConfig(testConfNoKeepAlive())
defer client.Close()
defer server.Close()
count := 10000
var wg sync.WaitGroup
begin := make(chan struct{})
for i := 0; i < count; i++ {
wg.Add(2)
go func() {
defer wg.Done()
<-begin
if _, err := server.Ping(); err != nil {
t.Error(err)
}
}()
go func() {
defer wg.Done()
<-begin
if _, err := client.Ping(); err != nil {
t.Error(err)
}
}()
}
close(begin)
wg.Wait()
}
func TestSendData_VeryLarge(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
var n int64 = 1 * 1024 * 1024 * 1024
var workers int = 16
wg := &sync.WaitGroup{}
wg.Add(workers * 2)
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
stream, err := server.AcceptStream()
if err != nil {
t.Errorf("err: %v", err)
return
}
defer stream.Close()
buf := make([]byte, 4)
_, err = io.ReadFull(stream, buf)
if err != nil {
t.Errorf("err: %v", err)
return
}
if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
t.Errorf("bad header")
return
}
recv, err := io.Copy(ioutil.Discard, stream)
if err != nil {
t.Errorf("err: %v", err)
return
}
if recv != n {
t.Errorf("bad: %v", recv)
return
}
}()
}
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
stream, err := client.Open()
if err != nil {
t.Errorf("err: %v", err)
return
}
defer stream.Close()
_, err = stream.Write([]byte{0, 1, 2, 3})
if err != nil {
t.Errorf("err: %v", err)
return
}
unlimited := &UnlimitedReader{}
sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
if err != nil {
t.Errorf("err: %v", err)
return
}
if sent != n {
t.Errorf("bad: %v", sent)
return
}
}()
}
doneCh := make(chan struct{})
go func() {
wg.Wait()
close(doneCh)
}()
select {
case <-doneCh:
case <-time.After(20 * time.Second):
server.Close()
client.Close()
wg.Wait()
t.Fatal("timeout")
}
}
func TestLargeWindow(t *testing.T) {
conf := DefaultConfig()
conf.MaxStreamWindowSize *= 2
client, server := testClientServerConfig(conf)
defer client.Close()
defer server.Close()
stream, err := client.Open()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()
stream2, err := server.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream2.Close()
err = stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
if err != nil {
t.Fatal(err)
}
buf := make([]byte, conf.MaxStreamWindowSize)
n, err := stream.Write(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != len(buf) {
t.Fatalf("short write: %d", n)
}
}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment