multiplex.go 10.8 KB
Newer Older
Jeromy's avatar
Jeromy committed
1 2 3 4
package multiplex

import (
	"bufio"
5
	"context"
Jeromy's avatar
Jeromy committed
6
	"encoding/binary"
Jeromy's avatar
Jeromy committed
7 8 9
	"errors"
	"fmt"
	"io"
Jeromy's avatar
Jeromy committed
10
	"net"
11
	"sync"
Jeromy's avatar
Jeromy committed
12
	"time"
Jeromy's avatar
Jeromy committed
13

14
	"github.com/multiformats/go-varint"
15
	logging "gitlab.dms3.io/dms3/go-log"
16
	pool "gitlab.dms3.io/p2p/go-buffer-pool"
Jeromy's avatar
Jeromy committed
17 18
)

Steven Allen's avatar
Steven Allen committed
19
var log = logging.Logger("mplex")
Jeromy's avatar
Jeromy committed
20

Jeromy's avatar
Jeromy committed
21 22
var MaxMessageSize = 1 << 20

23 24 25 26 27
// Max time to block waiting for a slow reader to read from a stream before
// resetting it. Preferably, we'd have some form of back-pressure mechanism but
// we don't have that in this protocol.
var ReceiveTimeout = 5 * time.Second

28
// ErrShutdown is returned when operating on a shutdown session
Steven Allen's avatar
Steven Allen committed
29 30
var ErrShutdown = errors.New("session shut down")

31 32 33
// ErrTwoInitiators is returned when both sides think they're the initiator
var ErrTwoInitiators = errors.New("two initiators")

Steven Allen's avatar
Steven Allen committed
34 35 36 37
// ErrInvalidState is returned when the other side does something it shouldn't.
// In this case, we close the connection to be safe.
var ErrInvalidState = errors.New("received an unexpected message from the peer")

38
var errTimeout = timeout{}
vyzo's avatar
vyzo committed
39
var errStreamClosed = errors.New("stream closed")
40

vyzo's avatar
vyzo committed
41
var (
vyzo's avatar
vyzo committed
42
	ResetStreamTimeout = 2 * time.Minute
43

44
	WriteCoalesceDelay = 100 * time.Microsecond
vyzo's avatar
vyzo committed
45 46
)

47 48 49
type timeout struct{}

func (_ timeout) Error() string {
vyzo's avatar
vyzo committed
50
	return "i/o deadline exceeded"
51 52 53 54 55 56 57 58 59 60
}

func (_ timeout) Temporary() bool {
	return true
}

func (_ timeout) Timeout() bool {
	return true
}

Steven Allen's avatar
Steven Allen committed
61
// +1 for initiator
Jeromy's avatar
Jeromy committed
62
const (
Steven Allen's avatar
Steven Allen committed
63 64 65 66
	newStreamTag = 0
	messageTag   = 2
	closeTag     = 4
	resetTag     = 6
Jeromy's avatar
Jeromy committed
67 68
)

Steven Allen's avatar
Steven Allen committed
69
// Multiplex is a mplex session.
Jeromy's avatar
Jeromy committed
70 71 72 73 74 75
type Multiplex struct {
	con       net.Conn
	buf       *bufio.Reader
	nextID    uint64
	initiator bool

76
	closed       chan struct{}
Steven Allen's avatar
Steven Allen committed
77 78
	shutdown     chan struct{}
	shutdownErr  error
79 80
	shutdownLock sync.Mutex

vyzo's avatar
vyzo committed
81 82 83
	writeCh         chan []byte
	writeTimer      *time.Timer
	writeTimerFired bool
Jeromy's avatar
Jeromy committed
84 85 86

	nstreams chan *Stream

Steven Allen's avatar
Steven Allen committed
87
	channels map[streamID]*Stream
Jeromy's avatar
Jeromy committed
88 89
	chLock   sync.Mutex
}
Jeromy's avatar
Jeromy committed
90

Steven Allen's avatar
Steven Allen committed
91
// NewMultiplex creates a new multiplexer session.
Jeromy's avatar
Jeromy committed
92 93
func NewMultiplex(con net.Conn, initiator bool) *Multiplex {
	mp := &Multiplex{
vyzo's avatar
vyzo committed
94 95 96 97 98 99 100 101 102
		con:        con,
		initiator:  initiator,
		buf:        bufio.NewReader(con),
		channels:   make(map[streamID]*Stream),
		closed:     make(chan struct{}),
		shutdown:   make(chan struct{}),
		writeCh:    make(chan []byte, 16),
		writeTimer: time.NewTimer(0),
		nstreams:   make(chan *Stream, 16),
Jeromy's avatar
Jeromy committed
103
	}
Jeromy's avatar
Jeromy committed
104

Jeromy's avatar
Jeromy committed
105
	go mp.handleIncoming()
vyzo's avatar
vyzo committed
106
	go mp.handleOutgoing()
107

Jeromy's avatar
Jeromy committed
108
	return mp
Jeromy's avatar
Jeromy committed
109 110
}

111 112
func (mp *Multiplex) newStream(id streamID, name string) (s *Stream) {
	s = &Stream{
113 114 115 116 117 118 119 120
		id:          id,
		name:        name,
		dataIn:      make(chan []byte, 8),
		rDeadline:   makePipeDeadline(),
		wDeadline:   makePipeDeadline(),
		mp:          mp,
		writeCancel: make(chan struct{}),
		readCancel:  make(chan struct{}),
Jeromy's avatar
Jeromy committed
121
	}
122
	return
Jeromy's avatar
Jeromy committed
123 124
}

Steven Allen's avatar
Steven Allen committed
125
// Accept accepts the next stream from the connection.
Jeromy's avatar
Jeromy committed
126 127 128 129 130 131 132 133
func (m *Multiplex) Accept() (*Stream, error) {
	select {
	case s, ok := <-m.nstreams:
		if !ok {
			return nil, errors.New("multiplex closed")
		}
		return s, nil
	case <-m.closed:
Steven Allen's avatar
Steven Allen committed
134
		return nil, m.shutdownErr
Jeromy's avatar
Jeromy committed
135 136 137
	}
}

Steven Allen's avatar
Steven Allen committed
138
// Close closes the session.
Jeromy's avatar
Jeromy committed
139
func (mp *Multiplex) Close() error {
Steven Allen's avatar
Steven Allen committed
140
	mp.closeNoWait()
141

Steven Allen's avatar
Steven Allen committed
142 143
	// Wait for the receive loop to finish.
	<-mp.closed
144

Steven Allen's avatar
Steven Allen committed
145 146
	return nil
}
147

Steven Allen's avatar
Steven Allen committed
148 149 150 151 152 153 154
func (mp *Multiplex) closeNoWait() {
	mp.shutdownLock.Lock()
	select {
	case <-mp.shutdown:
	default:
		mp.con.Close()
		close(mp.shutdown)
Jeromy's avatar
Jeromy committed
155
	}
Steven Allen's avatar
Steven Allen committed
156
	mp.shutdownLock.Unlock()
Jeromy's avatar
Jeromy committed
157 158
}

Steven Allen's avatar
Steven Allen committed
159
// IsClosed returns true if the session is closed.
Jeromy's avatar
Jeromy committed
160 161 162 163 164 165
func (mp *Multiplex) IsClosed() bool {
	select {
	case <-mp.closed:
		return true
	default:
		return false
Jeromy's avatar
Jeromy committed
166 167 168
	}
}

169
func (mp *Multiplex) sendMsg(timeout, cancel <-chan struct{}, header uint64, data []byte) error {
170 171 172 173 174 175 176
	buf := pool.Get(len(data) + 20)

	n := 0
	n += binary.PutUvarint(buf[n:], header)
	n += binary.PutUvarint(buf[n:], uint64(len(data)))
	n += copy(buf[n:], data)

177
	select {
vyzo's avatar
vyzo committed
178 179 180 181
	case mp.writeCh <- buf[:n]:
		return nil
	case <-mp.shutdown:
		return ErrShutdown
182
	case <-timeout:
183
		return errTimeout
184 185
	case <-cancel:
		return ErrStreamClosed
186
	}
vyzo's avatar
vyzo committed
187
}
188

vyzo's avatar
vyzo committed
189 190 191 192 193 194 195
func (mp *Multiplex) handleOutgoing() {
	for {
		select {
		case <-mp.shutdown:
			return

		case data := <-mp.writeCh:
196
			// FIXME: https://gitlab.dms3.io/p2p/go-p2p/issues/644
197 198 199 200
			// write coalescing disabled until this can be fixed.
			//err := mp.writeMsg(data)
			err := mp.doWriteMsg(data)
			pool.Put(data)
vyzo's avatar
vyzo committed
201
			if err != nil {
vyzo's avatar
vyzo committed
202
				// the connection is closed by this time
203
				log.Warnf("error writing data: %s", err.Error())
vyzo's avatar
vyzo committed
204
				return
vyzo's avatar
vyzo committed
205 206
			}
		}
207
	}
vyzo's avatar
vyzo committed
208
}
209

vyzo's avatar
vyzo committed
210 211
func (mp *Multiplex) writeMsg(data []byte) error {
	if len(data) >= 512 {
212 213 214
		err := mp.doWriteMsg(data)
		pool.Put(data)
		return err
vyzo's avatar
vyzo committed
215 216 217 218 219 220 221 222
	}

	buf := pool.Get(4096)
	defer pool.Put(buf)

	n := copy(buf, data)
	pool.Put(data)

vyzo's avatar
vyzo committed
223 224 225 226 227
	if !mp.writeTimerFired {
		if !mp.writeTimer.Stop() {
			<-mp.writeTimer.C
		}
	}
228
	mp.writeTimer.Reset(WriteCoalesceDelay)
vyzo's avatar
vyzo committed
229
	mp.writeTimerFired = false
vyzo's avatar
vyzo committed
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250

	for {
		select {
		case data = <-mp.writeCh:
			wr := copy(buf[n:], data)
			if wr < len(data) {
				// we filled the buffer, send it
				err := mp.doWriteMsg(buf)
				if err != nil {
					pool.Put(data)
					return err
				}

				if len(data)-wr >= 512 {
					// the remaining data is not a small write, send it
					err := mp.doWriteMsg(data[wr:])
					pool.Put(data)
					return err
				}

				n = copy(buf, data[wr:])
251 252

				// we've written some, reset the timer to coalesce the rest
vyzo's avatar
vyzo committed
253 254
				if !mp.writeTimer.Stop() {
					<-mp.writeTimer.C
255
				}
256
				mp.writeTimer.Reset(WriteCoalesceDelay)
vyzo's avatar
vyzo committed
257 258 259 260 261 262
			} else {
				n += wr
			}

			pool.Put(data)

vyzo's avatar
vyzo committed
263 264
		case <-mp.writeTimer.C:
			mp.writeTimerFired = true
vyzo's avatar
vyzo committed
265 266 267 268
			return mp.doWriteMsg(buf[:n])

		case <-mp.shutdown:
			return ErrShutdown
Jeromy's avatar
Jeromy committed
269 270
		}
	}
vyzo's avatar
vyzo committed
271
}
272

vyzo's avatar
vyzo committed
273 274 275
func (mp *Multiplex) doWriteMsg(data []byte) error {
	if mp.isShutdown() {
		return ErrShutdown
Jeromy's avatar
Jeromy committed
276
	}
277

vyzo's avatar
vyzo committed
278 279 280
	_, err := mp.con.Write(data)
	if err != nil {
		mp.closeNoWait()
Jeromy's avatar
Jeromy committed
281
	}
Jeromy's avatar
Jeromy committed
282

283
	return err
Jeromy's avatar
Jeromy committed
284 285
}

Steven Allen's avatar
Steven Allen committed
286 287 288 289
func (mp *Multiplex) nextChanID() uint64 {
	out := mp.nextID
	mp.nextID++
	return out
290 291
}

Steven Allen's avatar
Steven Allen committed
292
// NewStream creates a new stream.
293 294
func (mp *Multiplex) NewStream(ctx context.Context) (*Stream, error) {
	return mp.NewNamedStream(ctx, "")
Jeromy's avatar
Jeromy committed
295 296
}

Steven Allen's avatar
Steven Allen committed
297
// NewNamedStream creates a new named stream.
298
func (mp *Multiplex) NewNamedStream(ctx context.Context, name string) (*Stream, error) {
Jeromy's avatar
Jeromy committed
299
	mp.chLock.Lock()
300 301 302 303

	// We could call IsClosed but this is faster (given that we already have
	// the lock).
	if mp.channels == nil {
304
		mp.chLock.Unlock()
305 306 307
		return nil, ErrShutdown
	}

308
	sid := mp.nextChanID()
Steven Allen's avatar
Steven Allen committed
309
	header := (sid << 3) | newStreamTag
Jeromy's avatar
Jeromy committed
310 311 312 313

	if name == "" {
		name = fmt.Sprint(sid)
	}
Steven Allen's avatar
Steven Allen committed
314 315 316 317 318
	s := mp.newStream(streamID{
		id:        sid,
		initiator: true,
	}, name)
	mp.channels[s.id] = s
Jeromy's avatar
Jeromy committed
319
	mp.chLock.Unlock()
320

321
	err := mp.sendMsg(ctx.Done(), nil, header, []byte(name))
Jeromy's avatar
Jeromy committed
322
	if err != nil {
323 324 325
		if err == errTimeout {
			return nil, ctx.Err()
		}
Jeromy's avatar
Jeromy committed
326
		return nil, err
Jeromy's avatar
Jeromy committed
327
	}
328

Jeromy's avatar
Jeromy committed
329
	return s, nil
Jeromy's avatar
Jeromy committed
330 331
}

Steven Allen's avatar
Steven Allen committed
332 333
func (mp *Multiplex) cleanup() {
	mp.closeNoWait()
334 335

	// Take the channels.
Steven Allen's avatar
Steven Allen committed
336
	mp.chLock.Lock()
337 338 339
	channels := mp.channels
	mp.channels = nil
	mp.chLock.Unlock()
340

341 342 343 344
	// Cancel any reads/writes
	for _, msch := range channels {
		msch.cancelRead(ErrStreamReset)
		msch.cancelWrite(ErrStreamReset)
Jeromy's avatar
Jeromy committed
345
	}
346 347

	// And... shutdown!
Steven Allen's avatar
Steven Allen committed
348 349 350 351
	if mp.shutdownErr == nil {
		mp.shutdownErr = ErrShutdown
	}
	close(mp.closed)
Jeromy's avatar
Jeromy committed
352
}
Jeromy's avatar
Jeromy committed
353

Jeromy's avatar
Jeromy committed
354
func (mp *Multiplex) handleIncoming() {
Steven Allen's avatar
Steven Allen committed
355
	defer mp.cleanup()
356 357 358 359 360 361 362 363

	recvTimeout := time.NewTimer(0)
	defer recvTimeout.Stop()

	if !recvTimeout.Stop() {
		<-recvTimeout.C
	}

Jeromy's avatar
Jeromy committed
364
	for {
Steven Allen's avatar
Steven Allen committed
365
		chID, tag, err := mp.readNextHeader()
Jeromy's avatar
Jeromy committed
366
		if err != nil {
Steven Allen's avatar
Steven Allen committed
367
			mp.shutdownErr = err
Jeromy's avatar
Jeromy committed
368
			return
Jeromy's avatar
Jeromy committed
369 370
		}

Steven Allen's avatar
Steven Allen committed
371 372 373 374 375 376 377 378 379 380 381 382 383 384
		remoteIsInitiator := tag&1 == 0
		ch := streamID{
			// true if *I'm* the initiator.
			initiator: !remoteIsInitiator,
			id:        chID,
		}
		// Rounds up the tag:
		// 0 -> 0
		// 1 -> 2
		// 2 -> 2
		// 3 -> 4
		// etc...
		tag += (tag & 1)

Jeromy's avatar
Jeromy committed
385 386
		b, err := mp.readNext()
		if err != nil {
Steven Allen's avatar
Steven Allen committed
387
			mp.shutdownErr = err
Jeromy's avatar
Jeromy committed
388
			return
Jeromy's avatar
Jeromy committed
389 390
		}

Jeromy's avatar
Jeromy committed
391
		mp.chLock.Lock()
Jeromy's avatar
Jeromy committed
392
		msch, ok := mp.channels[ch]
Jeromy's avatar
Jeromy committed
393
		mp.chLock.Unlock()
Steven Allen's avatar
Steven Allen committed
394

Jeromy's avatar
Jeromy committed
395
		switch tag {
Steven Allen's avatar
Steven Allen committed
396
		case newStreamTag:
Jeromy's avatar
Jeromy committed
397
			if ok {
Jeromy's avatar
Jeromy committed
398
				log.Debugf("received NewStream message for existing stream: %d", ch)
Steven Allen's avatar
Steven Allen committed
399 400
				mp.shutdownErr = ErrInvalidState
				return
Jeromy's avatar
Jeromy committed
401
			}
Jeromy's avatar
Jeromy committed
402 403

			name := string(b)
404 405
			pool.Put(b)

Steven Allen's avatar
Steven Allen committed
406
			msch = mp.newStream(ch, name)
Jeromy's avatar
Jeromy committed
407
			mp.chLock.Lock()
Jeromy's avatar
Jeromy committed
408
			mp.channels[ch] = msch
Jeromy's avatar
Jeromy committed
409
			mp.chLock.Unlock()
Jeromy's avatar
Jeromy committed
410 411
			select {
			case mp.nstreams <- msch:
Steven Allen's avatar
Steven Allen committed
412
			case <-mp.shutdown:
Jeromy's avatar
Jeromy committed
413 414
				return
			}
Jeromy's avatar
Jeromy committed
415

Steven Allen's avatar
Steven Allen committed
416
		case resetTag:
Steven Allen's avatar
Steven Allen committed
417
			if !ok {
Steven Allen's avatar
Steven Allen committed
418
				// This is *ok*. We forget the stream on reset.
Steven Allen's avatar
Steven Allen committed
419 420
				continue
			}
vyzo's avatar
vyzo committed
421

422 423 424
			// Cancel any ongoing reads/writes.
			msch.cancelRead(ErrStreamReset)
			msch.cancelWrite(ErrStreamReset)
Steven Allen's avatar
Steven Allen committed
425
		case closeTag:
Jeromy's avatar
Jeromy committed
426
			if !ok {
427
				// may have canceled our reads already.
Jeromy's avatar
Jeromy committed
428 429 430
				continue
			}

431 432 433 434
			// unregister and throw away future data.
			mp.chLock.Lock()
			delete(mp.channels, ch)
			mp.chLock.Unlock()
435

436
			// close data channel, there will be no more data.
437 438
			close(msch.dataIn)

439 440 441 442 443
			// We intentionally don't cancel any deadlines, cancel reads, cancel
			// writes, etc. We just deliver the EOF by closing the
			// data channel, and unregister the channel so we don't
			// receive any more data. The user still needs to call
			// `Close()` or `Reset()`.
Steven Allen's avatar
Steven Allen committed
444
		case messageTag:
Jeromy's avatar
Jeromy committed
445
			if !ok {
446 447 448
				// We're not accepting data on this stream, for
				// some reason. It's likely that we reset it, or
				// simply canceled reads (e.g., called Close).
449
				pool.Put(b)
450 451
				continue
			}
Steven Allen's avatar
Steven Allen committed
452

453
			recvTimeout.Reset(ReceiveTimeout)
454 455
			select {
			case msch.dataIn <- b:
456 457
			case <-msch.readCancel:
				// the user has canceled reading. walk away.
458
				pool.Put(b)
459
			case <-recvTimeout.C:
460
				pool.Put(b)
461
				log.Warnf("timed out receiving message into stream queue.")
462 463 464 465
				// Do not do this asynchronously. Otherwise, we
				// could drop a message, then receive a message,
				// then reset.
				msch.Reset()
Steven Allen's avatar
Steven Allen committed
466
				continue
Steven Allen's avatar
Steven Allen committed
467
			case <-mp.shutdown:
468
				pool.Put(b)
469 470
				return
			}
Steven Allen's avatar
Steven Allen committed
471 472 473
			if !recvTimeout.Stop() {
				<-recvTimeout.C
			}
Steven Allen's avatar
Steven Allen committed
474 475 476 477 478
		default:
			log.Debugf("message with unknown header on stream %s", ch)
			if ok {
				msch.Reset()
			}
Jeromy's avatar
Jeromy committed
479 480 481 482
		}
	}
}

vyzo's avatar
vyzo committed
483 484 485 486 487 488 489 490 491
func (mp *Multiplex) isShutdown() bool {
	select {
	case <-mp.shutdown:
		return true
	default:
		return false
	}
}

492
func (mp *Multiplex) sendResetMsg(header uint64, hard bool) {
493
	ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout)
494 495
	defer cancel()

496
	err := mp.sendMsg(ctx.Done(), nil, header, nil)
vyzo's avatar
vyzo committed
497
	if err != nil && !mp.isShutdown() {
498
		if hard {
499
			log.Warnf("error sending reset message: %s; killing connection", err.Error())
500 501 502 503
			mp.Close()
		} else {
			log.Debugf("error sending reset message: %s", err.Error())
		}
504 505 506
	}
}

Jeromy's avatar
Jeromy committed
507
func (mp *Multiplex) readNextHeader() (uint64, uint64, error) {
508
	h, err := varint.ReadUvarint(mp.buf)
Jeromy's avatar
Jeromy committed
509 510 511 512 513 514 515 516 517 518 519 520 521 522
	if err != nil {
		return 0, 0, err
	}

	// get channel ID
	ch := h >> 3

	rem := h & 7

	return ch, rem, nil
}

func (mp *Multiplex) readNext() ([]byte, error) {
	// get length
523
	l, err := varint.ReadUvarint(mp.buf)
Jeromy's avatar
Jeromy committed
524 525 526 527
	if err != nil {
		return nil, err
	}

Jeromy's avatar
Jeromy committed
528 529 530 531
	if l > uint64(MaxMessageSize) {
		return nil, fmt.Errorf("message size too large!")
	}

532 533 534 535
	if l == 0 {
		return nil, nil
	}

Steven Allen's avatar
Steven Allen committed
536
	buf := pool.Get(int(l))
537
	n, err := io.ReadFull(mp.buf, buf)
Jeromy's avatar
Jeromy committed
538 539 540 541
	if err != nil {
		return nil, err
	}

Jeromy's avatar
Jeromy committed
542
	return buf[:n], nil
Jeromy's avatar
Jeromy committed
543
}
vyzo's avatar
vyzo committed
544 545 546 547 548 549 550 551

func isFatalNetworkError(err error) bool {
	nerr, ok := err.(net.Error)
	if ok {
		return !(nerr.Timeout() || nerr.Temporary())
	}
	return false
}