Commit eaca669f authored by Adin Schmahmann's avatar Adin Schmahmann

fullrt: fix dividing up bulk sending of keys into groups

parent c5223963
...@@ -962,19 +962,9 @@ func (dht *FullRT) bulkMessageSend(ctx context.Context, keys []peer.ID, fn func( ...@@ -962,19 +962,9 @@ func (dht *FullRT) bulkMessageSend(ctx context.Context, keys []peer.ID, fn func(
var numSendsSuccessful uint64 = 0 var numSendsSuccessful uint64 = 0
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(dht.bulkSendParallelism)
chunkSize := len(sortedKeys) / dht.bulkSendParallelism
onePctKeys := uint64(len(sortedKeys)) / 100 onePctKeys := uint64(len(sortedKeys)) / 100
for i := 0; i < dht.bulkSendParallelism; i++ {
var chunk []peer.ID
end := (i + 1) * chunkSize
if end > len(sortedKeys) {
chunk = sortedKeys[i*chunkSize:]
} else {
chunk = sortedKeys[i*chunkSize : end]
}
go func() { bulkSendFn := func(chunk []peer.ID) {
defer wg.Done() defer wg.Done()
for _, key := range chunk { for _, key := range chunk {
if ctx.Err() != nil { if ctx.Err() != nil {
...@@ -997,8 +987,14 @@ func (dht *FullRT) bulkMessageSend(ctx context.Context, keys []peer.ID, fn func( ...@@ -997,8 +987,14 @@ func (dht *FullRT) bulkMessageSend(ctx context.Context, keys []peer.ID, fn func(
atomic.AddUint64(&numSendsSuccessful, 1) atomic.AddUint64(&numSendsSuccessful, 1)
} }
} }
}()
} }
keyGroups := divideIntoGroups(sortedKeys, dht.bulkSendParallelism)
wg.Add(len(keyGroups))
for _, chunk := range keyGroups {
go bulkSendFn(chunk)
}
wg.Wait() wg.Wait()
if numSendsSuccessful == 0 { if numSendsSuccessful == 0 {
...@@ -1010,6 +1006,39 @@ func (dht *FullRT) bulkMessageSend(ctx context.Context, keys []peer.ID, fn func( ...@@ -1010,6 +1006,39 @@ func (dht *FullRT) bulkMessageSend(ctx context.Context, keys []peer.ID, fn func(
return nil return nil
} }
// divideIntoGroups divides the set of keys into (at most) the number of groups
func divideIntoGroups(keys []peer.ID, groups int) [][]peer.ID {
var keyGroups [][]peer.ID
if len(keys) < groups {
for i := 0; i < len(keys); i++ {
keyGroups = append(keyGroups, keys[i:i+1])
}
return keyGroups
}
chunkSize := len(keys) / groups
remainder := len(keys) % groups
start := 0
end := chunkSize
for i := 0; i < groups; i++ {
var chunk []peer.ID
// distribute the remainder as one extra entry per parallel thread
if remainder > 0 {
chunk = keys[start : end+1]
remainder--
start = end + 1
end = end + 1 + chunkSize
} else {
chunk = keys[start:end]
start = end
end = end + chunkSize
}
keyGroups = append(keyGroups, chunk)
}
return keyGroups
}
// FindProviders searches until the context expires. // FindProviders searches until the context expires.
func (dht *FullRT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrInfo, error) { func (dht *FullRT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrInfo, error) {
if !dht.enableProviders { if !dht.enableProviders {
......
package fullrt
import (
"strconv"
"testing"
"github.com/libp2p/go-libp2p-core/peer"
)
func TestDivideIntoGroups(t *testing.T) {
var keys []peer.ID
for i := 0; i < 10; i++ {
keys = append(keys, peer.ID(strconv.Itoa(i)))
}
convertToStrings := func(peers []peer.ID) []string {
var out []string
for _, p := range peers {
out = append(out, string(p))
}
return out
}
pidsEquals := func(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
t.Run("Divides", func(t *testing.T) {
gr := divideIntoGroups(keys, 2)
if len(gr) != 2 {
t.Fatal("incorrect number of groups")
}
if g1, expected := convertToStrings(gr[0]), []string{"0", "1", "2", "3", "4"}; !pidsEquals(g1, expected) {
t.Fatalf("expected %v, got %v", expected, g1)
}
if g2, expected := convertToStrings(gr[1]), []string{"5", "6", "7", "8", "9"}; !pidsEquals(g2, expected) {
t.Fatalf("expected %v, got %v", expected, g2)
}
})
t.Run("Remainder", func(t *testing.T) {
gr := divideIntoGroups(keys, 3)
if len(gr) != 3 {
t.Fatal("incorrect number of groups")
}
if g, expected := convertToStrings(gr[0]), []string{"0", "1", "2", "3"}; !pidsEquals(g, expected) {
t.Fatalf("expected %v, got %v", expected, g)
}
if g, expected := convertToStrings(gr[1]), []string{"4", "5", "6"}; !pidsEquals(g, expected) {
t.Fatalf("expected %v, got %v", expected, g)
}
if g, expected := convertToStrings(gr[2]), []string{"7", "8", "9"}; !pidsEquals(g, expected) {
t.Fatalf("expected %v, got %v", expected, g)
}
})
t.Run("OneEach", func(t *testing.T) {
gr := divideIntoGroups(keys, 10)
if len(gr) != 10 {
t.Fatal("incorrect number of groups")
}
for i := 0; i < 10; i++ {
if g, expected := convertToStrings(gr[i]), []string{strconv.Itoa(i)}; !pidsEquals(g, expected) {
t.Fatalf("expected %v, got %v", expected, g)
}
}
})
t.Run("TooManyGroups", func(t *testing.T) {
gr := divideIntoGroups(keys, 11)
if len(gr) != 10 {
t.Fatal("incorrect number of groups")
}
for i := 0; i < 10; i++ {
if g, expected := convertToStrings(gr[i]), []string{strconv.Itoa(i)}; !pidsEquals(g, expected) {
t.Fatalf("expected %v, got %v", expected, g)
}
}
})
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment