Commit 5f9a182b authored by yl2chen's avatar yl2chen

Add cidr and ip utilities

parent a2d7e23d
hash: 534e4e9766f824c96de0b446e6a0b2c3ec248a6fb43e532a6cf15c887792a536
updated: 2017-07-26T00:09:24.018576859-07:00
imports:
- name: github.com/stretchr/testify
version: 05e8a0eda380579888eb53c394909df027f06991
devImports: []
package: github.com/yl2chen/iptable
import:
- package: github.com/stretchr/testify
package cidr
import (
"fmt"
"net"
"github.com/yl2chen/cidranger/util/ip"
)
// ErrNoGreatestCommonBit is an error returned when no greatest common bit
// exists for the cidr ranges.
var ErrNoGreatestCommonBit = fmt.Errorf("No greatest common bit")
// GreatestCommonBitPosition returns the greatest common bit position of
// given cidr blocks.
func GreatestCommonBitPosition(network1 *net.IPNet, network2 *net.IPNet) (uint8, error) {
ip1, err := ip.IPv4ToBigEndianUint32(network1.IP)
if err != nil {
return 0, err
}
ip2, err := ip.IPv4ToBigEndianUint32(network2.IP)
if err != nil {
return 0, err
}
maskSize, _ := network1.Mask.Size()
if maskSize2, _ := network2.Mask.Size(); maskSize2 < maskSize {
maskSize = maskSize2
}
mask := uint32(1) << 31
if ip1&mask != ip2&mask {
return 0, ErrNoGreatestCommonBit
}
var i = 1
for ; i < maskSize; i++ {
mask = mask >> 1
if ip1&mask != ip2&mask {
break
}
}
return uint8(31 - i + 1), nil
}
// MaskNetwork returns a copy of given network with new mask.
func MaskNetwork(network *net.IPNet, ones int) *net.IPNet {
mask := net.CIDRMask(ones, 32)
return &net.IPNet{
IP: network.IP.Mask(mask),
Mask: mask,
}
}
// IPsInNetwork returns a channel that generates all ips in given network.
func IPsInNetwork(network net.IPNet) <-chan net.IP {
ipChannel := make(chan net.IP)
startingIP := network.IP
ones, bits := network.Mask.Size()
networkSize := 1 << uint(bits-ones)
go func() {
for i := 0; i < networkSize; i++ {
ipChannel <- startingIP
startingIP = ip.NextIP(startingIP)
}
close(ipChannel)
}()
return ipChannel
}
package cidr
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/yl2chen/cidranger/util/ip"
)
func TestGreatestCommonBitPosition(t *testing.T) {
cases := []struct {
cidr1 string
cidr2 string
expectedPos uint8
expectedErr error
}{
{"0.0.1.0/24", "0.0.0.0/24", uint8(9), nil},
{"0.0.0.0/24", "0.0.0.0/24", uint8(8), nil},
{"128.0.0.0/24", "0.0.0.0/24", 0, ErrNoGreatestCommonBit},
{"128.0.0.0/24", "192.0.0.0/16", uint8(31), nil},
{"128.0.0.0/24", "128.0.0.0/16", uint8(16), nil},
{"128.0.0.0/24", "128.1.0.0/16", uint8(17), nil},
}
for _, c := range cases {
_, cidr1, err := net.ParseCIDR(c.cidr1)
assert.NoError(t, err)
_, cidr2, err := net.ParseCIDR(c.cidr2)
assert.NoError(t, err)
pos, err := GreatestCommonBitPosition(cidr1, cidr2)
if c.expectedErr != nil {
assert.Equal(t, c.expectedErr, err)
} else {
assert.Equal(t, c.expectedPos, pos)
}
}
}
func TestMaskNetwork(t *testing.T) {
cases := []struct {
network string
mask int
maskedNetwork string
}{
{"192.168.0.0/16", 16, "192.168.0.0/16"},
{"192.168.0.0/16", 14, "192.168.0.0/14"},
{"192.168.0.0/16", 18, "192.168.0.0/18"},
{"192.168.0.0/16", 8, "192.0.0.0/8"},
}
for _, testcase := range cases {
_, network, err := net.ParseCIDR(testcase.network)
assert.NoError(t, err)
_, expected, err := net.ParseCIDR(testcase.maskedNetwork)
assert.NoError(t, err)
assert.Equal(t, expected, MaskNetwork(network, testcase.mask))
}
}
func TestIPsInNetwork(t *testing.T) {
cases := []struct {
network string
start net.IP
end net.IP
name string
}{
{
"192.168.0.0/30",
net.ParseIP("192.168.0.0"),
net.ParseIP("192.168.0.4"),
"IPs for 192.168.0.0/30",
},
{
"192.168.0.0/29",
net.ParseIP("192.168.0.0"),
net.ParseIP("192.168.0.8"),
"IPs for 192.168.0.0/29",
},
{
"192.168.0.0/24",
net.ParseIP("192.168.0.0"),
net.ParseIP("192.168.1.0"),
"IPs for 192.168.0.0/24",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
_, network, err := net.ParseCIDR(tc.network)
assert.NoError(t, err)
ips := IPsInNetwork(*network)
start := tc.start
for actual := range ips {
assert.Equal(t, start, actual.To16())
start = ip.NextIP(start)
}
assert.Equal(t, tc.end, start)
})
}
}
package ip
import (
"encoding/binary"
"fmt"
"net"
)
// ErrNotIPv4Error is returned when IPv4 operations is performed on IPv6.
var ErrNotIPv4Error = fmt.Errorf("IP is not IPv4")
// ErrBitsNotValid is returned when bits requested is not valid.
var ErrBitsNotValid = fmt.Errorf("bits requested not valid")
const ipv4BitLength = 32
// IPv4ToBigEndianUint32 converts IPv4 to uint32.
func IPv4ToBigEndianUint32(ip net.IP) (uint32, error) {
ip = ip.To4()
if ip == nil {
return 0, ErrNotIPv4Error
}
return binary.BigEndian.Uint32(ip), nil
}
// IPv4BitsAsUint returns uint32 representing bits at position of length
// numberOfBits, position is a number in [0, 31] representing the starting
// position in ip, with 31 being the most significant bit.
// E.g.,
// "128.0.0.0" has bit value of 1 at the 31th bit.
func IPv4BitsAsUint(ip uint32, position uint8, numberOfBits uint8) (uint32, error) {
if numberOfBits == 0 || numberOfBits > ipv4BitLength || position > ipv4BitLength-1 {
return 0, ErrBitsNotValid
}
if numberOfBits-1 > position {
return 0, ErrBitsNotValid
}
shiftLeft := position - (numberOfBits - 1)
mask := (uint32(1)<<numberOfBits - 1) << shiftLeft
return (ip & mask) >> shiftLeft, nil
}
// NextIP returns the next sequential ip.
func NextIP(ip net.IP) net.IP {
newIP := make([]byte, len(ip))
copy(newIP, ip)
for i := len(newIP) - 1; i >= 0; i-- {
newIP[i]++
if newIP[i] > 0 {
break
}
}
return newIP
}
// PreviousIP returns the previous sequential ip.
func PreviousIP(ip net.IP) net.IP {
newIP := make([]byte, len(ip))
copy(newIP, ip)
for i := len(newIP) - 1; i >= 0; i-- {
newIP[i]--
if newIP[i] < 255 {
break
}
}
return newIP
}
package ip
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIPv4ToLittleEndianUint32(t *testing.T) {
cases := []struct {
ip string
ipUint32 uint32
expectedErr error
}{
{"0.0.0.1", 1, nil},
{"1.0.0.0", 16777216, nil},
{"2001:0db8:0000:0000:0000:ff00:0042:8329", 0, ErrNotIPv4Error},
}
for _, c := range cases {
t.Run(c.ip, func(t *testing.T) {
ret, err := IPv4ToBigEndianUint32(net.ParseIP(c.ip))
assert.Equal(t, c.expectedErr, err)
assert.Equal(t, c.ipUint32, ret)
})
}
}
func TestIPv4BitsAsUint(t *testing.T) {
cases := []struct {
ip string
ipUint32 uint32
position uint8
bits uint8
expectedBits uint32
expectedErr error
}{
{"0.0.0.1", 1, 0, 0, 0, ErrBitsNotValid},
{"0.0.0.1", 1, 0, 2, 0, ErrBitsNotValid},
{"0.0.0.1", 1, 0, 33, 0, ErrBitsNotValid},
{"0.0.0.1", 1, 32, 1, 0, ErrBitsNotValid},
{"0.0.0.1", 1, 0, 1, 1, nil},
{"0.0.0.1", 1, 1, 2, 1, nil},
{"0.0.0.1", 1, 2, 1, 0, nil},
{"1.0.0.0", 16777216, 24, 1, 1, nil},
{"1.0.0.0", 16777216, 24, 2, 2, nil},
{"1.0.0.0", 16777216, 24, 3, 4, nil},
{"1.0.0.0", 16777216, 24, 25, 16777216, nil},
}
for _, c := range cases {
t.Run(c.ip, func(t *testing.T) {
ret, err := IPv4BitsAsUint(c.ipUint32, c.position, c.bits)
assert.Equal(t, c.expectedErr, err)
assert.Equal(t, c.expectedBits, ret)
})
}
}
// TODO: add test cases for ipV6
func TestNextIP(t *testing.T) {
cases := []struct {
ip string
next string
name string
}{
{"0.0.0.0", "0.0.0.1", "basic"},
{"0.0.0.255", "0.0.1.0", "rollover"},
{"0.255.255.255", "1.0.0.0", "consecutive rollover"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, net.ParseIP(tc.next), NextIP(net.ParseIP(tc.ip)))
})
}
}
// TODO: add test cases for ipV6
func TestPreviousIP(t *testing.T) {
cases := []struct {
ip string
next string
name string
}{
{"0.0.0.1", "0.0.0.0", "basic"},
{"0.0.1.0", "0.0.0.255", "rollover"},
{"1.0.0.0", "0.255.255.255", "consecutive rollover"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, net.ParseIP(tc.next), PreviousIP(net.ParseIP(tc.ip)))
})
}
}
testify @ 05e8a0ed
Subproject commit 05e8a0eda380579888eb53c394909df027f06991
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment