Commit f2b8803a authored by Juan Batiz-Benet's avatar Juan Batiz-Benet

net/service now uses ctxcloser

parent 93497c2d
...@@ -115,19 +115,9 @@ func NewIpfsNode(cfg *config.Config, online bool) (*IpfsNode, error) { ...@@ -115,19 +115,9 @@ func NewIpfsNode(cfg *config.Config, online bool) (*IpfsNode, error) {
if online { if online {
dhtService := netservice.NewService(nil) // nil handler for now, need to patch it dhtService := netservice.NewService(ctx, nil) // nil handler for now, need to patch it
exchangeService := netservice.NewService(nil) // nil handler for now, need to patch it exchangeService := netservice.NewService(ctx, nil) // nil handler for now, need to patch it
diagService := netservice.NewService(nil) diagService := netservice.NewService(ctx, nil)
if err := dhtService.Start(ctx); err != nil {
return nil, err
}
if err := exchangeService.Start(ctx); err != nil {
return nil, err
}
if err := diagService.Start(ctx); err != nil {
return nil, err
}
net, err = inet.NewIpfsNetwork(ctx, local, peerstore, &mux.ProtocolMap{ net, err = inet.NewIpfsNetwork(ctx, local, peerstore, &mux.ProtocolMap{
mux.ProtocolID_Routing: dhtService, mux.ProtocolID_Routing: dhtService,
......
...@@ -2,10 +2,12 @@ package service ...@@ -2,10 +2,12 @@ package service
import ( import (
"errors" "errors"
"fmt"
"sync" "sync"
msg "github.com/jbenet/go-ipfs/net/message" msg "github.com/jbenet/go-ipfs/net/message"
u "github.com/jbenet/go-ipfs/util" u "github.com/jbenet/go-ipfs/util"
ctxc "github.com/jbenet/go-ipfs/util/ctxcloser"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
) )
...@@ -39,10 +41,7 @@ type Sender interface { ...@@ -39,10 +41,7 @@ type Sender interface {
// incomig (SetHandler) requests. // incomig (SetHandler) requests.
type Service interface { type Service interface {
Sender Sender
ctxc.ContextCloser
// Start + Stop Service
Start(ctx context.Context) error
Stop()
// GetPipe // GetPipe
GetPipe() *msg.Pipe GetPipe() *msg.Pipe
...@@ -56,45 +55,30 @@ type Service interface { ...@@ -56,45 +55,30 @@ type Service interface {
// messages over the same channel, and to issue + handle requests. // messages over the same channel, and to issue + handle requests.
type service struct { type service struct {
// Handler is the object registered to handle incoming requests. // Handler is the object registered to handle incoming requests.
Handler Handler Handler Handler
HandlerLock sync.RWMutex
// Requests are all the pending requests on this service. // Requests are all the pending requests on this service.
Requests RequestMap Requests RequestMap
RequestsLock sync.RWMutex RequestsLock sync.RWMutex
// cancel is the function to stop the Service
cancel context.CancelFunc
// Message Pipe (connected to the outside world) // Message Pipe (connected to the outside world)
*msg.Pipe *msg.Pipe
ctxc.ContextCloser
} }
// NewService creates a service object with given type ID and Handler // NewService creates a service object with given type ID and Handler
func NewService(h Handler) Service { func NewService(ctx context.Context, h Handler) Service {
return &service{ s := &service{
Handler: h, Handler: h,
Requests: RequestMap{}, Requests: RequestMap{},
Pipe: msg.NewPipe(10), Pipe: msg.NewPipe(10),
} ContextCloser: ctxc.NewContextCloser(ctx, nil),
}
// Start kicks off the Service goroutines.
func (s *service) Start(ctx context.Context) error {
if s.cancel != nil {
return errors.New("Service already started.")
} }
// make a cancellable context. s.Children().Add(1)
ctx, s.cancel = context.WithCancel(ctx) go s.handleIncomingMessages()
return s
go s.handleIncomingMessages(ctx)
return nil
}
// Stop stops Service activity.
func (s *service) Stop() {
s.cancel()
s.cancel = context.CancelFunc(nil)
} }
// GetPipe implements the mux.Protocol interface // GetPipe implements the mux.Protocol interface
...@@ -132,6 +116,15 @@ func (s *service) SendMessage(ctx context.Context, m msg.NetMessage) error { ...@@ -132,6 +116,15 @@ func (s *service) SendMessage(ctx context.Context, m msg.NetMessage) error {
// SendRequest sends a request message out and awaits a response. // SendRequest sends a request message out and awaits a response.
func (s *service) SendRequest(ctx context.Context, m msg.NetMessage) (msg.NetMessage, error) { func (s *service) SendRequest(ctx context.Context, m msg.NetMessage) (msg.NetMessage, error) {
// check if we should bail given our contexts
select {
default:
case <-s.Closing():
return nil, fmt.Errorf("service closed: %s", s.Context().Err())
case <-ctx.Done():
return nil, ctx.Err()
}
// create a request // create a request
r, err := NewRequest(m.Peer().ID()) r, err := NewRequest(m.Peer().ID())
if err != nil { if err != nil {
...@@ -153,6 +146,8 @@ func (s *service) SendRequest(ctx context.Context, m msg.NetMessage) (msg.NetMes ...@@ -153,6 +146,8 @@ func (s *service) SendRequest(ctx context.Context, m msg.NetMessage) (msg.NetMes
// check if we should bail after waiting for mutex // check if we should bail after waiting for mutex
select { select {
default: default:
case <-s.Closing():
return nil, fmt.Errorf("service closed: %s", s.Context().Err())
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
} }
...@@ -165,6 +160,8 @@ func (s *service) SendRequest(ctx context.Context, m msg.NetMessage) (msg.NetMes ...@@ -165,6 +160,8 @@ func (s *service) SendRequest(ctx context.Context, m msg.NetMessage) (msg.NetMes
err = nil err = nil
select { select {
case m = <-r.Response: case m = <-r.Response:
case <-s.Closed():
err = fmt.Errorf("service closed: %s", s.Context().Err())
case <-ctx.Done(): case <-ctx.Done():
err = ctx.Err() err = ctx.Err()
} }
...@@ -178,43 +175,50 @@ func (s *service) SendRequest(ctx context.Context, m msg.NetMessage) (msg.NetMes ...@@ -178,43 +175,50 @@ func (s *service) SendRequest(ctx context.Context, m msg.NetMessage) (msg.NetMes
// handleIncoming consumes the messages on the s.Incoming channel and // handleIncoming consumes the messages on the s.Incoming channel and
// routes them appropriately (to requests, or handler). // routes them appropriately (to requests, or handler).
func (s *service) handleIncomingMessages(ctx context.Context) { func (s *service) handleIncomingMessages() {
defer s.Children().Done()
for { for {
select { select {
case m, more := <-s.Incoming: case m, more := <-s.Incoming:
if !more { if !more {
return return
} }
go s.handleIncomingMessage(ctx, m) s.Children().Add(1)
go s.handleIncomingMessage(m)
case <-ctx.Done(): case <-s.Closing():
return return
} }
} }
} }
func (s *service) handleIncomingMessage(ctx context.Context, m msg.NetMessage) { func (s *service) handleIncomingMessage(m msg.NetMessage) {
defer s.Children().Done()
// unwrap the incoming message // unwrap the incoming message
data, rid, err := unwrapData(m.Data()) data, rid, err := unwrapData(m.Data())
if err != nil { if err != nil {
log.Errorf("de-serializing error: %v", err) log.Errorf("service de-serializing error: %v", err)
return
} }
m2 := msg.New(m.Peer(), data) m2 := msg.New(m.Peer(), data)
// if it's a request (or has no RequestID), handle it // if it's a request (or has no RequestID), handle it
if rid == nil || rid.IsRequest() { if rid == nil || rid.IsRequest() {
if s.Handler == nil { handler := s.GetHandler()
if handler == nil {
log.Errorf("service dropped msg: %v", m) log.Errorf("service dropped msg: %v", m)
return // no handler, drop it. return // no handler, drop it.
} }
// should this be "go HandleMessage ... ?" // should this be "go HandleMessage ... ?"
r1 := s.Handler.HandleMessage(ctx, m2) r1 := handler.HandleMessage(s.Context(), m2)
// if handler gave us a response, send it back out! // if handler gave us a response, send it back out!
if r1 != nil { if r1 != nil {
err := s.sendMessage(ctx, r1, rid.Response()) err := s.sendMessage(s.Context(), r1, rid.Response())
if err != nil { if err != nil {
log.Errorf("error sending response message: %v", err) log.Errorf("error sending response message: %v", err)
} }
...@@ -239,16 +243,20 @@ func (s *service) handleIncomingMessage(ctx context.Context, m msg.NetMessage) { ...@@ -239,16 +243,20 @@ func (s *service) handleIncomingMessage(ctx context.Context, m msg.NetMessage) {
select { select {
case r.Response <- m2: case r.Response <- m2:
case <-ctx.Done(): case <-s.Closing():
} }
} }
// SetHandler assigns the request Handler for this service. // SetHandler assigns the request Handler for this service.
func (s *service) SetHandler(h Handler) { func (s *service) SetHandler(h Handler) {
s.HandlerLock.Lock()
defer s.HandlerLock.Unlock()
s.Handler = h s.Handler = h
} }
// GetHandler returns the request Handler for this service. // GetHandler returns the request Handler for this service.
func (s *service) GetHandler() Handler { func (s *service) GetHandler() Handler {
s.HandlerLock.RLock()
defer s.HandlerLock.RUnlock()
return s.Handler return s.Handler
} }
...@@ -38,13 +38,9 @@ func newPeer(t *testing.T, id string) peer.Peer { ...@@ -38,13 +38,9 @@ func newPeer(t *testing.T, id string) peer.Peer {
func TestServiceHandler(t *testing.T) { func TestServiceHandler(t *testing.T) {
ctx := context.Background() ctx := context.Background()
h := &ReverseHandler{} h := &ReverseHandler{}
s := NewService(h) s := NewService(ctx, h)
peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa")
if err := s.Start(ctx); err != nil {
t.Error(err)
}
d, err := wrapData([]byte("beep"), nil) d, err := wrapData([]byte("beep"), nil)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
...@@ -70,16 +66,8 @@ func TestServiceHandler(t *testing.T) { ...@@ -70,16 +66,8 @@ func TestServiceHandler(t *testing.T) {
func TestServiceRequest(t *testing.T) { func TestServiceRequest(t *testing.T) {
ctx := context.Background() ctx := context.Background()
s1 := NewService(&ReverseHandler{}) s1 := NewService(ctx, &ReverseHandler{})
s2 := NewService(&ReverseHandler{}) s2 := NewService(ctx, &ReverseHandler{})
if err := s1.Start(ctx); err != nil {
t.Error(err)
}
if err := s2.Start(ctx); err != nil {
t.Error(err)
}
peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa")
...@@ -110,18 +98,10 @@ func TestServiceRequest(t *testing.T) { ...@@ -110,18 +98,10 @@ func TestServiceRequest(t *testing.T) {
func TestServiceRequestTimeout(t *testing.T) { func TestServiceRequestTimeout(t *testing.T) {
ctx, _ := context.WithTimeout(context.Background(), time.Millisecond) ctx, _ := context.WithTimeout(context.Background(), time.Millisecond)
s1 := NewService(&ReverseHandler{}) s1 := NewService(ctx, &ReverseHandler{})
s2 := NewService(&ReverseHandler{}) s2 := NewService(ctx, &ReverseHandler{})
peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa")
if err := s1.Start(ctx); err != nil {
t.Error(err)
}
if err := s2.Start(ctx); err != nil {
t.Error(err)
}
// patch services together // patch services together
go func() { go func() {
for { for {
...@@ -143,3 +123,41 @@ func TestServiceRequestTimeout(t *testing.T) { ...@@ -143,3 +123,41 @@ func TestServiceRequestTimeout(t *testing.T) {
t.Error("should've timed out") t.Error("should've timed out")
} }
} }
func TestServiceClose(t *testing.T) {
ctx := context.Background()
s1 := NewService(ctx, &ReverseHandler{})
s2 := NewService(ctx, &ReverseHandler{})
peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa")
// patch services together
go func() {
for {
select {
case m := <-s1.GetPipe().Outgoing:
s2.GetPipe().Incoming <- m
case m := <-s2.GetPipe().Outgoing:
s1.GetPipe().Incoming <- m
case <-ctx.Done():
return
}
}
}()
m1 := msg.New(peer1, []byte("beep"))
m2, err := s1.SendRequest(ctx, m1)
if err != nil {
t.Error(err)
}
if !bytes.Equal(m2.Data(), []byte("peeb")) {
t.Errorf("service handler data incorrect: %v != %v", m2.Data(), "oof")
}
s1.Close()
s2.Close()
<-s1.Closed()
<-s2.Closed()
}
...@@ -23,11 +23,7 @@ import ( ...@@ -23,11 +23,7 @@ import (
func setupDHT(ctx context.Context, t *testing.T, p peer.Peer) *IpfsDHT { func setupDHT(ctx context.Context, t *testing.T, p peer.Peer) *IpfsDHT {
peerstore := peer.NewPeerstore() peerstore := peer.NewPeerstore()
dhts := netservice.NewService(nil) // nil handler for now, need to patch it dhts := netservice.NewService(ctx, nil) // nil handler for now, need to patch it
if err := dhts.Start(ctx); err != nil {
t.Fatal(err)
}
net, err := inet.NewIpfsNetwork(ctx, p, peerstore, &mux.ProtocolMap{ net, err := inet.NewIpfsNetwork(ctx, p, peerstore, &mux.ProtocolMap{
mux.ProtocolID_Routing: dhts, mux.ProtocolID_Routing: dhts,
}) })
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment