Unverified Commit 67680fbd authored by Steven Allen's avatar Steven Allen Committed by GitHub

Merge pull request #81 from libp2p/feat/rw-close

Implement new CloseWrite/CloseRead interface
parents 6ee3b241 7bfef51f
......@@ -50,7 +50,7 @@ func TestSmallPackets(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if slowdown > 0.15 {
if slowdown > 0.15 && !raceEnabled {
t.Fatalf("Slowdown from mplex was >15%%: %f", slowdown)
}
}
......@@ -90,6 +90,7 @@ func testSmallPackets(b *testing.B, n1, n2 net.Conn) {
wg.Add(1)
go func() {
defer wg.Done()
defer localB.Close()
receiveBuf := make([]byte, 2048)
for {
......@@ -103,7 +104,7 @@ func testSmallPackets(b *testing.B, n1, n2 net.Conn) {
atomic.AddUint64(&receivedBytes, uint64(n))
}
}()
defer localA.Close()
i := 0
for {
n, err := localA.Write(msgs[i])
......@@ -116,7 +117,6 @@ func testSmallPackets(b *testing.B, n1, n2 net.Conn) {
break
}
}
localA.Close()
})
b.StopTimer()
wg.Wait()
......
......@@ -32,6 +32,11 @@ func (d *pipeDeadline) set(t time.Time) {
d.mu.Lock()
defer d.mu.Unlock()
// deadline closed
if d.cancel == nil {
return
}
if d.timer != nil && !d.timer.Stop() {
<-d.cancel // Wait for the timer callback to finish and close cancel
}
......@@ -70,6 +75,18 @@ func (d *pipeDeadline) wait() chan struct{} {
return d.cancel
}
// close closes, the deadline. Any future calls to `set` will do nothing.
func (d *pipeDeadline) close() {
d.mu.Lock()
defer d.mu.Unlock()
if d.timer != nil && !d.timer.Stop() {
<-d.cancel // Wait for the timer callback to finish and close cancel
}
d.timer = nil
d.cancel = nil
}
func isClosedChan(c <-chan struct{}) bool {
select {
case <-c:
......
......@@ -7,6 +7,7 @@ require (
github.com/libp2p/go-libp2p-testing v0.1.2-0.20200422005655-8775583591d8
github.com/multiformats/go-varint v0.0.6
github.com/opentracing/opentracing-go v1.2.0 // indirect
go.uber.org/multierr v1.5.0
go.uber.org/zap v1.15.0 // indirect
golang.org/x/crypto v0.0.0-20190618222545-ea8f1a30c443 // indirect
google.golang.org/grpc v1.28.1
......
......@@ -111,16 +111,15 @@ func NewMultiplex(con net.Conn, initiator bool) *Multiplex {
func (mp *Multiplex) newStream(id streamID, name string) (s *Stream) {
s = &Stream{
id: id,
name: name,
dataIn: make(chan []byte, 8),
reset: make(chan struct{}),
rDeadline: makePipeDeadline(),
wDeadline: makePipeDeadline(),
mp: mp,
id: id,
name: name,
dataIn: make(chan []byte, 8),
rDeadline: makePipeDeadline(),
wDeadline: makePipeDeadline(),
mp: mp,
writeCancel: make(chan struct{}),
readCancel: make(chan struct{}),
}
s.closedLocal, s.doCloseLocal = context.WithCancel(context.Background())
return
}
......@@ -168,7 +167,7 @@ func (mp *Multiplex) IsClosed() bool {
}
}
func (mp *Multiplex) sendMsg(done <-chan struct{}, header uint64, data []byte) error {
func (mp *Multiplex) sendMsg(timeout, cancel <-chan struct{}, header uint64, data []byte) error {
buf := pool.Get(len(data) + 20)
n := 0
......@@ -181,8 +180,10 @@ func (mp *Multiplex) sendMsg(done <-chan struct{}, header uint64, data []byte) e
return nil
case <-mp.shutdown:
return ErrShutdown
case <-done:
case <-timeout:
return errTimeout
case <-cancel:
return ErrStreamClosed
}
}
......@@ -321,7 +322,7 @@ func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) {
ctx, cancel := context.WithTimeout(context.Background(), NewStreamTimeout)
defer cancel()
err := mp.sendMsg(ctx.Done(), header, []byte(name))
err := mp.sendMsg(ctx.Done(), nil, header, []byte(name))
if err != nil {
return nil, err
}
......@@ -331,23 +332,20 @@ func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) {
func (mp *Multiplex) cleanup() {
mp.closeNoWait()
// Take the channels.
mp.chLock.Lock()
defer mp.chLock.Unlock()
for _, msch := range mp.channels {
msch.clLock.Lock()
if !msch.closedRemote {
msch.closedRemote = true
// Cancel readers
close(msch.reset)
}
channels := mp.channels
mp.channels = nil
mp.chLock.Unlock()
msch.doCloseLocal()
msch.clLock.Unlock()
// Cancel any reads/writes
for _, msch := range channels {
msch.cancelRead(ErrStreamReset)
msch.cancelWrite(ErrStreamReset)
}
// Don't remove this nil assignment. We check if this is nil to check if
// the connection is closed when we already have the lock (faster than
// checking if the stream is closed).
mp.channels = nil
// And... shutdown!
if mp.shutdownErr == nil {
mp.shutdownErr = ErrShutdown
}
......@@ -421,81 +419,43 @@ func (mp *Multiplex) handleIncoming() {
// This is *ok*. We forget the stream on reset.
continue
}
msch.clLock.Lock()
isClosed := msch.isClosed()
if !msch.closedRemote {
close(msch.reset)
msch.closedRemote = true
}
if !isClosed {
msch.doCloseLocal()
}
msch.clLock.Unlock()
msch.cancelDeadlines()
mp.chLock.Lock()
delete(mp.channels, ch)
mp.chLock.Unlock()
// Cancel any ongoing reads/writes.
msch.cancelRead(ErrStreamReset)
msch.cancelWrite(ErrStreamReset)
case closeTag:
if !ok {
// may have canceled our reads already.
continue
}
msch.clLock.Lock()
if msch.closedRemote {
msch.clLock.Unlock()
// Technically a bug on the other side. We
// should consider killing the connection.
continue
}
// unregister and throw away future data.
mp.chLock.Lock()
delete(mp.channels, ch)
mp.chLock.Unlock()
// close data channel, there will be no more data.
close(msch.dataIn)
msch.closedRemote = true
cleanup := msch.isClosed()
msch.clLock.Unlock()
if cleanup {
msch.cancelDeadlines()
mp.chLock.Lock()
delete(mp.channels, ch)
mp.chLock.Unlock()
}
// We intentionally don't cancel any deadlines, cancel reads, cancel
// writes, etc. We just deliver the EOF by closing the
// data channel, and unregister the channel so we don't
// receive any more data. The user still needs to call
// `Close()` or `Reset()`.
case messageTag:
if !ok {
// reset stream, return b
pool.Put(b)
// This is a perfectly valid case when we reset
// and forget about the stream.
log.Debugf("message for non-existant stream, dropping data: %d", ch)
// go mp.sendResetMsg(ch.header(resetTag), false)
continue
}
msch.clLock.Lock()
remoteClosed := msch.closedRemote
msch.clLock.Unlock()
if remoteClosed {
// closed stream, return b
// We're not accepting data on this stream, for
// some reason. It's likely that we reset it, or
// simply canceled reads (e.g., called Close).
pool.Put(b)
log.Warnf("Received data from remote after stream was closed by them. (len = %d)", len(b))
// go mp.sendResetMsg(msch.id.header(resetTag), false)
continue
}
recvTimeout.Reset(ReceiveTimeout)
select {
case msch.dataIn <- b:
case <-msch.reset:
case <-msch.readCancel:
// the user has canceled reading. walk away.
pool.Put(b)
case <-recvTimeout.C:
pool.Put(b)
......@@ -534,7 +494,7 @@ func (mp *Multiplex) sendResetMsg(header uint64, hard bool) {
ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout)
defer cancel()
err := mp.sendMsg(ctx.Done(), header, nil)
err := mp.sendMsg(ctx.Done(), nil, header, nil)
if err != nil && !mp.isShutdown() {
if hard {
log.Warnf("error sending reset message: %s; killing connection", err.Error())
......
......@@ -6,6 +6,8 @@ import (
"io/ioutil"
"math/rand"
"net"
"os"
"sync"
"testing"
"time"
)
......@@ -205,6 +207,53 @@ func TestEcho(t *testing.T) {
mpb.Close()
}
func TestFullClose(t *testing.T) {
a, b := net.Pipe()
mpa := NewMultiplex(a, false)
mpb := NewMultiplex(b, true)
mes := make([]byte, 40960)
rand.Read(mes)
s, err := mpa.NewStream()
if err != nil {
t.Fatal(err)
}
{
s2, err := mpb.Accept()
if err != nil {
t.Error(err)
}
_, err = s.Write(mes)
if err != nil {
t.Error(err)
}
s2.Close()
}
err = s.Close()
if err != nil {
t.Fatal(err)
}
if n, err := s.Write([]byte("foo")); err != ErrStreamClosed {
t.Fatal("expected stream closed error on write to closed stream, got", err)
} else if n != 0 {
t.Fatal("should not have written any bytes to closed stream")
}
// We closed for reading, this should fail.
if n, err := s.Read([]byte{0}); err != ErrStreamClosed {
t.Fatal("expected stream closed error on read from closed stream, got", err)
} else if n != 0 {
t.Fatal("should not have read any bytes from closed stream, got", n)
}
mpa.Close()
mpb.Close()
}
func TestHalfClose(t *testing.T) {
a, b := net.Pipe()
mpa := NewMultiplex(a, false)
......@@ -216,15 +265,19 @@ func TestHalfClose(t *testing.T) {
go func() {
s, err := mpb.Accept()
if err != nil {
t.Fatal(err)
t.Error(err)
}
defer s.Close()
if err := s.CloseRead(); err != nil {
t.Error(err)
}
<-wait
_, err = s.Write(mes)
if err != nil {
t.Fatal(err)
t.Error(err)
}
}()
......@@ -232,8 +285,9 @@ func TestHalfClose(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer s.Close()
err = s.Close()
err = s.CloseWrite()
if err != nil {
t.Fatal(err)
}
......@@ -362,6 +416,184 @@ func TestReset(t *testing.T) {
}
}
func TestCancelRead(t *testing.T) {
a, b := net.Pipe()
mpa := NewMultiplex(a, false)
mpb := NewMultiplex(b, true)
defer mpa.Close()
defer mpb.Close()
sa, err := mpa.NewStream()
if err != nil {
t.Fatal(err)
}
defer sa.Reset()
sb, err := mpb.Accept()
if err != nil {
t.Fatal(err)
}
defer sb.Reset()
// spin off a read
done := make(chan struct{})
go func() {
defer close(done)
_, err := sa.Read([]byte{0})
if err != ErrStreamClosed {
t.Error(err)
}
}()
// give it a chance to start.
time.Sleep(time.Millisecond)
// cancel it.
err = sa.CloseRead()
if err != nil {
t.Fatal(err)
}
// It should be canceled.
<-done
// Writing should still succeed.
_, err = sa.Write([]byte("foo"))
if err != nil {
t.Fatal(err)
}
err = sa.Close()
if err != nil {
t.Fatal(err)
}
// Data should still be sent.
buf, err := ioutil.ReadAll(sb)
if err != nil {
t.Fatal(err)
}
if string(buf) != "foo" {
t.Fatalf("expected foo, got %#v", err)
}
}
func TestCancelWrite(t *testing.T) {
a, b := net.Pipe()
mpa := NewMultiplex(a, false)
mpb := NewMultiplex(b, true)
defer mpa.Close()
defer mpb.Close()
sa, err := mpa.NewStream()
if err != nil {
t.Fatal(err)
}
defer sa.Reset()
sb, err := mpb.Accept()
if err != nil {
t.Fatal(err)
}
defer sb.Reset()
// spin off a read
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
for {
_, err := sa.Write([]byte("foo"))
if err != nil {
if err != ErrStreamClosed {
t.Error("unexpected error", err)
}
return
}
}
}()
// give it a chance to fill up.
time.Sleep(time.Millisecond)
go func() {
defer wg.Done()
// close it.
err := sa.CloseWrite()
if err != nil {
t.Error(err)
}
}()
_, err = ioutil.ReadAll(sb)
if err != nil {
t.Fatalf("expected stream to be closed correctly")
}
// It should be canceled.
wg.Wait()
// Reading should still succeed.
_, err = sb.Write([]byte("bar"))
if err != nil {
t.Fatal(err)
}
err = sb.Close()
if err != nil {
t.Fatal(err)
}
// Data should still be sent.
buf, err := ioutil.ReadAll(sa)
if err != nil {
t.Fatal(err)
}
if string(buf) != "bar" {
t.Fatalf("expected foo, got %#v", err)
}
}
func TestCancelReadAfterWrite(t *testing.T) {
a, b := net.Pipe()
mpa := NewMultiplex(a, false)
mpb := NewMultiplex(b, true)
defer mpa.Close()
defer mpb.Close()
sa, err := mpa.NewStream()
if err != nil {
t.Fatal(err)
}
defer sa.Reset()
sb, err := mpb.Accept()
if err != nil {
t.Fatal(err)
}
defer sb.Reset()
// Write small messages till we would block.
sa.SetWriteDeadline(time.Now().Add(time.Millisecond))
for {
_, err = sa.Write([]byte("foo"))
if err != nil {
if os.IsTimeout(err) {
break
} else {
t.Fatal(err)
}
}
}
// Cancel inbound reads.
sb.CloseRead()
// We shouldn't read anything.
n, err := sb.Read([]byte{0})
if n != 0 || err != ErrStreamClosed {
t.Fatal("got data", err)
}
}
func TestResetAfterEOF(t *testing.T) {
a, b := net.Pipe()
......@@ -377,7 +609,7 @@ func TestResetAfterEOF(t *testing.T) {
}
sb, err := mpb.Accept()
if err := sa.Close(); err != nil {
if err := sa.CloseWrite(); err != nil {
t.Fatal(err)
}
......
//+build !race
package multiplex
var raceEnabled = false
//+build race
package multiplex
var raceEnabled = true
......@@ -8,6 +8,7 @@ import (
"time"
pool "github.com/libp2p/go-buffer-pool"
"go.uber.org/multierr"
)
var (
......@@ -44,15 +45,9 @@ type Stream struct {
rDeadline, wDeadline pipeDeadline
clLock sync.Mutex
closedRemote bool
// Closed when the connection is reset.
reset chan struct{}
// Closed when the writer is closed (reset will also be closed)
closedLocal context.Context
doCloseLocal context.CancelFunc
clLock sync.Mutex
writeCancelErr, readCancelErr error
writeCancel, readCancel chan struct{}
}
func (s *Stream) Name() string {
......@@ -74,10 +69,6 @@ func (s *Stream) preloadData() {
func (s *Stream) waitForData() error {
select {
case <-s.reset:
// This is the only place where it's safe to return these.
s.returnBuffers()
return ErrStreamReset
case read, ok := <-s.dataIn:
if !ok {
return io.EOF
......@@ -85,6 +76,10 @@ func (s *Stream) waitForData() error {
s.extra = read
s.exbuf = read
return nil
case <-s.readCancel:
// This is the only place where it's safe to return these.
s.returnBuffers()
return s.readCancelErr
case <-s.rDeadline.wait():
return errTimeout
}
......@@ -114,10 +109,11 @@ func (s *Stream) returnBuffers() {
func (s *Stream) Read(b []byte) (int, error) {
select {
case <-s.reset:
return 0, ErrStreamReset
case <-s.readCancel:
return 0, s.readCancelErr
default:
}
if s.extra == nil {
err := s.waitForData()
if err != nil {
......@@ -162,134 +158,112 @@ func (s *Stream) Write(b []byte) (int, error) {
}
func (s *Stream) write(b []byte) (int, error) {
if s.isClosed() {
return 0, ErrStreamClosed
select {
case <-s.writeCancel:
return 0, s.writeCancelErr
default:
}
err := s.mp.sendMsg(s.wDeadline.wait(), s.id.header(messageTag), b)
err := s.mp.sendMsg(s.wDeadline.wait(), s.writeCancel, s.id.header(messageTag), b)
if err != nil {
if err == context.Canceled {
err = ErrStreamClosed
}
return 0, err
}
return len(b), nil
}
func (s *Stream) isClosed() bool {
return s.closedLocal.Err() != nil
}
func (s *Stream) cancelWrite(err error) bool {
s.wDeadline.close()
func (s *Stream) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout)
defer cancel()
s.clLock.Lock()
defer s.clLock.Unlock()
select {
case <-s.writeCancel:
return false
default:
s.writeCancelErr = err
close(s.writeCancel)
return true
}
}
err := s.mp.sendMsg(ctx.Done(), s.id.header(closeTag), nil)
func (s *Stream) cancelRead(err error) bool {
// Always unregister for reading first, even if we're already closed (or
// already closing). When handleIncoming calls this, it expects the
// stream to be unregistered by the time it returns.
s.mp.chLock.Lock()
delete(s.mp.channels, s.id)
s.mp.chLock.Unlock()
if s.isClosed() {
return nil
}
s.rDeadline.close()
s.clLock.Lock()
remote := s.closedRemote
s.clLock.Unlock()
s.doCloseLocal()
defer s.clLock.Unlock()
select {
case <-s.readCancel:
return false
default:
s.readCancelErr = err
close(s.readCancel)
return true
}
}
if remote {
s.cancelDeadlines()
s.mp.chLock.Lock()
delete(s.mp.channels, s.id)
s.mp.chLock.Unlock()
func (s *Stream) CloseWrite() error {
if !s.cancelWrite(ErrStreamClosed) {
// Check if we closed the stream _nicely_. If so, we don't need
// to report an error to the user.
if s.writeCancelErr == ErrStreamClosed {
return nil
}
// Closed for some other reason. Report it.
return s.writeCancelErr
}
ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout)
defer cancel()
err := s.mp.sendMsg(ctx.Done(), nil, s.id.header(closeTag), nil)
// We failed to close the stream after 2 minutes, something is probably wrong.
if err != nil && !s.mp.isShutdown() {
log.Warnf("Error closing stream: %s; killing connection", err.Error())
s.mp.Close()
}
return err
}
func (s *Stream) Reset() error {
s.clLock.Lock()
// Don't reset when fully closed.
if s.closedRemote && s.isClosed() {
s.clLock.Unlock()
return nil
}
// Don't reset twice.
select {
case <-s.reset:
s.clLock.Unlock()
return nil
default:
}
close(s.reset)
s.doCloseLocal()
s.closedRemote = true
s.cancelDeadlines()
go s.mp.sendResetMsg(s.id.header(resetTag), true)
s.clLock.Unlock()
s.mp.chLock.Lock()
delete(s.mp.channels, s.id)
s.mp.chLock.Unlock()
func (s *Stream) CloseRead() error {
s.cancelRead(ErrStreamClosed)
return nil
}
func (s *Stream) cancelDeadlines() {
s.rDeadline.set(time.Time{})
s.wDeadline.set(time.Time{})
func (s *Stream) Close() error {
return multierr.Combine(s.CloseRead(), s.CloseWrite())
}
func (s *Stream) SetDeadline(t time.Time) error {
s.clLock.Lock()
defer s.clLock.Unlock()
if s.closedRemote && s.isClosed() {
return errStreamClosed
}
func (s *Stream) Reset() error {
s.cancelRead(ErrStreamReset)
if !s.closedRemote {
s.rDeadline.set(t)
if s.cancelWrite(ErrStreamReset) {
// Send a reset in the background.
go s.mp.sendResetMsg(s.id.header(resetTag), true)
}
if !s.isClosed() {
s.wDeadline.set(t)
}
return nil
}
func (s *Stream) SetDeadline(t time.Time) error {
s.rDeadline.set(t)
s.wDeadline.set(t)
return nil
}
func (s *Stream) SetReadDeadline(t time.Time) error {
s.clLock.Lock()
defer s.clLock.Unlock()
if s.closedRemote {
return errStreamClosed
}
s.rDeadline.set(t)
return nil
}
func (s *Stream) SetWriteDeadline(t time.Time) error {
s.clLock.Lock()
defer s.clLock.Unlock()
if s.isClosed() {
return errStreamClosed
}
s.wDeadline.set(t)
return nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment