Unverified Commit 7bf6d4a1 authored by Will's avatar Will Committed by GitHub

Merge pull request #28 from libp2p/fix/lock

tighten lock around appending new chunks of read data in stream
parents 29a0d6e9 7fbda6ba
......@@ -54,7 +54,7 @@ func (s *Stream) LocalAddr() net.Addr {
return s.session.LocalAddr()
}
// LocalAddr returns the remote address
// RemoteAddr returns the remote address
func (s *Stream) RemoteAddr() net.Addr {
return s.session.RemoteAddr()
}
......@@ -2,4 +2,6 @@ module github.com/libp2p/go-yamux
go 1.12
require github.com/libp2p/go-buffer-pool v0.0.2
require (
github.com/libp2p/go-buffer-pool v0.0.2
)
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/libp2p/go-buffer-pool v0.0.2 h1:QNK2iAFa8gjAe1SPz6mHSMuCcjs+X1wlHzeOSqcmlfs=
github.com/libp2p/go-buffer-pool v0.0.2/go.mod h1:MvaB6xw5vOrDl8rYZGLFdKAuk/hRoRZd1Vi32+RXyFM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
......@@ -548,8 +548,6 @@ func TestSendData_Large(t *testing.T) {
t.Errorf("err: %v", err)
return
}
t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
}()
go func() {
......
......@@ -5,8 +5,6 @@ import (
"sync"
"sync/atomic"
"time"
"github.com/libp2p/go-buffer-pool"
)
type streamState int
......@@ -25,7 +23,6 @@ const (
// Stream is used to represent a logical stream
// within a session.
type Stream struct {
recvWindow uint32
sendWindow uint32
id uint32
......@@ -35,7 +32,7 @@ type Stream struct {
stateLock sync.Mutex
recvLock sync.Mutex
recvBuf pool.Buffer
recvBuf segmentedBuffer
sendLock sync.Mutex
......@@ -52,10 +49,10 @@ func newStream(session *Session, id uint32, state streamState) *Stream {
id: id,
session: session,
state: state,
recvWindow: initialStreamWindow,
sendWindow: initialStreamWindow,
readDeadline: makePipeDeadline(),
writeDeadline: makePipeDeadline(),
recvBuf: NewSegmentedBuffer(initialStreamWindow),
recvNotifyCh: make(chan struct{}, 1),
sendNotifyCh: make(chan struct{}, 1),
}
......@@ -84,9 +81,7 @@ START:
case streamRemoteClose:
fallthrough
case streamClosed:
s.recvLock.Lock()
empty := s.recvBuf.Len() == 0
s.recvLock.Unlock()
if empty {
return 0, io.EOF
}
......@@ -213,19 +208,13 @@ func (s *Stream) sendWindowUpdate() error {
// Determine the delta update
max := s.session.config.MaxStreamWindowSize
s.recvLock.Lock()
delta := (max - uint32(s.recvBuf.Len())) - s.recvWindow
// Check if we can omit the update
if delta < (max/2) && flags == 0 {
s.recvLock.Unlock()
// Update our window
needed, delta := s.recvBuf.GrowTo(max, flags != 0)
if !needed {
return nil
}
// Update our window
s.recvWindow += delta
s.recvLock.Unlock()
// Send the header
hdr := encode(typeWindowUpdate, flags, s.id, delta)
if err := s.session.sendMsg(hdr, nil, nil); err != nil {
......@@ -406,29 +395,17 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
return nil
}
// Wrap in a limited reader
conn = &io.LimitedReader{R: conn, N: int64(length)}
// Copy into buffer
s.recvLock.Lock()
if length > s.recvWindow {
s.recvLock.Unlock()
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
// Validate it's okay to copy
if !s.recvBuf.TryReserve(length) {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvBuf.Cap(), length)
return ErrRecvWindowExceeded
}
s.recvBuf.Grow(int(length))
if _, err := io.Copy(&s.recvBuf, conn); err != nil {
s.recvLock.Unlock()
// Copy into buffer
if err := s.recvBuf.Append(conn, int(length)); err != nil {
s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
return err
}
// Decrement the receive window
s.recvWindow -= length
s.recvLock.Unlock()
// Unblock any readers
asyncNotify(s.recvNotifyCh)
return nil
......
package yamux
import (
"io"
"sync"
"sync/atomic"
pool "github.com/libp2p/go-buffer-pool"
)
// asyncSendErr is used to try an async send of an error
func asyncSendErr(ch chan error, err error) {
if ch == nil {
......@@ -29,3 +37,111 @@ func min(values ...uint32) uint32 {
}
return m
}
type segmentedBuffer struct {
cap uint32
pending uint32
len uint32
bm sync.Mutex
b [][]byte
}
// NewSegmentedBuffer allocates a ring buffer.
func NewSegmentedBuffer(initialCapacity uint32) segmentedBuffer {
return segmentedBuffer{cap: initialCapacity, b: make([][]byte, 0)}
}
func (s *segmentedBuffer) Len() int {
return int(atomic.LoadUint32(&s.len))
}
func (s *segmentedBuffer) Cap() uint32 {
return atomic.LoadUint32(&s.cap)
}
// If the space to write into + current buffer size has grown to half of the window size,
// grow up to that max size, and indicate how much additional space was reserved.
func (s *segmentedBuffer) GrowTo(max uint32, force bool) (bool, uint32) {
s.bm.Lock()
defer s.bm.Unlock()
currentWindow := atomic.LoadUint32(&s.len) + atomic.LoadUint32(&s.cap) + s.pending
if currentWindow > max {
// somewhat counter-intuitively not an error.
// note that len+cap is the 'window' that shouldn't exceed max or a reservation
// would fail, triggering an error.
// We pre-count 'pending' data where we've read a header and are working on
// reading it into available data here, so that we don't undercount the remaining
// window size, but that can mean this sum ends up larger than max.
return false, 0
}
delta := max - currentWindow
if delta < (max/2) && !force {
return false, 0
}
atomic.AddUint32(&s.cap, delta)
return true, delta
}
func (s *segmentedBuffer) TryReserve(space uint32) bool {
// It is noticable that the check-and-set of pending is not atomic,
// Due to this, accesses to pending are protected by bm.
s.bm.Lock()
defer s.bm.Unlock()
if atomic.LoadUint32(&s.cap) < s.pending+space {
return false
}
s.pending += space
return true
}
func (s *segmentedBuffer) Read(b []byte) (int, error) {
s.bm.Lock()
defer s.bm.Unlock()
if len(s.b) == 0 {
return 0, io.EOF
}
n := copy(b, s.b[0])
if n == len(s.b[0]) {
pool.Put(s.b[0])
s.b[0] = nil
s.b = s.b[1:]
} else {
s.b[0] = s.b[0][n:]
}
if n > 0 {
atomic.AddUint32(&s.len, ^uint32(n-1))
}
return n, nil
}
func (s *segmentedBuffer) Append(input io.Reader, length int) error {
dst := pool.Get(length)
n := 0
read := 0
var err error
for n < length && err == nil {
read, err = input.Read(dst[n:])
n += read
}
if err == io.EOF {
if length == n {
err = nil
} else {
err = ErrStreamReset
}
}
s.bm.Lock()
defer s.bm.Unlock()
if n > 0 {
atomic.AddUint32(&s.len, uint32(n))
// cap -= n
atomic.AddUint32(&s.cap, ^uint32(n-1))
s.pending = s.pending - uint32(length)
s.b = append(s.b, dst[0:n])
}
return err
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment