Commit cdc87be0 authored by Steven Allen's avatar Steven Allen

engine(test): make the test peer tagger more reliable

parent 9d580a65
...@@ -19,38 +19,63 @@ import ( ...@@ -19,38 +19,63 @@ import (
testutil "github.com/libp2p/go-libp2p-core/test" testutil "github.com/libp2p/go-libp2p-core/test"
) )
type peerTag struct {
done chan struct{}
peers map[peer.ID]int
}
type fakePeerTagger struct { type fakePeerTagger struct {
lk sync.Mutex lk sync.Mutex
wait sync.WaitGroup tags map[string]*peerTag
taggedPeers []peer.ID
} }
func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, n int) { func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, n int) {
fpt.wait.Add(1)
fpt.lk.Lock() fpt.lk.Lock()
defer fpt.lk.Unlock() defer fpt.lk.Unlock()
fpt.taggedPeers = append(fpt.taggedPeers, p) if fpt.tags == nil {
fpt.tags = make(map[string]*peerTag, 1)
}
pt, ok := fpt.tags[tag]
if !ok {
pt = &peerTag{peers: make(map[peer.ID]int, 1), done: make(chan struct{})}
fpt.tags[tag] = pt
}
pt.peers[p] = n
} }
func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) { func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {
defer fpt.wait.Done()
fpt.lk.Lock() fpt.lk.Lock()
defer fpt.lk.Unlock() defer fpt.lk.Unlock()
for i := 0; i < len(fpt.taggedPeers); i++ { pt := fpt.tags[tag]
if fpt.taggedPeers[i] == p { if pt == nil {
fpt.taggedPeers[i] = fpt.taggedPeers[len(fpt.taggedPeers)-1] return
fpt.taggedPeers = fpt.taggedPeers[:len(fpt.taggedPeers)-1] }
return delete(pt.peers, p)
} if len(pt.peers) == 0 {
close(pt.done)
delete(fpt.tags, tag)
} }
} }
func (fpt *fakePeerTagger) count() int { func (fpt *fakePeerTagger) count(tag string) int {
fpt.lk.Lock() fpt.lk.Lock()
defer fpt.lk.Unlock() defer fpt.lk.Unlock()
return len(fpt.taggedPeers) if pt, ok := fpt.tags[tag]; ok {
return len(pt.peers)
}
return 0
}
func (fpt *fakePeerTagger) wait(tag string) {
fpt.lk.Lock()
pt := fpt.tags[tag]
if pt == nil {
fpt.lk.Unlock()
return
}
doneCh := pt.done
fpt.lk.Unlock()
<-doneCh
} }
type engineSet struct { type engineSet struct {
...@@ -241,13 +266,13 @@ func TestTaggingPeers(t *testing.T) { ...@@ -241,13 +266,13 @@ func TestTaggingPeers(t *testing.T) {
next := <-sanfrancisco.Engine.Outbox() next := <-sanfrancisco.Engine.Outbox()
envelope := <-next envelope := <-next
if sanfrancisco.PeerTagger.count() != 1 { if sanfrancisco.PeerTagger.count(sanfrancisco.Engine.tagQueued) != 1 {
t.Fatal("Incorrect number of peers tagged") t.Fatal("Incorrect number of peers tagged")
} }
envelope.Sent() envelope.Sent()
<-sanfrancisco.Engine.Outbox() <-sanfrancisco.Engine.Outbox()
sanfrancisco.PeerTagger.wait.Wait() sanfrancisco.PeerTagger.wait(sanfrancisco.Engine.tagQueued)
if sanfrancisco.PeerTagger.count() != 0 { if sanfrancisco.PeerTagger.count(sanfrancisco.Engine.tagQueued) != 0 {
t.Fatal("Peers should be untagged but weren't") t.Fatal("Peers should be untagged but weren't")
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment