diff --git a/cidranger_test.go b/cidranger_test.go index e727b95b5def70b31185ce86b78143f900ac6880..8e806275cac700724b5383d5ec828bbbec3619c2 100644 --- a/cidranger_test.go +++ b/cidranger_test.go @@ -107,10 +107,10 @@ func testCoversNetworksAgainstBase(t *testing.T, iterations int, netGen networkG for i := 0; i < iterations; i++ { network := netGen() - expected, err := baseRanger.CoveredNetworks(network.IPNet) + expected, err := baseRanger.CoveredNetworks(network.IPNet()) assert.NoError(t, err) for _, ranger := range rangers { - actual, err := ranger.CoveredNetworks(network.IPNet) + actual, err := ranger.CoveredNetworks(network.IPNet()) assert.NoError(t, err) assert.Equal(t, len(expected), len(actual)) for _, network := range actual { diff --git a/net/ip.go b/net/ip.go index de700cf3b8f8be006302276c655ea3915734acf3..f00ec208fddc01c8b86ef59ae39da4687d068fc1 100644 --- a/net/ip.go +++ b/net/ip.go @@ -170,63 +170,82 @@ func (n NetworkNumber) LeastCommonBitPosition(n1 NetworkNumber) (uint, error) { // Network represents a block of network numbers, also known as CIDR. type Network struct { - net.IPNet Number NetworkNumber Mask NetworkNumberMask } // NewNetwork returns Network built using given net.IPNet. func NewNetwork(ipNet net.IPNet) Network { + ones, _ := ipNet.Mask.Size() return Network{ - IPNet: ipNet, Number: NewNetworkNumber(ipNet.IP), - Mask: NetworkNumberMask(NewNetworkNumber(net.IP(ipNet.Mask))), + Mask: NetworkNumberMask(ones), } } // Masked returns a new network conforming to new mask. func (n Network) Masked(ones int) Network { - mask := net.CIDRMask(ones, len(n.Number)*BitsPerUint32) - return NewNetwork(net.IPNet{ - IP: n.IP.Mask(mask), - Mask: mask, - }) + mask := NetworkNumberMask(ones) + return Network{ + Number: mask.Mask(n.Number), + Mask: mask, + } +} + +func sub(a, b uint8) uint8 { + res := a - b + if res > a { + res = 0 + } + return res +} + +func mask(m NetworkNumberMask) (mask1, mask2, mask3, mask4 uint32) { + // We're relying on overflow here. + const ones uint32 = 0xFFFFFFFF + mask1 = ones << sub(1*32, uint8(m)) + mask2 = ones << sub(2*32, uint8(m)) + mask3 = ones << sub(3*32, uint8(m)) + mask4 = ones << sub(4*32, uint8(m)) + return } // Contains returns true if NetworkNumber is in range of Network, false // otherwise. func (n Network) Contains(nn NetworkNumber) bool { - if len(n.Mask) != len(nn) { + if len(n.Number) != len(nn) { return false } - if nn[0]&n.Mask[0] != n.Number[0] { + const ones uint32 = 0xFFFFFFFF + + mask1, mask2, mask3, mask4 := mask(n.Mask) + switch len(n.Number) { + case IPv4Uint32Count: + return nn[0]&mask1 == n.Number[0] + case IPv6Uint32Count: + return nn[0]&mask1 == n.Number[0] && + nn[1]&mask2 == n.Number[1] && + nn[2]&mask3 == n.Number[2] && + nn[3]&mask4 == n.Number[3] + default: return false } - if len(nn) == IPv6Uint32Count { - return nn[1]&n.Mask[1] == n.Number[1] && nn[2]&n.Mask[2] == n.Number[2] && nn[3]&n.Mask[3] == n.Number[3] - } - return true } // Contains returns true if Network covers o, false otherwise func (n Network) Covers(o Network) bool { - if len(n.Number) != len(o.Number) { - return false - } - nMaskSize, _ := n.IPNet.Mask.Size() - oMaskSize, _ := o.IPNet.Mask.Size() - return n.Contains(o.Number) && nMaskSize <= oMaskSize + return n.Contains(o.Number) && n.Mask <= o.Mask } // LeastCommonBitPosition returns the smallest position of the preceding common // bits of the 2 networks, and returns an error ErrNoGreatestCommonBit // if the two network number diverges from the first bit. func (n Network) LeastCommonBitPosition(n1 Network) (uint, error) { - maskSize, _ := n.IPNet.Mask.Size() - if maskSize1, _ := n1.IPNet.Mask.Size(); maskSize1 < maskSize { - maskSize = maskSize1 + maskSize := n.Mask + if n1.Mask < n.Mask { + maskSize = n1.Mask } - maskPosition := len(n1.Number)*BitsPerUint32 - maskSize + maskPosition := len(n1.Number)*BitsPerUint32 - int(maskSize) lcb, err := n.Number.LeastCommonBitPosition(n1.Number) if err != nil { return 0, err @@ -236,32 +255,38 @@ func (n Network) LeastCommonBitPosition(n1 Network) (uint, error) { // Equal is the equality test for 2 networks. func (n Network) Equal(n1 Network) bool { - nones, nbits := n.IPNet.Mask.Size() - n1ones, n1bits := n1.IPNet.Mask.Size() - - return nones == n1ones && nbits == n1bits && n.IPNet.IP.Equal(n1.IPNet.IP) + return n.Number.Equal(n1.Number) && n.Mask == n1.Mask } func (n Network) String() string { - return n.IPNet.String() + return fmt.Sprintf("%s/%d", n.Number.ToIP(), n.Mask) +} + +func (n Network) IPNet() net.IPNet { + return net.IPNet{ + IP: n.Number.ToIP(), + Mask: net.CIDRMask(int(n.Mask), len(n.Number)*32), + } } // NetworkNumberMask is an IP address. -type NetworkNumberMask NetworkNumber +type NetworkNumberMask int // Mask returns a new masked NetworkNumber from given NetworkNumber. -func (m NetworkNumberMask) Mask(n NetworkNumber) (NetworkNumber, error) { - if len(m) != len(n) { - return nil, ErrVersionMismatch - } - result := make(NetworkNumber, len(m)) - result[0] = m[0] & n[0] - if len(m) == IPv6Uint32Count { - result[1] = m[1] & n[1] - result[2] = m[2] & n[2] - result[3] = m[3] & n[3] +func (m NetworkNumberMask) Mask(n NetworkNumber) NetworkNumber { + mask1, mask2, mask3, mask4 := mask(m) + + result := make(NetworkNumber, len(n)) + switch len(n) { + case IPv4Uint32Count: + result[0] = n[0] & mask1 + case IPv6Uint32Count: + result[0] = n[0] & mask1 + result[1] = n[1] & mask2 + result[2] = n[2] & mask3 + result[3] = n[3] & mask4 } - return result, nil + return result } // NextIP returns the next sequential ip. diff --git a/net/ip_test.go b/net/ip_test.go index 30e6407ebd1e5aee2a8c250e0d36aaa7bd4a7e8a..8bc40ab24f646d16757b1d03e9d8303cc4076c9b 100644 --- a/net/ip_test.go +++ b/net/ip_test.go @@ -220,9 +220,11 @@ func TestNewNetwork(t *testing.T) { _, ipNet, _ := net.ParseCIDR("192.128.0.0/24") n := NewNetwork(*ipNet) - assert.Equal(t, *ipNet, n.IPNet) + newIPNet := n.IPNet() + assert.True(t, ipNet.IP.Equal(newIPNet.IP)) + assert.Equal(t, ipNet.Mask, newIPNet.Mask) assert.Equal(t, NetworkNumber{3229614080}, n.Number) - assert.Equal(t, NetworkNumberMask{math.MaxUint32 - uint32(math.MaxUint8)}, n.Mask) + assert.Equal(t, NetworkNumberMask(24), n.Mask) } func TestNetworkMasked(t *testing.T) { @@ -245,7 +247,7 @@ func TestNetworkMasked(t *testing.T) { _, expected, _ := net.ParseCIDR(testcase.maskedNetwork) n1 := NewNetwork(*network) e1 := NewNetwork(*expected) - assert.True(t, e1.String() == n1.Masked(testcase.mask).String()) + assert.Equal(t, e1.String(), n1.Masked(testcase.mask).String()) } } @@ -377,22 +379,19 @@ func TestMask(t *testing.T) { mask NetworkNumberMask ip NetworkNumber masked NetworkNumber - err error name string }{ - {NetworkNumberMask{math.MaxUint32}, NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32}, nil, "nop IPv4 mask"}, - {NetworkNumberMask{math.MaxUint32 - math.MaxUint16}, NetworkNumber{math.MaxUint16 + 1}, NetworkNumber{math.MaxUint16 + 1}, nil, "nop IPv4 mask"}, - {NetworkNumberMask{math.MaxUint32 - math.MaxUint16}, NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32 - math.MaxUint16}, nil, "IPv4 masked"}, - {NetworkNumberMask{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32, 0, 0, 0}, nil, "nop IPv6 mask"}, - {NetworkNumberMask{math.MaxUint32 - math.MaxUint16, 0, 0, 0}, NetworkNumber{math.MaxUint16 + 1, 0, 0, 0}, NetworkNumber{math.MaxUint16 + 1, 0, 0, 0}, nil, "nop IPv6 mask"}, - {NetworkNumberMask{math.MaxUint32 - math.MaxUint16, 0, 0, 0}, NetworkNumber{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32 - math.MaxUint16, 0, 0, 0}, nil, "IPv6 masked"}, - {NetworkNumberMask{math.MaxUint32}, NetworkNumber{math.MaxUint32, 0}, nil, ErrVersionMismatch, "Version mismatch"}, + {32, NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32}, "nop IPv4 mask"}, + {16, NetworkNumber{math.MaxUint16 + 1}, NetworkNumber{math.MaxUint16 + 1}, "nop IPv4 mask"}, + {16, NetworkNumber{math.MaxUint32}, NetworkNumber{math.MaxUint32 - math.MaxUint16}, "IPv4 masked"}, + {96, NetworkNumber{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32, 0, 0, 0}, "nop IPv6 mask"}, + {16, NetworkNumber{math.MaxUint16 + 1, 0, 0, 0}, NetworkNumber{math.MaxUint16 + 1, 0, 0, 0}, "nop IPv6 mask"}, + {16, NetworkNumber{math.MaxUint32, 0, 0, 0}, NetworkNumber{math.MaxUint32 - math.MaxUint16, 0, 0, 0}, "IPv6 masked"}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - masked, err := tc.mask.Mask(tc.ip) - assert.Equal(t, tc.masked, masked) - assert.Equal(t, tc.err, err) + masked := tc.mask.Mask(tc.ip) + assert.Equal(t, tc.masked, masked, tc.name) }) } } diff --git a/trie.go b/trie.go index 6c1c5437d05e0f2e60c64b48eb2a1717c93e7bf5..eaed9e731258537bfc38c111ddd14966c852c69b 100644 --- a/trie.go +++ b/trie.go @@ -84,8 +84,7 @@ func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie { } func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie { - ones, _ := network.IPNet.Mask.Size() - leaf := newPathprefixTrie(network, uint(ones)) + leaf := newPathprefixTrie(network, uint(network.Mask)) leaf.entry = entry return leaf }