Commit 4c98edaf authored by Łukasz Magiera's avatar Łukasz Magiera

p2p: fix remote/local listener races

License: MIT
Signed-off-by: default avatarŁukasz Magiera <magik6k@gmail.com>
parent a3c84e20
...@@ -31,26 +31,33 @@ type ListenerRegistry struct { ...@@ -31,26 +31,33 @@ type ListenerRegistry struct {
sync.Mutex sync.Mutex
Listeners map[listenerKey]Listener Listeners map[listenerKey]Listener
starting map[listenerKey]struct{}
} }
// Register registers listenerInfo into this registry and starts it // Register registers listenerInfo into this registry and starts it
func (r *ListenerRegistry) Register(l Listener) error { func (r *ListenerRegistry) Register(l Listener) error {
r.Lock() r.Lock()
k := getListenerKey(l)
if _, ok := r.Listeners[getListenerKey(l)]; ok { if _, ok := r.Listeners[k]; ok {
r.Unlock() r.Unlock()
return errors.New("listener already registered") return errors.New("listener already registered")
} }
r.Listeners[getListenerKey(l)] = l r.Listeners[k] = l
r.starting[k] = struct{}{}
r.Unlock() r.Unlock()
if err := l.start(); err != nil { err := l.start()
r.Lock()
defer r.Lock()
delete(r.Listeners, getListenerKey(l)) r.Lock()
defer r.Unlock()
delete(r.starting, k)
if err != nil {
delete(r.Listeners, k)
return err return err
} }
...@@ -58,13 +65,17 @@ func (r *ListenerRegistry) Register(l Listener) error { ...@@ -58,13 +65,17 @@ func (r *ListenerRegistry) Register(l Listener) error {
} }
// Deregister removes p2p listener from this registry // Deregister removes p2p listener from this registry
func (r *ListenerRegistry) Deregister(k listenerKey) bool { func (r *ListenerRegistry) Deregister(k listenerKey) (bool, error) {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
if _, ok := r.starting[k]; ok {
return false, errors.New("listener didn't start yet")
}
_, ok := r.Listeners[k] _, ok := r.Listeners[k]
delete(r.Listeners, k) delete(r.Listeners, k)
return ok return ok, nil
} }
func getListenerKey(l Listener) listenerKey { func getListenerKey(l Listener) listenerKey {
......
...@@ -2,7 +2,6 @@ package p2p ...@@ -2,7 +2,6 @@ package p2p
import ( import (
"context" "context"
"errors"
"time" "time"
"gx/ipfs/QmV6FjemM1K8oXjrvuq3wuVWWoU2TLDPmNnKrxHzY3v6Ai/go-multiaddr-net" "gx/ipfs/QmV6FjemM1K8oXjrvuq3wuVWWoU2TLDPmNnKrxHzY3v6Ai/go-multiaddr-net"
...@@ -105,11 +104,11 @@ func (l *localListener) start() error { ...@@ -105,11 +104,11 @@ func (l *localListener) start() error {
} }
func (l *localListener) Close() error { func (l *localListener) Close() error {
if l.listener == nil { ok, err := l.p2p.Listeners.Deregister(getListenerKey(l))
return errors.New("uninitialized") if err != nil {
return err
} }
if ok {
if l.p2p.Listeners.Deregister(getListenerKey(l)) {
l.listener.Close() l.listener.Close()
l.listener = nil l.listener = nil
} }
......
...@@ -28,6 +28,7 @@ func NewP2P(identity peer.ID, peerHost p2phost.Host, peerstore pstore.Peerstore) ...@@ -28,6 +28,7 @@ func NewP2P(identity peer.ID, peerHost p2phost.Host, peerstore pstore.Peerstore)
Listeners: &ListenerRegistry{ Listeners: &ListenerRegistry{
Listeners: map[listenerKey]Listener{}, Listeners: map[listenerKey]Listener{},
starting: map[listenerKey]struct{}{},
}, },
Streams: &StreamRegistry{ Streams: &StreamRegistry{
Streams: map[uint64]*Stream{}, Streams: map[uint64]*Stream{},
......
...@@ -2,7 +2,6 @@ package p2p ...@@ -2,7 +2,6 @@ package p2p
import ( import (
"context" "context"
"errors"
manet "gx/ipfs/QmV6FjemM1K8oXjrvuq3wuVWWoU2TLDPmNnKrxHzY3v6Ai/go-multiaddr-net" manet "gx/ipfs/QmV6FjemM1K8oXjrvuq3wuVWWoU2TLDPmNnKrxHzY3v6Ai/go-multiaddr-net"
ma "gx/ipfs/QmYmsdtJ3HsodkePE3eU3TsCaP2YvPZJ4LoXnNkDE5Tpt7/go-multiaddr" ma "gx/ipfs/QmYmsdtJ3HsodkePE3eU3TsCaP2YvPZJ4LoXnNkDE5Tpt7/go-multiaddr"
...@@ -21,8 +20,6 @@ type remoteListener struct { ...@@ -21,8 +20,6 @@ type remoteListener struct {
// Address to proxy the incoming connections to // Address to proxy the incoming connections to
addr ma.Multiaddr addr ma.Multiaddr
initialized bool
} }
// ForwardRemote creates new p2p listener // ForwardRemote creates new p2p listener
...@@ -72,7 +69,6 @@ func (l *remoteListener) start() error { ...@@ -72,7 +69,6 @@ func (l *remoteListener) start() error {
stream.startStreaming() stream.startStreaming()
}) })
l.initialized = true
return nil return nil
} }
...@@ -93,13 +89,12 @@ func (l *remoteListener) TargetAddress() ma.Multiaddr { ...@@ -93,13 +89,12 @@ func (l *remoteListener) TargetAddress() ma.Multiaddr {
} }
func (l *remoteListener) Close() error { func (l *remoteListener) Close() error {
if !l.initialized { ok, err := l.p2p.Listeners.Deregister(getListenerKey(l))
return errors.New("uninitialized") if err != nil {
return err
} }
if ok {
if l.p2p.Listeners.Deregister(getListenerKey(l)) {
l.p2p.peerHost.RemoveStreamHandler(l.proto) l.p2p.peerHost.RemoveStreamHandler(l.proto)
l.initialized = false
} }
return nil return nil
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment