From f6390fed24fc8fd2f26c4d59879a5c4a61150930 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Tue, 4 Aug 2020 21:04:33 -0700 Subject: [PATCH] Make ScanForLinks non-recursive This way, we can't blow out our stack. --- utils.go | 92 ++++++++++++++++++++++++-------------------------------- 1 file changed, 40 insertions(+), 52 deletions(-) diff --git a/utils.go b/utils.go index 100a7b9..2307716 100644 --- a/utils.go +++ b/utils.go @@ -17,67 +17,55 @@ const maxCidLength = 100 const maxHeaderSize = 9 func ScanForLinks(br io.Reader, cb func(cid.Cid)) error { - buf := make([]byte, maxCidLength) - return scanForLinksRec(br, cb, buf) -} - -func scanForLinksRec(br io.Reader, cb func(cid.Cid), scratch []byte) error { - maj, extra, err := CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - - switch maj { - case MajUnsignedInt, MajNegativeInt, MajOther: - case MajByteString, MajTextString: - _, err := io.CopyN(ioutil.Discard, br, int64(extra)) + scratch := make([]byte, maxCidLength) + for remaining := uint64(1); remaining > 0; remaining-- { + maj, extra, err := CborReadHeaderBuf(br, scratch) if err != nil { return err } - case MajTag: - if extra == 42 { - maj, extra, err = CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - - if maj != MajByteString { - return fmt.Errorf("expected cbor type 'byte string' in input") - } - - if extra > maxCidLength { - return fmt.Errorf("string in cbor input too long") - } - - if _, err := io.ReadAtLeast(br, scratch[:extra], int(extra)); err != nil { - return err - } - c, err := cid.Cast(scratch[1:extra]) + switch maj { + case MajUnsignedInt, MajNegativeInt, MajOther: + case MajByteString, MajTextString: + _, err := io.CopyN(ioutil.Discard, br, int64(extra)) if err != nil { return err } - cb(c) - - } else { - if err := scanForLinksRec(br, cb, scratch); err != nil { - return err - } - } - case MajArray: - for i := 0; i < int(extra); i++ { - if err := scanForLinksRec(br, cb, scratch); err != nil { - return err - } - } - case MajMap: - for i := 0; i < int(extra*2); i++ { - if err := scanForLinksRec(br, cb, scratch); err != nil { - return err + case MajTag: + if extra == 42 { + maj, extra, err = CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + + if maj != MajByteString { + return fmt.Errorf("expected cbor type 'byte string' in input") + } + + if extra > maxCidLength { + return fmt.Errorf("string in cbor input too long") + } + + if _, err := io.ReadAtLeast(br, scratch[:extra], int(extra)); err != nil { + return err + } + + c, err := cid.Cast(scratch[1:extra]) + if err != nil { + return err + } + cb(c) + + } else { + remaining++ } + case MajArray: + remaining += extra + case MajMap: + remaining += (extra * 2) + default: + return fmt.Errorf("unhandled cbor type: %d", maj) } - default: - return fmt.Errorf("unhandled cbor type: %d", maj) } return nil } -- GitLab