diff --git a/brute.go b/brute.go index 8da74258835a3f1586735b252074314815d03705..4173256ec6496471783ef75eeadf844df303264e 100644 --- a/brute.go +++ b/brute.go @@ -98,7 +98,11 @@ func (b *bruteRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) var results []RangerEntry for _, entry := range entries { entrynetwork := entry.Network() - if network.Contains(entrynetwork.IP) { + + searchMaskSize, _ := network.Mask.Size() + entryMaskSize, _ := entrynetwork.Mask.Size() + + if network.Contains(entrynetwork.IP) && searchMaskSize <= entryMaskSize { results = append(results, entry) } } diff --git a/net/ip.go b/net/ip.go index f26044cc15f0986775dffbdda6eb96b929ad0728..d310f17b98b29f346f5daae6b687db0585f8daf9 100644 --- a/net/ip.go +++ b/net/ip.go @@ -208,6 +208,13 @@ func (n Network) Contains(nn NetworkNumber) bool { return true } +// Contains returns true if Network covers o, false otherwise +func (n Network) Covers(o Network) bool { + nMaskSize, _ := n.IPNet.Mask.Size() + oMaskSize, _ := o.IPNet.Mask.Size() + return n.Contains(o.Number) && nMaskSize <= oMaskSize +} + // LeastCommonBitPosition returns the smallest position of the preceding common // bits of the 2 networks, and returns an error ErrNoGreatestCommonBit // if the two network number diverges from the first bit. diff --git a/net/ip_test.go b/net/ip_test.go index 538983add1817f01a0a68045e2a798f31b78f7a5..6a9d2b1b276b1a1cbc3740cc9c86aefa6f8ec57e 100644 --- a/net/ip_test.go +++ b/net/ip_test.go @@ -412,6 +412,31 @@ func TestPreviousIP(t *testing.T) { } } +func TestNetworkCovers(t *testing.T) { + cases := []struct { + network string + covers string + result bool + name string + }{ + {"10.0.0.0/24", "10.0.0.1/25", true, "contains"}, + {"10.0.0.0/24", "11.0.0.1/25", false, "not contains"}, + {"10.0.0.0/16", "10.0.0.0/15", false, "prefix false"}, + {"10.0.0.0/15", "10.0.0.0/16", true, "prefix true"}, + {"10.0.0.0/15", "10.0.0.0/15", true, "same"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, n, _ := net.ParseCIDR(tc.network) + network := NewNetwork(*n) + _, n, _ = net.ParseCIDR(tc.covers) + covers := NewNetwork(*n) + assert.Equal(t, tc.result, network.Covers(covers)) + }) + } +} + /* ********************************* Benchmarking ip manipulations. diff --git a/trie.go b/trie.go index 0b9c9e2f4f37f51b0e1ea207b0ca263a809ff8da..34dd54490635f2dc73e6c473bd372a16a179c5ac 100644 --- a/trie.go +++ b/trie.go @@ -186,7 +186,7 @@ func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntr func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) { var results []RangerEntry - if p.hasEntry() && network.Contains(p.network.Number) { + if p.hasEntry() && network.Covers(p.network) { results = []RangerEntry{p.entry} } if p.targetBitPosition() < 0 { diff --git a/trie_test.go b/trie_test.go index 8955e2b22f8604e6daa4288d6ce62f3ccf2bc3ac..29125fd6112d3bf6f3a14bdc86d40843857fcb44 100644 --- a/trie_test.go +++ b/trie_test.go @@ -404,6 +404,15 @@ var coveredNetworkTests = []coveredNetworkTest{ []string{"192.168.0.0/24", "192.168.1.1/32"}, "path not taken", }, + { + rnet.IPv4, + []string{ + "192.168.0.0/15", + }, + "192.168.0.0/16", + nil, + "only masks different", + }, } func TestPrefixTrieCoveredNetworks(t *testing.T) {