diff --git a/go.mod b/go.mod index b91ebd7cf4d87455d24902c5e32304885e3782ae..50915bd74e0c5fef7bcd5cdef54b93cfb7784c8d 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,4 @@ module github.com/yl2chen/cidranger go 1.13 -require ( - github.com/davecgh/go-spew v1.1.0 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/testify v1.2.1 -) +require github.com/stretchr/testify v1.4.0 diff --git a/go.sum b/go.sum index 314bd6aeb72ec4670a13787c9f349efa05fba016..2789b9ac31c8192ee032b3bb0f65e862e99bdc4c 100644 --- a/go.sum +++ b/go.sum @@ -2,5 +2,11 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.1 h1:52QO5WkIUcHGIR7EnGagH88x1bUzqGXTC5/1bDTUQ7U= github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/trie.go b/trie.go index 82a84501639274bcd1e4b5164cf9dfa336aca0d9..fdcfa2d6dce05351d54c47347690d8ffab8c52fa 100644 --- a/trie.go +++ b/trie.go @@ -248,22 +248,11 @@ func (p *prefixTrie) insertPrefix(bits uint32, prefix *prefixTrie) error { func (p *prefixTrie) remove(network rnet.Network) (RangerEntry, error) { if p.hasEntry() && p.network.Equal(network) { entry := p.entry - if p.childrenCount() > 1 { - p.entry = nil - } else { - // Has 0 or 1 child. - parentBits, err := p.parent.targetBitFromIP(network.Number) - if err != nil { - return nil, err - } - var skipChild *prefixTrie - for _, child := range p.children { - if child != nil { - skipChild = child - break - } - } - p.parent.children[parentBits] = skipChild + p.entry = nil + + err := p.compressPathIfPossible() + if err != nil { + return nil, err } return entry, nil } @@ -278,6 +267,44 @@ func (p *prefixTrie) remove(network rnet.Network) (RangerEntry, error) { return nil, nil } +func (p *prefixTrie) qualifiesForPathCompression() bool { + // Current prefix trie can be path compressed if it meets all following. + // 1. records no CIDR entry + // 2. has single or no child + // 3. is not root trie + return !p.hasEntry() && p.childrenCount() <= 1 && p.parent != nil; +} + +func (p *prefixTrie) compressPathIfPossible() error { + if !p.qualifiesForPathCompression() { + // Does not qualify to be compressed + return nil + } + + // Find lone child. + var loneChild *prefixTrie + for _, child := range p.children { + if child != nil { + loneChild = child + break + } + } + + // Find root of currnt single child lineage. + parent := p.parent + for ; parent.qualifiesForPathCompression(); parent = parent.parent { + } + parentBit, err := parent.targetBitFromIP(p.network.Number) + if err != nil { + return err + } + parent.children[parentBit] = loneChild + + // Attempts to furthur apply path compression at current lineage parent, in case current lineage + // compressed into parent. + return parent.compressPathIfPossible() +} + func (p *prefixTrie) childrenCount() int { count := 0 for _, child := range p.children { diff --git a/trie_test.go b/trie_test.go index 29125fd6112d3bf6f3a14bdc86d40843857fcb44..56b0ebea321c003459072a1c81a2585d61d6ab01 100644 --- a/trie_test.go +++ b/trie_test.go @@ -1,8 +1,12 @@ package cidranger import ( + "encoding/binary" + "math/rand" "net" + "runtime" "testing" + "time" "github.com/stretchr/testify/assert" rnet "github.com/yl2chen/cidranger/net" @@ -105,6 +109,7 @@ func TestPrefixTrieRemove(t *testing.T) { removes []string expectedRemoves []string expectedNetworksInDepthOrder []string + expectedTrieString string name string }{ { @@ -113,6 +118,7 @@ func TestPrefixTrieRemove(t *testing.T) { []string{"192.168.0.1/24"}, []string{"192.168.0.1/24"}, []string{}, + "0.0.0.0/0 (target_pos:31:has_entry:false)", "basic remove", }, { @@ -121,6 +127,8 @@ func TestPrefixTrieRemove(t *testing.T) { []string{"1.2.3.5/32"}, []string{"1.2.3.5/32"}, []string{"1.2.3.4/32"}, + `0.0.0.0/0 (target_pos:31:has_entry:false) +| 0--> 1.2.3.4/32 (target_pos:-1:has_entry:true)`, "single ip IPv4 network remove", }, { @@ -129,6 +137,8 @@ func TestPrefixTrieRemove(t *testing.T) { []string{"0::2/128"}, []string{"0::2/128"}, []string{"0::1/128"}, + `0.0.0.0/0 (target_pos:31:has_entry:false) +| 0--> ::1/128 (target_pos:-1:has_entry:true)`, "single ip IPv6 network remove", }, { @@ -137,6 +147,9 @@ func TestPrefixTrieRemove(t *testing.T) { []string{"192.168.0.1/25"}, []string{"192.168.0.1/25"}, []string{"192.168.0.1/24", "192.168.0.1/26"}, + `0.0.0.0/0 (target_pos:31:has_entry:false) +| 1--> 192.168.0.0/24 (target_pos:7:has_entry:true) +| | 0--> 192.168.0.0/26 (target_pos:5:has_entry:true)`, "remove path prefix", }, { @@ -145,6 +158,11 @@ func TestPrefixTrieRemove(t *testing.T) { []string{"192.168.0.1/25"}, []string{"192.168.0.1/25"}, []string{"192.168.0.1/24", "192.168.0.1/26", "192.168.0.64/26"}, + `0.0.0.0/0 (target_pos:31:has_entry:false) +| 1--> 192.168.0.0/24 (target_pos:7:has_entry:true) +| | 0--> 192.168.0.0/25 (target_pos:6:has_entry:false) +| | | 0--> 192.168.0.0/26 (target_pos:5:has_entry:true) +| | | 1--> 192.168.0.64/26 (target_pos:5:has_entry:true)`, "remove path prefix with more than 1 children", }, { @@ -153,6 +171,9 @@ func TestPrefixTrieRemove(t *testing.T) { []string{"192.168.0.1/26"}, []string{""}, []string{"192.168.0.1/24", "192.168.0.1/25"}, + `0.0.0.0/0 (target_pos:31:has_entry:false) +| 1--> 192.168.0.0/24 (target_pos:7:has_entry:true) +| | 0--> 192.168.0.0/25 (target_pos:6:has_entry:true)`, "remove non existent", }, } @@ -189,6 +210,8 @@ func TestPrefixTrieRemove(t *testing.T) { for network := range walk { assert.Nil(t, network) } + + assert.Equal(t, tc.expectedTrieString, trie.String()) }) } } @@ -437,3 +460,72 @@ func TestPrefixTrieCoveredNetworks(t *testing.T) { }) } } + +func TestTrieMemUsage(t *testing.T) { + if testing.Short() { + t.Skip("Skipping memory test in `-short` mode") + } + numIPs := 100000 + runs := 10 + + // Avg heap allocation over all runs should not be more than the heap allocation of first run multiplied + // by threshold, picking 1% as sane number for detecting memory leak. + thresh := 1.01 + + trie := newPrefixTree(rnet.IPv4) + + var baseLineHeap, totalHeapAllocOverRuns uint64 + for i := 0; i < runs; i++ { + + // Insert networks. + for n := 0; n < numIPs; n++ { + trie.Insert(NewBasicRangerEntry(GenLeafIPNet(GenIPV4()))) + } + + // Remove networks. + _, all, _ := net.ParseCIDR("0.0.0.0/0") + ll, _ := trie.CoveredNetworks(*all) + for i := 0; i < len(ll); i++ { + trie.Remove(ll[i].Network()) + } + + // Perform GC + runtime.GC() + + // Get HeapAlloc stats. + heapAlloc := GetHeapAllocation() + totalHeapAllocOverRuns += heapAlloc + if i ==0 { + baseLineHeap = heapAlloc + } + } + + // Assert that heap allocation from first loop is within set threshold of avg over all runs. + assert.Less(t, uint64(0), baseLineHeap) + assert.LessOrEqual(t, float64(baseLineHeap), float64(totalHeapAllocOverRuns/uint64(runs))*thresh) +} + +func GenLeafIPNet(ip net.IP) net.IPNet { + return net.IPNet{ + IP: ip, + Mask: net.CIDRMask(32, 32), + } +} + +// GenIPV4 generates an IPV4 address +func GenIPV4() net.IP { + rand.Seed(time.Now().UnixNano()) + var min, max int + min = 1 + max = 4294967295 + nn := rand.Intn(max-min) + min + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, uint32(nn)) + return ip +} + +func GetHeapAllocation() uint64 { + var m runtime.MemStats + runtime.ReadMemStats(&m) + return m.HeapAlloc +}