loader_test.go 3.58 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
package loader

import (
	"context"
	"errors"
	"io"
	"io/ioutil"
	"math/rand"
	"reflect"
	"testing"
	"time"

	"github.com/ipfs/go-graphsync/ipldbridge"

	"github.com/ipfs/go-graphsync/requestmanager/asyncloader"
	"github.com/ipfs/go-graphsync/testbridge"
	"github.com/ipfs/go-graphsync/testutil"
	"github.com/ipld/go-ipld-prime"

	gsmsg "github.com/ipfs/go-graphsync/message"
)

type callParams struct {
	requestID gsmsg.GraphSyncRequestID
	link      ipld.Link
}

func makeAsyncLoadFn(responseChan chan asyncloader.AsyncLoadResult, calls chan callParams) AsyncLoadFn {
	return func(requestID gsmsg.GraphSyncRequestID, link ipld.Link) <-chan asyncloader.AsyncLoadResult {
		calls <- callParams{requestID, link}
		return responseChan
	}
}

func TestWrappedAsyncLoaderReturnsValues(t *testing.T) {
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
	defer cancel()
	responseChan := make(chan asyncloader.AsyncLoadResult, 1)
	calls := make(chan callParams, 1)
	asyncLoadFn := makeAsyncLoadFn(responseChan, calls)
	errChan := make(chan error)
	requestID := gsmsg.GraphSyncRequestID(rand.Int31())
	loader := WrapAsyncLoader(ctx, asyncLoadFn, requestID, errChan)

	link := testbridge.NewMockLink()
	data := testutil.RandomBytes(100)
	responseChan <- asyncloader.AsyncLoadResult{Data: data, Err: nil}
	stream, err := loader(link, ipldbridge.LinkContext{})
	if err != nil {
		t.Fatal("Should not have errored on load")
	}
	returnedData, err := ioutil.ReadAll(stream)
	if err != nil {
		t.Fatal("error in return stream")
	}
	if !reflect.DeepEqual(data, returnedData) {
		t.Fatal("returned data did not match expected")
	}
}

func TestWrappedAsyncLoaderSideChannelsErrors(t *testing.T) {
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
	defer cancel()
	responseChan := make(chan asyncloader.AsyncLoadResult, 1)
	calls := make(chan callParams, 1)
	asyncLoadFn := makeAsyncLoadFn(responseChan, calls)
	errChan := make(chan error, 1)
	requestID := gsmsg.GraphSyncRequestID(rand.Int31())
	loader := WrapAsyncLoader(ctx, asyncLoadFn, requestID, errChan)

	link := testbridge.NewMockLink()
	err := errors.New("something went wrong")
	responseChan <- asyncloader.AsyncLoadResult{Data: nil, Err: err}
	stream, loadErr := loader(link, ipldbridge.LinkContext{})
	if stream != nil || loadErr != ipldbridge.ErrDoNotFollow() {
		t.Fatal("Should have errored on load")
	}
	select {
	case <-ctx.Done():
		t.Fatal("should have returned an error on side channel but didn't")
	case returnedErr := <-errChan:
		if returnedErr != err {
			t.Fatal("returned wrong error on side channel")
		}
	}
}

func TestWrappedAsyncLoaderContextCancels(t *testing.T) {
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
	defer cancel()
	subCtx, subCancel := context.WithCancel(ctx)
	responseChan := make(chan asyncloader.AsyncLoadResult, 1)
	calls := make(chan callParams, 1)
	asyncLoadFn := makeAsyncLoadFn(responseChan, calls)
	errChan := make(chan error, 1)
	requestID := gsmsg.GraphSyncRequestID(rand.Int31())
	loader := WrapAsyncLoader(subCtx, asyncLoadFn, requestID, errChan)
	link := testbridge.NewMockLink()
	resultsChan := make(chan struct {
		io.Reader
		error
	})
	go func() {
		stream, err := loader(link, ipldbridge.LinkContext{})
		resultsChan <- struct {
			io.Reader
			error
		}{stream, err}
	}()
	subCancel()

	select {
	case <-ctx.Done():
		t.Fatal("should have returned from context cancelling but didn't")
	case result := <-resultsChan:
		if result.Reader != nil || result.error == nil {
			t.Fatal("should have errored from context cancelling but didn't")
		}
	}
}