Commit 16f00de5 authored by hannahhoward's avatar hannahhoward

test(session): make test more reliable

parent c5f9a91e
...@@ -2,7 +2,6 @@ package session ...@@ -2,7 +2,6 @@ package session
import ( import (
"context" "context"
"fmt"
"sync" "sync"
"testing" "testing"
"time" "time"
...@@ -18,48 +17,52 @@ import ( ...@@ -18,48 +17,52 @@ import (
type wantReq struct { type wantReq struct {
cids []cid.Cid cids []cid.Cid
peers []peer.ID peers []peer.ID
isCancel bool
} }
type fakeWantManager struct { type fakeWantManager struct {
lk sync.RWMutex wantReqs chan wantReq
wantReqs []wantReq cancelReqs chan wantReq
} }
func (fwm *fakeWantManager) WantBlocks(ctx context.Context, cids []cid.Cid, peers []peer.ID, ses uint64) { func (fwm *fakeWantManager) WantBlocks(ctx context.Context, cids []cid.Cid, peers []peer.ID, ses uint64) {
fwm.lk.Lock() fwm.wantReqs <- wantReq{cids, peers}
fwm.wantReqs = append(fwm.wantReqs, wantReq{cids, peers, false})
fwm.lk.Unlock()
} }
func (fwm *fakeWantManager) CancelWants(ctx context.Context, cids []cid.Cid, peers []peer.ID, ses uint64) { func (fwm *fakeWantManager) CancelWants(ctx context.Context, cids []cid.Cid, peers []peer.ID, ses uint64) {
fwm.lk.Lock() fwm.cancelReqs <- wantReq{cids, peers}
fwm.wantReqs = append(fwm.wantReqs, wantReq{cids, peers, true})
fwm.lk.Unlock()
} }
type fakePeerManager struct { type fakePeerManager struct {
lk sync.RWMutex
peers []peer.ID peers []peer.ID
findMorePeersRequested bool findMorePeersRequested bool
} }
func (fpm *fakePeerManager) FindMorePeers(context.Context, cid.Cid) { func (fpm *fakePeerManager) FindMorePeers(context.Context, cid.Cid) {
fpm.lk.Lock()
fpm.findMorePeersRequested = true fpm.findMorePeersRequested = true
fpm.lk.Unlock()
} }
func (fpm *fakePeerManager) GetOptimizedPeers() []peer.ID { func (fpm *fakePeerManager) GetOptimizedPeers() []peer.ID {
fpm.lk.Lock()
defer fpm.lk.Unlock()
return fpm.peers return fpm.peers
} }
func (fpm *fakePeerManager) RecordPeerRequests([]peer.ID, []cid.Cid) {} func (fpm *fakePeerManager) RecordPeerRequests([]peer.ID, []cid.Cid) {}
func (fpm *fakePeerManager) RecordPeerResponse(p peer.ID, c cid.Cid) { func (fpm *fakePeerManager) RecordPeerResponse(p peer.ID, c cid.Cid) {
fpm.lk.Lock()
fpm.peers = append(fpm.peers, p) fpm.peers = append(fpm.peers, p)
fpm.lk.Unlock()
} }
func TestSessionGetBlocks(t *testing.T) { func TestSessionGetBlocks(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel() defer cancel()
fwm := &fakeWantManager{} wantReqs := make(chan wantReq, 1)
cancelReqs := make(chan wantReq, 1)
fwm := &fakeWantManager{wantReqs, cancelReqs}
fpm := &fakePeerManager{} fpm := &fakePeerManager{}
id := testutil.GenerateSessionID() id := testutil.GenerateSessionID()
session := New(ctx, id, fwm, fpm) session := New(ctx, id, fwm, fpm)
...@@ -69,24 +72,15 @@ func TestSessionGetBlocks(t *testing.T) { ...@@ -69,24 +72,15 @@ func TestSessionGetBlocks(t *testing.T) {
for _, block := range blks { for _, block := range blks {
cids = append(cids, block.Cid()) cids = append(cids, block.Cid())
} }
var receivedBlocks []blocks.Block
getBlocksCh, err := session.GetBlocks(ctx, cids) getBlocksCh, err := session.GetBlocks(ctx, cids)
go func() {
for block := range getBlocksCh {
receivedBlocks = append(receivedBlocks, block)
}
}()
if err != nil { if err != nil {
t.Fatal("error getting blocks") t.Fatal("error getting blocks")
} }
// check initial want request // check initial want request
time.Sleep(3 * time.Millisecond) receivedWantReq := <-fwm.wantReqs
if len(fwm.wantReqs) != 1 {
t.Fatal("failed to enqueue wants")
}
fwm.lk.Lock()
receivedWantReq := fwm.wantReqs[0]
if len(receivedWantReq.cids) != activeWantsLimit { if len(receivedWantReq.cids) != activeWantsLimit {
t.Fatal("did not enqueue correct initial number of wants") t.Fatal("did not enqueue correct initial number of wants")
} }
...@@ -94,17 +88,23 @@ func TestSessionGetBlocks(t *testing.T) { ...@@ -94,17 +88,23 @@ func TestSessionGetBlocks(t *testing.T) {
t.Fatal("first want request should be a broadcast") t.Fatal("first want request should be a broadcast")
} }
fwm.wantReqs = nil
fwm.lk.Unlock()
// now receive the first set of blocks // now receive the first set of blocks
peers := testutil.GeneratePeers(activeWantsLimit) peers := testutil.GeneratePeers(activeWantsLimit)
var newCancelReqs []wantReq
var newBlockReqs []wantReq
var receivedBlocks []blocks.Block
for i, p := range peers { for i, p := range peers {
session.ReceiveBlockFrom(p, blks[i]) session.ReceiveBlockFrom(p, blks[testutil.IndexOf(blks, receivedWantReq.cids[i])])
receivedBlock := <-getBlocksCh
receivedBlocks = append(receivedBlocks, receivedBlock)
cancelBlock := <-cancelReqs
newCancelReqs = append(newCancelReqs, cancelBlock)
wantBlock := <-wantReqs
newBlockReqs = append(newBlockReqs, wantBlock)
} }
time.Sleep(3 * time.Millisecond)
// verify new peers were recorded // verify new peers were recorded
fpm.lk.Lock()
if len(fpm.peers) != activeWantsLimit { if len(fpm.peers) != activeWantsLimit {
t.Fatal("received blocks not recorded by the peer manager") t.Fatal("received blocks not recorded by the peer manager")
} }
...@@ -113,21 +113,12 @@ func TestSessionGetBlocks(t *testing.T) { ...@@ -113,21 +113,12 @@ func TestSessionGetBlocks(t *testing.T) {
t.Fatal("incorrect peer recorded to peer manager") t.Fatal("incorrect peer recorded to peer manager")
} }
} }
fpm.lk.Unlock()
// look at new interactions with want manager // look at new interactions with want manager
var cancelReqs []wantReq
var newBlockReqs []wantReq
fwm.lk.Lock()
for _, w := range fwm.wantReqs {
if w.isCancel {
cancelReqs = append(cancelReqs, w)
} else {
newBlockReqs = append(newBlockReqs, w)
}
}
// should have cancelled each received block // should have cancelled each received block
if len(cancelReqs) != activeWantsLimit { if len(newCancelReqs) != activeWantsLimit {
t.Fatal("did not cancel each block once it was received") t.Fatal("did not cancel each block once it was received")
} }
// new session reqs should be targeted // new session reqs should be targeted
...@@ -138,7 +129,6 @@ func TestSessionGetBlocks(t *testing.T) { ...@@ -138,7 +129,6 @@ func TestSessionGetBlocks(t *testing.T) {
} }
totalEnqueued += len(w.cids) totalEnqueued += len(w.cids)
} }
fwm.lk.Unlock()
// full new round of cids should be requested // full new round of cids should be requested
if totalEnqueued != activeWantsLimit { if totalEnqueued != activeWantsLimit {
...@@ -147,15 +137,13 @@ func TestSessionGetBlocks(t *testing.T) { ...@@ -147,15 +137,13 @@ func TestSessionGetBlocks(t *testing.T) {
// receive remaining blocks // receive remaining blocks
for i, p := range peers { for i, p := range peers {
session.ReceiveBlockFrom(p, blks[i+activeWantsLimit]) session.ReceiveBlockFrom(p, blks[testutil.IndexOf(blks, newBlockReqs[i].cids[0])])
receivedBlock := <-getBlocksCh
receivedBlocks = append(receivedBlocks, receivedBlock)
cancelBlock := <-cancelReqs
newCancelReqs = append(newCancelReqs, cancelBlock)
} }
// wait for everything to wrap up
<-ctx.Done()
// check that we got everything
fmt.Printf("%d\n", len(receivedBlocks))
if len(receivedBlocks) != len(blks) { if len(receivedBlocks) != len(blks) {
t.Fatal("did not receive enough blocks") t.Fatal("did not receive enough blocks")
} }
...@@ -170,60 +158,52 @@ func TestSessionFindMorePeers(t *testing.T) { ...@@ -170,60 +158,52 @@ func TestSessionFindMorePeers(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel() defer cancel()
fwm := &fakeWantManager{} wantReqs := make(chan wantReq, 1)
cancelReqs := make(chan wantReq, 1)
fwm := &fakeWantManager{wantReqs, cancelReqs}
fpm := &fakePeerManager{} fpm := &fakePeerManager{}
id := testutil.GenerateSessionID() id := testutil.GenerateSessionID()
session := New(ctx, id, fwm, fpm) session := New(ctx, id, fwm, fpm)
session.SetBaseTickDelay(1 * time.Millisecond) session.SetBaseTickDelay(200 * time.Microsecond)
blockGenerator := blocksutil.NewBlockGenerator() blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(activeWantsLimit * 2) blks := blockGenerator.Blocks(activeWantsLimit * 2)
var cids []cid.Cid var cids []cid.Cid
for _, block := range blks { for _, block := range blks {
cids = append(cids, block.Cid()) cids = append(cids, block.Cid())
} }
var receivedBlocks []blocks.Block
getBlocksCh, err := session.GetBlocks(ctx, cids) getBlocksCh, err := session.GetBlocks(ctx, cids)
go func() {
for block := range getBlocksCh {
receivedBlocks = append(receivedBlocks, block)
}
}()
if err != nil { if err != nil {
t.Fatal("error getting blocks") t.Fatal("error getting blocks")
} }
// clear the initial block of wants
<-wantReqs
// receive a block to trigger a tick reset // receive a block to trigger a tick reset
time.Sleep(1 * time.Millisecond) time.Sleep(200 * time.Microsecond)
p := testutil.GeneratePeers(1)[0] p := testutil.GeneratePeers(1)[0]
session.ReceiveBlockFrom(p, blks[0]) session.ReceiveBlockFrom(p, blks[0])
<-getBlocksCh
// wait then clear the want list <-wantReqs
time.Sleep(1 * time.Millisecond) <-cancelReqs
fwm.lk.Lock()
fwm.wantReqs = nil
fwm.lk.Unlock()
// wait long enough for a tick to occur // wait long enough for a tick to occur
// baseTickDelay + 3 * latency = 4ms time.Sleep(20 * time.Millisecond)
time.Sleep(6 * time.Millisecond)
// trigger to find providers should have happened // trigger to find providers should have happened
fpm.lk.Lock()
if fpm.findMorePeersRequested != true { if fpm.findMorePeersRequested != true {
t.Fatal("should have attempted to find more peers but didn't") t.Fatal("should have attempted to find more peers but didn't")
} }
fpm.lk.Unlock()
// verify a broadcast was made // verify a broadcast was made
fwm.lk.Lock() receivedWantReq := <-wantReqs
if len(fwm.wantReqs) != 1 {
t.Fatal("did not make a new broadcast")
}
receivedWantReq := fwm.wantReqs[0]
if len(receivedWantReq.cids) != activeWantsLimit { if len(receivedWantReq.cids) != activeWantsLimit {
t.Fatal("did not rebroadcast whole live list") t.Fatal("did not rebroadcast whole live list")
} }
if receivedWantReq.peers != nil { if receivedWantReq.peers != nil {
t.Fatal("did not make a broadcast") t.Fatal("did not make a broadcast")
} }
fwm.wantReqs = nil <-ctx.Done()
fwm.lk.Unlock()
} }
...@@ -130,6 +130,7 @@ func TestUntaggingPeers(t *testing.T) { ...@@ -130,6 +130,7 @@ func TestUntaggingPeers(t *testing.T) {
t.Fatal("Peers were not tagged!") t.Fatal("Peers were not tagged!")
} }
<-ctx.Done() <-ctx.Done()
time.Sleep(5 * time.Millisecond)
if len(fcm.taggedPeers) != 0 { if len(fcm.taggedPeers) != 0 {
t.Fatal("Peers were not untagged!") t.Fatal("Peers were not untagged!")
} }
......
...@@ -78,12 +78,17 @@ func ContainsPeer(peers []peer.ID, p peer.ID) bool { ...@@ -78,12 +78,17 @@ func ContainsPeer(peers []peer.ID, p peer.ID) bool {
return false return false
} }
// ContainsBlock returns true if a block is found n a list of blocks // IndexOf returns the index of a given cid in an array of blocks
func ContainsBlock(blks []blocks.Block, block blocks.Block) bool { func IndexOf(blks []blocks.Block, c cid.Cid) int {
for _, n := range blks { for i, n := range blks {
if block.Cid() == n.Cid() { if n.Cid() == c {
return true return i
} }
} }
return false return -1
}
// ContainsBlock returns true if a block is found n a list of blocks
func ContainsBlock(blks []blocks.Block, block blocks.Block) bool {
return IndexOf(blks, block.Cid()) != -1
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment