Commit 81349d05 authored by Steven Allen's avatar Steven Allen

Implement new CloseWrite/CloseRead interface

Close now closes the stream in both directions while CloseRead discards inbound
bytes and CloseWrite sends an EOF. This matches the user's expectation where
Close actually closes the stream.

part of https://github.com/libp2p/go-libp2p-core/pull/166
parent 30712bb8
......@@ -90,6 +90,7 @@ func testSmallPackets(b *testing.B, n1, n2 net.Conn) {
wg.Add(1)
go func() {
defer wg.Done()
defer localB.Close()
receiveBuf := make([]byte, 2048)
for {
......@@ -103,7 +104,7 @@ func testSmallPackets(b *testing.B, n1, n2 net.Conn) {
atomic.AddUint64(&receivedBytes, uint64(n))
}
}()
defer localA.Close()
i := 0
for {
n, err := localA.Write(msgs[i])
......@@ -116,7 +117,6 @@ func testSmallPackets(b *testing.B, n1, n2 net.Conn) {
break
}
}
localA.Close()
})
b.StopTimer()
wg.Wait()
......
......@@ -32,6 +32,11 @@ func (d *pipeDeadline) set(t time.Time) {
d.mu.Lock()
defer d.mu.Unlock()
// deadline closed
if d.cancel == nil {
return
}
if d.timer != nil && !d.timer.Stop() {
<-d.cancel // Wait for the timer callback to finish and close cancel
}
......@@ -70,6 +75,18 @@ func (d *pipeDeadline) wait() chan struct{} {
return d.cancel
}
// close closes, the deadline. Any future calls to `set` will do nothing.
func (d *pipeDeadline) close() {
d.mu.Lock()
defer d.mu.Unlock()
if d.timer != nil && !d.timer.Stop() {
<-d.cancel // Wait for the timer callback to finish and close cancel
}
d.timer = nil
d.cancel = nil
}
func isClosedChan(c <-chan struct{}) bool {
select {
case <-c:
......
......@@ -7,6 +7,7 @@ require (
github.com/libp2p/go-libp2p-testing v0.1.2-0.20200422005655-8775583591d8
github.com/multiformats/go-varint v0.0.6
github.com/opentracing/opentracing-go v1.2.0 // indirect
go.uber.org/multierr v1.5.0
go.uber.org/zap v1.15.0 // indirect
golang.org/x/crypto v0.0.0-20190618222545-ea8f1a30c443 // indirect
google.golang.org/grpc v1.28.1
......
......@@ -111,16 +111,15 @@ func NewMultiplex(con net.Conn, initiator bool) *Multiplex {
func (mp *Multiplex) newStream(id streamID, name string) (s *Stream) {
s = &Stream{
id: id,
name: name,
dataIn: make(chan []byte, 8),
reset: make(chan struct{}),
rDeadline: makePipeDeadline(),
wDeadline: makePipeDeadline(),
mp: mp,
id: id,
name: name,
dataIn: make(chan []byte, 8),
rDeadline: makePipeDeadline(),
wDeadline: makePipeDeadline(),
mp: mp,
writeCancel: make(chan struct{}),
readCancel: make(chan struct{}),
}
s.closedLocal, s.doCloseLocal = context.WithCancel(context.Background())
return
}
......@@ -168,7 +167,7 @@ func (mp *Multiplex) IsClosed() bool {
}
}
func (mp *Multiplex) sendMsg(done <-chan struct{}, header uint64, data []byte) error {
func (mp *Multiplex) sendMsg(timeout, cancel <-chan struct{}, header uint64, data []byte) error {
buf := pool.Get(len(data) + 20)
n := 0
......@@ -181,8 +180,10 @@ func (mp *Multiplex) sendMsg(done <-chan struct{}, header uint64, data []byte) e
return nil
case <-mp.shutdown:
return ErrShutdown
case <-done:
case <-timeout:
return errTimeout
case <-cancel:
return ErrStreamClosed
}
}
......@@ -321,7 +322,7 @@ func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) {
ctx, cancel := context.WithTimeout(context.Background(), NewStreamTimeout)
defer cancel()
err := mp.sendMsg(ctx.Done(), header, []byte(name))
err := mp.sendMsg(ctx.Done(), nil, header, []byte(name))
if err != nil {
return nil, err
}
......@@ -331,23 +332,20 @@ func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) {
func (mp *Multiplex) cleanup() {
mp.closeNoWait()
// Take the channels.
mp.chLock.Lock()
defer mp.chLock.Unlock()
for _, msch := range mp.channels {
msch.clLock.Lock()
if !msch.closedRemote {
msch.closedRemote = true
// Cancel readers
close(msch.reset)
}
channels := mp.channels
mp.channels = nil
mp.chLock.Unlock()
msch.doCloseLocal()
msch.clLock.Unlock()
// Cancel any reads/writes
for _, msch := range channels {
msch.cancelRead(ErrStreamReset)
msch.cancelWrite(ErrStreamReset)
}
// Don't remove this nil assignment. We check if this is nil to check if
// the connection is closed when we already have the lock (faster than
// checking if the stream is closed).
mp.channels = nil
// And... shutdown!
if mp.shutdownErr == nil {
mp.shutdownErr = ErrShutdown
}
......@@ -421,81 +419,43 @@ func (mp *Multiplex) handleIncoming() {
// This is *ok*. We forget the stream on reset.
continue
}
msch.clLock.Lock()
isClosed := msch.isClosed()
if !msch.closedRemote {
close(msch.reset)
msch.closedRemote = true
}
if !isClosed {
msch.doCloseLocal()
}
msch.clLock.Unlock()
msch.cancelDeadlines()
mp.chLock.Lock()
delete(mp.channels, ch)
mp.chLock.Unlock()
// Cancel any ongoing reads/writes.
msch.cancelRead(ErrStreamReset)
msch.cancelWrite(ErrStreamReset)
case closeTag:
if !ok {
// may have canceled our reads already.
continue
}
msch.clLock.Lock()
if msch.closedRemote {
msch.clLock.Unlock()
// Technically a bug on the other side. We
// should consider killing the connection.
continue
}
// unregister and throw away future data.
mp.chLock.Lock()
delete(mp.channels, ch)
mp.chLock.Unlock()
// close data channel, there will be no more data.
close(msch.dataIn)
msch.closedRemote = true
cleanup := msch.isClosed()
msch.clLock.Unlock()
if cleanup {
msch.cancelDeadlines()
mp.chLock.Lock()
delete(mp.channels, ch)
mp.chLock.Unlock()
}
// We intentionally don't cancel any deadlines, cancel reads, cancel
// writes, etc. We just deliver the EOF by closing the
// data channel, and unregister the channel so we don't
// receive any more data. The user still needs to call
// `Close()` or `Reset()`.
case messageTag:
if !ok {
// reset stream, return b
pool.Put(b)
// This is a perfectly valid case when we reset
// and forget about the stream.
log.Debugf("message for non-existant stream, dropping data: %d", ch)
// go mp.sendResetMsg(ch.header(resetTag), false)
continue
}
msch.clLock.Lock()
remoteClosed := msch.closedRemote
msch.clLock.Unlock()
if remoteClosed {
// closed stream, return b
// We're not accepting data on this stream, for
// some reason. It's likely that we reset it, or
// simply canceled reads (e.g., called Close).
pool.Put(b)
log.Warnf("Received data from remote after stream was closed by them. (len = %d)", len(b))
// go mp.sendResetMsg(msch.id.header(resetTag), false)
continue
}
recvTimeout.Reset(ReceiveTimeout)
select {
case msch.dataIn <- b:
case <-msch.reset:
case <-msch.readCancel:
// the user has canceled reading. walk away.
pool.Put(b)
case <-recvTimeout.C:
pool.Put(b)
......@@ -534,7 +494,7 @@ func (mp *Multiplex) sendResetMsg(header uint64, hard bool) {
ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout)
defer cancel()
err := mp.sendMsg(ctx.Done(), header, nil)
err := mp.sendMsg(ctx.Done(), nil, header, nil)
if err != nil && !mp.isShutdown() {
if hard {
log.Warnf("error sending reset message: %s; killing connection", err.Error())
......
......@@ -6,6 +6,8 @@ import (
"io/ioutil"
"math/rand"
"net"
"os"
"sync"
"testing"
"time"
)
......@@ -205,6 +207,53 @@ func TestEcho(t *testing.T) {
mpb.Close()
}
func TestFullClose(t *testing.T) {
a, b := net.Pipe()
mpa := NewMultiplex(a, false)
mpb := NewMultiplex(b, true)
mes := make([]byte, 40960)
rand.Read(mes)
s, err := mpa.NewStream()
if err != nil {
t.Fatal(err)
}
{
s2, err := mpb.Accept()
if err != nil {
t.Error(err)
}
_, err = s.Write(mes)
if err != nil {
t.Error(err)
}
s2.Close()
}
err = s.Close()
if err != nil {
t.Fatal(err)
}
if n, err := s.Write([]byte("foo")); err != ErrStreamClosed {
t.Fatal("expected stream closed error on write to closed stream, got", err)
} else if n != 0 {
t.Fatal("should not have written any bytes to closed stream")
}
// We closed for reading, this should fail.
if n, err := s.Read([]byte{0}); err != ErrStreamClosed {
t.Fatal("expected stream closed error on read from closed stream, got", err)
} else if n != 0 {
t.Fatal("should not have read any bytes from closed stream, got", n)
}
mpa.Close()
mpb.Close()
}
func TestHalfClose(t *testing.T) {
a, b := net.Pipe()
mpa := NewMultiplex(a, false)
......@@ -216,15 +265,19 @@ func TestHalfClose(t *testing.T) {
go func() {
s, err := mpb.Accept()
if err != nil {
t.Fatal(err)
t.Error(err)
}
defer s.Close()
if err := s.CloseRead(); err != nil {
t.Error(err)
}
<-wait
_, err = s.Write(mes)
if err != nil {
t.Fatal(err)
t.Error(err)
}
}()
......@@ -232,8 +285,9 @@ func TestHalfClose(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer s.Close()
err = s.Close()
err = s.CloseWrite()
if err != nil {
t.Fatal(err)
}
......@@ -362,6 +416,184 @@ func TestReset(t *testing.T) {
}
}
func TestCancelRead(t *testing.T) {
a, b := net.Pipe()
mpa := NewMultiplex(a, false)
mpb := NewMultiplex(b, true)
defer mpa.Close()
defer mpb.Close()
sa, err := mpa.NewStream()
if err != nil {
t.Fatal(err)
}
defer sa.Reset()
sb, err := mpb.Accept()
if err != nil {
t.Fatal(err)
}
defer sb.Reset()
// spin off a read
done := make(chan struct{})
go func() {
defer close(done)
_, err := sa.Read([]byte{0})
if err != ErrStreamClosed {
t.Error(err)
}
}()
// give it a chance to start.
time.Sleep(time.Millisecond)
// cancel it.
err = sa.CloseRead()
if err != nil {
t.Fatal(err)
}
// It should be canceled.
<-done
// Writing should still succeed.
_, err = sa.Write([]byte("foo"))
if err != nil {
t.Fatal(err)
}
err = sa.Close()
if err != nil {
t.Fatal(err)
}
// Data should still be sent.
buf, err := ioutil.ReadAll(sb)
if err != nil {
t.Fatal(err)
}
if string(buf) != "foo" {
t.Fatalf("expected foo, got %#v", err)
}
}
func TestCancelWrite(t *testing.T) {
a, b := net.Pipe()
mpa := NewMultiplex(a, false)
mpb := NewMultiplex(b, true)
defer mpa.Close()
defer mpb.Close()
sa, err := mpa.NewStream()
if err != nil {
t.Fatal(err)
}
defer sa.Reset()
sb, err := mpb.Accept()
if err != nil {
t.Fatal(err)
}
defer sb.Reset()
// spin off a read
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
for {
_, err := sa.Write([]byte("foo"))
if err != nil {
if err != ErrStreamClosed {
t.Error("unexpected error", err)
}
return
}
}
}()
// give it a chance to fill up.
time.Sleep(time.Millisecond)
go func() {
defer wg.Done()
// close it.
err := sa.CloseWrite()
if err != nil {
t.Error(err)
}
}()
_, err = ioutil.ReadAll(sb)
if err != nil {
t.Fatalf("expected stream to be closed correctly")
}
// It should be canceled.
wg.Wait()
// Reading should still succeed.
_, err = sb.Write([]byte("bar"))
if err != nil {
t.Fatal(err)
}
err = sb.Close()
if err != nil {
t.Fatal(err)
}
// Data should still be sent.
buf, err := ioutil.ReadAll(sa)
if err != nil {
t.Fatal(err)
}
if string(buf) != "bar" {
t.Fatalf("expected foo, got %#v", err)
}
}
func TestCancelReadAfterWrite(t *testing.T) {
a, b := net.Pipe()
mpa := NewMultiplex(a, false)
mpb := NewMultiplex(b, true)
defer mpa.Close()
defer mpb.Close()
sa, err := mpa.NewStream()
if err != nil {
t.Fatal(err)
}
defer sa.Reset()
sb, err := mpb.Accept()
if err != nil {
t.Fatal(err)
}
defer sb.Reset()
// Write small messages till we would block.
sa.SetWriteDeadline(time.Now().Add(time.Millisecond))
for {
_, err = sa.Write([]byte("foo"))
if err != nil {
if os.IsTimeout(err) {
break
} else {
t.Fatal(err)
}
}
}
// Cancel inbound reads.
sb.CloseRead()
// We shouldn't read anything.
n, err := sb.Read([]byte{0})
if n != 0 || err != ErrStreamClosed {
t.Fatal("got data", err)
}
}
func TestResetAfterEOF(t *testing.T) {
a, b := net.Pipe()
......@@ -377,7 +609,7 @@ func TestResetAfterEOF(t *testing.T) {
}
sb, err := mpb.Accept()
if err := sa.Close(); err != nil {
if err := sa.CloseWrite(); err != nil {
t.Fatal(err)
}
......
......@@ -8,6 +8,7 @@ import (
"time"
pool "github.com/libp2p/go-buffer-pool"
"go.uber.org/multierr"
)
var (
......@@ -44,15 +45,9 @@ type Stream struct {
rDeadline, wDeadline pipeDeadline
clLock sync.Mutex
closedRemote bool
// Closed when the connection is reset.
reset chan struct{}
// Closed when the writer is closed (reset will also be closed)
closedLocal context.Context
doCloseLocal context.CancelFunc
clLock sync.Mutex
writeCancelErr, readCancelErr error
writeCancel, readCancel chan struct{}
}
func (s *Stream) Name() string {
......@@ -74,10 +69,6 @@ func (s *Stream) preloadData() {
func (s *Stream) waitForData() error {
select {
case <-s.reset:
// This is the only place where it's safe to return these.
s.returnBuffers()
return ErrStreamReset
case read, ok := <-s.dataIn:
if !ok {
return io.EOF
......@@ -85,6 +76,10 @@ func (s *Stream) waitForData() error {
s.extra = read
s.exbuf = read
return nil
case <-s.readCancel:
// This is the only place where it's safe to return these.
s.returnBuffers()
return s.readCancelErr
case <-s.rDeadline.wait():
return errTimeout
}
......@@ -114,10 +109,11 @@ func (s *Stream) returnBuffers() {
func (s *Stream) Read(b []byte) (int, error) {
select {
case <-s.reset:
return 0, ErrStreamReset
case <-s.readCancel:
return 0, s.readCancelErr
default:
}
if s.extra == nil {
err := s.waitForData()
if err != nil {
......@@ -162,134 +158,112 @@ func (s *Stream) Write(b []byte) (int, error) {
}
func (s *Stream) write(b []byte) (int, error) {
if s.isClosed() {
return 0, ErrStreamClosed
select {
case <-s.writeCancel:
return 0, s.writeCancelErr
default:
}
err := s.mp.sendMsg(s.wDeadline.wait(), s.id.header(messageTag), b)
err := s.mp.sendMsg(s.wDeadline.wait(), s.writeCancel, s.id.header(messageTag), b)
if err != nil {
if err == context.Canceled {
err = ErrStreamClosed
}
return 0, err
}
return len(b), nil
}
func (s *Stream) isClosed() bool {
return s.closedLocal.Err() != nil
}
func (s *Stream) cancelWrite(err error) bool {
s.wDeadline.close()
func (s *Stream) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout)
defer cancel()
s.clLock.Lock()
defer s.clLock.Unlock()
select {
case <-s.writeCancel:
return false
default:
s.writeCancelErr = err
close(s.writeCancel)
return true
}
}
err := s.mp.sendMsg(ctx.Done(), s.id.header(closeTag), nil)
func (s *Stream) cancelRead(err error) bool {
// Always unregister for reading first, even if we're already closed (or
// already closing). When handleIncoming calls this, it expects the
// stream to be unregistered by the time it returns.
s.mp.chLock.Lock()
delete(s.mp.channels, s.id)
s.mp.chLock.Unlock()
if s.isClosed() {
return nil
}
s.rDeadline.close()
s.clLock.Lock()
remote := s.closedRemote
s.clLock.Unlock()
s.doCloseLocal()
defer s.clLock.Unlock()
select {
case <-s.readCancel:
return false
default:
s.readCancelErr = err
close(s.readCancel)
return true
}
}
if remote {
s.cancelDeadlines()
s.mp.chLock.Lock()
delete(s.mp.channels, s.id)
s.mp.chLock.Unlock()
func (s *Stream) CloseWrite() error {
if !s.cancelWrite(ErrStreamClosed) {
// Check if we closed the stream _nicely_. If so, we don't need
// to report an error to the user.
if s.writeCancelErr == ErrStreamClosed {
return nil
}
// Closed for some other reason. Report it.
return s.writeCancelErr
}
ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout)
defer cancel()
err := s.mp.sendMsg(ctx.Done(), nil, s.id.header(closeTag), nil)
// We failed to close the stream after 2 minutes, something is probably wrong.
if err != nil && !s.mp.isShutdown() {
log.Warnf("Error closing stream: %s; killing connection", err.Error())
s.mp.Close()
}
return err
}
func (s *Stream) Reset() error {
s.clLock.Lock()
// Don't reset when fully closed.
if s.closedRemote && s.isClosed() {
s.clLock.Unlock()
return nil
}
// Don't reset twice.
select {
case <-s.reset:
s.clLock.Unlock()
return nil
default:
}
close(s.reset)
s.doCloseLocal()
s.closedRemote = true
s.cancelDeadlines()
go s.mp.sendResetMsg(s.id.header(resetTag), true)
s.clLock.Unlock()
s.mp.chLock.Lock()
delete(s.mp.channels, s.id)
s.mp.chLock.Unlock()
func (s *Stream) CloseRead() error {
s.cancelRead(ErrStreamClosed)
return nil
}
func (s *Stream) cancelDeadlines() {
s.rDeadline.set(time.Time{})
s.wDeadline.set(time.Time{})
func (s *Stream) Close() error {
return multierr.Combine(s.CloseRead(), s.CloseWrite())
}
func (s *Stream) SetDeadline(t time.Time) error {
s.clLock.Lock()
defer s.clLock.Unlock()
if s.closedRemote && s.isClosed() {
return errStreamClosed
}
func (s *Stream) Reset() error {
s.cancelRead(ErrStreamReset)
if !s.closedRemote {
s.rDeadline.set(t)
if s.cancelWrite(ErrStreamReset) {
// Send a reset in the background.
go s.mp.sendResetMsg(s.id.header(resetTag), true)
}
if !s.isClosed() {
s.wDeadline.set(t)
}
return nil
}
func (s *Stream) SetDeadline(t time.Time) error {
s.rDeadline.set(t)
s.wDeadline.set(t)
return nil
}
func (s *Stream) SetReadDeadline(t time.Time) error {
s.clLock.Lock()
defer s.clLock.Unlock()
if s.closedRemote {
return errStreamClosed
}
s.rDeadline.set(t)
return nil
}
func (s *Stream) SetWriteDeadline(t time.Time) error {
s.clLock.Lock()
defer s.clLock.Unlock()
if s.isClosed() {
return errStreamClosed
}
s.wDeadline.set(t)
return nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment