dht_net.go 3.42 KB
Newer Older
1 2 3 4 5 6 7 8 9
package dht

import (
	"errors"
	"time"

	inet "github.com/jbenet/go-ipfs/net"
	peer "github.com/jbenet/go-ipfs/peer"
	pb "github.com/jbenet/go-ipfs/routing/dht/pb"
10
	ctxutil "github.com/jbenet/go-ipfs/util/ctx"
11 12

	context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
13
	ggio "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/gogoprotobuf/io"
14 15 16 17 18 19 20 21 22 23 24
)

// handleNewStream implements the inet.StreamHandler
func (dht *IpfsDHT) handleNewStream(s inet.Stream) {
	go dht.handleNewMessage(s)
}

func (dht *IpfsDHT) handleNewMessage(s inet.Stream) {
	defer s.Close()

	ctx := dht.Context()
25 26 27 28
	cr := ctxutil.NewReader(ctx, s) // ok to use. we defer close stream in this func
	cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func
	r := ggio.NewDelimitedReader(cr, inet.MessageSizeMax)
	w := ggio.NewDelimitedWriter(cw)
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
	mPeer := s.Conn().RemotePeer()

	// receive msg
	pmes := new(pb.Message)
	if err := r.ReadMsg(pmes); err != nil {
		log.Error("Error unmarshaling data")
		return
	}
	// update the peer (on valid msgs only)
	dht.Update(ctx, mPeer)

	log.Event(ctx, "foo", dht.self, mPeer, pmes)

	// get handler for this msg type.
	handler := dht.handlerForMsgType(pmes.GetType())
	if handler == nil {
		log.Error("got back nil handler from handlerForMsgType")
		return
	}

	// dispatch handler.
	rpmes, err := handler(ctx, mPeer, pmes)
	if err != nil {
		log.Errorf("handle message error: %s", err)
		return
	}

	// if nil response, return it before serializing
	if rpmes == nil {
		log.Warning("Got back nil response from request.")
		return
	}

	// send out response msg
	if err := w.WriteMsg(rpmes); err != nil {
		log.Errorf("send response error: %s", err)
		return
	}

	return
}

// sendRequest sends out a request, but also makes sure to
// measure the RTT for latency measurements.
73
func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) {
74 75 76 77 78 79 80 81

	log.Debugf("%s dht starting stream", dht.self)
	s, err := dht.network.NewStream(inet.ProtocolDHT, p)
	if err != nil {
		return nil, err
	}
	defer s.Close()

82 83 84 85
	cr := ctxutil.NewReader(ctx, s) // ok to use. we defer close stream in this func
	cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func
	r := ggio.NewDelimitedReader(cr, inet.MessageSizeMax)
	w := ggio.NewDelimitedWriter(cw)
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105

	start := time.Now()

	log.Debugf("%s writing", dht.self)
	if err := w.WriteMsg(pmes); err != nil {
		return nil, err
	}
	log.Event(ctx, "dhtSentMessage", dht.self, p, pmes)

	log.Debugf("%s reading", dht.self)
	defer log.Debugf("%s done", dht.self)

	rpmes := new(pb.Message)
	if err := r.ReadMsg(rpmes); err != nil {
		return nil, err
	}
	if rpmes == nil {
		return nil, errors.New("no response to request")
	}

106
	dht.peerstore.RecordLatency(p, time.Since(start))
107 108 109
	log.Event(ctx, "dhtReceivedMessage", dht.self, p, rpmes)
	return rpmes, nil
}
110 111

// sendMessage sends out a message
112
func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error {
113 114 115 116 117 118 119 120

	log.Debugf("%s dht starting stream", dht.self)
	s, err := dht.network.NewStream(inet.ProtocolDHT, p)
	if err != nil {
		return err
	}
	defer s.Close()

121 122
	cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func
	w := ggio.NewDelimitedWriter(cw)
123 124 125 126 127 128 129 130 131

	log.Debugf("%s writing", dht.self)
	if err := w.WriteMsg(pmes); err != nil {
		return err
	}
	log.Event(ctx, "dhtSentMessage", dht.self, p, pmes)
	log.Debugf("%s done", dht.self)
	return nil
}