Commit f56473fb authored by Steven Allen's avatar Steven Allen

make sure reset works on half-closed streams

parent d9712a3b
...@@ -22,7 +22,7 @@ type stream struct { ...@@ -22,7 +22,7 @@ type stream struct {
close chan struct{} close chan struct{}
closed chan struct{} closed chan struct{}
state error writeErr error
protocol protocol.ID protocol protocol.ID
} }
...@@ -56,7 +56,7 @@ func (s *stream) Write(p []byte) (n int, err error) { ...@@ -56,7 +56,7 @@ func (s *stream) Write(p []byte) (n int, err error) {
t := time.Now().Add(delay) t := time.Now().Add(delay)
select { select {
case <-s.closed: // bail out if we're closing. case <-s.closed: // bail out if we're closing.
return 0, s.state return 0, s.writeErr
case s.toDeliver <- &transportObject{msg: p, arrivalTime: t}: case s.toDeliver <- &transportObject{msg: p, arrivalTime: t}:
} }
return len(p), nil return len(p), nil
...@@ -76,30 +76,28 @@ func (s *stream) Close() error { ...@@ -76,30 +76,28 @@ func (s *stream) Close() error {
default: default:
} }
<-s.closed <-s.closed
if s.state != ErrClosed { if s.writeErr != ErrClosed {
return s.state return s.writeErr
} }
return nil return nil
} }
func (s *stream) Reset() error { func (s *stream) Reset() error {
// Cancel any pending writes. // Cancel any pending reads/writes with an error.
s.write.Close() s.write.CloseWithError(ErrReset)
s.read.CloseWithError(ErrReset)
select { select {
case s.reset <- struct{}{}: case s.reset <- struct{}{}:
default: default:
} }
<-s.closed <-s.closed
if s.state != ErrReset {
return s.state // No meaningful error case here.
}
return nil return nil
} }
func (s *stream) teardown() { func (s *stream) teardown() {
s.write.Close()
// at this point, no streams are writing. // at this point, no streams are writing.
s.conn.removeStream(s) s.conn.removeStream(s)
...@@ -151,20 +149,21 @@ func (s *stream) transport() { ...@@ -151,20 +149,21 @@ func (s *stream) transport() {
// writeBuf writes the contents of buf through to the s.Writer. // writeBuf writes the contents of buf through to the s.Writer.
// done only when arrival time makes sense. // done only when arrival time makes sense.
drainBuf := func() { drainBuf := func() error {
if buf.Len() > 0 { if buf.Len() > 0 {
_, err := s.write.Write(buf.Bytes()) _, err := s.write.Write(buf.Bytes())
if err != nil { if err != nil {
return return err
} }
buf.Reset() buf.Reset()
} }
return nil
} }
// deliverOrWait is a helper func that processes // deliverOrWait is a helper func that processes
// an incoming packet. it waits until the arrival time, // an incoming packet. it waits until the arrival time,
// and then writes things out. // and then writes things out.
deliverOrWait := func(o *transportObject) { deliverOrWait := func(o *transportObject) error {
buffered := len(o.msg) + buf.Len() buffered := len(o.msg) + buf.Len()
// Yes, we can end up extending a timer multiple times if we // Yes, we can end up extending a timer multiple times if we
...@@ -189,43 +188,65 @@ func (s *stream) transport() { ...@@ -189,43 +188,65 @@ func (s *stream) transport() {
select { select {
case <-timer.C: case <-timer.C:
case <-s.reset: case <-s.reset:
s.reset <- struct{}{} select {
return case s.reset <- struct{}{}:
default:
}
return ErrReset
}
if err := drainBuf(); err != nil {
return err
} }
drainBuf()
// write this message. // write this message.
_, err := s.write.Write(o.msg) _, err := s.write.Write(o.msg)
if err != nil { if err != nil {
log.Error("mock_stream", err) return err
} }
} else { } else {
buf.Write(o.msg) buf.Write(o.msg)
} }
return nil
} }
for { for {
// Reset takes precedent. // Reset takes precedent.
select { select {
case <-s.reset: case <-s.reset:
s.state = ErrReset s.writeErr = ErrReset
s.read.CloseWithError(ErrReset)
return return
default: default:
} }
select { select {
case <-s.reset: case <-s.reset:
s.state = ErrReset s.writeErr = ErrReset
s.read.CloseWithError(ErrReset)
return return
case <-s.close: case <-s.close:
s.state = ErrClosed if err := drainBuf(); err != nil {
drainBuf() s.resetWith(err)
return
}
s.writeErr = s.write.Close()
if s.writeErr == nil {
s.writeErr = ErrClosed
}
return return
case o := <-s.toDeliver: case o := <-s.toDeliver:
deliverOrWait(o) if err := deliverOrWait(o); err != nil {
s.resetWith(err)
return
}
case <-timer.C: // ok, due to write it out. case <-timer.C: // ok, due to write it out.
drainBuf() if err := drainBuf(); err != nil {
s.resetWith(err)
return
}
} }
} }
} }
func (s *stream) resetWith(err error) {
s.write.CloseWithError(err)
s.read.CloseWithError(err)
s.writeErr = err
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment