libp2p_impl_test.go 3.86 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
package network

import (
	"context"
	"math/rand"
	"reflect"
	"testing"
	"time"

	gsmsg "github.com/ipfs/go-graphsync/message"
11
	"github.com/ipfs/go-graphsync/testutil"
12
	"github.com/libp2p/go-libp2p-core/peer"
13 14 15 16 17 18 19 20
	mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
)

// Receiver is an interface for receiving messages from the GraphSyncNetwork.
type receiver struct {
	messageReceived chan struct{}
	lastMessage     gsmsg.GraphSyncMessage
	lastSender      peer.ID
21
	connectedPeers  chan peer.ID
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
}

func (r *receiver) ReceiveMessage(
	ctx context.Context,
	sender peer.ID,
	incoming gsmsg.GraphSyncMessage) {
	r.lastSender = sender
	r.lastMessage = incoming
	select {
	case <-ctx.Done():
	case r.messageReceived <- struct{}{}:
	}
}

func (r *receiver) ReceiveError(err error) {
}

39 40 41 42 43 44 45
func (r *receiver) Connected(p peer.ID) {
	r.connectedPeers <- p
}

func (r *receiver) Disconnected(p peer.ID) {
}

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
func TestMessageSendAndReceive(t *testing.T) {
	// create network
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
	defer cancel()
	mn := mocknet.New(ctx)

	host1, err := mn.GenPeer()
	if err != nil {
		t.Fatal("error generating host")
	}
	host2, err := mn.GenPeer()
	if err != nil {
		t.Fatal("error generating host")
	}
	err = mn.LinkAll()
	if err != nil {
		t.Fatal("error linking hosts")
	}
65 66
	gsnet1 := NewFromLibp2pHost(host1)
	gsnet2 := NewFromLibp2pHost(host2)
67 68
	r := &receiver{
		messageReceived: make(chan struct{}),
69
		connectedPeers:  make(chan peer.ID, 2),
70 71 72 73
	}
	gsnet1.SetDelegate(r)
	gsnet2.SetDelegate(r)

74
	root := testutil.GenerateCids(1)[0]
75
	selector := testutil.RandomBytes(100)
76 77 78 79 80
	extensionName := gsmsg.GraphSyncExtensionName("graphsync/awesome")
	extension := gsmsg.GraphSyncExtension{
		Name: extensionName,
		Data: testutil.RandomBytes(100),
	}
81 82 83 84 85
	id := gsmsg.GraphSyncRequestID(rand.Int31())
	priority := gsmsg.GraphSyncPriority(rand.Int31())
	status := gsmsg.RequestAcknowledged

	sent := gsmsg.New()
86
	sent.AddRequest(gsmsg.NewRequest(id, root, selector, priority))
87
	sent.AddResponse(gsmsg.NewResponse(id, status, extension))
88

89 90 91 92 93
	err = gsnet1.ConnectTo(ctx, host2.ID())
	if err != nil {
		t.Fatal("Unable to connect peers")
	}

94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
	gsnet1.SendMessage(ctx, host2.ID(), sent)

	select {
	case <-ctx.Done():
		t.Fatal("did not receive message sent")
	case <-r.messageReceived:
	}

	sender := r.lastSender
	if sender != host1.ID() {
		t.Fatal("received message from wrong node")
	}

	received := r.lastMessage

	sentRequests := sent.Requests()
	if len(sentRequests) != 1 {
		t.Fatal("Did not add request to sent message")
	}
	sentRequest := sentRequests[0]
	receivedRequests := received.Requests()
	if len(receivedRequests) != 1 {
		t.Fatal("Did not add request to received message")
	}
	receivedRequest := receivedRequests[0]
	if receivedRequest.ID() != sentRequest.ID() ||
		receivedRequest.IsCancel() != sentRequest.IsCancel() ||
		receivedRequest.Priority() != sentRequest.Priority() ||
122
		receivedRequest.Root().String() != sentRequest.Root().String() ||
123 124 125 126 127 128 129 130 131 132 133 134 135
		!reflect.DeepEqual(receivedRequest.Selector(), sentRequest.Selector()) {
		t.Fatal("Sent message requests did not match received message requests")
	}
	sentResponses := sent.Responses()
	if len(sentResponses) != 1 {
		t.Fatal("Did not add response to sent message")
	}
	sentResponse := sentResponses[0]
	receivedResponses := received.Responses()
	if len(receivedResponses) != 1 {
		t.Fatal("Did not add response to received message")
	}
	receivedResponse := receivedResponses[0]
136
	extensionData, err := receivedResponse.Extension(extensionName)
137 138
	if receivedResponse.RequestID() != sentResponse.RequestID() ||
		receivedResponse.Status() != sentResponse.Status() ||
139 140
		err != nil ||
		!reflect.DeepEqual(extension.Data, extensionData) {
141 142
		t.Fatal("Sent message responses did not match received message responses")
	}
143 144 145 146 147 148 149 150 151

	for i := 0; i < 2; i++ {
		select {
		case <-ctx.Done():
			t.Fatal("did notify of all peer connections")
		case <-r.connectedPeers:
		}
	}

152
}