dht_net.go 9.22 KB
Newer Older
1 2 3
package dht

import (
4
	"bufio"
Jeromy's avatar
Jeromy committed
5
	"context"
6
	"fmt"
7
	"io"
8
	"sync"
9 10
	"time"

11 12
	ggio "github.com/gogo/protobuf/io"
	ctxio "github.com/jbenet/go-context/io"
13
	"github.com/libp2p/go-libp2p-kad-dht/metrics"
14
	pb "github.com/libp2p/go-libp2p-kad-dht/pb"
15 16
	inet "github.com/libp2p/go-libp2p-net"
	peer "github.com/libp2p/go-libp2p-peer"
17 18
	"go.opencensus.io/stats"
	"go.opencensus.io/tag"
19 20
)

21 22 23
var dhtReadMessageTimeout = time.Minute
var ErrReadTimeout = fmt.Errorf("timed out reading response")

24 25 26 27 28 29 30 31
// The Protobuf writer performs multiple small writes when writing a message.
// We need to buffer those writes, to make sure that we're not sending a new
// packet for every single write.
type bufferedDelimitedWriter struct {
	*bufio.Writer
	ggio.WriteCloser
}

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
var writerPool = sync.Pool{
	New: func() interface{} {
		w := bufio.NewWriter(nil)
		return &bufferedDelimitedWriter{
			Writer:      w,
			WriteCloser: ggio.NewDelimitedWriter(w),
		}
	},
}

func writeMsg(w io.Writer, mes *pb.Message) error {
	bw := writerPool.Get().(*bufferedDelimitedWriter)
	bw.Reset(w)
	err := bw.WriteMsg(mes)
	if err == nil {
		err = bw.Flush()
48
	}
49 50 51
	bw.Reset(nil)
	writerPool.Put(bw)
	return err
52 53 54 55 56 57
}

func (w *bufferedDelimitedWriter) Flush() error {
	return w.Writer.Flush()
}

58 59
// handleNewStream implements the inet.StreamHandler
func (dht *IpfsDHT) handleNewStream(s inet.Stream) {
Matt Joiner's avatar
Matt Joiner committed
60 61 62 63 64
	defer s.Reset()
	if dht.handleNewMessage(s) {
		// Gracefully close the stream for writes.
		s.Close()
	}
65 66
}

Matt Joiner's avatar
Matt Joiner committed
67 68
// Returns true on orderly completion of writes (so we can Close the stream).
func (dht *IpfsDHT) handleNewMessage(s inet.Stream) bool {
69 70
	ctx := dht.ctx

71 72
	cr := ctxio.NewReader(ctx, s) // ok to use. we defer close stream in this func
	cw := ctxio.NewWriter(ctx, s) // ok to use. we defer close stream in this func
73
	r := ggio.NewDelimitedReader(cr, inet.MessageSizeMax)
74 75
	mPeer := s.Conn().RemotePeer()

76
	for {
Matt Joiner's avatar
Matt Joiner committed
77 78
		var req pb.Message
		switch err := r.ReadMsg(&req); err {
79
		case io.EOF:
Matt Joiner's avatar
Matt Joiner committed
80
			return true
81
		default:
Matt Joiner's avatar
Matt Joiner committed
82 83 84 85 86
			// This string test is necessary because there isn't a single stream reset error
			// instance	in use.
			if err.Error() != "stream reset" {
				logger.Debugf("error reading message: %#v", err)
			}
87 88 89 90 91
			stats.RecordWithTags(
				ctx,
				[]tag.Mutator{tag.Upsert(metrics.KeyMessageType, "UNKNOWN")},
				metrics.ReceivedMessageErrors.M(1),
			)
Matt Joiner's avatar
Matt Joiner committed
92 93
			return false
		case nil:
94 95
		}

96 97 98 99 100 101 102 103 104 105 106 107
		startTime := time.Now()
		ctx, _ = tag.New(
			ctx,
			tag.Upsert(metrics.KeyMessageType, req.GetType().String()),
		)

		stats.Record(
			ctx,
			metrics.ReceivedMessages.M(1),
			metrics.ReceivedBytes.M(int64(req.Size())),
		)

Matt Joiner's avatar
Matt Joiner committed
108
		handler := dht.handlerForMsgType(req.GetType())
109
		if handler == nil {
110
			stats.Record(ctx, metrics.ReceivedMessageErrors.M(1))
Matt Joiner's avatar
Matt Joiner committed
111 112
			logger.Warningf("can't handle received message of type %v", req.GetType())
			return false
113 114
		}

Matt Joiner's avatar
Matt Joiner committed
115
		resp, err := handler(ctx, mPeer, &req)
116
		if err != nil {
117
			stats.Record(ctx, metrics.ReceivedMessageErrors.M(1))
Matt Joiner's avatar
Matt Joiner committed
118 119
			logger.Debugf("error handling message: %v", err)
			return false
120 121
		}

Matt Joiner's avatar
Matt Joiner committed
122 123 124
		dht.updateFromMessage(ctx, mPeer, &req)

		if resp == nil {
125 126 127 128
			continue
		}

		// send out response msg
129
		err = writeMsg(cw, resp)
130
		if err != nil {
131
			stats.Record(ctx, metrics.ReceivedMessageErrors.M(1))
Matt Joiner's avatar
Matt Joiner committed
132 133
			logger.Debugf("error writing response: %v", err)
			return false
134
		}
Matt Joiner's avatar
Matt Joiner committed
135

136 137 138
		elapsedTime := time.Since(startTime)
		latencyMillis := float64(elapsedTime) / float64(time.Millisecond)
		stats.Record(ctx, metrics.InboundRequestLatency.M(latencyMillis))
139 140 141 142 143
	}
}

// sendRequest sends out a request, but also makes sure to
// measure the RTT for latency measurements.
144
func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) {
145
	ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes))
146

Steven Allen's avatar
Steven Allen committed
147
	ms, err := dht.messageSenderForPeer(ctx, p)
148
	if err != nil {
149
		stats.Record(ctx, metrics.SentRequestErrors.M(1))
150 151
		return nil, err
	}
152 153 154

	start := time.Now()

155 156
	rpmes, err := ms.SendRequest(ctx, pmes)
	if err != nil {
157
		stats.Record(ctx, metrics.SentRequestErrors.M(1))
158 159 160
		return nil, err
	}

161 162 163
	// update the peer (on valid msgs only)
	dht.updateFromMessage(ctx, p, rpmes)

164 165 166 167 168 169 170 171
	stats.Record(
		ctx,
		metrics.SentRequests.M(1),
		metrics.SentBytes.M(int64(pmes.Size())),
		metrics.OutboundRequestLatency.M(
			float64(time.Since(start))/float64(time.Millisecond),
		),
	)
172
	dht.peerstore.RecordLatency(p, time.Since(start))
Matt Joiner's avatar
Matt Joiner committed
173
	logger.Event(ctx, "dhtReceivedMessage", dht.self, p, rpmes)
174 175
	return rpmes, nil
}
176 177

// sendMessage sends out a message
178
func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error {
179 180
	ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes))

Steven Allen's avatar
Steven Allen committed
181
	ms, err := dht.messageSenderForPeer(ctx, p)
182
	if err != nil {
183
		stats.Record(ctx, metrics.SentMessageErrors.M(1))
184 185
		return err
	}
186

187
	if err := ms.SendMessage(ctx, pmes); err != nil {
188
		stats.Record(ctx, metrics.SentMessageErrors.M(1))
189 190
		return err
	}
191 192 193 194 195 196

	stats.Record(
		ctx,
		metrics.SentMessages.M(1),
		metrics.SentBytes.M(int64(pmes.Size())),
	)
Matt Joiner's avatar
Matt Joiner committed
197
	logger.Event(ctx, "dhtSentMessage", dht.self, p, pmes)
198 199
	return nil
}
200 201

func (dht *IpfsDHT) updateFromMessage(ctx context.Context, p peer.ID, mes *pb.Message) error {
202 203 204 205 206
	// Make sure that this node is actually a DHT server, not just a client.
	protos, err := dht.peerstore.SupportsProtocols(p, dht.protocolStrs()...)
	if err == nil && len(protos) > 0 {
		dht.Update(ctx, p)
	}
207 208
	return nil
}
209

Steven Allen's avatar
Steven Allen committed
210
func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) {
211 212
	dht.smlk.Lock()
	ms, ok := dht.strmap[p]
213 214 215
	if ok {
		dht.smlk.Unlock()
		return ms, nil
216
	}
217 218 219 220
	ms = &messageSender{p: p, dht: dht}
	dht.strmap[p] = ms
	dht.smlk.Unlock()

Steven Allen's avatar
Steven Allen committed
221
	if err := ms.prepOrInvalidate(ctx); err != nil {
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
		dht.smlk.Lock()
		defer dht.smlk.Unlock()

		if msCur, ok := dht.strmap[p]; ok {
			// Changed. Use the new one, old one is invalid and
			// not in the map so we can just throw it away.
			if ms != msCur {
				return msCur, nil
			}
			// Not changed, remove the now invalid stream from the
			// map.
			delete(dht.strmap, p)
		}
		// Invalid but not in map. Must have been removed by a disconnect.
		return nil, err
	}
	// All ready to go.
	return ms, nil
240 241 242 243 244 245 246 247
}

type messageSender struct {
	s   inet.Stream
	r   ggio.ReadCloser
	lk  sync.Mutex
	p   peer.ID
	dht *IpfsDHT
Jeromy's avatar
Jeromy committed
248

249
	invalid   bool
Jeromy's avatar
Jeromy committed
250
	singleMes int
251 252
}

Steven Allen's avatar
Steven Allen committed
253 254 255
// invalidate is called before this messageSender is removed from the strmap.
// It prevents the messageSender from being reused/reinitialized and then
// forgotten (leaving the stream open).
256 257 258 259 260 261 262 263
func (ms *messageSender) invalidate() {
	ms.invalid = true
	if ms.s != nil {
		ms.s.Reset()
		ms.s = nil
	}
}

Steven Allen's avatar
Steven Allen committed
264
func (ms *messageSender) prepOrInvalidate(ctx context.Context) error {
265 266
	ms.lk.Lock()
	defer ms.lk.Unlock()
Steven Allen's avatar
Steven Allen committed
267
	if err := ms.prep(ctx); err != nil {
268 269 270 271
		ms.invalidate()
		return err
	}
	return nil
272 273
}

Steven Allen's avatar
Steven Allen committed
274
func (ms *messageSender) prep(ctx context.Context) error {
275 276 277
	if ms.invalid {
		return fmt.Errorf("message sender has been invalidated")
	}
278 279 280 281
	if ms.s != nil {
		return nil
	}

Steven Allen's avatar
Steven Allen committed
282
	nstr, err := ms.dht.host.NewStream(ctx, ms.p, ms.dht.protocols...)
283 284 285 286
	if err != nil {
		return err
	}

287 288 289 290 291 292
	ms.r = ggio.NewDelimitedReader(nstr, inet.MessageSizeMax)
	ms.s = nstr

	return nil
}

293 294 295 296
// streamReuseTries is the number of times we will try to reuse a stream to a
// given peer before giving up and reverting to the old one-message-per-stream
// behaviour.
const streamReuseTries = 3
297

298 299 300
func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error {
	ms.lk.Lock()
	defer ms.lk.Unlock()
301 302
	retry := false
	for {
Steven Allen's avatar
Steven Allen committed
303
		if err := ms.prep(ctx); err != nil {
Jeromy's avatar
Jeromy committed
304 305 306
			return err
		}

307
		if err := ms.writeMsg(pmes); err != nil {
308 309 310 311
			ms.s.Reset()
			ms.s = nil

			if retry {
Matt Joiner's avatar
Matt Joiner committed
312
				logger.Info("error writing message, bailing: ", err)
313 314
				return err
			}
315 316 317
			logger.Info("error writing message, trying again: ", err)
			retry = true
			continue
318 319
		}

Matt Joiner's avatar
Matt Joiner committed
320
		logger.Event(ctx, "dhtSentMessage", ms.dht.self, ms.p, pmes)
321 322

		if ms.singleMes > streamReuseTries {
Steven Allen's avatar
Steven Allen committed
323
			go inet.FullClose(ms.s)
324 325 326
			ms.s = nil
		} else if retry {
			ms.singleMes++
Jeromy's avatar
Jeromy committed
327 328
		}

329
		return nil
330
	}
331
}
332 333 334 335

func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) {
	ms.lk.Lock()
	defer ms.lk.Unlock()
336 337
	retry := false
	for {
Steven Allen's avatar
Steven Allen committed
338
		if err := ms.prep(ctx); err != nil {
339 340
			return nil, err
		}
341

342
		if err := ms.writeMsg(pmes); err != nil {
343 344 345 346
			ms.s.Reset()
			ms.s = nil

			if retry {
Matt Joiner's avatar
Matt Joiner committed
347
				logger.Info("error writing message, bailing: ", err)
348 349
				return nil, err
			}
350 351 352
			logger.Info("error writing message, trying again: ", err)
			retry = true
			continue
353
		}
354

355 356 357 358 359 360
		mes := new(pb.Message)
		if err := ms.ctxReadMsg(ctx, mes); err != nil {
			ms.s.Reset()
			ms.s = nil

			if retry {
Matt Joiner's avatar
Matt Joiner committed
361
				logger.Info("error reading message, bailing: ", err)
362 363
				return nil, err
			}
364 365 366
			logger.Info("error reading message, trying again: ", err)
			retry = true
			continue
367
		}
368

Matt Joiner's avatar
Matt Joiner committed
369
		logger.Event(ctx, "dhtSentMessage", ms.dht.self, ms.p, pmes)
370

371
		if ms.singleMes > streamReuseTries {
Steven Allen's avatar
Steven Allen committed
372
			go inet.FullClose(ms.s)
373 374 375 376
			ms.s = nil
		} else if retry {
			ms.singleMes++
		}
Jeromy's avatar
Jeromy committed
377

378 379
		return mes, nil
	}
380
}
381

382
func (ms *messageSender) writeMsg(pmes *pb.Message) error {
383
	return writeMsg(ms.s, pmes)
384 385
}

386 387
func (ms *messageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error {
	errc := make(chan error, 1)
388 389 390
	go func(r ggio.ReadCloser) {
		errc <- r.ReadMsg(mes)
	}(ms.r)
391

392 393 394
	t := time.NewTimer(dhtReadMessageTimeout)
	defer t.Stop()

395 396 397 398 399
	select {
	case err := <-errc:
		return err
	case <-ctx.Done():
		return ctx.Err()
400 401
	case <-t.C:
		return ErrReadTimeout
402 403
	}
}