providerquerymanager.go 11.8 KB
Newer Older
1 2 3 4
package providerquerymanager

import (
	"context"
5
	"fmt"
6
	"sync"
7
	"time"
8

9 10 11
	"gitlab.dms3.io/dms3/go-cid"
	logging "gitlab.dms3.io/dms3/go-log"
	peer "gitlab.dms3.io/p2p/go-p2p-core/peer"
12 13 14 15 16 17 18
)

var log = logging.Logger("bitswap")

const (
	maxProviders         = 10
	maxInProcessRequests = 6
19
	defaultTimeout       = 10 * time.Second
20 21 22
)

type inProgressRequestStatus struct {
23 24
	ctx            context.Context
	cancelFn       func()
25
	providersSoFar []peer.ID
26
	listeners      map[chan peer.ID]struct{}
27 28
}

29 30 31 32 33
type findProviderRequest struct {
	k   cid.Cid
	ctx context.Context
}

34 35 36 37 38 39 40 41
// ProviderQueryNetwork is an interface for finding providers and connecting to
// peers.
type ProviderQueryNetwork interface {
	ConnectTo(context.Context, peer.ID) error
	FindProvidersAsync(context.Context, cid.Cid, int) <-chan peer.ID
}

type providerQueryMessage interface {
42
	debugMessage() string
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
	handle(pqm *ProviderQueryManager)
}

type receivedProviderMessage struct {
	k cid.Cid
	p peer.ID
}

type finishedProviderQueryMessage struct {
	k cid.Cid
}

type newProvideQueryMessage struct {
	k                     cid.Cid
	inProgressRequestChan chan<- inProgressRequest
}

type cancelRequestMessage struct {
61 62
	incomingProviders chan peer.ID
	k                 cid.Cid
63 64 65 66 67 68 69 70 71 72
}

// ProviderQueryManager manages requests to find more providers for blocks
// for bitswap sessions. It's main goals are to:
// - rate limit requests -- don't have too many find provider calls running
// simultaneously
// - connect to found peers and filter them if it can't connect
// - ensure two findprovider calls for the same block don't run concurrently
// - manage timeouts
type ProviderQueryManager struct {
73 74 75
	ctx                          context.Context
	network                      ProviderQueryNetwork
	providerQueryMessages        chan providerQueryMessage
76 77
	providerRequestsProcessing   chan *findProviderRequest
	incomingFindProviderRequests chan *findProviderRequest
78 79 80 81 82 83

	findProviderTimeout time.Duration
	timeoutMutex        sync.RWMutex

	// do not touch outside the run loop
	inProgressRequestStatuses map[cid.Cid]*inProgressRequestStatus
84 85 86 87 88 89 90 91 92
}

// New initializes a new ProviderQueryManager for a given context and a given
// network provider.
func New(ctx context.Context, network ProviderQueryNetwork) *ProviderQueryManager {
	return &ProviderQueryManager{
		ctx:                          ctx,
		network:                      network,
		providerQueryMessages:        make(chan providerQueryMessage, 16),
93 94
		providerRequestsProcessing:   make(chan *findProviderRequest),
		incomingFindProviderRequests: make(chan *findProviderRequest),
95
		inProgressRequestStatuses:    make(map[cid.Cid]*inProgressRequestStatus),
96
		findProviderTimeout:          defaultTimeout,
97 98 99 100 101 102 103 104 105 106
	}
}

// Startup starts processing for the ProviderQueryManager.
func (pqm *ProviderQueryManager) Startup() {
	go pqm.run()
}

type inProgressRequest struct {
	providersSoFar []peer.ID
107
	incoming       chan peer.ID
108 109
}

110 111 112 113 114 115 116
// SetFindProviderTimeout changes the timeout for finding providers
func (pqm *ProviderQueryManager) SetFindProviderTimeout(findProviderTimeout time.Duration) {
	pqm.timeoutMutex.Lock()
	pqm.findProviderTimeout = findProviderTimeout
	pqm.timeoutMutex.Unlock()
}

117
// FindProvidersAsync finds providers for the given block.
118
func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, k cid.Cid) <-chan peer.ID {
119 120 121 122 123 124 125 126
	inProgressRequestChan := make(chan inProgressRequest)

	select {
	case pqm.providerQueryMessages <- &newProvideQueryMessage{
		k:                     k,
		inProgressRequestChan: inProgressRequestChan,
	}:
	case <-pqm.ctx.Done():
127 128 129
		ch := make(chan peer.ID)
		close(ch)
		return ch
130
	case <-sessionCtx.Done():
131 132 133
		ch := make(chan peer.ID)
		close(ch)
		return ch
134 135
	}

136 137 138
	// DO NOT select on sessionCtx. We only want to abort here if we're
	// shutting down because we can't actually _cancel_ the request till we
	// get to receiveProviders.
139 140
	var receivedInProgressRequest inProgressRequest
	select {
141
	case <-pqm.ctx.Done():
142 143 144
		ch := make(chan peer.ID)
		close(ch)
		return ch
145 146 147
	case receivedInProgressRequest = <-inProgressRequestChan:
	}

148
	return pqm.receiveProviders(sessionCtx, k, receivedInProgressRequest)
149 150
}

151
func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k cid.Cid, receivedInProgressRequest inProgressRequest) <-chan peer.ID {
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
	// maintains an unbuffered queue for incoming providers for given request for a given session
	// essentially, as a provider comes in, for a given CID, we want to immediately broadcast to all
	// sessions that queried that CID, without worrying about whether the client code is actually
	// reading from the returned channel -- so that the broadcast never blocks
	// based on: https://medium.com/capital-one-tech/building-an-unbounded-channel-in-go-789e175cd2cd
	returnedProviders := make(chan peer.ID)
	receivedProviders := append([]peer.ID(nil), receivedInProgressRequest.providersSoFar[0:]...)
	incomingProviders := receivedInProgressRequest.incoming

	go func() {
		defer close(returnedProviders)
		outgoingProviders := func() chan<- peer.ID {
			if len(receivedProviders) == 0 {
				return nil
			}
			return returnedProviders
		}
		nextProvider := func() peer.ID {
			if len(receivedProviders) == 0 {
				return ""
			}
			return receivedProviders[0]
		}
		for len(receivedProviders) > 0 || incomingProviders != nil {
			select {
177 178
			case <-pqm.ctx.Done():
				return
179
			case <-sessionCtx.Done():
180 181 182
				if incomingProviders != nil {
					pqm.cancelProviderRequest(k, incomingProviders)
				}
183
				return
184 185 186 187 188 189 190 191 192 193 194 195 196 197
			case provider, ok := <-incomingProviders:
				if !ok {
					incomingProviders = nil
				} else {
					receivedProviders = append(receivedProviders, provider)
				}
			case outgoingProviders() <- nextProvider():
				receivedProviders = receivedProviders[1:]
			}
		}
	}()
	return returnedProviders
}

198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
func (pqm *ProviderQueryManager) cancelProviderRequest(k cid.Cid, incomingProviders chan peer.ID) {
	cancelMessageChannel := pqm.providerQueryMessages
	for {
		select {
		case cancelMessageChannel <- &cancelRequestMessage{
			incomingProviders: incomingProviders,
			k:                 k,
		}:
			cancelMessageChannel = nil
		// clear out any remaining providers, in case and "incoming provider"
		// messages get processed before our cancel message
		case _, ok := <-incomingProviders:
			if !ok {
				return
			}
		case <-pqm.ctx.Done():
			return
		}
	}
}

219 220 221 222 223 224
func (pqm *ProviderQueryManager) findProviderWorker() {
	// findProviderWorker just cycles through incoming provider queries one
	// at a time. We have six of these workers running at once
	// to let requests go in parallel but keep them rate limited
	for {
		select {
225
		case fpr, ok := <-pqm.providerRequestsProcessing:
226 227 228
			if !ok {
				return
			}
229
			k := fpr.k
230
			log.Debugf("Beginning Find Provider Request for cid: %s", k.String())
231
			pqm.timeoutMutex.RLock()
232
			findProviderCtx, cancel := context.WithTimeout(fpr.ctx, pqm.findProviderTimeout)
233 234
			pqm.timeoutMutex.RUnlock()
			providers := pqm.network.FindProvidersAsync(findProviderCtx, k, maxProviders)
235 236 237 238 239
			wg := &sync.WaitGroup{}
			for p := range providers {
				wg.Add(1)
				go func(p peer.ID) {
					defer wg.Done()
240
					err := pqm.network.ConnectTo(findProviderCtx, p)
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
					if err != nil {
						log.Debugf("failed to connect to provider %s: %s", p, err)
						return
					}
					select {
					case pqm.providerQueryMessages <- &receivedProviderMessage{
						k: k,
						p: p,
					}:
					case <-pqm.ctx.Done():
						return
					}
				}(p)
			}
			wg.Wait()
256
			cancel()
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
			select {
			case pqm.providerQueryMessages <- &finishedProviderQueryMessage{
				k: k,
			}:
			case <-pqm.ctx.Done():
			}
		case <-pqm.ctx.Done():
			return
		}
	}
}

func (pqm *ProviderQueryManager) providerRequestBufferWorker() {
	// the provider request buffer worker just maintains an unbounded
	// buffer for incoming provider queries and dispatches to the find
	// provider workers as they become available
	// based on: https://medium.com/capital-one-tech/building-an-unbounded-channel-in-go-789e175cd2cd
274 275
	var providerQueryRequestBuffer []*findProviderRequest
	nextProviderQuery := func() *findProviderRequest {
276
		if len(providerQueryRequestBuffer) == 0 {
277
			return nil
278 279 280
		}
		return providerQueryRequestBuffer[0]
	}
281
	outgoingRequests := func() chan<- *findProviderRequest {
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
		if len(providerQueryRequestBuffer) == 0 {
			return nil
		}
		return pqm.providerRequestsProcessing
	}

	for {
		select {
		case incomingRequest, ok := <-pqm.incomingFindProviderRequests:
			if !ok {
				return
			}
			providerQueryRequestBuffer = append(providerQueryRequestBuffer, incomingRequest)
		case outgoingRequests() <- nextProviderQuery():
			providerQueryRequestBuffer = providerQueryRequestBuffer[1:]
		case <-pqm.ctx.Done():
			return
		}
	}
}

func (pqm *ProviderQueryManager) cleanupInProcessRequests() {
	for _, requestStatus := range pqm.inProgressRequestStatuses {
305
		for listener := range requestStatus.listeners {
306 307
			close(listener)
		}
308
		requestStatus.cancelFn()
309 310 311 312 313 314 315 316 317 318 319 320 321 322
	}
}

func (pqm *ProviderQueryManager) run() {
	defer pqm.cleanupInProcessRequests()

	go pqm.providerRequestBufferWorker()
	for i := 0; i < maxInProcessRequests; i++ {
		go pqm.findProviderWorker()
	}

	for {
		select {
		case nextMessage := <-pqm.providerQueryMessages:
323
			log.Debug(nextMessage.debugMessage())
324 325 326 327 328 329 330
			nextMessage.handle(pqm)
		case <-pqm.ctx.Done():
			return
		}
	}
}

331 332 333 334
func (rpm *receivedProviderMessage) debugMessage() string {
	return fmt.Sprintf("Received provider (%s) for cid (%s)", rpm.p.String(), rpm.k.String())
}

335 336 337 338 339 340 341
func (rpm *receivedProviderMessage) handle(pqm *ProviderQueryManager) {
	requestStatus, ok := pqm.inProgressRequestStatuses[rpm.k]
	if !ok {
		log.Errorf("Received provider (%s) for cid (%s) not requested", rpm.p.String(), rpm.k.String())
		return
	}
	requestStatus.providersSoFar = append(requestStatus.providersSoFar, rpm.p)
342
	for listener := range requestStatus.listeners {
343 344 345 346 347 348 349 350
		select {
		case listener <- rpm.p:
		case <-pqm.ctx.Done():
			return
		}
	}
}

351 352 353 354
func (fpqm *finishedProviderQueryMessage) debugMessage() string {
	return fmt.Sprintf("Finished Provider Query on cid: %s", fpqm.k.String())
}

355 356 357
func (fpqm *finishedProviderQueryMessage) handle(pqm *ProviderQueryManager) {
	requestStatus, ok := pqm.inProgressRequestStatuses[fpqm.k]
	if !ok {
358
		// we canceled the request as it finished.
359 360
		return
	}
361
	for listener := range requestStatus.listeners {
362 363 364
		close(listener)
	}
	delete(pqm.inProgressRequestStatuses, fpqm.k)
365
	requestStatus.cancelFn()
366 367
}

368
func (npqm *newProvideQueryMessage) debugMessage() string {
369
	return fmt.Sprintf("New Provider Query on cid: %s", npqm.k.String())
370 371
}

372 373 374
func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) {
	requestStatus, ok := pqm.inProgressRequestStatuses[npqm.k]
	if !ok {
375
		ctx, cancelFn := context.WithCancel(pqm.ctx)
376
		requestStatus = &inProgressRequestStatus{
377
			listeners: make(map[chan peer.ID]struct{}),
378 379
			ctx:       ctx,
			cancelFn:  cancelFn,
380 381 382
		}
		pqm.inProgressRequestStatuses[npqm.k] = requestStatus
		select {
383 384 385 386
		case pqm.incomingFindProviderRequests <- &findProviderRequest{
			k:   npqm.k,
			ctx: ctx,
		}:
387 388 389 390
		case <-pqm.ctx.Done():
			return
		}
	}
391 392
	inProgressChan := make(chan peer.ID)
	requestStatus.listeners[inProgressChan] = struct{}{}
393 394 395
	select {
	case npqm.inProgressRequestChan <- inProgressRequest{
		providersSoFar: requestStatus.providersSoFar,
396
		incoming:       inProgressChan,
397 398 399 400 401
	}:
	case <-pqm.ctx.Done():
	}
}

402
func (crm *cancelRequestMessage) debugMessage() string {
403
	return fmt.Sprintf("Cancel provider query on cid: %s", crm.k.String())
404 405
}

406 407
func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) {
	requestStatus, ok := pqm.inProgressRequestStatuses[crm.k]
408
	if !ok {
409
		// Request finished while queued.
410
		return
411
	}
412 413
	_, ok = requestStatus.listeners[crm.incomingProviders]
	if !ok {
414
		// Request finished and _restarted_ while queued.
415 416 417
		return
	}
	delete(requestStatus.listeners, crm.incomingProviders)
418
	close(crm.incomingProviders)
419 420 421 422
	if len(requestStatus.listeners) == 0 {
		delete(pqm.inProgressRequestStatuses, crm.k)
		requestStatus.cancelFn()
	}
423
}