Unverified Commit ad0faef7 authored by Raúl Kripalani's avatar Raúl Kripalani Committed by GitHub

make RemoveProtocols take the write lock. (#90)

parent 06edc321
...@@ -39,8 +39,9 @@ func NewProtoBook(meta pstore.PeerMetadata) pstore.ProtoBook { ...@@ -39,8 +39,9 @@ func NewProtoBook(meta pstore.PeerMetadata) pstore.ProtoBook {
} }
func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error { func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error {
pb.segments.get(p).Lock() s := pb.segments.get(p)
defer pb.segments.get(p).Unlock() s.Lock()
defer s.Unlock()
protomap := make(map[string]struct{}, len(protos)) protomap := make(map[string]struct{}, len(protos))
for _, proto := range protos { for _, proto := range protos {
...@@ -51,8 +52,9 @@ func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error { ...@@ -51,8 +52,9 @@ func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error {
} }
func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error { func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error {
pb.segments.get(p).Lock() s := pb.segments.get(p)
defer pb.segments.get(p).Unlock() s.Lock()
defer s.Unlock()
pmap, err := pb.getProtocolMap(p) pmap, err := pb.getProtocolMap(p)
if err != nil { if err != nil {
...@@ -67,8 +69,9 @@ func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error { ...@@ -67,8 +69,9 @@ func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error {
} }
func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) { func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) {
pb.segments.get(p).RLock() s := pb.segments.get(p)
defer pb.segments.get(p).RUnlock() s.RLock()
defer s.RUnlock()
pmap, err := pb.getProtocolMap(p) pmap, err := pb.getProtocolMap(p)
if err != nil { if err != nil {
...@@ -84,8 +87,9 @@ func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) { ...@@ -84,8 +87,9 @@ func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) {
} }
func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, error) { func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, error) {
pb.segments.get(p).RLock() s := pb.segments.get(p)
defer pb.segments.get(p).RUnlock() s.RLock()
defer s.RUnlock()
pmap, err := pb.getProtocolMap(p) pmap, err := pb.getProtocolMap(p)
if err != nil { if err != nil {
...@@ -104,8 +108,8 @@ func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, ...@@ -104,8 +108,8 @@ func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string,
func (pb *dsProtoBook) RemoveProtocols(p peer.ID, protos ...string) error { func (pb *dsProtoBook) RemoveProtocols(p peer.ID, protos ...string) error {
s := pb.segments.get(p) s := pb.segments.get(p)
s.RLock() s.Lock()
defer s.RUnlock() defer s.Unlock()
pmap, err := pb.getProtocolMap(p) pmap, err := pb.getProtocolMap(p)
if err != nil { if err != nil {
......
...@@ -114,8 +114,8 @@ func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]string, error) { ...@@ -114,8 +114,8 @@ func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]string, error) {
func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...string) error { func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...string) error {
s := pb.segments.get(p) s := pb.segments.get(p)
s.RLock() s.Lock()
defer s.RUnlock() defer s.Unlock()
protomap, ok := s.protocols[p] protomap, ok := s.protocols[p]
if !ok { if !ok {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment