Commit f56473fb authored by Steven Allen's avatar Steven Allen

make sure reset works on half-closed streams

parent d9712a3b
......@@ -22,7 +22,7 @@ type stream struct {
close chan struct{}
closed chan struct{}
state error
writeErr error
protocol protocol.ID
}
......@@ -56,7 +56,7 @@ func (s *stream) Write(p []byte) (n int, err error) {
t := time.Now().Add(delay)
select {
case <-s.closed: // bail out if we're closing.
return 0, s.state
return 0, s.writeErr
case s.toDeliver <- &transportObject{msg: p, arrivalTime: t}:
}
return len(p), nil
......@@ -76,30 +76,28 @@ func (s *stream) Close() error {
default:
}
<-s.closed
if s.state != ErrClosed {
return s.state
if s.writeErr != ErrClosed {
return s.writeErr
}
return nil
}
func (s *stream) Reset() error {
// Cancel any pending writes.
s.write.Close()
// Cancel any pending reads/writes with an error.
s.write.CloseWithError(ErrReset)
s.read.CloseWithError(ErrReset)
select {
case s.reset <- struct{}{}:
default:
}
<-s.closed
if s.state != ErrReset {
return s.state
}
// No meaningful error case here.
return nil
}
func (s *stream) teardown() {
s.write.Close()
// at this point, no streams are writing.
s.conn.removeStream(s)
......@@ -151,20 +149,21 @@ func (s *stream) transport() {
// writeBuf writes the contents of buf through to the s.Writer.
// done only when arrival time makes sense.
drainBuf := func() {
drainBuf := func() error {
if buf.Len() > 0 {
_, err := s.write.Write(buf.Bytes())
if err != nil {
return
return err
}
buf.Reset()
}
return nil
}
// deliverOrWait is a helper func that processes
// an incoming packet. it waits until the arrival time,
// and then writes things out.
deliverOrWait := func(o *transportObject) {
deliverOrWait := func(o *transportObject) error {
buffered := len(o.msg) + buf.Len()
// Yes, we can end up extending a timer multiple times if we
......@@ -189,43 +188,65 @@ func (s *stream) transport() {
select {
case <-timer.C:
case <-s.reset:
s.reset <- struct{}{}
return
select {
case s.reset <- struct{}{}:
default:
}
return ErrReset
}
if err := drainBuf(); err != nil {
return err
}
drainBuf()
// write this message.
_, err := s.write.Write(o.msg)
if err != nil {
log.Error("mock_stream", err)
return err
}
} else {
buf.Write(o.msg)
}
return nil
}
for {
// Reset takes precedent.
select {
case <-s.reset:
s.state = ErrReset
s.read.CloseWithError(ErrReset)
s.writeErr = ErrReset
return
default:
}
select {
case <-s.reset:
s.state = ErrReset
s.read.CloseWithError(ErrReset)
s.writeErr = ErrReset
return
case <-s.close:
s.state = ErrClosed
drainBuf()
if err := drainBuf(); err != nil {
s.resetWith(err)
return
}
s.writeErr = s.write.Close()
if s.writeErr == nil {
s.writeErr = ErrClosed
}
return
case o := <-s.toDeliver:
deliverOrWait(o)
if err := deliverOrWait(o); err != nil {
s.resetWith(err)
return
}
case <-timer.C: // ok, due to write it out.
drainBuf()
if err := drainBuf(); err != nil {
s.resetWith(err)
return
}
}
}
}
func (s *stream) resetWith(err error) {
s.write.CloseWithError(err)
s.read.CloseWithError(err)
s.writeErr = err
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment