stream.go 4.83 KB
Newer Older
Jeromy's avatar
Jeromy committed
1 2 3 4
package multiplex

import (
	"context"
5
	"errors"
Jeromy's avatar
Jeromy committed
6 7 8 9
	"io"
	"sync"
	"time"

10
	pool "gitlab.dms3.io/p2p/go-buffer-pool"
11
	"go.uber.org/multierr"
Jeromy's avatar
Jeromy committed
12 13
)

14 15 16 17 18
var (
	ErrStreamReset  = errors.New("stream reset")
	ErrStreamClosed = errors.New("closed stream")
)

Steven Allen's avatar
Steven Allen committed
19 20
// streamID is a convenience type for operating on stream IDs
type streamID struct {
Steven Allen's avatar
Steven Allen committed
21
	id        uint64
Steven Allen's avatar
Steven Allen committed
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
	initiator bool
}

// header computes the header for the given tag
func (id *streamID) header(tag uint64) uint64 {
	header := id.id<<3 | tag
	if !id.initiator {
		header--
	}
	return header
}

type Stream struct {
	id     streamID
	name   string
	dataIn chan []byte
	mp     *Multiplex
Jeromy's avatar
Jeromy committed
39 40 41 42 43 44 45

	extra []byte

	// exbuf is for holding the reference to the beginning of the extra slice
	// for later memory pool freeing
	exbuf []byte

vyzo's avatar
vyzo committed
46
	rDeadline, wDeadline pipeDeadline
Jeromy's avatar
Jeromy committed
47

48 49 50
	clLock                        sync.Mutex
	writeCancelErr, readCancelErr error
	writeCancel, readCancel       chan struct{}
Jeromy's avatar
Jeromy committed
51 52 53 54 55 56
}

func (s *Stream) Name() string {
	return s.name
}

57 58 59 60 61 62 63 64 65 66 67 68 69
// tries to preload pending data
func (s *Stream) preloadData() {
	select {
	case read, ok := <-s.dataIn:
		if !ok {
			return
		}
		s.extra = read
		s.exbuf = read
	default:
	}
}

vyzo's avatar
vyzo committed
70
func (s *Stream) waitForData() error {
Jeromy's avatar
Jeromy committed
71 72 73 74 75 76 77 78
	select {
	case read, ok := <-s.dataIn:
		if !ok {
			return io.EOF
		}
		s.extra = read
		s.exbuf = read
		return nil
79 80 81 82
	case <-s.readCancel:
		// This is the only place where it's safe to return these.
		s.returnBuffers()
		return s.readCancelErr
vyzo's avatar
vyzo committed
83
	case <-s.rDeadline.wait():
vyzo's avatar
vyzo committed
84
		return errTimeout
Jeromy's avatar
Jeromy committed
85 86 87
	}
}

88 89
func (s *Stream) returnBuffers() {
	if s.exbuf != nil {
Steven Allen's avatar
Steven Allen committed
90
		pool.Put(s.exbuf)
91 92 93 94 95 96 97 98 99 100 101 102
		s.exbuf = nil
		s.extra = nil
	}
	for {
		select {
		case read, ok := <-s.dataIn:
			if !ok {
				return
			}
			if read == nil {
				continue
			}
Steven Allen's avatar
Steven Allen committed
103
			pool.Put(read)
104 105 106 107 108 109
		default:
			return
		}
	}
}

Jeromy's avatar
Jeromy committed
110
func (s *Stream) Read(b []byte) (int, error) {
111
	select {
112 113
	case <-s.readCancel:
		return 0, s.readCancelErr
114 115
	default:
	}
116

Jeromy's avatar
Jeromy committed
117
	if s.extra == nil {
vyzo's avatar
vyzo committed
118
		err := s.waitForData()
Jeromy's avatar
Jeromy committed
119 120 121 122
		if err != nil {
			return 0, err
		}
	}
123 124 125 126 127 128 129 130 131 132 133 134 135
	n := 0
	for s.extra != nil && n < len(b) {
		read := copy(b[n:], s.extra)
		n += read
		if read < len(s.extra) {
			s.extra = s.extra[read:]
		} else {
			if s.exbuf != nil {
				pool.Put(s.exbuf)
			}
			s.extra = nil
			s.exbuf = nil
			s.preloadData()
Jeromy's avatar
Jeromy committed
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
		}
	}
	return n, nil
}

func (s *Stream) Write(b []byte) (int, error) {
	var written int
	for written < len(b) {
		wl := len(b) - written
		if wl > MaxMessageSize {
			wl = MaxMessageSize
		}

		n, err := s.write(b[written : written+wl])
		if err != nil {
			return written, err
		}

		written += n
	}

	return written, nil
}

func (s *Stream) write(b []byte) (int, error) {
161 162 163 164
	select {
	case <-s.writeCancel:
		return 0, s.writeCancelErr
	default:
165 166
	}

167
	err := s.mp.sendMsg(s.wDeadline.wait(), s.writeCancel, s.id.header(messageTag), b)
Jeromy's avatar
Jeromy committed
168 169 170 171 172 173 174
	if err != nil {
		return 0, err
	}

	return len(b), nil
}

175 176
func (s *Stream) cancelWrite(err error) bool {
	s.wDeadline.close()
Jeromy's avatar
Jeromy committed
177

178 179 180 181 182 183 184 185 186 187 188
	s.clLock.Lock()
	defer s.clLock.Unlock()
	select {
	case <-s.writeCancel:
		return false
	default:
		s.writeCancelErr = err
		close(s.writeCancel)
		return true
	}
}
189

190 191 192 193 194 195 196
func (s *Stream) cancelRead(err error) bool {
	// Always unregister for reading first, even if we're already closed (or
	// already closing). When handleIncoming calls this, it expects the
	// stream to be unregistered by the time it returns.
	s.mp.chLock.Lock()
	delete(s.mp.channels, s.id)
	s.mp.chLock.Unlock()
Jeromy's avatar
Jeromy committed
197

198
	s.rDeadline.close()
Jeromy's avatar
Jeromy committed
199

200
	s.clLock.Lock()
201 202 203 204 205 206 207 208 209 210
	defer s.clLock.Unlock()
	select {
	case <-s.readCancel:
		return false
	default:
		s.readCancelErr = err
		close(s.readCancel)
		return true
	}
}
211

212 213 214 215 216 217 218 219 220
func (s *Stream) CloseWrite() error {
	if !s.cancelWrite(ErrStreamClosed) {
		// Check if we closed the stream _nicely_. If so, we don't need
		// to report an error to the user.
		if s.writeCancelErr == ErrStreamClosed {
			return nil
		}
		// Closed for some other reason. Report it.
		return s.writeCancelErr
Jeromy's avatar
Jeromy committed
221 222
	}

223 224 225 226 227
	ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout)
	defer cancel()

	err := s.mp.sendMsg(ctx.Done(), nil, s.id.header(closeTag), nil)
	// We failed to close the stream after 2 minutes, something is probably wrong.
228
	if err != nil && !s.mp.isShutdown() {
229
		log.Warnf("Error closing stream: %s; killing connection", err.Error())
230 231
		s.mp.Close()
	}
Jeromy's avatar
Jeromy committed
232
	return err
Jeromy's avatar
Jeromy committed
233 234
}

235 236
func (s *Stream) CloseRead() error {
	s.cancelRead(ErrStreamClosed)
Steven Allen's avatar
Steven Allen committed
237 238 239
	return nil
}

240 241
func (s *Stream) Close() error {
	return multierr.Combine(s.CloseRead(), s.CloseWrite())
vyzo's avatar
vyzo committed
242 243
}

244 245
func (s *Stream) Reset() error {
	s.cancelRead(ErrStreamReset)
vyzo's avatar
vyzo committed
246

247 248 249
	if s.cancelWrite(ErrStreamReset) {
		// Send a reset in the background.
		go s.mp.sendResetMsg(s.id.header(resetTag), true)
vyzo's avatar
vyzo committed
250 251
	}

252 253
	return nil
}
vyzo's avatar
vyzo committed
254

255 256 257
func (s *Stream) SetDeadline(t time.Time) error {
	s.rDeadline.set(t)
	s.wDeadline.set(t)
Jeromy's avatar
Jeromy committed
258 259 260 261
	return nil
}

func (s *Stream) SetReadDeadline(t time.Time) error {
vyzo's avatar
vyzo committed
262
	s.rDeadline.set(t)
Jeromy's avatar
Jeromy committed
263 264 265 266
	return nil
}

func (s *Stream) SetWriteDeadline(t time.Time) error {
vyzo's avatar
vyzo committed
267
	s.wDeadline.set(t)
Jeromy's avatar
Jeromy committed
268 269
	return nil
}