providerquerymanager.go 11.8 KB
Newer Older
1 2 3 4
package providerquerymanager

import (
	"context"
5
	"fmt"
6
	"sync"
7
	"time"
8 9 10 11 12 13 14 15 16 17 18

	"github.com/ipfs/go-cid"
	logging "github.com/ipfs/go-log"
	peer "github.com/libp2p/go-libp2p-peer"
)

var log = logging.Logger("bitswap")

const (
	maxProviders         = 10
	maxInProcessRequests = 6
19
	defaultTimeout       = 10 * time.Second
20 21 22
)

type inProgressRequestStatus struct {
23 24
	ctx            context.Context
	cancelFn       func()
25
	providersSoFar []peer.ID
26
	listeners      map[chan peer.ID]struct{}
27 28
}

29 30 31 32 33
type findProviderRequest struct {
	k   cid.Cid
	ctx context.Context
}

34 35 36 37 38 39 40 41
// ProviderQueryNetwork is an interface for finding providers and connecting to
// peers.
type ProviderQueryNetwork interface {
	ConnectTo(context.Context, peer.ID) error
	FindProvidersAsync(context.Context, cid.Cid, int) <-chan peer.ID
}

type providerQueryMessage interface {
42
	debugMessage() string
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
	handle(pqm *ProviderQueryManager)
}

type receivedProviderMessage struct {
	k cid.Cid
	p peer.ID
}

type finishedProviderQueryMessage struct {
	k cid.Cid
}

type newProvideQueryMessage struct {
	k                     cid.Cid
	inProgressRequestChan chan<- inProgressRequest
}

type cancelRequestMessage struct {
61 62
	incomingProviders chan peer.ID
	k                 cid.Cid
63 64 65 66 67 68 69 70 71 72
}

// ProviderQueryManager manages requests to find more providers for blocks
// for bitswap sessions. It's main goals are to:
// - rate limit requests -- don't have too many find provider calls running
// simultaneously
// - connect to found peers and filter them if it can't connect
// - ensure two findprovider calls for the same block don't run concurrently
// - manage timeouts
type ProviderQueryManager struct {
73 74 75
	ctx                          context.Context
	network                      ProviderQueryNetwork
	providerQueryMessages        chan providerQueryMessage
76 77
	providerRequestsProcessing   chan *findProviderRequest
	incomingFindProviderRequests chan *findProviderRequest
78 79 80 81 82 83

	findProviderTimeout time.Duration
	timeoutMutex        sync.RWMutex

	// do not touch outside the run loop
	inProgressRequestStatuses map[cid.Cid]*inProgressRequestStatus
84 85 86 87 88 89 90 91 92
}

// New initializes a new ProviderQueryManager for a given context and a given
// network provider.
func New(ctx context.Context, network ProviderQueryNetwork) *ProviderQueryManager {
	return &ProviderQueryManager{
		ctx:                          ctx,
		network:                      network,
		providerQueryMessages:        make(chan providerQueryMessage, 16),
93 94
		providerRequestsProcessing:   make(chan *findProviderRequest),
		incomingFindProviderRequests: make(chan *findProviderRequest),
95
		inProgressRequestStatuses:    make(map[cid.Cid]*inProgressRequestStatus),
96
		findProviderTimeout:          defaultTimeout,
97 98 99 100 101 102 103 104 105 106
	}
}

// Startup starts processing for the ProviderQueryManager.
func (pqm *ProviderQueryManager) Startup() {
	go pqm.run()
}

type inProgressRequest struct {
	providersSoFar []peer.ID
107
	incoming       chan peer.ID
108 109
}

110 111 112 113 114 115 116
// SetFindProviderTimeout changes the timeout for finding providers
func (pqm *ProviderQueryManager) SetFindProviderTimeout(findProviderTimeout time.Duration) {
	pqm.timeoutMutex.Lock()
	pqm.findProviderTimeout = findProviderTimeout
	pqm.timeoutMutex.Unlock()
}

117
// FindProvidersAsync finds providers for the given block.
118
func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, k cid.Cid) <-chan peer.ID {
119 120 121 122 123 124 125 126
	inProgressRequestChan := make(chan inProgressRequest)

	select {
	case pqm.providerQueryMessages <- &newProvideQueryMessage{
		k:                     k,
		inProgressRequestChan: inProgressRequestChan,
	}:
	case <-pqm.ctx.Done():
127 128 129
		ch := make(chan peer.ID)
		close(ch)
		return ch
130
	case <-sessionCtx.Done():
131 132 133
		ch := make(chan peer.ID)
		close(ch)
		return ch
134 135 136 137
	}

	var receivedInProgressRequest inProgressRequest
	select {
138
	case <-pqm.ctx.Done():
139 140 141
		ch := make(chan peer.ID)
		close(ch)
		return ch
142
	case <-sessionCtx.Done():
143 144 145
		ch := make(chan peer.ID)
		close(ch)
		return ch
146 147 148
	case receivedInProgressRequest = <-inProgressRequestChan:
	}

149
	return pqm.receiveProviders(sessionCtx, k, receivedInProgressRequest)
150 151
}

152
func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k cid.Cid, receivedInProgressRequest inProgressRequest) <-chan peer.ID {
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
	// maintains an unbuffered queue for incoming providers for given request for a given session
	// essentially, as a provider comes in, for a given CID, we want to immediately broadcast to all
	// sessions that queried that CID, without worrying about whether the client code is actually
	// reading from the returned channel -- so that the broadcast never blocks
	// based on: https://medium.com/capital-one-tech/building-an-unbounded-channel-in-go-789e175cd2cd
	returnedProviders := make(chan peer.ID)
	receivedProviders := append([]peer.ID(nil), receivedInProgressRequest.providersSoFar[0:]...)
	incomingProviders := receivedInProgressRequest.incoming

	go func() {
		defer close(returnedProviders)
		outgoingProviders := func() chan<- peer.ID {
			if len(receivedProviders) == 0 {
				return nil
			}
			return returnedProviders
		}
		nextProvider := func() peer.ID {
			if len(receivedProviders) == 0 {
				return ""
			}
			return receivedProviders[0]
		}
		for len(receivedProviders) > 0 || incomingProviders != nil {
			select {
178 179
			case <-pqm.ctx.Done():
				return
180
			case <-sessionCtx.Done():
181 182
				pqm.cancelProviderRequest(k, incomingProviders)
				return
183 184 185 186 187 188 189 190 191 192 193 194 195 196
			case provider, ok := <-incomingProviders:
				if !ok {
					incomingProviders = nil
				} else {
					receivedProviders = append(receivedProviders, provider)
				}
			case outgoingProviders() <- nextProvider():
				receivedProviders = receivedProviders[1:]
			}
		}
	}()
	return returnedProviders
}

197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
func (pqm *ProviderQueryManager) cancelProviderRequest(k cid.Cid, incomingProviders chan peer.ID) {
	cancelMessageChannel := pqm.providerQueryMessages
	for {
		select {
		case cancelMessageChannel <- &cancelRequestMessage{
			incomingProviders: incomingProviders,
			k:                 k,
		}:
			cancelMessageChannel = nil
		// clear out any remaining providers, in case and "incoming provider"
		// messages get processed before our cancel message
		case _, ok := <-incomingProviders:
			if !ok {
				return
			}
		case <-pqm.ctx.Done():
			return
		}
	}
}

218 219 220 221 222 223
func (pqm *ProviderQueryManager) findProviderWorker() {
	// findProviderWorker just cycles through incoming provider queries one
	// at a time. We have six of these workers running at once
	// to let requests go in parallel but keep them rate limited
	for {
		select {
224
		case fpr, ok := <-pqm.providerRequestsProcessing:
225 226 227
			if !ok {
				return
			}
228
			k := fpr.k
229
			log.Debugf("Beginning Find Provider Request for cid: %s", k.String())
230
			pqm.timeoutMutex.RLock()
231
			findProviderCtx, cancel := context.WithTimeout(fpr.ctx, pqm.findProviderTimeout)
232 233
			pqm.timeoutMutex.RUnlock()
			providers := pqm.network.FindProvidersAsync(findProviderCtx, k, maxProviders)
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
			wg := &sync.WaitGroup{}
			for p := range providers {
				wg.Add(1)
				go func(p peer.ID) {
					defer wg.Done()
					err := pqm.network.ConnectTo(pqm.ctx, p)
					if err != nil {
						log.Debugf("failed to connect to provider %s: %s", p, err)
						return
					}
					select {
					case pqm.providerQueryMessages <- &receivedProviderMessage{
						k: k,
						p: p,
					}:
					case <-pqm.ctx.Done():
						return
					}
				}(p)
			}
254
			cancel()
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
			wg.Wait()
			select {
			case pqm.providerQueryMessages <- &finishedProviderQueryMessage{
				k: k,
			}:
			case <-pqm.ctx.Done():
			}
		case <-pqm.ctx.Done():
			return
		}
	}
}

func (pqm *ProviderQueryManager) providerRequestBufferWorker() {
	// the provider request buffer worker just maintains an unbounded
	// buffer for incoming provider queries and dispatches to the find
	// provider workers as they become available
	// based on: https://medium.com/capital-one-tech/building-an-unbounded-channel-in-go-789e175cd2cd
273 274
	var providerQueryRequestBuffer []*findProviderRequest
	nextProviderQuery := func() *findProviderRequest {
275
		if len(providerQueryRequestBuffer) == 0 {
276
			return nil
277 278 279
		}
		return providerQueryRequestBuffer[0]
	}
280
	outgoingRequests := func() chan<- *findProviderRequest {
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
		if len(providerQueryRequestBuffer) == 0 {
			return nil
		}
		return pqm.providerRequestsProcessing
	}

	for {
		select {
		case incomingRequest, ok := <-pqm.incomingFindProviderRequests:
			if !ok {
				return
			}
			providerQueryRequestBuffer = append(providerQueryRequestBuffer, incomingRequest)
		case outgoingRequests() <- nextProviderQuery():
			providerQueryRequestBuffer = providerQueryRequestBuffer[1:]
		case <-pqm.ctx.Done():
			return
		}
	}
}

func (pqm *ProviderQueryManager) cleanupInProcessRequests() {
	for _, requestStatus := range pqm.inProgressRequestStatuses {
304
		for listener := range requestStatus.listeners {
305 306
			close(listener)
		}
307
		requestStatus.cancelFn()
308 309 310 311 312 313 314 315 316 317 318 319 320 321
	}
}

func (pqm *ProviderQueryManager) run() {
	defer pqm.cleanupInProcessRequests()

	go pqm.providerRequestBufferWorker()
	for i := 0; i < maxInProcessRequests; i++ {
		go pqm.findProviderWorker()
	}

	for {
		select {
		case nextMessage := <-pqm.providerQueryMessages:
322
			log.Debug(nextMessage.debugMessage())
323 324 325 326 327 328 329
			nextMessage.handle(pqm)
		case <-pqm.ctx.Done():
			return
		}
	}
}

330 331 332 333
func (rpm *receivedProviderMessage) debugMessage() string {
	return fmt.Sprintf("Received provider (%s) for cid (%s)", rpm.p.String(), rpm.k.String())
}

334 335 336 337 338 339 340
func (rpm *receivedProviderMessage) handle(pqm *ProviderQueryManager) {
	requestStatus, ok := pqm.inProgressRequestStatuses[rpm.k]
	if !ok {
		log.Errorf("Received provider (%s) for cid (%s) not requested", rpm.p.String(), rpm.k.String())
		return
	}
	requestStatus.providersSoFar = append(requestStatus.providersSoFar, rpm.p)
341
	for listener := range requestStatus.listeners {
342 343 344 345 346 347 348 349
		select {
		case listener <- rpm.p:
		case <-pqm.ctx.Done():
			return
		}
	}
}

350 351 352 353
func (fpqm *finishedProviderQueryMessage) debugMessage() string {
	return fmt.Sprintf("Finished Provider Query on cid: %s", fpqm.k.String())
}

354 355 356 357 358 359
func (fpqm *finishedProviderQueryMessage) handle(pqm *ProviderQueryManager) {
	requestStatus, ok := pqm.inProgressRequestStatuses[fpqm.k]
	if !ok {
		log.Errorf("Ended request for cid (%s) not in progress", fpqm.k.String())
		return
	}
360
	for listener := range requestStatus.listeners {
361 362 363
		close(listener)
	}
	delete(pqm.inProgressRequestStatuses, fpqm.k)
364
	requestStatus.cancelFn()
365 366
}

367
func (npqm *newProvideQueryMessage) debugMessage() string {
368
	return fmt.Sprintf("New Provider Query on cid: %s", npqm.k.String())
369 370
}

371 372 373
func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) {
	requestStatus, ok := pqm.inProgressRequestStatuses[npqm.k]
	if !ok {
374
		ctx, cancelFn := context.WithCancel(pqm.ctx)
375
		requestStatus = &inProgressRequestStatus{
376
			listeners: make(map[chan peer.ID]struct{}),
377 378
			ctx:       ctx,
			cancelFn:  cancelFn,
379 380 381
		}
		pqm.inProgressRequestStatuses[npqm.k] = requestStatus
		select {
382 383 384 385
		case pqm.incomingFindProviderRequests <- &findProviderRequest{
			k:   npqm.k,
			ctx: ctx,
		}:
386 387 388 389
		case <-pqm.ctx.Done():
			return
		}
	}
390 391
	inProgressChan := make(chan peer.ID)
	requestStatus.listeners[inProgressChan] = struct{}{}
392 393 394
	select {
	case npqm.inProgressRequestChan <- inProgressRequest{
		providersSoFar: requestStatus.providersSoFar,
395
		incoming:       inProgressChan,
396 397 398 399 400
	}:
	case <-pqm.ctx.Done():
	}
}

401
func (crm *cancelRequestMessage) debugMessage() string {
402
	return fmt.Sprintf("Cancel provider query on cid: %s", crm.k.String())
403 404
}

405 406
func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) {
	requestStatus, ok := pqm.inProgressRequestStatuses[crm.k]
407
	if !ok {
408
		log.Errorf("Attempt to cancel request for cid (%s) not in progress", crm.k.String())
409
		return
410
	}
411 412 413 414 415 416
	_, ok = requestStatus.listeners[crm.incomingProviders]
	if !ok {
		log.Errorf("Attempt to cancel request for for cid (%s) this is not a listener", crm.k.String())
		return
	}
	delete(requestStatus.listeners, crm.incomingProviders)
417
	close(crm.incomingProviders)
418 419 420 421
	if len(requestStatus.listeners) == 0 {
		delete(pqm.inProgressRequestStatuses, crm.k)
		requestStatus.cancelFn()
	}
422
}