Unverified Commit 35dde006 authored by Marten Seemann's avatar Marten Seemann Committed by GitHub

Merge pull request #49 from libp2p/cleanup-window-check

clean up the receive window check
parents fad228ec c6444def
...@@ -422,15 +422,9 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { ...@@ -422,15 +422,9 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
return nil return nil
} }
// Validate it's okay to copy
if !s.recvBuf.TryReserve(length) {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvBuf.Cap(), length)
return ErrRecvWindowExceeded
}
// Copy into buffer // Copy into buffer
if err := s.recvBuf.Append(conn, int(length)); err != nil { if err := s.recvBuf.Append(conn, length); err != nil {
s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err) s.session.logger.Printf("[ERR] yamux: Failed to read stream data on stream %d: %v", s.id, err)
return err return err
} }
// Unblock the reader // Unblock the reader
......
package yamux package yamux
import ( import (
"fmt"
"io" "io"
"sync" "sync"
...@@ -42,14 +43,12 @@ func min(values ...uint32) uint32 { ...@@ -42,14 +43,12 @@ func min(values ...uint32) uint32 {
// | data | empty space | // | data | empty space |
// < window (10) > // < window (10) >
// < len (5) > < cap (5) > // < len (5) > < cap (5) >
// < pending (4) >
// //
// As data is read, the buffer gets updated like so: // As data is read, the buffer gets updated like so:
// //
// | data | empty space | // | data | empty space |
// < window (8) > // < window (8) >
// < len (3) > < cap (5) > // < len (3) > < cap (5) >
// < pending (4) >
// //
// It can then grow as follows (given a "max" of 10): // It can then grow as follows (given a "max" of 10):
// //
...@@ -57,21 +56,18 @@ func min(values ...uint32) uint32 { ...@@ -57,21 +56,18 @@ func min(values ...uint32) uint32 {
// | data | empty space | // | data | empty space |
// < window (10) > // < window (10) >
// < len (3) > < cap (7) > // < len (3) > < cap (7) >
// < pending (4) >
// //
// Data can then be written into the pending space, expanding len, and shrinking // Data can then be written into the empty space, expanding len,
// cap and pending: // and shrinking cap:
// //
// | data | empty space | // | data | empty space |
// < window (10) > // < window (10) >
// < len (5) > < cap (5) > // < len (5) > < cap (5) >
// < pending (2)>
// //
type segmentedBuffer struct { type segmentedBuffer struct {
cap uint32 cap uint32
pending uint32 len uint32
len uint32 bm sync.Mutex
bm sync.Mutex
// read position in b[0]. // read position in b[0].
// We must not reslice any of the buffers in b, as we need to put them back into the pool. // We must not reslice any of the buffers in b, as we need to put them back into the pool.
readPos int readPos int
...@@ -84,22 +80,10 @@ func newSegmentedBuffer(initialCapacity uint32) segmentedBuffer { ...@@ -84,22 +80,10 @@ func newSegmentedBuffer(initialCapacity uint32) segmentedBuffer {
} }
// Len is the amount of data in the receive buffer. // Len is the amount of data in the receive buffer.
func (s *segmentedBuffer) Len() int { func (s *segmentedBuffer) Len() uint32 {
s.bm.Lock() s.bm.Lock()
len := s.len defer s.bm.Unlock()
s.bm.Unlock() return s.len
return int(len)
}
// Cap is the remaining capacity in the receive buffer.
//
// Note: this is _not_ the same as go's 'cap' function. The total size of the
// buffer is len+cap.
func (s *segmentedBuffer) Cap() uint32 {
s.bm.Lock()
cap := s.cap
s.bm.Unlock()
return cap
} }
// If the space to write into + current buffer size has grown to half of the window size, // If the space to write into + current buffer size has grown to half of the window size,
...@@ -122,16 +106,6 @@ func (s *segmentedBuffer) GrowTo(max uint32, force bool) (bool, uint32) { ...@@ -122,16 +106,6 @@ func (s *segmentedBuffer) GrowTo(max uint32, force bool) (bool, uint32) {
return true, delta return true, delta
} }
func (s *segmentedBuffer) TryReserve(space uint32) bool {
s.bm.Lock()
defer s.bm.Unlock()
if s.cap < s.pending+space {
return false
}
s.pending += space
return true
}
func (s *segmentedBuffer) Read(b []byte) (int, error) { func (s *segmentedBuffer) Read(b []byte) (int, error) {
s.bm.Lock() s.bm.Lock()
defer s.bm.Unlock() defer s.bm.Unlock()
...@@ -154,8 +128,21 @@ func (s *segmentedBuffer) Read(b []byte) (int, error) { ...@@ -154,8 +128,21 @@ func (s *segmentedBuffer) Read(b []byte) (int, error) {
return n, nil return n, nil
} }
func (s *segmentedBuffer) Append(input io.Reader, length int) error { func (s *segmentedBuffer) checkOverflow(l uint32) error {
dst := pool.Get(length) s.bm.Lock()
defer s.bm.Unlock()
if s.cap < l {
return fmt.Errorf("receive window exceeded (remain: %d, recv: %d)", s.cap, l)
}
return nil
}
func (s *segmentedBuffer) Append(input io.Reader, length uint32) error {
if err := s.checkOverflow(length); err != nil {
return err
}
dst := pool.Get(int(length))
n, err := io.ReadFull(input, dst) n, err := io.ReadFull(input, dst)
if err == io.EOF { if err == io.EOF {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
...@@ -165,7 +152,6 @@ func (s *segmentedBuffer) Append(input io.Reader, length int) error { ...@@ -165,7 +152,6 @@ func (s *segmentedBuffer) Append(input io.Reader, length int) error {
if n > 0 { if n > 0 {
s.len += uint32(n) s.len += uint32(n)
s.cap -= uint32(n) s.cap -= uint32(n)
s.pending = s.pending - uint32(length)
s.b = append(s.b, dst[0:n]) s.b = append(s.b, dst[0:n])
} }
return err return err
......
...@@ -54,18 +54,17 @@ func TestMin(t *testing.T) { ...@@ -54,18 +54,17 @@ func TestMin(t *testing.T) {
func TestSegmentedBuffer(t *testing.T) { func TestSegmentedBuffer(t *testing.T) {
buf := newSegmentedBuffer(100) buf := newSegmentedBuffer(100)
assert := func(len, cap int) { assert := func(len, cap uint32) {
if buf.Len() != len { if buf.Len() != len {
t.Fatalf("expected length %d, got %d", len, buf.Len()) t.Fatalf("expected length %d, got %d", len, buf.Len())
} }
if buf.Cap() != uint32(cap) { buf.bm.Lock()
defer buf.bm.Unlock()
if buf.cap != cap {
t.Fatalf("expected length %d, got %d", len, buf.Len()) t.Fatalf("expected length %d, got %d", len, buf.Len())
} }
} }
assert(0, 100) assert(0, 100)
if !buf.TryReserve(3) {
t.Fatal("reservation should have worked")
}
if err := buf.Append(bytes.NewReader([]byte("fooo")), 3); err != nil { if err := buf.Append(bytes.NewReader([]byte("fooo")), 3); err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -87,9 +86,6 @@ func TestSegmentedBuffer(t *testing.T) { ...@@ -87,9 +86,6 @@ func TestSegmentedBuffer(t *testing.T) {
t.Fatal("should have grown by 2") t.Fatal("should have grown by 2")
} }
if !buf.TryReserve(50) {
t.Fatal("reservation should have worked")
}
if err := buf.Append(bytes.NewReader(make([]byte, 50)), 50); err != nil { if err := buf.Append(bytes.NewReader(make([]byte, 50)), 50); err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -104,9 +100,7 @@ func TestSegmentedBuffer(t *testing.T) { ...@@ -104,9 +100,7 @@ func TestSegmentedBuffer(t *testing.T) {
if read != 50 { if read != 50 {
t.Fatal("expected to read 50 bytes") t.Fatal("expected to read 50 bytes")
} }
if !buf.TryReserve(49) {
t.Fatal("should have been able to reserve rest of space")
}
assert(1, 49) assert(1, 49)
if grew, amount := buf.GrowTo(100, false); !grew || amount != 50 { if grew, amount := buf.GrowTo(100, false); !grew || amount != 50 {
t.Fatal("should have grown when below half, even with reserved space") t.Fatal("should have grown when below half, even with reserved space")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment