From 8c56974d0513150f1d1022d8ecaa8eee52b1f9c5 Mon Sep 17 00:00:00 2001 From: Rob Adams Date: Mon, 18 Dec 2017 17:06:47 -0800 Subject: [PATCH] Add new CoveredNetworks option * Search by CIDR rather than just by IP. Signed-off-by: Rob Adams --- brute.go | 18 +++++++++++ brute_test.go | 31 +++++++++++++++++++ cidranger.go | 1 + trie.go | 33 ++++++++++++++++++++ trie_test.go | 86 +++++++++++++++++++++++++++++++++++++++++++++++++++ version.go | 8 +++++ 6 files changed, 177 insertions(+) diff --git a/brute.go b/brute.go index 622e80f..8da7425 100644 --- a/brute.go +++ b/brute.go @@ -87,6 +87,24 @@ func (b *bruteRanger) ContainingNetworks(ip net.IP) ([]RangerEntry, error) { return results, nil } +// CoveredNetworks returns the list of RangerEntry(s) the given ipnet +// covers. That is, the networks that are completely subsumed by the +// specified network. +func (b *bruteRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) { + entries, err := b.getEntriesByVersion(network.IP) + if err != nil { + return nil, err + } + var results []RangerEntry + for _, entry := range entries { + entrynetwork := entry.Network() + if network.Contains(entrynetwork.IP) { + results = append(results, entry) + } + } + return results, nil +} + func (b *bruteRanger) getEntriesByVersion(ip net.IP) (map[string]RangerEntry, error) { if ip.To4() != nil { return b.ipV4Entries, nil diff --git a/brute_test.go b/brute_test.go index 8254e4a..71ee637 100644 --- a/brute_test.go +++ b/brute_test.go @@ -2,6 +2,7 @@ package cidranger import ( "net" + "sort" "testing" "github.com/stretchr/testify/assert" @@ -144,3 +145,33 @@ func TestContainingNetworks(t *testing.T) { }) } } + +func TestCoveredNetworks(t *testing.T) { + for _, tc := range coveredNetworkTests { + t.Run(tc.name, func(t *testing.T) { + ranger := newBruteRanger() + for _, insert := range tc.inserts { + _, network, _ := net.ParseCIDR(insert) + err := ranger.Insert(NewBasicRangerEntry(*network)) + assert.NoError(t, err) + } + var expectedEntries []string + for _, network := range tc.networks { + expectedEntries = append(expectedEntries, network) + } + sort.Strings(expectedEntries) + _, snet, _ := net.ParseCIDR(tc.search) + networks, err := ranger.CoveredNetworks(*snet) + assert.NoError(t, err) + + var results []string + for _, result := range networks { + net := result.Network() + results = append(results, net.String()) + } + sort.Strings(results) + + assert.Equal(t, expectedEntries, results) + }) + } +} diff --git a/cidranger.go b/cidranger.go index a3ef101..e2c9ee5 100644 --- a/cidranger.go +++ b/cidranger.go @@ -71,6 +71,7 @@ type Ranger interface { Remove(network net.IPNet) (RangerEntry, error) Contains(ip net.IP) (bool, error) ContainingNetworks(ip net.IP) ([]RangerEntry, error) + CoveredNetworks(network net.IPNet) ([]RangerEntry, error) } // NewPCTrieRanger returns a versionedRanger that supports both IPv4 and IPv6 diff --git a/trie.go b/trie.go index 7fe8368..0b9c9e2 100644 --- a/trie.go +++ b/trie.go @@ -107,6 +107,14 @@ func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) { return p.containingNetworks(nn) } +// CoveredNetworks returns the list of RangerEntry(s) the given ipnet +// covers. That is, the networks that are completely subsumed by the +// specified network. +func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) { + net := rnet.NewNetwork(network) + return p.coveredNetworks(net) +} + // String returns string representation of trie, mainly for visualization and // debugging. func (p *prefixTrie) String() string { @@ -176,6 +184,31 @@ func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntr return results, nil } +func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) { + var results []RangerEntry + if p.hasEntry() && network.Contains(p.network.Number) { + results = []RangerEntry{p.entry} + } + if p.targetBitPosition() < 0 { + return results, nil + } + + masked := network.Masked(int(p.numBitsSkipped)) + if !masked.Equal(p.network) { + return results, nil + } + for _, child := range p.children { + if child != nil { + ranges, err := child.coveredNetworks(network) + if err != nil { + return nil, err + } + results = append(results, ranges...) + } + } + return results, nil +} + func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) error { if p.network.Equal(network) { p.entry = entry diff --git a/trie_test.go b/trie_test.go index 44ca547..8955e2b 100644 --- a/trie_test.go +++ b/trie_test.go @@ -342,3 +342,89 @@ func TestPrefixTrieContainingNetworks(t *testing.T) { }) } } + +type coveredNetworkTest struct { + version rnet.IPVersion + inserts []string + search string + networks []string + name string +} + +var coveredNetworkTests = []coveredNetworkTest{ + { + rnet.IPv4, + []string{"192.168.0.0/24"}, + "192.168.0.0/16", + []string{"192.168.0.0/24"}, + "basic covered networks", + }, + { + rnet.IPv4, + []string{"192.168.0.0/24"}, + "10.1.0.0/16", + nil, + "nothing", + }, + { + rnet.IPv4, + []string{"192.168.0.0/24", "192.168.0.0/25"}, + "192.168.0.0/16", + []string{"192.168.0.0/24", "192.168.0.0/25"}, + "multiple networks", + }, + { + rnet.IPv4, + []string{"192.168.0.0/24", "192.168.0.0/25", "192.168.0.1/32"}, + "192.168.0.0/16", + []string{"192.168.0.0/24", "192.168.0.0/25", "192.168.0.1/32"}, + "multiple networks 2", + }, + { + rnet.IPv4, + []string{"192.168.1.1/32"}, + "192.168.0.0/16", + []string{"192.168.1.1/32"}, + "leaf", + }, + { + rnet.IPv4, + []string{"0.0.0.0/0", "192.168.1.1/32"}, + "192.168.0.0/16", + []string{"192.168.1.1/32"}, + "leaf with root", + }, + { + rnet.IPv4, + []string{ + "0.0.0.0/0", "192.168.0.0/24", "192.168.1.1/32", + "10.1.0.0/16", "10.1.1.0/24", + }, + "192.168.0.0/16", + []string{"192.168.0.0/24", "192.168.1.1/32"}, + "path not taken", + }, +} + +func TestPrefixTrieCoveredNetworks(t *testing.T) { + for _, tc := range coveredNetworkTests { + t.Run(tc.name, func(t *testing.T) { + trie := newPrefixTree(tc.version) + for _, insert := range tc.inserts { + _, network, _ := net.ParseCIDR(insert) + err := trie.Insert(NewBasicRangerEntry(*network)) + assert.NoError(t, err) + } + var expectedEntries []RangerEntry + for _, network := range tc.networks { + _, net, _ := net.ParseCIDR(network) + expectedEntries = append(expectedEntries, + NewBasicRangerEntry(*net)) + } + _, snet, _ := net.ParseCIDR(tc.search) + networks, err := trie.CoveredNetworks(*snet) + assert.NoError(t, err) + assert.Equal(t, expectedEntries, networks) + }) + } +} diff --git a/version.go b/version.go index b9af754..cfea061 100644 --- a/version.go +++ b/version.go @@ -53,6 +53,14 @@ func (v *versionedRanger) ContainingNetworks(ip net.IP) ([]RangerEntry, error) { return ranger.ContainingNetworks(ip) } +func (v *versionedRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) { + ranger, err := v.getRangerForIP(network.IP) + if err != nil { + return nil, err + } + return ranger.CoveredNetworks(network) +} + func (v *versionedRanger) getRangerForIP(ip net.IP) (Ranger, error) { if ip.To4() != nil { return v.ipV4Ranger, nil -- GitLab