Commit 994279bd authored by Dirk McCormick's avatar Dirk McCormick

fix: make sure GetBlocks() channel is closed on session close

parent 693e97d0
......@@ -61,15 +61,19 @@ func SyncGetBlock(p context.Context, k cid.Cid, gb GetBlocksFunc) (blocks.Block,
type WantFunc func(context.Context, []cid.Cid)
// AsyncGetBlocks take a set of block cids, a pubsub channel for incoming
// blocks, a want function, and a close function,
// and returns a channel of incoming blocks.
func AsyncGetBlocks(ctx context.Context, keys []cid.Cid, notif notifications.PubSub, want WantFunc, cwants func([]cid.Cid)) (<-chan blocks.Block, error) {
// blocks, a want function, and a close function, and returns a channel of
// incoming blocks.
func AsyncGetBlocks(ctx context.Context, sessctx context.Context, keys []cid.Cid, notif notifications.PubSub,
want WantFunc, cwants func([]cid.Cid)) (<-chan blocks.Block, error) {
// If there are no keys supplied, just return a closed channel
if len(keys) == 0 {
out := make(chan blocks.Block)
close(out)
return out, nil
}
// Use a PubSub notifier to listen for incoming blocks for each key
remaining := cid.NewSet()
promise := notif.Subscribe(ctx, keys...)
for _, k := range keys {
......@@ -77,24 +81,36 @@ func AsyncGetBlocks(ctx context.Context, keys []cid.Cid, notif notifications.Pub
remaining.Add(k)
}
// Send the want request for the keys to the network
want(ctx, keys)
out := make(chan blocks.Block)
go handleIncoming(ctx, remaining, promise, out, cwants)
go handleIncoming(ctx, sessctx, remaining, promise, out, cwants)
return out, nil
}
func handleIncoming(ctx context.Context, remaining *cid.Set, in <-chan blocks.Block, out chan blocks.Block, cfun func([]cid.Cid)) {
// Listens for incoming blocks, passing them to the out channel.
// If the context is cancelled or the incoming channel closes, calls cfun with
// any keys corresponding to blocks that were never received.
func handleIncoming(ctx context.Context, sessctx context.Context, remaining *cid.Set,
in <-chan blocks.Block, out chan blocks.Block, cfun func([]cid.Cid)) {
ctx, cancel := context.WithCancel(ctx)
// Clean up before exiting this function, and call the cancel function on
// any remaining keys
defer func() {
cancel()
close(out)
// can't just defer this call on its own, arguments are resolved *when* the defer is created
cfun(remaining.Keys())
}()
for {
select {
case blk, ok := <-in:
// If the channel is closed, we're done (note that PubSub closes
// the channel once all the keys have been received)
if !ok {
return
}
......@@ -104,9 +120,13 @@ func handleIncoming(ctx context.Context, remaining *cid.Set, in <-chan blocks.Bl
case out <- blk:
case <-ctx.Done():
return
case <-sessctx.Done():
return
}
case <-ctx.Done():
return
case <-sessctx.Done():
return
}
}
}
......@@ -60,8 +60,8 @@ func (ps *impl) Shutdown() {
}
// Subscribe returns a channel of blocks for the given |keys|. |blockChannel|
// is closed if the |ctx| times out or is cancelled, or after sending len(keys)
// blocks.
// is closed if the |ctx| times out or is cancelled, or after receiving the blocks
// corresponding to |keys|.
func (ps *impl) Subscribe(ctx context.Context, keys ...cid.Cid) <-chan blocks.Block {
blocksCh := make(chan blocks.Block, len(keys))
......@@ -82,6 +82,8 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...cid.Cid) <-chan blocks.Bl
default:
}
// AddSubOnceEach listens for each key in the list, and closes the channel
// once all keys have been received
ps.wrapped.AddSubOnceEach(valuesCh, toStrings(keys)...)
go func() {
defer func() {
......
......@@ -182,7 +182,8 @@ func (s *Session) GetBlock(parent context.Context, k cid.Cid) (blocks.Block, err
// guaranteed on the returned blocks.
func (s *Session) GetBlocks(ctx context.Context, keys []cid.Cid) (<-chan blocks.Block, error) {
ctx = logging.ContextWithLoggable(ctx, s.uuid)
return bsgetter.AsyncGetBlocks(ctx, keys, s.notif,
return bsgetter.AsyncGetBlocks(ctx, s.ctx, keys, s.notif,
func(ctx context.Context, keys []cid.Cid) {
select {
case s.newReqs <- keys:
......
......@@ -416,3 +416,45 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) {
t.Fatal("Did not rebroadcast to find more peers")
}
}
func TestSessionCtxCancelClosesGetBlocksChannel(t *testing.T) {
wantReqs := make(chan wantReq, 1)
cancelReqs := make(chan wantReq, 1)
fwm := &fakeWantManager{wantReqs, cancelReqs}
fpm := &fakePeerManager{}
frs := &fakeRequestSplitter{}
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
// Create a new session with its own context
sessctx, sesscancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
session := New(sessctx, id, fwm, fpm, frs, notif, time.Second, delay.Fixed(time.Minute))
timerCtx, timerCancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer timerCancel()
// Request a block with a new context
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(1)
getctx, getcancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer getcancel()
getBlocksCh, err := session.GetBlocks(getctx, []cid.Cid{blks[0].Cid()})
if err != nil {
t.Fatal("error getting blocks")
}
// Cancel the session context
sesscancel()
// Expect the GetBlocks() channel to be closed
select {
case _, ok := <-getBlocksCh:
if ok {
t.Fatal("expected channel to be closed but was not closed")
}
case <-timerCtx.Done():
t.Fatal("expected channel to be closed before timeout")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment