Commit 46362a13 authored by hannahhoward's avatar hannahhoward

refactor(message): convert Extension to return bool

return a boolean instead of an error if an extention is found on a message
parent 8b6458eb
package message
import (
"errors"
"fmt"
"io"
......@@ -88,11 +87,6 @@ const (
RequestFailedContentNotFound = GraphSyncResponseStatusCode(34)
)
var (
// ErrExtensionNotPresent means the looked up extension was not found
ErrExtensionNotPresent = errors.New("Extension is missing from this message")
)
// IsTerminalSuccessCode returns true if the response code indicates the
// request terminated successfully.
func IsTerminalSuccessCode(status GraphSyncResponseStatusCode) bool {
......@@ -398,15 +392,15 @@ func (gsr GraphSyncRequest) Priority() GraphSyncPriority { return gsr.priority }
// Extension returns the content for an extension on a response, or errors
// if extension is not present
func (gsr GraphSyncRequest) Extension(name GraphSyncExtensionName) ([]byte, error) {
func (gsr GraphSyncRequest) Extension(name GraphSyncExtensionName) ([]byte, bool) {
if gsr.extensions == nil {
return nil, ErrExtensionNotPresent
return nil, false
}
val, ok := gsr.extensions[string(name)]
if !ok {
return nil, ErrExtensionNotPresent
return nil, false
}
return val, nil
return val, true
}
// IsCancel returns true if this particular request is being cancelled
......@@ -420,14 +414,14 @@ func (gsr GraphSyncResponse) Status() GraphSyncResponseStatusCode { return gsr.s
// Extension returns the content for an extension on a response, or errors
// if extension is not present
func (gsr GraphSyncResponse) Extension(name GraphSyncExtensionName) ([]byte, error) {
func (gsr GraphSyncResponse) Extension(name GraphSyncExtensionName) ([]byte, bool) {
if gsr.extensions == nil {
return nil, ErrExtensionNotPresent
return nil, false
}
val, ok := gsr.extensions[string(name)]
if !ok {
return nil, ErrExtensionNotPresent
return nil, false
}
return val, nil
return val, true
}
......@@ -30,13 +30,13 @@ func TestAppendingRequests(t *testing.T) {
t.Fatal("Did not add request to message")
}
request := requests[0]
extensionData, err := request.Extension(extensionName)
extensionData, found := request.Extension(extensionName)
if request.ID() != id ||
request.IsCancel() != false ||
request.Priority() != priority ||
request.Root().String() != root.String() ||
!reflect.DeepEqual(request.Selector(), selector) ||
err != nil ||
!found ||
!reflect.DeepEqual(extension.Data, extensionData) {
t.Fatal("Did not properly add request to message")
}
......@@ -61,13 +61,13 @@ func TestAppendingRequests(t *testing.T) {
t.Fatal("Did not add request to deserialized message")
}
deserializedRequest := deserializedRequests[0]
extensionData, err = deserializedRequest.Extension(extensionName)
extensionData, found = deserializedRequest.Extension(extensionName)
if deserializedRequest.ID() != id ||
deserializedRequest.IsCancel() != false ||
deserializedRequest.Priority() != priority ||
deserializedRequest.Root().String() != root.String() ||
!reflect.DeepEqual(deserializedRequest.Selector(), selector) ||
err != nil ||
!found ||
!reflect.DeepEqual(extension.Data, extensionData) {
t.Fatal("Did not properly deserialize protobuf messages so requests are equal")
}
......@@ -89,10 +89,10 @@ func TestAppendingResponses(t *testing.T) {
t.Fatal("Did not add response to message")
}
response := responses[0]
extensionData, err := response.Extension(extensionName)
extensionData, found := response.Extension(extensionName)
if response.RequestID() != requestID ||
response.Status() != status ||
err != nil ||
!found ||
!reflect.DeepEqual(extension.Data, extensionData) {
t.Fatal("Did not properly add response to message")
}
......@@ -114,10 +114,10 @@ func TestAppendingResponses(t *testing.T) {
t.Fatal("Did not add response to message")
}
deserializedResponse := deserializedResponses[0]
extensionData, err = deserializedResponse.Extension(extensionName)
extensionData, found = deserializedResponse.Extension(extensionName)
if deserializedResponse.RequestID() != response.RequestID() ||
deserializedResponse.Status() != response.Status() ||
err != nil ||
!found ||
!reflect.DeepEqual(extensionData, extension.Data) {
t.Fatal("Did not properly deserialize protobuf messages so responses are equal")
}
......@@ -216,13 +216,13 @@ func TestToNetFromNetEquivalency(t *testing.T) {
t.Fatal("Did not add request to deserialized message")
}
deserializedRequest := deserializedRequests[0]
extensionData, err := deserializedRequest.Extension(extensionName)
extensionData, found := deserializedRequest.Extension(extensionName)
if deserializedRequest.ID() != request.ID() ||
deserializedRequest.IsCancel() != request.IsCancel() ||
deserializedRequest.Priority() != request.Priority() ||
deserializedRequest.Root().String() != request.Root().String() ||
!reflect.DeepEqual(deserializedRequest.Selector(), request.Selector()) ||
err != nil ||
!found ||
!reflect.DeepEqual(extensionData, extension.Data) {
t.Fatal("Did not keep requests when writing to stream and back")
}
......@@ -237,10 +237,10 @@ func TestToNetFromNetEquivalency(t *testing.T) {
t.Fatal("Did not add response to message")
}
deserializedResponse := deserializedResponses[0]
extensionData, err = deserializedResponse.Extension(extensionName)
extensionData, found = deserializedResponse.Extension(extensionName)
if deserializedResponse.RequestID() != response.RequestID() ||
deserializedResponse.Status() != response.Status() ||
err != nil ||
!found ||
!reflect.DeepEqual(extensionData, extension.Data) {
t.Fatal("Did not keep responses when writing to stream and back")
}
......
......@@ -209,10 +209,10 @@ func TestProcessingNotification(t *testing.T) {
}
}
firstResponse := message.Responses()[0]
extensionData, err := firstResponse.Extension(extensionName)
extensionData, found := firstResponse.Extension(extensionName)
if responseID != firstResponse.RequestID() ||
status != firstResponse.Status() ||
err != nil ||
!found ||
!reflect.DeepEqual(extension.Data, extensionData) {
t.Fatal("Send incorrect response")
}
......
......@@ -133,10 +133,10 @@ func TestMessageSendAndReceive(t *testing.T) {
t.Fatal("Did not add response to received message")
}
receivedResponse := receivedResponses[0]
extensionData, err := receivedResponse.Extension(extensionName)
extensionData, found := receivedResponse.Extension(extensionName)
if receivedResponse.RequestID() != sentResponse.RequestID() ||
receivedResponse.Status() != sentResponse.Status() ||
err != nil ||
!found ||
!reflect.DeepEqual(extension.Data, extensionData) {
t.Fatal("Sent message responses did not match received message responses")
}
......
......@@ -27,8 +27,8 @@ func visitToChannel(ctx context.Context, inProgressChan chan types.ResponseProgr
func metadataForResponses(responses []gsmsg.GraphSyncResponse, ipldBridge ipldbridge.IPLDBridge) map[gsmsg.GraphSyncRequestID]metadata.Metadata {
responseMetadata := make(map[gsmsg.GraphSyncRequestID]metadata.Metadata, len(responses))
for _, response := range responses {
mdRaw, err := response.Extension(gsmsg.ExtensionMetadata)
if err != nil {
mdRaw, found := response.Extension(gsmsg.ExtensionMetadata)
if !found {
log.Warningf("Unable to decode metadata in response for request id: %d", response.RequestID())
continue
}
......
......@@ -66,8 +66,8 @@ func TestMessageBuilding(t *testing.T) {
t.Fatal("did not generate completed partial response")
}
response1MetadataRaw, err := response1.Extension(gsmsg.ExtensionMetadata)
if err != nil {
response1MetadataRaw, found := response1.Extension(gsmsg.ExtensionMetadata)
if !found {
t.Fatal("Metadata not included in response")
}
response1Metadata, err := metadata.DecodeMetadata(response1MetadataRaw, ipldBridge)
......@@ -83,8 +83,8 @@ func TestMessageBuilding(t *testing.T) {
if err != nil || response2.Status() != gsmsg.RequestCompletedFull {
t.Fatal("did not generate completed partial response")
}
response2MetadataRaw, err := response2.Extension(gsmsg.ExtensionMetadata)
if err != nil {
response2MetadataRaw, found := response2.Extension(gsmsg.ExtensionMetadata)
if !found {
t.Fatal("Metadata not included in response")
}
response2Metadata, err := metadata.DecodeMetadata(response2MetadataRaw, ipldBridge)
......@@ -100,8 +100,8 @@ func TestMessageBuilding(t *testing.T) {
if err != nil || response3.Status() != gsmsg.PartialResponse {
t.Fatal("did not generate completed partial response")
}
response3MetadataRaw, err := response3.Extension(gsmsg.ExtensionMetadata)
if err != nil {
response3MetadataRaw, found := response3.Extension(gsmsg.ExtensionMetadata)
if !found {
t.Fatal("Metadata not included in response")
}
response3Metadata, err := metadata.DecodeMetadata(response3MetadataRaw, ipldBridge)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment