sessionpeermanager_test.go 5.96 KB
Newer Older
1 2 3 4
package sessionpeermanager

import (
	"context"
hannahhoward's avatar
hannahhoward committed
5
	"math/rand"
6
	"sync"
7 8 9 10 11 12
	"testing"
	"time"

	"github.com/ipfs/go-bitswap/testutil"

	cid "github.com/ipfs/go-cid"
Raúl Kripalani's avatar
Raúl Kripalani committed
13
	peer "github.com/libp2p/go-libp2p-core/peer"
14 15
)

16 17 18
type fakePeerProviderFinder struct {
	peers     []peer.ID
	completed chan struct{}
19 20
}

21
func (fppf *fakePeerProviderFinder) FindProvidersAsync(ctx context.Context, c cid.Cid) <-chan peer.ID {
22 23
	peerCh := make(chan peer.ID)
	go func() {
24 25

		for _, p := range fppf.peers {
26 27 28
			select {
			case peerCh <- p:
			case <-ctx.Done():
29
				close(peerCh)
30 31 32
				return
			}
		}
33 34 35
		close(peerCh)

		select {
36
		case fppf.completed <- struct{}{}:
37 38
		case <-ctx.Done():
		}
39 40 41 42
	}()
	return peerCh
}

43
type fakePeerTagger struct {
44
	lk          sync.Mutex
45
	taggedPeers []peer.ID
46
	wait        sync.WaitGroup
47 48
}

49 50
func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, n int) {
	fpt.wait.Add(1)
51 52 53

	fpt.lk.Lock()
	defer fpt.lk.Unlock()
54
	fpt.taggedPeers = append(fpt.taggedPeers, p)
55
}
56

57 58 59
func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {
	defer fpt.wait.Done()

60 61
	fpt.lk.Lock()
	defer fpt.lk.Unlock()
62 63 64 65
	for i := 0; i < len(fpt.taggedPeers); i++ {
		if fpt.taggedPeers[i] == p {
			fpt.taggedPeers[i] = fpt.taggedPeers[len(fpt.taggedPeers)-1]
			fpt.taggedPeers = fpt.taggedPeers[:len(fpt.taggedPeers)-1]
66 67 68 69
			return
		}
	}
}
70

71 72 73 74 75 76
func (fpt *fakePeerTagger) count() int {
	fpt.lk.Lock()
	defer fpt.lk.Unlock()
	return len(fpt.taggedPeers)
}

77 78 79 80
func TestFindingMorePeers(t *testing.T) {
	ctx := context.Background()
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()
81 82
	completed := make(chan struct{})

83
	peers := testutil.GeneratePeers(5)
84 85
	fpt := &fakePeerTagger{}
	fppf := &fakePeerProviderFinder{peers, completed}
86 87 88
	c := testutil.GenerateCids(1)[0]
	id := testutil.GenerateSessionID()

89
	sessionPeerManager := New(ctx, id, fpt, fppf)
90 91 92 93

	findCtx, findCancel := context.WithTimeout(ctx, 10*time.Millisecond)
	defer findCancel()
	sessionPeerManager.FindMorePeers(ctx, c)
94 95 96 97 98 99 100
	select {
	case <-completed:
	case <-findCtx.Done():
		t.Fatal("Did not finish finding providers")
	}
	time.Sleep(2 * time.Millisecond)

101 102 103 104 105 106 107 108 109
	sessionPeers := sessionPeerManager.GetOptimizedPeers()
	if len(sessionPeers) != len(peers) {
		t.Fatal("incorrect number of peers found")
	}
	for _, p := range sessionPeers {
		if !testutil.ContainsPeer(peers, p) {
			t.Fatal("incorrect peer found through finding providers")
		}
	}
110
	if len(fpt.taggedPeers) != len(peers) {
111 112 113 114 115 116 117 118 119
		t.Fatal("Peers were not tagged!")
	}
}

func TestRecordingReceivedBlocks(t *testing.T) {
	ctx := context.Background()
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()
	p := testutil.GeneratePeers(1)[0]
120 121
	fpt := &fakePeerTagger{}
	fppf := &fakePeerProviderFinder{}
122 123 124
	c := testutil.GenerateCids(1)[0]
	id := testutil.GenerateSessionID()

125
	sessionPeerManager := New(ctx, id, fpt, fppf)
126 127 128 129 130 131 132 133 134
	sessionPeerManager.RecordPeerResponse(p, c)
	time.Sleep(10 * time.Millisecond)
	sessionPeers := sessionPeerManager.GetOptimizedPeers()
	if len(sessionPeers) != 1 {
		t.Fatal("did not add peer on receive")
	}
	if sessionPeers[0] != p {
		t.Fatal("incorrect peer added on receive")
	}
135
	if len(fpt.taggedPeers) != 1 {
136 137 138 139
		t.Fatal("Peers was not tagged!")
	}
}

hannahhoward's avatar
hannahhoward committed
140 141
func TestOrderingPeers(t *testing.T) {
	ctx := context.Background()
142
	ctx, cancel := context.WithTimeout(ctx, 30*time.Millisecond)
hannahhoward's avatar
hannahhoward committed
143 144
	defer cancel()
	peers := testutil.GeneratePeers(100)
145
	completed := make(chan struct{})
146 147
	fpt := &fakePeerTagger{}
	fppf := &fakePeerProviderFinder{peers, completed}
hannahhoward's avatar
hannahhoward committed
148 149
	c := testutil.GenerateCids(1)
	id := testutil.GenerateSessionID()
150
	sessionPeerManager := New(ctx, id, fpt, fppf)
hannahhoward's avatar
hannahhoward committed
151 152 153

	// add all peers to session
	sessionPeerManager.FindMorePeers(ctx, c[0])
154 155 156 157 158 159
	select {
	case <-completed:
	case <-ctx.Done():
		t.Fatal("Did not finish finding providers")
	}
	time.Sleep(2 * time.Millisecond)
hannahhoward's avatar
hannahhoward committed
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209

	// record broadcast
	sessionPeerManager.RecordPeerRequests(nil, c)

	// record receives
	peer1 := peers[rand.Intn(100)]
	peer2 := peers[rand.Intn(100)]
	peer3 := peers[rand.Intn(100)]
	time.Sleep(1 * time.Millisecond)
	sessionPeerManager.RecordPeerResponse(peer1, c[0])
	time.Sleep(1 * time.Millisecond)
	sessionPeerManager.RecordPeerResponse(peer2, c[0])
	time.Sleep(1 * time.Millisecond)
	sessionPeerManager.RecordPeerResponse(peer3, c[0])

	sessionPeers := sessionPeerManager.GetOptimizedPeers()
	if len(sessionPeers) != maxOptimizedPeers {
		t.Fatal("Should not return more than the max of optimized peers")
	}

	// should prioritize peers which have received blocks
	if (sessionPeers[0] != peer3) || (sessionPeers[1] != peer2) || (sessionPeers[2] != peer1) {
		t.Fatal("Did not prioritize peers that received blocks")
	}

	// Receive a second time from same node
	sessionPeerManager.RecordPeerResponse(peer3, c[0])

	// call again
	nextSessionPeers := sessionPeerManager.GetOptimizedPeers()
	if len(nextSessionPeers) != maxOptimizedPeers {
		t.Fatal("Should not return more than the max of optimized peers")
	}

	// should not duplicate
	if (nextSessionPeers[0] != peer3) || (nextSessionPeers[1] != peer2) || (nextSessionPeers[2] != peer1) {
		t.Fatal("Did dedup peers which received multiple blocks")
	}

	// should randomize other peers
	totalSame := 0
	for i := 3; i < maxOptimizedPeers; i++ {
		if sessionPeers[i] == nextSessionPeers[i] {
			totalSame++
		}
	}
	if totalSame >= maxOptimizedPeers-3 {
		t.Fatal("should not return the same random peers each time")
	}
}
210

211 212 213 214 215
func TestUntaggingPeers(t *testing.T) {
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
	defer cancel()
	peers := testutil.GeneratePeers(5)
216
	completed := make(chan struct{})
217 218
	fpt := &fakePeerTagger{}
	fppf := &fakePeerProviderFinder{peers, completed}
219 220 221
	c := testutil.GenerateCids(1)[0]
	id := testutil.GenerateSessionID()

222
	sessionPeerManager := New(ctx, id, fpt, fppf)
223 224

	sessionPeerManager.FindMorePeers(ctx, c)
225 226 227 228 229 230 231
	select {
	case <-completed:
	case <-ctx.Done():
		t.Fatal("Did not finish finding providers")
	}
	time.Sleep(2 * time.Millisecond)

232
	if fpt.count() != len(peers) {
233 234 235
		t.Fatal("Peers were not tagged!")
	}
	<-ctx.Done()
236
	fpt.wait.Wait()
237

238
	if fpt.count() != 0 {
239 240 241
		t.Fatal("Peers were not untagged!")
	}
}