Unverified Commit dbb3d2c0 authored by Steven Allen's avatar Steven Allen Committed by GitHub

Merge pull request #462 from libp2p/fix/observe-context-in-message-sender

fix: obey the context when sending messages to peers
parents a92f79b9 0b029388
package dht
import (
"context"
)
type ctxMutex chan struct{}
func newCtxMutex() ctxMutex {
return make(ctxMutex, 1)
}
func (m ctxMutex) Lock(ctx context.Context) error {
select {
case m <- struct{}{}:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (m ctxMutex) Unlock() {
select {
case <-m:
default:
panic("not locked")
}
}
......@@ -246,7 +246,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
dht.smlk.Unlock()
return ms, nil
}
ms = &messageSender{p: p, dht: dht}
ms = &messageSender{p: p, dht: dht, lk: newCtxMutex()}
dht.strmap[p] = ms
dht.smlk.Unlock()
......@@ -274,7 +274,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
type messageSender struct {
s network.Stream
r msgio.ReadCloser
lk sync.Mutex
lk ctxMutex
p peer.ID
dht *IpfsDHT
......@@ -294,8 +294,11 @@ func (ms *messageSender) invalidate() {
}
func (ms *messageSender) prepOrInvalidate(ctx context.Context) error {
ms.lk.Lock()
if err := ms.lk.Lock(ctx); err != nil {
return err
}
defer ms.lk.Unlock()
if err := ms.prep(ctx); err != nil {
ms.invalidate()
return err
......@@ -328,8 +331,11 @@ func (ms *messageSender) prep(ctx context.Context) error {
const streamReuseTries = 3
func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error {
ms.lk.Lock()
if err := ms.lk.Lock(ctx); err != nil {
return err
}
defer ms.lk.Unlock()
retry := false
for {
if err := ms.prep(ctx); err != nil {
......@@ -363,8 +369,11 @@ func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) erro
}
func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) {
ms.lk.Lock()
if err := ms.lk.Lock(ctx); err != nil {
return nil, err
}
defer ms.lk.Unlock()
retry := false
for {
if err := ms.prep(ctx); err != nil {
......
......@@ -18,6 +18,49 @@ import (
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
)
func TestHang(t *testing.T) {
ctx := context.Background()
mn, err := mocknet.FullMeshConnected(ctx, 2)
if err != nil {
t.Fatal(err)
}
hosts := mn.Hosts()
os := []opts.Option{opts.DisableAutoRefresh()}
d, err := New(ctx, hosts[0], os...)
if err != nil {
t.Fatal(err)
}
// Hang on every request.
hosts[1].SetStreamHandler(d.protocols[0], func(s network.Stream) {
defer s.Reset()
<-ctx.Done()
})
d.Update(ctx, hosts[1].ID())
ctx1, cancel1 := context.WithTimeout(ctx, 1*time.Second)
defer cancel1()
peers, err := d.GetClosestPeers(ctx1, testCaseCids[0].KeyString())
if err != nil {
t.Fatal(err)
}
time.Sleep(100 * time.Millisecond)
ctx2, cancel2 := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel2()
_ = d.Provide(ctx2, testCaseCids[0], true)
if ctx2.Err() != context.DeadlineExceeded {
t.Errorf("expected to fail with deadline exceeded, got: %s", ctx2.Err())
}
select {
case <-peers:
t.Error("GetClosestPeers should not have returned yet")
default:
}
}
func TestGetFailures(t *testing.T) {
if testing.Short() {
t.SkipNow()
......
package dht
import (
"context"
"github.com/libp2p/go-libp2p-core/helpers"
"github.com/libp2p/go-libp2p-core/network"
......@@ -130,7 +132,7 @@ func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) {
// Do this asynchronously as ms.lk can block for a while.
go func() {
ms.lk.Lock()
ms.lk.Lock(context.Background())
defer ms.lk.Unlock()
ms.invalidate()
}()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment