Commit 15fff359 authored by Steven Allen's avatar Steven Allen Committed by Adin Schmahmann

Reduce network size & allocations

parent 8a304357
......@@ -107,10 +107,10 @@ func testCoversNetworksAgainstBase(t *testing.T, iterations int, netGen networkG
for i := 0; i < iterations; i++ {
network := netGen()
expected, err := baseRanger.CoveredNetworks(network.IPNet)
expected, err := baseRanger.CoveredNetworks(network.IPNet())
assert.NoError(t, err)
for _, ranger := range rangers {
actual, err := ranger.CoveredNetworks(network.IPNet)
actual, err := ranger.CoveredNetworks(network.IPNet())
assert.NoError(t, err)
assert.Equal(t, len(expected), len(actual))
for _, network := range actual {
......
......@@ -170,63 +170,82 @@ func (n NetworkNumber) LeastCommonBitPosition(n1 NetworkNumber) (uint, error) {
// Network represents a block of network numbers, also known as CIDR.
type Network struct {
net.IPNet
Number NetworkNumber
Mask NetworkNumberMask
}
// NewNetwork returns Network built using given net.IPNet.
func NewNetwork(ipNet net.IPNet) Network {
ones, _ := ipNet.Mask.Size()
return Network{
IPNet: ipNet,
Number: NewNetworkNumber(ipNet.IP),
Mask: NetworkNumberMask(NewNetworkNumber(net.IP(ipNet.Mask))),
Mask: NetworkNumberMask(ones),
}
}
// Masked returns a new network conforming to new mask.
func (n Network) Masked(ones int) Network {
mask := net.CIDRMask(ones, len(n.Number)*BitsPerUint32)
return NewNetwork(net.IPNet{
IP: n.IP.Mask(mask),
Mask: mask,
})
mask := NetworkNumberMask(ones)
return Network{
Number: mask.Mask(n.Number),
Mask: mask,
}
}
func sub(a, b uint8) uint8 {
res := a - b
if res > a {
res = 0
}
return res
}
func mask(m NetworkNumberMask) (mask1, mask2, mask3, mask4 uint32) {
// We're relying on overflow here.
const ones uint32 = 0xFFFFFFFF
mask1 = ones << sub(1*32, uint8(m))
mask2 = ones << sub(2*32, uint8(m))
mask3 = ones << sub(3*32, uint8(m))
mask4 = ones << sub(4*32, uint8(m))
return
}
// Contains returns true if NetworkNumber is in range of Network, false
// otherwise.
func (n Network) Contains(nn NetworkNumber) bool {
if len(n.Mask) != len(nn) {
if len(n.Number) != len(nn) {
return false
}
if nn[0]&n.Mask[0] != n.Number[0] {
const ones uint32 = 0xFFFFFFFF
mask1, mask2, mask3, mask4 := mask(n.Mask)
switch len(n.Number) {
case IPv4Uint32Count:
return nn[0]&mask1 == n.Number[0]
case IPv6Uint32Count:
return nn[0]&mask1 == n.Number[0] &&
nn[1]&mask2 == n.Number[1] &&
nn[2]&mask3 == n.Number[2] &&
nn[3]&mask4 == n.Number[3]
default:
return false
}
if len(nn) == IPv6Uint32Count {
return nn[1]&n.Mask[1] == n.Number[1] && nn[2]&n.Mask[2] == n.Number[2] && nn[3]&n.Mask[3] == n.Number[3]
}
return true
}
// Contains returns true if Network covers o, false otherwise
func (n Network) Covers(o Network) bool {
if len(n.Number) != len(o.Number) {
return false
}
nMaskSize, _ := n.IPNet.Mask.Size()
oMaskSize, _ := o.IPNet.Mask.Size()
return n.Contains(o.Number) && nMaskSize <= oMaskSize
return n.Contains(o.Number) && n.Mask <= o.Mask
}
// LeastCommonBitPosition returns the smallest position of the preceding common
// bits of the 2 networks, and returns an error ErrNoGreatestCommonBit
// if the two network number diverges from the first bit.
func (n Network) LeastCommonBitPosition(n1 Network) (uint, error) {
maskSize, _ := n.IPNet.Mask.Size()
if maskSize1, _ := n1.IPNet.Mask.Size(); maskSize1 < maskSize {
maskSize = maskSize1
maskSize := n.Mask
if n1.Mask < n.Mask {
maskSize = n1.Mask
}
maskPosition := len(n1.Number)*BitsPerUint32 - maskSize
maskPosition := len(n1.Number)*BitsPerUint32 - int(maskSize)
lcb, err := n.Number.LeastCommonBitPosition(n1.Number)
if err != nil {
return 0, err
......@@ -236,32 +255,38 @@ func (n Network) LeastCommonBitPosition(n1 Network) (uint, error) {
// Equal is the equality test for 2 networks.
func (n Network) Equal(n1 Network) bool {
nones, nbits := n.IPNet.Mask.Size()
n1ones, n1bits := n1.IPNet.Mask.Size()
return nones == n1ones && nbits == n1bits && n.IPNet.IP.Equal(n1.IPNet.IP)
return n.Number.Equal(n1.Number) && n.Mask == n1.Mask
}
func (n Network) String() string {
return n.IPNet.String()
return fmt.Sprintf("%s/%d", n.Number.ToIP(), n.Mask)
}
func (n Network) IPNet() net.IPNet {
return net.IPNet{
IP: n.Number.ToIP(),
Mask: net.CIDRMask(int(n.Mask), len(n.Number)*32),
}
}
// NetworkNumberMask is an IP address.
type NetworkNumberMask NetworkNumber
type NetworkNumberMask int
// Mask returns a new masked NetworkNumber from given NetworkNumber.
func (m NetworkNumberMask) Mask(n NetworkNumber) (NetworkNumber, error) {
if len(m) != len(n) {
return nil, ErrVersionMismatch
}
result := make(NetworkNumber, len(m))
result[0] = m[0] & n[0]
if len(m) == IPv6Uint32Count {
result[1] = m[1] & n[1]
result[2] = m[2] & n[2]
result[3] = m[3] & n[3]
func (m NetworkNumberMask) Mask(n NetworkNumber) NetworkNumber {
mask1, mask2, mask3, mask4 := mask(m)
result := make(NetworkNumber, len(n))
switch len(n) {
case IPv4Uint32Count:
result[0] = n[0] & mask1
case IPv6Uint32Count:
result[0] = n[0] & mask1
result[1] = n[1] & mask2
result[2] = n[2] & mask3
result[3] = n[3] & mask4
}
return result, nil
return result
}
// NextIP returns the next sequential ip.
......
......@@ -220,9 +220,11 @@ func TestNewNetwork(t *testing.T) {
_, ipNet, _ := net.ParseCIDR("192.128.0.0/24")
n := NewNetwork(*ipNet)
assert.Equal(t, *ipNet, n.IPNet)
newIPNet := n.IPNet()
assert.True(t, ipNet.IP.Equal(newIPNet.IP))
assert.Equal(t, ipNet.Mask, newIPNet.Mask)
assert.Equal(t, NetworkNumber{3229614080}, n.Number)
assert.Equal(t, NetworkNumberMask{math.MaxUint32 - uint32(math.MaxUint8)}, n.Mask)
assert.Equal(t, NetworkNumberMask(24), n.Mask)
}
func TestNetworkMasked(t *testing.T) {
......@@ -245,7 +247,7 @@ func TestNetworkMasked(t *testing.T) {
_, expected, _ := net.ParseCIDR(testcase.maskedNetwork)
n1 := NewNetwork(*network)
e1 := NewNetwork(*expected)
assert.True(t, e1.String() == n1.Masked(testcase.mask).String())
assert.Equal(t, e1.String(), n1.Masked(testcase.mask).String())
}
}
......@@ -377,22 +379,19 @@ func TestMask(t *testing.T) {
mask NetworkNumberMask
ip NetworkNumber
masked NetworkNumber
err error
name string
}{
{NetworkNumberMask{math.MaxUint32}, NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32}, nil, "nop IPv4 mask"},
{NetworkNumberMask{math.MaxUint32 - math.MaxUint16}, NetworkNumber{math.MaxUint16 + 1}, NetworkNumber{math.MaxUint16 + 1}, nil, "nop IPv4 mask"},
{NetworkNumberMask{math.MaxUint32 - math.MaxUint16}, NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32 - math.MaxUint16}, nil, "IPv4 masked"},
{NetworkNumberMask{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32, 0, 0, 0}, nil, "nop IPv6 mask"},
{NetworkNumberMask{math.MaxUint32 - math.MaxUint16, 0, 0, 0}, NetworkNumber{math.MaxUint16 + 1, 0, 0, 0}, NetworkNumber{math.MaxUint16 + 1, 0, 0, 0}, nil, "nop IPv6 mask"},
{NetworkNumberMask{math.MaxUint32 - math.MaxUint16, 0, 0, 0}, NetworkNumber{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32 - math.MaxUint16, 0, 0, 0}, nil, "IPv6 masked"},
{NetworkNumberMask{math.MaxUint32}, NetworkNumber{math.MaxUint32, 0}, nil, ErrVersionMismatch, "Version mismatch"},
{32, NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32}, "nop IPv4 mask"},
{16, NetworkNumber{math.MaxUint16 + 1}, NetworkNumber{math.MaxUint16 + 1}, "nop IPv4 mask"},
{16, NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32 - math.MaxUint16}, "IPv4 masked"},
{96, NetworkNumber{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32, 0, 0, 0}, "nop IPv6 mask"},
{16, NetworkNumber{math.MaxUint16 + 1, 0, 0, 0}, NetworkNumber{math.MaxUint16 + 1, 0, 0, 0}, "nop IPv6 mask"},
{16, NetworkNumber{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32 - math.MaxUint16, 0, 0, 0}, "IPv6 masked"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
masked, err := tc.mask.Mask(tc.ip)
assert.Equal(t, tc.masked, masked)
assert.Equal(t, tc.err, err)
masked := tc.mask.Mask(tc.ip)
assert.Equal(t, tc.masked, masked, tc.name)
})
}
}
......
......@@ -84,8 +84,7 @@ func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie {
}
func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie {
ones, _ := network.IPNet.Mask.Size()
leaf := newPathprefixTrie(network, uint(ones))
leaf := newPathprefixTrie(network, uint(network.Mask))
leaf.entry = entry
return leaf
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment