From 39f3c34e410ee1a00af4255279f69adb31229ba9 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Thu, 1 Apr 2021 20:09:10 -0700 Subject: [PATCH] fix: handle missing session exchange in Session Otherwise, we'll panic. --- blockservice.go | 12 ++++++++++-- blockservice_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/blockservice.go b/blockservice.go index 33f6914..2f320b1 100644 --- a/blockservice.go +++ b/blockservice.go @@ -366,12 +366,20 @@ func (s *Session) getSession() exchange.Fetcher { // GetBlock gets a block in the context of a request session func (s *Session) GetBlock(ctx context.Context, c cid.Cid) (blocks.Block, error) { - return getBlock(ctx, c, s.bs, s.getSession) // hash security + var f func() exchange.Fetcher + if s.sessEx != nil { + f = s.getSession + } + return getBlock(ctx, c, s.bs, f) // hash security } // GetBlocks gets blocks in the context of a request session func (s *Session) GetBlocks(ctx context.Context, ks []cid.Cid) <-chan blocks.Block { - return getBlocks(ctx, ks, s.bs, s.getSession) // hash security + var f func() exchange.Fetcher + if s.sessEx != nil { + f = s.getSession + } + return getBlocks(ctx, ks, s.bs, f) // hash security } var _ BlockGetter = (*Session)(nil) diff --git a/blockservice_test.go b/blockservice_test.go index dfd12fc..36cdf03 100644 --- a/blockservice_test.go +++ b/blockservice_test.go @@ -119,3 +119,31 @@ func (fe *fakeSessionExchange) NewSession(ctx context.Context) exchange.Fetcher } return fe.session } + +func TestNilExchange(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + bgen := butil.NewBlockGenerator() + block := bgen.Next() + + bs := blockstore.NewBlockstore(dssync.MutexWrap(ds.NewMapDatastore())) + bserv := NewWriteThrough(bs, nil) + sess := NewSession(ctx, bserv) + _, err := sess.GetBlock(ctx, block.Cid()) + if err != ErrNotFound { + t.Fatal("expected block to not be found") + } + err = bs.Put(block) + if err != nil { + t.Fatal(err) + } + b, err := sess.GetBlock(ctx, block.Cid()) + if err != nil { + t.Fatal(err) + } + if b.Cid() != block.Cid() { + t.Fatal("got the wrong block") + } +} -- GitLab