Commit d3d572b7 authored by Steven Allen's avatar Steven Allen

Retry queries if a write to the stream fails.

Previously, we'd only retry when a write failed (in case the peer was using
single-use streams). However, in the new world of half-open streams, writes
still succeed on remote-closed streams so this isn't sufficient. Instead, we
need to retry on read failure as well.

Note: this was technically broken before because a peer could write on a stream
before receiving the close message causing the write to succeed but the
subsequent read to fail.
parent 286dd24f
......@@ -161,73 +161,88 @@ const streamReuseTries = 3
func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error {
ms.lk.Lock()
defer ms.lk.Unlock()
if err := ms.prep(); err != nil {
return err
}
if err := ms.writeMessage(pmes); err != nil {
return err
}
if ms.singleMes > streamReuseTries {
ms.s.Close()
ms.s = nil
}
return nil
}
func (ms *messageSender) writeMessage(pmes *pb.Message) error {
err := ms.w.WriteMsg(pmes)
if err != nil {
// If the other side isnt expecting us to be reusing streams, we're gonna
// end up erroring here. To make sure things work seamlessly, lets retry once
// before continuing
log.Infof("error writing message: ", err)
ms.s.Close()
ms.s = nil
retry := false
for {
if err := ms.prep(); err != nil {
return err
}
if err := ms.w.WriteMsg(pmes); err != nil {
return err
ms.s.Reset()
ms.s = nil
if retry {
log.Info("error writing message, bailing: ", err)
return err
} else {
log.Info("error writing message, trying again: ", err)
retry = true
continue
}
}
log.Event(ctx, "dhtSentMessage", ms.dht.self, ms.p, pmes)
if ms.singleMes > streamReuseTries {
ms.s.Close()
ms.s = nil
} else if retry {
ms.singleMes++
}
// keep track of this happening. If it happens a few times, its
// likely we can assume the otherside will never support stream reuse
ms.singleMes++
return nil
}
return nil
}
func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) {
ms.lk.Lock()
defer ms.lk.Unlock()
if err := ms.prep(); err != nil {
return nil, err
}
retry := false
for {
if err := ms.prep(); err != nil {
return nil, err
}
if err := ms.writeMessage(pmes); err != nil {
return nil, err
}
if err := ms.w.WriteMsg(pmes); err != nil {
ms.s.Reset()
ms.s = nil
if retry {
log.Info("error writing message, bailing: ", err)
return nil, err
} else {
log.Info("error writing message, trying again: ", err)
retry = true
continue
}
}
log.Event(ctx, "dhtSentMessage", ms.dht.self, ms.p, pmes)
mes := new(pb.Message)
if err := ms.ctxReadMsg(ctx, mes); err != nil {
ms.s.Reset()
ms.s = nil
if retry {
log.Info("error reading message, bailing: ", err)
return nil, err
} else {
log.Info("error reading message, trying again: ", err)
retry = true
continue
}
}
mes := new(pb.Message)
if err := ms.ctxReadMsg(ctx, mes); err != nil {
ms.s.Reset()
ms.s = nil
return nil, err
}
log.Event(ctx, "dhtSentMessage", ms.dht.self, ms.p, pmes)
if ms.singleMes > streamReuseTries {
ms.s.Close()
ms.s = nil
}
if ms.singleMes > streamReuseTries {
ms.s.Close()
ms.s = nil
} else if retry {
ms.singleMes++
}
return mes, nil
return mes, nil
}
}
func (ms *messageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error {
......
......@@ -192,7 +192,6 @@ func TestNotFound(t *testing.T) {
if err := pbw.WriteMsg(resp); err != nil {
panic(err)
}
default:
panic("Shouldnt recieve this.")
}
......@@ -288,3 +287,68 @@ func TestLessThanKResponses(t *testing.T) {
}
t.Fatal("Expected to recieve an error.")
}
// Test multiple queries against a node that closes its stream after every query.
func TestMultipleQueries(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
ctx := context.Background()
mn, err := mocknet.FullMeshConnected(ctx, 2)
if err != nil {
t.Fatal(err)
}
hosts := mn.Hosts()
tsds := dssync.MutexWrap(ds.NewMapDatastore())
d := NewDHT(ctx, hosts[0], tsds)
d.Update(ctx, hosts[1].ID())
// It would be nice to be able to just get a value and succeed but then
// we'd need to deal with selectors and validators...
hosts[1].SetStreamHandler(ProtocolDHT, func(s inet.Stream) {
defer s.Close()
pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
pbw := ggio.NewDelimitedWriter(s)
pmes := new(pb.Message)
if err := pbr.ReadMsg(pmes); err != nil {
panic(err)
}
switch pmes.GetType() {
case pb.Message_GET_VALUE:
pi := hosts[1].Peerstore().PeerInfo(hosts[0].ID())
resp := &pb.Message{
Type: pmes.Type,
CloserPeers: pb.PeerInfosToPBPeers(d.host.Network(), []pstore.PeerInfo{pi}),
}
if err := pbw.WriteMsg(resp); err != nil {
panic(err)
}
default:
panic("Shouldnt recieve this.")
}
})
// long timeout to ensure timing is not at play.
ctx, cancel := context.WithTimeout(ctx, time.Second*20)
defer cancel()
for i := 0; i < 10; i++ {
if _, err := d.GetValue(ctx, "hello"); err != nil {
switch err {
case routing.ErrNotFound:
//Success!
continue
case u.ErrTimeout:
t.Fatal("Should not have gotten timeout!")
default:
t.Fatalf("Got unexpected error: %s", err)
}
}
t.Fatal("Expected to recieve an error.")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment