Commit 5ad877dd authored by Yulin Chen's avatar Yulin Chen

Simplify covered networks logic and add random tests

parent 3e90dcc8
......@@ -2,6 +2,8 @@ package cidranger
import (
"net"
rnet "github.com/yl2chen/cidranger/net"
)
// bruteRanger is a brute force implementation of Ranger. Insertion and
......@@ -96,13 +98,10 @@ func (b *bruteRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error)
return nil, err
}
var results []RangerEntry
testNetwork := rnet.NewNetwork(network)
for _, entry := range entries {
entrynetwork := entry.Network()
searchMaskSize, _ := network.Mask.Size()
entryMaskSize, _ := entrynetwork.Mask.Size()
if network.Contains(entrynetwork.IP) && searchMaskSize <= entryMaskSize {
entryNetwork := rnet.NewNetwork(entry.Network())
if testNetwork.Covers(entryNetwork) {
results = append(results, entry)
}
}
......
......@@ -26,6 +26,10 @@ func TestContainingNetworksAgaistBaseIPv4(t *testing.T) {
testContainingNetworksAgainstBase(t, 100000, randIPv4Gen)
}
func TestCoveredNetworksAgainstBaseIPv4(t *testing.T) {
testCoversNetworksAgainstBase(t, 100000, randomIPNetGenFactory(ipV4AWSRangesIPNets))
}
// IPv6 spans an extremely large address space (2^128), randomly generated IPs
// will often fall outside of the test ranges (AWS public CIDR blocks), so it
// it more meaningful for testing to run from a curated list of IPv6 IPs.
......@@ -37,6 +41,10 @@ func TestContainingNetworksAgaistBaseIPv6(t *testing.T) {
testContainingNetworksAgainstBase(t, 100000, curatedAWSIPv6Gen)
}
func TestCoveredNetworksAgainstBaseIPv6(t *testing.T) {
testCoversNetworksAgainstBase(t, 100000, randomIPNetGenFactory(ipV6AWSRangesIPNets))
}
func testContainsAgainstBase(t *testing.T, iterations int, ipGen ipGenerator) {
rangers := []Ranger{NewPCTrieRanger()}
baseRanger := newBruteRanger()
......@@ -80,6 +88,29 @@ func testContainingNetworksAgainstBase(t *testing.T, iterations int, ipGen ipGen
}
}
func testCoversNetworksAgainstBase(t *testing.T, iterations int, netGen networkGenerator) {
rangers := []Ranger{NewPCTrieRanger()}
baseRanger := newBruteRanger()
for _, ranger := range rangers {
configureRangerWithAWSRanges(t, ranger)
}
configureRangerWithAWSRanges(t, baseRanger)
for i := 0; i < iterations; i++ {
network := netGen()
expected, err := baseRanger.CoveredNetworks(network.IPNet)
assert.NoError(t, err)
for _, ranger := range rangers {
actual, err := ranger.CoveredNetworks(network.IPNet)
assert.NoError(t, err)
assert.Equal(t, len(expected), len(actual))
for _, network := range actual {
assert.Contains(t, expected, network)
}
}
}
}
/*
******************************************************************
Benchmarks.
......@@ -183,6 +214,14 @@ func curatedAWSIPv6Gen() rnet.NetworkNumber {
return nn
}
type networkGenerator func() rnet.Network
func randomIPNetGenFactory(pool []*net.IPNet) networkGenerator {
return func() rnet.Network {
return rnet.NewNetwork(*pool[rand.Intn(len(pool))])
}
}
type AWSRanges struct {
Prefixes []Prefix `json:"prefixes"`
IPv6Prefixes []IPv6Prefix `json:"ipv6_prefixes"`
......@@ -201,6 +240,7 @@ type IPv6Prefix struct {
}
var awsRanges *AWSRanges
var ipV4AWSRangesIPNets []*net.IPNet
var ipV6AWSRangesIPNets []*net.IPNet
func loadAWSRanges() *AWSRanges {
......@@ -235,5 +275,9 @@ func init() {
_, network, _ := net.ParseCIDR(prefix.IPPrefix)
ipV6AWSRangesIPNets = append(ipV6AWSRangesIPNets, network)
}
for _, prefix := range awsRanges.Prefixes {
_, network, _ := net.ParseCIDR(prefix.IPPrefix)
ipV4AWSRangesIPNets = append(ipV4AWSRangesIPNets, network)
}
rand.Seed(time.Now().Unix())
}
......@@ -186,24 +186,18 @@ func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntr
func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) {
var results []RangerEntry
if p.hasEntry() && network.Covers(p.network) {
results = []RangerEntry{p.entry}
}
if p.targetBitPosition() < 0 {
return results, nil
}
masked := network.Masked(int(p.numBitsSkipped))
if !masked.Equal(p.network) {
return results, nil
}
for _, child := range p.children {
if network.Covers(p.network) {
for entry := range p.walkDepth() {
results = append(results, entry)
}
} else if p.targetBitPosition() >= 0 {
bit, err := p.targetBitFromIP(network.Number)
if err != nil {
return results, err
}
child := p.children[bit]
if child != nil {
ranges, err := child.coveredNetworks(network)
if err != nil {
return nil, err
}
results = append(results, ranges...)
return child.coveredNetworks(network)
}
}
return results, nil
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment