Unverified Commit 10c91193 authored by Steven Allen's avatar Steven Allen Committed by GitHub

Merge pull request #5 from libp2p/feat/rw-close

implement CloseRead/CloseWrite
parents fd43d7f1 fa4a0b4c
...@@ -220,14 +220,18 @@ func (s *Session) Accept() (net.Conn, error) { ...@@ -220,14 +220,18 @@ func (s *Session) Accept() (net.Conn, error) {
// AcceptStream is used to block until the next available stream // AcceptStream is used to block until the next available stream
// is ready to be accepted. // is ready to be accepted.
func (s *Session) AcceptStream() (*Stream, error) { func (s *Session) AcceptStream() (*Stream, error) {
select { for {
case stream := <-s.acceptCh: select {
if err := stream.sendWindowUpdate(); err != nil { case stream := <-s.acceptCh:
return nil, err if err := stream.sendWindowUpdate(); err != nil {
// don't return accept errors.
s.logger.Printf("[WARN] error sending window update before accepting: %s", err)
continue
}
return stream, nil
case <-s.shutdownCh:
return nil, s.shutdownErr
} }
return stream, nil
case <-s.shutdownCh:
return nil, s.shutdownErr
} }
} }
......
...@@ -407,6 +407,7 @@ func TestSendData_Small(t *testing.T) { ...@@ -407,6 +407,7 @@ func TestSendData_Small(t *testing.T) {
t.Errorf("err: %v", err) t.Errorf("err: %v", err)
return return
} }
defer stream.Close()
if server.NumStreams() != 1 { if server.NumStreams() != 1 {
t.Errorf("bad") t.Errorf("bad")
...@@ -430,7 +431,7 @@ func TestSendData_Small(t *testing.T) { ...@@ -430,7 +431,7 @@ func TestSendData_Small(t *testing.T) {
} }
} }
if err := stream.Close(); err != nil { if err := stream.CloseWrite(); err != nil {
t.Errorf("err: %v", err) t.Errorf("err: %v", err)
return return
} }
...@@ -442,11 +443,12 @@ func TestSendData_Small(t *testing.T) { ...@@ -442,11 +443,12 @@ func TestSendData_Small(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
stream, err := client.Open() stream, err := client.OpenStream()
if err != nil { if err != nil {
t.Errorf("err: %v", err) t.Errorf("err: %v", err)
return return
} }
defer stream.Close()
if client.NumStreams() != 1 { if client.NumStreams() != 1 {
t.Errorf("bad") t.Errorf("bad")
...@@ -465,7 +467,7 @@ func TestSendData_Small(t *testing.T) { ...@@ -465,7 +467,7 @@ func TestSendData_Small(t *testing.T) {
} }
} }
if err := stream.Close(); err != nil { if err := stream.CloseWrite(); err != nil {
t.Errorf("err: %v", err) t.Errorf("err: %v", err)
return return
} }
...@@ -785,12 +787,12 @@ func TestManyStreams_PingPong(t *testing.T) { ...@@ -785,12 +787,12 @@ func TestManyStreams_PingPong(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestHalfClose(t *testing.T) { func TestCloseRead(t *testing.T) {
client, server := testClientServer() client, server := testClientServer()
defer client.Close() defer client.Close()
defer server.Close() defer server.Close()
stream, err := client.Open() stream, err := client.OpenStream()
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
...@@ -798,17 +800,43 @@ func TestHalfClose(t *testing.T) { ...@@ -798,17 +800,43 @@ func TestHalfClose(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
stream2, err := server.Accept() stream2, err := server.AcceptStream()
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
stream2.Close() // Half close stream2.CloseRead()
buf := make([]byte, 4) buf := make([]byte, 4)
n, err := stream2.Read(buf) n, err := stream2.Read(buf)
if n != 0 || err == nil {
t.Fatalf("read after close: %d %s", n, err)
}
}
func TestHalfClose(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
stream, err := client.OpenStream()
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if _, err = stream.Write([]byte("a")); err != nil {
t.Fatalf("err: %v", err)
}
stream2, err := server.AcceptStream()
if err != nil {
t.Fatalf("err: %v", err)
}
stream2.CloseWrite() // Half close
buf := make([]byte, 4)
n, err := io.ReadAtLeast(stream2, buf, 1)
if err != nil && err != io.EOF {
t.Fatalf("err: %v", err)
}
if n != 1 { if n != 1 {
t.Fatalf("bad: %v", n) t.Fatalf("bad: %v", n)
} }
...@@ -817,11 +845,17 @@ func TestHalfClose(t *testing.T) { ...@@ -817,11 +845,17 @@ func TestHalfClose(t *testing.T) {
if _, err = stream.Write([]byte("bcd")); err != nil { if _, err = stream.Write([]byte("bcd")); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
stream.Close() stream.CloseWrite()
// write after close
n, err = stream.Write([]byte("foobar"))
if n != 0 || err == nil {
t.Fatalf("wrote after close: %d %s", n, err)
}
// Read after close // Read after close
n, err = stream2.Read(buf) n, err = io.ReadAtLeast(stream2, buf, 3)
if err != nil { if err != nil && err != io.EOF {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if n != 3 { if n != 3 {
...@@ -1131,7 +1165,6 @@ func TestSession_PartialReadWindowUpdate(t *testing.T) { ...@@ -1131,7 +1165,6 @@ func TestSession_PartialReadWindowUpdate(t *testing.T) {
t.Errorf("err: %v", err) t.Errorf("err: %v", err)
return return
} }
defer wr.Close()
sendWindow := atomic.LoadUint32(&wr.sendWindow) sendWindow := atomic.LoadUint32(&wr.sendWindow)
if sendWindow != client.config.MaxStreamWindowSize { if sendWindow != client.config.MaxStreamWindowSize {
...@@ -1352,8 +1385,9 @@ func TestStreamHalfClose2(t *testing.T) { ...@@ -1352,8 +1385,9 @@ func TestStreamHalfClose2(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
defer stream.Close()
stream.Close() stream.CloseWrite()
wait <- struct{}{} wait <- struct{}{}
buf, err := ioutil.ReadAll(stream) buf, err := ioutil.ReadAll(stream)
......
...@@ -14,10 +14,15 @@ const ( ...@@ -14,10 +14,15 @@ const (
streamSYNSent streamSYNSent
streamSYNReceived streamSYNReceived
streamEstablished streamEstablished
streamLocalClose streamFinished
streamRemoteClose )
streamClosed
streamReset type halfStreamState int
const (
halfOpen halfStreamState = iota
halfClosed
halfReset
) )
// Stream is used to represent a logical stream // Stream is used to represent a logical stream
...@@ -28,8 +33,9 @@ type Stream struct { ...@@ -28,8 +33,9 @@ type Stream struct {
id uint32 id uint32
session *Session session *Session
state streamState state streamState
stateLock sync.Mutex writeState, readState halfStreamState
stateLock sync.Mutex
recvLock sync.Mutex recvLock sync.Mutex
recvBuf segmentedBuffer recvBuf segmentedBuffer
...@@ -74,19 +80,22 @@ func (s *Stream) Read(b []byte) (n int, err error) { ...@@ -74,19 +80,22 @@ func (s *Stream) Read(b []byte) (n int, err error) {
defer asyncNotify(s.recvNotifyCh) defer asyncNotify(s.recvNotifyCh)
START: START:
s.stateLock.Lock() s.stateLock.Lock()
state := s.state state := s.readState
s.stateLock.Unlock() s.stateLock.Unlock()
switch state { switch state {
case streamRemoteClose: case halfOpen:
fallthrough // Open -> read
case streamClosed: case halfClosed:
empty := s.recvBuf.Len() == 0 empty := s.recvBuf.Len() == 0
if empty { if empty {
return 0, io.EOF return 0, io.EOF
} }
case streamReset: // Closed, but we have data pending -> read.
case halfReset:
return 0, ErrStreamReset return 0, ErrStreamReset
default:
panic("unknown state")
} }
// If there is no data available, block // If there is no data available, block
...@@ -138,16 +147,18 @@ func (s *Stream) write(b []byte) (n int, err error) { ...@@ -138,16 +147,18 @@ func (s *Stream) write(b []byte) (n int, err error) {
START: START:
s.stateLock.Lock() s.stateLock.Lock()
state := s.state state := s.writeState
s.stateLock.Unlock() s.stateLock.Unlock()
switch state { switch state {
case streamLocalClose: case halfOpen:
fallthrough // Open for writing -> write
case streamClosed: case halfClosed:
return 0, ErrStreamClosed return 0, ErrStreamClosed
case streamReset: case halfReset:
return 0, ErrStreamReset return 0, ErrStreamReset
default:
panic("unknown state")
} }
// If there is no data available, block // If there is no data available, block
...@@ -239,75 +250,117 @@ func (s *Stream) sendReset() error { ...@@ -239,75 +250,117 @@ func (s *Stream) sendReset() error {
// Reset resets the stream (forcibly closes the stream) // Reset resets the stream (forcibly closes the stream)
func (s *Stream) Reset() error { func (s *Stream) Reset() error {
sendReset := false
s.stateLock.Lock() s.stateLock.Lock()
switch s.state { switch s.state {
case streamInit: case streamFinished:
// No need to send anything.
s.state = streamReset
s.stateLock.Unlock()
return nil
case streamClosed, streamReset:
s.stateLock.Unlock() s.stateLock.Unlock()
return nil return nil
case streamInit:
// we haven't sent anything, so we don't need to send a reset.
case streamSYNSent, streamSYNReceived, streamEstablished: case streamSYNSent, streamSYNReceived, streamEstablished:
case streamLocalClose, streamRemoteClose: sendReset = true
default: default:
panic("unhandled state") panic("unhandled state")
} }
s.state = streamReset
s.stateLock.Unlock()
err := s.sendReset() // at least one direction is open, we need to reset.
// If we've already sent/received an EOF, no need to reset that side.
if s.writeState == halfOpen {
s.writeState = halfReset
}
if s.readState == halfOpen {
s.readState = halfReset
}
s.state = streamFinished
s.notifyWaiting() s.notifyWaiting()
s.stateLock.Unlock()
if sendReset {
_ = s.sendReset()
}
s.cleanup() s.cleanup()
return nil
return err
} }
// Close is used to close the stream // CloseWrite is used to close the stream for writing.
func (s *Stream) Close() error { func (s *Stream) CloseWrite() error {
closeStream := false
s.stateLock.Lock() s.stateLock.Lock()
switch s.state { switch s.writeState {
case streamInit, streamSYNSent, streamSYNReceived, streamEstablished: case halfOpen:
s.state = streamLocalClose // Open for writing -> close write
goto SEND_CLOSE case halfClosed:
s.stateLock.Unlock()
return nil
case halfReset:
s.stateLock.Unlock()
return ErrStreamReset
default:
panic("invalid state")
}
s.writeState = halfClosed
cleanup := s.readState != halfOpen
if cleanup {
s.state = streamFinished
}
s.stateLock.Unlock()
s.notifyWaiting()
case streamLocalClose: err := s.sendClose()
case streamRemoteClose: if cleanup {
s.state = streamClosed // we're fully closed, might as well be nice to the user and
closeStream = true // free everything early.
goto SEND_CLOSE s.cleanup()
}
return err
}
case streamClosed: // CloseRead is used to close the stream for writing.
case streamReset: func (s *Stream) CloseRead() error {
cleanup := false
s.stateLock.Lock()
switch s.readState {
case halfOpen:
// Open for reading -> close read
case halfClosed, halfReset:
s.stateLock.Unlock()
return nil
default: default:
panic("unhandled state") panic("invalid state")
}
s.readState = halfReset
cleanup = s.writeState != halfOpen
if cleanup {
s.state = streamFinished
} }
s.stateLock.Unlock() s.stateLock.Unlock()
return nil
SEND_CLOSE:
s.stateLock.Unlock()
err := s.sendClose()
s.notifyWaiting() s.notifyWaiting()
if closeStream { if cleanup {
// we're fully closed, might as well be nice to the user and
// free everything early.
s.cleanup() s.cleanup()
} }
return err return nil
}
// Close is used to close the stream.
func (s *Stream) Close() error {
_ = s.CloseRead() // can't fail.
return s.CloseWrite()
} }
// forceClose is used for when the session is exiting // forceClose is used for when the session is exiting
func (s *Stream) forceClose() { func (s *Stream) forceClose() {
s.stateLock.Lock() s.stateLock.Lock()
switch s.state { if s.readState == halfOpen {
case streamClosed: s.readState = halfReset
// Already successfully closed. It just hasn't been removed from
// the list of streams yet.
default:
s.state = streamReset
} }
s.stateLock.Unlock() if s.writeState == halfOpen {
s.writeState = halfReset
}
s.state = streamFinished
s.notifyWaiting() s.notifyWaiting()
s.stateLock.Unlock()
s.readDeadline.set(time.Time{}) s.readDeadline.set(time.Time{})
s.writeDeadline.set(time.Time{}) s.writeDeadline.set(time.Time{})
...@@ -340,25 +393,24 @@ func (s *Stream) processFlags(flags uint16) error { ...@@ -340,25 +393,24 @@ func (s *Stream) processFlags(flags uint16) error {
s.session.establishStream(s.id) s.session.establishStream(s.id)
} }
if flags&flagFIN == flagFIN { if flags&flagFIN == flagFIN {
switch s.state { if s.readState == halfOpen {
case streamSYNSent: s.readState = halfClosed
fallthrough if s.writeState != halfOpen {
case streamSYNReceived: // We're now fully closed.
fallthrough closeStream = true
case streamEstablished: s.state = streamFinished
s.state = streamRemoteClose }
s.notifyWaiting() s.notifyWaiting()
case streamLocalClose:
s.state = streamClosed
closeStream = true
s.notifyWaiting()
default:
s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state)
return ErrUnexpectedFlag
} }
} }
if flags&flagRST == flagRST { if flags&flagRST == flagRST {
s.state = streamReset if s.readState == halfOpen {
s.readState = halfReset
}
if s.writeState == halfOpen {
s.writeState = halfReset
}
s.state = streamFinished
closeStream = true closeStream = true
s.notifyWaiting() s.notifyWaiting()
} }
...@@ -426,11 +478,9 @@ func (s *Stream) SetDeadline(t time.Time) error { ...@@ -426,11 +478,9 @@ func (s *Stream) SetDeadline(t time.Time) error {
func (s *Stream) SetReadDeadline(t time.Time) error { func (s *Stream) SetReadDeadline(t time.Time) error {
s.stateLock.Lock() s.stateLock.Lock()
defer s.stateLock.Unlock() defer s.stateLock.Unlock()
switch s.state { if s.readState == halfOpen {
case streamClosed, streamRemoteClose, streamReset: s.readDeadline.set(t)
return nil
} }
s.readDeadline.set(t)
return nil return nil
} }
...@@ -438,11 +488,9 @@ func (s *Stream) SetReadDeadline(t time.Time) error { ...@@ -438,11 +488,9 @@ func (s *Stream) SetReadDeadline(t time.Time) error {
func (s *Stream) SetWriteDeadline(t time.Time) error { func (s *Stream) SetWriteDeadline(t time.Time) error {
s.stateLock.Lock() s.stateLock.Lock()
defer s.stateLock.Unlock() defer s.stateLock.Unlock()
switch s.state { if s.writeState == halfOpen {
case streamClosed, streamLocalClose, streamReset: s.writeDeadline.set(t)
return nil
} }
s.writeDeadline.set(t)
return nil return nil
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment