diff --git a/brute.go b/brute.go index 4173256ec6496471783ef75eeadf844df303264e..c8a1b4fadbb80bb9fff28f269765701a9e4b06e4 100644 --- a/brute.go +++ b/brute.go @@ -2,6 +2,8 @@ package cidranger import ( "net" + + rnet "github.com/yl2chen/cidranger/net" ) // bruteRanger is a brute force implementation of Ranger. Insertion and @@ -96,13 +98,10 @@ func (b *bruteRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) return nil, err } var results []RangerEntry + testNetwork := rnet.NewNetwork(network) for _, entry := range entries { - entrynetwork := entry.Network() - - searchMaskSize, _ := network.Mask.Size() - entryMaskSize, _ := entrynetwork.Mask.Size() - - if network.Contains(entrynetwork.IP) && searchMaskSize <= entryMaskSize { + entryNetwork := rnet.NewNetwork(entry.Network()) + if testNetwork.Covers(entryNetwork) { results = append(results, entry) } } diff --git a/cidranger_test.go b/cidranger_test.go index 17cd7a848d04a0a2441c7bd1093ff0fe4aca93d3..6ea7158fa8eb1e111f900b2a2321c28420f72cc9 100644 --- a/cidranger_test.go +++ b/cidranger_test.go @@ -26,6 +26,10 @@ func TestContainingNetworksAgaistBaseIPv4(t *testing.T) { testContainingNetworksAgainstBase(t, 100000, randIPv4Gen) } +func TestCoveredNetworksAgainstBaseIPv4(t *testing.T) { + testCoversNetworksAgainstBase(t, 100000, randomIPNetGenFactory(ipV4AWSRangesIPNets)) +} + // IPv6 spans an extremely large address space (2^128), randomly generated IPs // will often fall outside of the test ranges (AWS public CIDR blocks), so it // it more meaningful for testing to run from a curated list of IPv6 IPs. @@ -37,6 +41,10 @@ func TestContainingNetworksAgaistBaseIPv6(t *testing.T) { testContainingNetworksAgainstBase(t, 100000, curatedAWSIPv6Gen) } +func TestCoveredNetworksAgainstBaseIPv6(t *testing.T) { + testCoversNetworksAgainstBase(t, 100000, randomIPNetGenFactory(ipV6AWSRangesIPNets)) +} + func testContainsAgainstBase(t *testing.T, iterations int, ipGen ipGenerator) { rangers := []Ranger{NewPCTrieRanger()} baseRanger := newBruteRanger() @@ -80,6 +88,29 @@ func testContainingNetworksAgainstBase(t *testing.T, iterations int, ipGen ipGen } } +func testCoversNetworksAgainstBase(t *testing.T, iterations int, netGen networkGenerator) { + rangers := []Ranger{NewPCTrieRanger()} + baseRanger := newBruteRanger() + for _, ranger := range rangers { + configureRangerWithAWSRanges(t, ranger) + } + configureRangerWithAWSRanges(t, baseRanger) + + for i := 0; i < iterations; i++ { + network := netGen() + expected, err := baseRanger.CoveredNetworks(network.IPNet) + assert.NoError(t, err) + for _, ranger := range rangers { + actual, err := ranger.CoveredNetworks(network.IPNet) + assert.NoError(t, err) + assert.Equal(t, len(expected), len(actual)) + for _, network := range actual { + assert.Contains(t, expected, network) + } + } + } +} + /* ****************************************************************** Benchmarks. @@ -183,6 +214,14 @@ func curatedAWSIPv6Gen() rnet.NetworkNumber { return nn } +type networkGenerator func() rnet.Network + +func randomIPNetGenFactory(pool []*net.IPNet) networkGenerator { + return func() rnet.Network { + return rnet.NewNetwork(*pool[rand.Intn(len(pool))]) + } +} + type AWSRanges struct { Prefixes []Prefix `json:"prefixes"` IPv6Prefixes []IPv6Prefix `json:"ipv6_prefixes"` @@ -201,6 +240,7 @@ type IPv6Prefix struct { } var awsRanges *AWSRanges +var ipV4AWSRangesIPNets []*net.IPNet var ipV6AWSRangesIPNets []*net.IPNet func loadAWSRanges() *AWSRanges { @@ -235,5 +275,9 @@ func init() { _, network, _ := net.ParseCIDR(prefix.IPPrefix) ipV6AWSRangesIPNets = append(ipV6AWSRangesIPNets, network) } + for _, prefix := range awsRanges.Prefixes { + _, network, _ := net.ParseCIDR(prefix.IPPrefix) + ipV4AWSRangesIPNets = append(ipV4AWSRangesIPNets, network) + } rand.Seed(time.Now().Unix()) } diff --git a/trie.go b/trie.go index 34dd54490635f2dc73e6c473bd372a16a179c5ac..82a84501639274bcd1e4b5164cf9dfa336aca0d9 100644 --- a/trie.go +++ b/trie.go @@ -186,24 +186,18 @@ func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntr func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) { var results []RangerEntry - if p.hasEntry() && network.Covers(p.network) { - results = []RangerEntry{p.entry} - } - if p.targetBitPosition() < 0 { - return results, nil - } - - masked := network.Masked(int(p.numBitsSkipped)) - if !masked.Equal(p.network) { - return results, nil - } - for _, child := range p.children { + if network.Covers(p.network) { + for entry := range p.walkDepth() { + results = append(results, entry) + } + } else if p.targetBitPosition() >= 0 { + bit, err := p.targetBitFromIP(network.Number) + if err != nil { + return results, err + } + child := p.children[bit] if child != nil { - ranges, err := child.coveredNetworks(network) - if err != nil { - return nil, err - } - results = append(results, ranges...) + return child.coveredNetworks(network) } } return results, nil