Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
10
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Open sidebar
p2p
go-yamux
Commits
10c91193
Unverified
Commit
10c91193
authored
Sep 01, 2020
by
Steven Allen
Committed by
GitHub
Sep 01, 2020
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #5 from libp2p/feat/rw-close
implement CloseRead/CloseWrite
parents
fd43d7f1
fa4a0b4c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
184 additions
and
98 deletions
+184
-98
session.go
session.go
+11
-7
session_test.go
session_test.go
+46
-12
stream.go
stream.go
+127
-79
No files found.
session.go
View file @
10c91193
...
@@ -220,14 +220,18 @@ func (s *Session) Accept() (net.Conn, error) {
...
@@ -220,14 +220,18 @@ func (s *Session) Accept() (net.Conn, error) {
// AcceptStream is used to block until the next available stream
// AcceptStream is used to block until the next available stream
// is ready to be accepted.
// is ready to be accepted.
func
(
s
*
Session
)
AcceptStream
()
(
*
Stream
,
error
)
{
func
(
s
*
Session
)
AcceptStream
()
(
*
Stream
,
error
)
{
select
{
for
{
case
stream
:=
<-
s
.
acceptCh
:
select
{
if
err
:=
stream
.
sendWindowUpdate
();
err
!=
nil
{
case
stream
:=
<-
s
.
acceptCh
:
return
nil
,
err
if
err
:=
stream
.
sendWindowUpdate
();
err
!=
nil
{
// don't return accept errors.
s
.
logger
.
Printf
(
"[WARN] error sending window update before accepting: %s"
,
err
)
continue
}
return
stream
,
nil
case
<-
s
.
shutdownCh
:
return
nil
,
s
.
shutdownErr
}
}
return
stream
,
nil
case
<-
s
.
shutdownCh
:
return
nil
,
s
.
shutdownErr
}
}
}
}
...
...
session_test.go
View file @
10c91193
...
@@ -407,6 +407,7 @@ func TestSendData_Small(t *testing.T) {
...
@@ -407,6 +407,7 @@ func TestSendData_Small(t *testing.T) {
t
.
Errorf
(
"err: %v"
,
err
)
t
.
Errorf
(
"err: %v"
,
err
)
return
return
}
}
defer
stream
.
Close
()
if
server
.
NumStreams
()
!=
1
{
if
server
.
NumStreams
()
!=
1
{
t
.
Errorf
(
"bad"
)
t
.
Errorf
(
"bad"
)
...
@@ -430,7 +431,7 @@ func TestSendData_Small(t *testing.T) {
...
@@ -430,7 +431,7 @@ func TestSendData_Small(t *testing.T) {
}
}
}
}
if
err
:=
stream
.
Close
();
err
!=
nil
{
if
err
:=
stream
.
Close
Write
();
err
!=
nil
{
t
.
Errorf
(
"err: %v"
,
err
)
t
.
Errorf
(
"err: %v"
,
err
)
return
return
}
}
...
@@ -442,11 +443,12 @@ func TestSendData_Small(t *testing.T) {
...
@@ -442,11 +443,12 @@ func TestSendData_Small(t *testing.T) {
go
func
()
{
go
func
()
{
defer
wg
.
Done
()
defer
wg
.
Done
()
stream
,
err
:=
client
.
Open
()
stream
,
err
:=
client
.
Open
Stream
()
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Errorf
(
"err: %v"
,
err
)
t
.
Errorf
(
"err: %v"
,
err
)
return
return
}
}
defer
stream
.
Close
()
if
client
.
NumStreams
()
!=
1
{
if
client
.
NumStreams
()
!=
1
{
t
.
Errorf
(
"bad"
)
t
.
Errorf
(
"bad"
)
...
@@ -465,7 +467,7 @@ func TestSendData_Small(t *testing.T) {
...
@@ -465,7 +467,7 @@ func TestSendData_Small(t *testing.T) {
}
}
}
}
if
err
:=
stream
.
Close
();
err
!=
nil
{
if
err
:=
stream
.
Close
Write
();
err
!=
nil
{
t
.
Errorf
(
"err: %v"
,
err
)
t
.
Errorf
(
"err: %v"
,
err
)
return
return
}
}
...
@@ -785,12 +787,12 @@ func TestManyStreams_PingPong(t *testing.T) {
...
@@ -785,12 +787,12 @@ func TestManyStreams_PingPong(t *testing.T) {
wg
.
Wait
()
wg
.
Wait
()
}
}
func
Test
Half
Close
(
t
*
testing
.
T
)
{
func
TestClose
Read
(
t
*
testing
.
T
)
{
client
,
server
:=
testClientServer
()
client
,
server
:=
testClientServer
()
defer
client
.
Close
()
defer
client
.
Close
()
defer
server
.
Close
()
defer
server
.
Close
()
stream
,
err
:=
client
.
Open
()
stream
,
err
:=
client
.
Open
Stream
()
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Fatalf
(
"err: %v"
,
err
)
t
.
Fatalf
(
"err: %v"
,
err
)
}
}
...
@@ -798,17 +800,43 @@ func TestHalfClose(t *testing.T) {
...
@@ -798,17 +800,43 @@ func TestHalfClose(t *testing.T) {
t
.
Fatalf
(
"err: %v"
,
err
)
t
.
Fatalf
(
"err: %v"
,
err
)
}
}
stream2
,
err
:=
server
.
Accept
()
stream2
,
err
:=
server
.
Accept
Stream
()
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Fatalf
(
"err: %v"
,
err
)
t
.
Fatalf
(
"err: %v"
,
err
)
}
}
stream2
.
Close
()
// Half close
stream2
.
Close
Read
()
buf
:=
make
([]
byte
,
4
)
buf
:=
make
([]
byte
,
4
)
n
,
err
:=
stream2
.
Read
(
buf
)
n
,
err
:=
stream2
.
Read
(
buf
)
if
n
!=
0
||
err
==
nil
{
t
.
Fatalf
(
"read after close: %d %s"
,
n
,
err
)
}
}
func
TestHalfClose
(
t
*
testing
.
T
)
{
client
,
server
:=
testClientServer
()
defer
client
.
Close
()
defer
server
.
Close
()
stream
,
err
:=
client
.
OpenStream
()
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Fatalf
(
"err: %v"
,
err
)
t
.
Fatalf
(
"err: %v"
,
err
)
}
}
if
_
,
err
=
stream
.
Write
([]
byte
(
"a"
));
err
!=
nil
{
t
.
Fatalf
(
"err: %v"
,
err
)
}
stream2
,
err
:=
server
.
AcceptStream
()
if
err
!=
nil
{
t
.
Fatalf
(
"err: %v"
,
err
)
}
stream2
.
CloseWrite
()
// Half close
buf
:=
make
([]
byte
,
4
)
n
,
err
:=
io
.
ReadAtLeast
(
stream2
,
buf
,
1
)
if
err
!=
nil
&&
err
!=
io
.
EOF
{
t
.
Fatalf
(
"err: %v"
,
err
)
}
if
n
!=
1
{
if
n
!=
1
{
t
.
Fatalf
(
"bad: %v"
,
n
)
t
.
Fatalf
(
"bad: %v"
,
n
)
}
}
...
@@ -817,11 +845,17 @@ func TestHalfClose(t *testing.T) {
...
@@ -817,11 +845,17 @@ func TestHalfClose(t *testing.T) {
if
_
,
err
=
stream
.
Write
([]
byte
(
"bcd"
));
err
!=
nil
{
if
_
,
err
=
stream
.
Write
([]
byte
(
"bcd"
));
err
!=
nil
{
t
.
Fatalf
(
"err: %v"
,
err
)
t
.
Fatalf
(
"err: %v"
,
err
)
}
}
stream
.
Close
()
stream
.
CloseWrite
()
// write after close
n
,
err
=
stream
.
Write
([]
byte
(
"foobar"
))
if
n
!=
0
||
err
==
nil
{
t
.
Fatalf
(
"wrote after close: %d %s"
,
n
,
err
)
}
// Read after close
// Read after close
n
,
err
=
stream2
.
Read
(
buf
)
n
,
err
=
io
.
ReadAtLeast
(
stream2
,
buf
,
3
)
if
err
!=
nil
{
if
err
!=
nil
&&
err
!=
io
.
EOF
{
t
.
Fatalf
(
"err: %v"
,
err
)
t
.
Fatalf
(
"err: %v"
,
err
)
}
}
if
n
!=
3
{
if
n
!=
3
{
...
@@ -1131,7 +1165,6 @@ func TestSession_PartialReadWindowUpdate(t *testing.T) {
...
@@ -1131,7 +1165,6 @@ func TestSession_PartialReadWindowUpdate(t *testing.T) {
t
.
Errorf
(
"err: %v"
,
err
)
t
.
Errorf
(
"err: %v"
,
err
)
return
return
}
}
defer
wr
.
Close
()
sendWindow
:=
atomic
.
LoadUint32
(
&
wr
.
sendWindow
)
sendWindow
:=
atomic
.
LoadUint32
(
&
wr
.
sendWindow
)
if
sendWindow
!=
client
.
config
.
MaxStreamWindowSize
{
if
sendWindow
!=
client
.
config
.
MaxStreamWindowSize
{
...
@@ -1352,8 +1385,9 @@ func TestStreamHalfClose2(t *testing.T) {
...
@@ -1352,8 +1385,9 @@ func TestStreamHalfClose2(t *testing.T) {
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Error
(
err
)
t
.
Error
(
err
)
}
}
defer
stream
.
Close
()
stream
.
Close
()
stream
.
Close
Write
()
wait
<-
struct
{}{}
wait
<-
struct
{}{}
buf
,
err
:=
ioutil
.
ReadAll
(
stream
)
buf
,
err
:=
ioutil
.
ReadAll
(
stream
)
...
...
stream.go
View file @
10c91193
...
@@ -14,10 +14,15 @@ const (
...
@@ -14,10 +14,15 @@ const (
streamSYNSent
streamSYNSent
streamSYNReceived
streamSYNReceived
streamEstablished
streamEstablished
streamLocalClose
streamFinished
streamRemoteClose
)
streamClosed
streamReset
type
halfStreamState
int
const
(
halfOpen
halfStreamState
=
iota
halfClosed
halfReset
)
)
// Stream is used to represent a logical stream
// Stream is used to represent a logical stream
...
@@ -28,8 +33,9 @@ type Stream struct {
...
@@ -28,8 +33,9 @@ type Stream struct {
id
uint32
id
uint32
session
*
Session
session
*
Session
state
streamState
state
streamState
stateLock
sync
.
Mutex
writeState
,
readState
halfStreamState
stateLock
sync
.
Mutex
recvLock
sync
.
Mutex
recvLock
sync
.
Mutex
recvBuf
segmentedBuffer
recvBuf
segmentedBuffer
...
@@ -74,19 +80,22 @@ func (s *Stream) Read(b []byte) (n int, err error) {
...
@@ -74,19 +80,22 @@ func (s *Stream) Read(b []byte) (n int, err error) {
defer
asyncNotify
(
s
.
recvNotifyCh
)
defer
asyncNotify
(
s
.
recvNotifyCh
)
START
:
START
:
s
.
stateLock
.
Lock
()
s
.
stateLock
.
Lock
()
state
:=
s
.
s
tate
state
:=
s
.
readS
tate
s
.
stateLock
.
Unlock
()
s
.
stateLock
.
Unlock
()
switch
state
{
switch
state
{
case
streamRemoteClose
:
case
halfOpen
:
fallthrough
// Open -> read
case
stream
Closed
:
case
half
Closed
:
empty
:=
s
.
recvBuf
.
Len
()
==
0
empty
:=
s
.
recvBuf
.
Len
()
==
0
if
empty
{
if
empty
{
return
0
,
io
.
EOF
return
0
,
io
.
EOF
}
}
case
streamReset
:
// Closed, but we have data pending -> read.
case
halfReset
:
return
0
,
ErrStreamReset
return
0
,
ErrStreamReset
default
:
panic
(
"unknown state"
)
}
}
// If there is no data available, block
// If there is no data available, block
...
@@ -138,16 +147,18 @@ func (s *Stream) write(b []byte) (n int, err error) {
...
@@ -138,16 +147,18 @@ func (s *Stream) write(b []byte) (n int, err error) {
START
:
START
:
s
.
stateLock
.
Lock
()
s
.
stateLock
.
Lock
()
state
:=
s
.
s
tate
state
:=
s
.
writeS
tate
s
.
stateLock
.
Unlock
()
s
.
stateLock
.
Unlock
()
switch
state
{
switch
state
{
case
streamLocalClose
:
case
halfOpen
:
fallthrough
// Open for writing -> write
case
stream
Closed
:
case
half
Closed
:
return
0
,
ErrStreamClosed
return
0
,
ErrStreamClosed
case
stream
Reset
:
case
half
Reset
:
return
0
,
ErrStreamReset
return
0
,
ErrStreamReset
default
:
panic
(
"unknown state"
)
}
}
// If there is no data available, block
// If there is no data available, block
...
@@ -239,75 +250,117 @@ func (s *Stream) sendReset() error {
...
@@ -239,75 +250,117 @@ func (s *Stream) sendReset() error {
// Reset resets the stream (forcibly closes the stream)
// Reset resets the stream (forcibly closes the stream)
func
(
s
*
Stream
)
Reset
()
error
{
func
(
s
*
Stream
)
Reset
()
error
{
sendReset
:=
false
s
.
stateLock
.
Lock
()
s
.
stateLock
.
Lock
()
switch
s
.
state
{
switch
s
.
state
{
case
streamInit
:
case
streamFinished
:
// No need to send anything.
s
.
state
=
streamReset
s
.
stateLock
.
Unlock
()
return
nil
case
streamClosed
,
streamReset
:
s
.
stateLock
.
Unlock
()
s
.
stateLock
.
Unlock
()
return
nil
return
nil
case
streamInit
:
// we haven't sent anything, so we don't need to send a reset.
case
streamSYNSent
,
streamSYNReceived
,
streamEstablished
:
case
streamSYNSent
,
streamSYNReceived
,
streamEstablished
:
case
streamLocalClose
,
streamRemoteClose
:
sendReset
=
true
default
:
default
:
panic
(
"unhandled state"
)
panic
(
"unhandled state"
)
}
}
s
.
state
=
streamReset
s
.
stateLock
.
Unlock
()
err
:=
s
.
sendReset
()
// at least one direction is open, we need to reset.
// If we've already sent/received an EOF, no need to reset that side.
if
s
.
writeState
==
halfOpen
{
s
.
writeState
=
halfReset
}
if
s
.
readState
==
halfOpen
{
s
.
readState
=
halfReset
}
s
.
state
=
streamFinished
s
.
notifyWaiting
()
s
.
notifyWaiting
()
s
.
stateLock
.
Unlock
()
if
sendReset
{
_
=
s
.
sendReset
()
}
s
.
cleanup
()
s
.
cleanup
()
return
nil
return
err
}
}
// Close is used to close the stream
// CloseWrite is used to close the stream for writing.
func
(
s
*
Stream
)
Close
()
error
{
func
(
s
*
Stream
)
CloseWrite
()
error
{
closeStream
:=
false
s
.
stateLock
.
Lock
()
s
.
stateLock
.
Lock
()
switch
s
.
state
{
switch
s
.
writeState
{
case
streamInit
,
streamSYNSent
,
streamSYNReceived
,
streamEstablished
:
case
halfOpen
:
s
.
state
=
streamLocalClose
// Open for writing -> close write
goto
SEND_CLOSE
case
halfClosed
:
s
.
stateLock
.
Unlock
()
return
nil
case
halfReset
:
s
.
stateLock
.
Unlock
()
return
ErrStreamReset
default
:
panic
(
"invalid state"
)
}
s
.
writeState
=
halfClosed
cleanup
:=
s
.
readState
!=
halfOpen
if
cleanup
{
s
.
state
=
streamFinished
}
s
.
stateLock
.
Unlock
()
s
.
notifyWaiting
()
case
streamLocalClose
:
err
:=
s
.
sendClose
()
case
streamRemoteClose
:
if
cleanup
{
s
.
state
=
streamClosed
// we're fully closed, might as well be nice to the user and
closeStream
=
true
// free everything early.
goto
SEND_CLOSE
s
.
cleanup
()
}
return
err
}
case
streamClosed
:
// CloseRead is used to close the stream for writing.
case
streamReset
:
func
(
s
*
Stream
)
CloseRead
()
error
{
cleanup
:=
false
s
.
stateLock
.
Lock
()
switch
s
.
readState
{
case
halfOpen
:
// Open for reading -> close read
case
halfClosed
,
halfReset
:
s
.
stateLock
.
Unlock
()
return
nil
default
:
default
:
panic
(
"unhandled state"
)
panic
(
"invalid state"
)
}
s
.
readState
=
halfReset
cleanup
=
s
.
writeState
!=
halfOpen
if
cleanup
{
s
.
state
=
streamFinished
}
}
s
.
stateLock
.
Unlock
()
s
.
stateLock
.
Unlock
()
return
nil
SEND_CLOSE
:
s
.
stateLock
.
Unlock
()
err
:=
s
.
sendClose
()
s
.
notifyWaiting
()
s
.
notifyWaiting
()
if
closeStream
{
if
cleanup
{
// we're fully closed, might as well be nice to the user and
// free everything early.
s
.
cleanup
()
s
.
cleanup
()
}
}
return
err
return
nil
}
// Close is used to close the stream.
func
(
s
*
Stream
)
Close
()
error
{
_
=
s
.
CloseRead
()
// can't fail.
return
s
.
CloseWrite
()
}
}
// forceClose is used for when the session is exiting
// forceClose is used for when the session is exiting
func
(
s
*
Stream
)
forceClose
()
{
func
(
s
*
Stream
)
forceClose
()
{
s
.
stateLock
.
Lock
()
s
.
stateLock
.
Lock
()
switch
s
.
state
{
if
s
.
readState
==
halfOpen
{
case
streamClosed
:
s
.
readState
=
halfReset
// Already successfully closed. It just hasn't been removed from
// the list of streams yet.
default
:
s
.
state
=
streamReset
}
}
s
.
stateLock
.
Unlock
()
if
s
.
writeState
==
halfOpen
{
s
.
writeState
=
halfReset
}
s
.
state
=
streamFinished
s
.
notifyWaiting
()
s
.
notifyWaiting
()
s
.
stateLock
.
Unlock
()
s
.
readDeadline
.
set
(
time
.
Time
{})
s
.
readDeadline
.
set
(
time
.
Time
{})
s
.
writeDeadline
.
set
(
time
.
Time
{})
s
.
writeDeadline
.
set
(
time
.
Time
{})
...
@@ -340,25 +393,24 @@ func (s *Stream) processFlags(flags uint16) error {
...
@@ -340,25 +393,24 @@ func (s *Stream) processFlags(flags uint16) error {
s
.
session
.
establishStream
(
s
.
id
)
s
.
session
.
establishStream
(
s
.
id
)
}
}
if
flags
&
flagFIN
==
flagFIN
{
if
flags
&
flagFIN
==
flagFIN
{
switch
s
.
state
{
if
s
.
readState
==
halfOpen
{
case
streamSYNSent
:
s
.
readState
=
halfClosed
fallthrough
if
s
.
writeState
!=
halfOpen
{
case
streamSYNReceiv
ed
:
// We're now fully clos
ed
.
fallthrough
closeStream
=
true
case
stream
Establ
ished
:
s
.
state
=
stream
Fin
ished
s
.
state
=
streamRemoteClose
}
s
.
notifyWaiting
()
s
.
notifyWaiting
()
case
streamLocalClose
:
s
.
state
=
streamClosed
closeStream
=
true
s
.
notifyWaiting
()
default
:
s
.
session
.
logger
.
Printf
(
"[ERR] yamux: unexpected FIN flag in state %d"
,
s
.
state
)
return
ErrUnexpectedFlag
}
}
}
}
if
flags
&
flagRST
==
flagRST
{
if
flags
&
flagRST
==
flagRST
{
s
.
state
=
streamReset
if
s
.
readState
==
halfOpen
{
s
.
readState
=
halfReset
}
if
s
.
writeState
==
halfOpen
{
s
.
writeState
=
halfReset
}
s
.
state
=
streamFinished
closeStream
=
true
closeStream
=
true
s
.
notifyWaiting
()
s
.
notifyWaiting
()
}
}
...
@@ -426,11 +478,9 @@ func (s *Stream) SetDeadline(t time.Time) error {
...
@@ -426,11 +478,9 @@ func (s *Stream) SetDeadline(t time.Time) error {
func
(
s
*
Stream
)
SetReadDeadline
(
t
time
.
Time
)
error
{
func
(
s
*
Stream
)
SetReadDeadline
(
t
time
.
Time
)
error
{
s
.
stateLock
.
Lock
()
s
.
stateLock
.
Lock
()
defer
s
.
stateLock
.
Unlock
()
defer
s
.
stateLock
.
Unlock
()
switch
s
.
state
{
if
s
.
readState
==
halfOpen
{
case
streamClosed
,
streamRemoteClose
,
streamReset
:
s
.
readDeadline
.
set
(
t
)
return
nil
}
}
s
.
readDeadline
.
set
(
t
)
return
nil
return
nil
}
}
...
@@ -438,11 +488,9 @@ func (s *Stream) SetReadDeadline(t time.Time) error {
...
@@ -438,11 +488,9 @@ func (s *Stream) SetReadDeadline(t time.Time) error {
func
(
s
*
Stream
)
SetWriteDeadline
(
t
time
.
Time
)
error
{
func
(
s
*
Stream
)
SetWriteDeadline
(
t
time
.
Time
)
error
{
s
.
stateLock
.
Lock
()
s
.
stateLock
.
Lock
()
defer
s
.
stateLock
.
Unlock
()
defer
s
.
stateLock
.
Unlock
()
switch
s
.
state
{
if
s
.
writeState
==
halfOpen
{
case
streamClosed
,
streamLocalClose
,
streamReset
:
s
.
writeDeadline
.
set
(
t
)
return
nil
}
}
s
.
writeDeadline
.
set
(
t
)
return
nil
return
nil
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment