From 1c6c8201b51cac7e8dab570d828e8f071e488abc Mon Sep 17 00:00:00 2001 From: Yulin Chen Date: Mon, 23 Dec 2019 03:40:02 -0800 Subject: [PATCH] Add Len() to Ranger (#24) * Add Len() to Ranger --- brute.go | 5 +++++ cidranger.go | 1 + trie.go | 30 +++++++++++++++++++++++------- trie_test.go | 16 ++++++++++++++-- version.go | 5 +++++ 5 files changed, 48 insertions(+), 9 deletions(-) diff --git a/brute.go b/brute.go index c8a1b4f..37a68be 100644 --- a/brute.go +++ b/brute.go @@ -108,6 +108,11 @@ func (b *bruteRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) return results, nil } +// Len returns number of networks in ranger. +func (b *bruteRanger) Len() int { + return len(b.ipV4Entries) + len(b.ipV6Entries) +} + func (b *bruteRanger) getEntriesByVersion(ip net.IP) (map[string]RangerEntry, error) { if ip.To4() != nil { return b.ipV4Entries, nil diff --git a/cidranger.go b/cidranger.go index e2c9ee5..2f38618 100644 --- a/cidranger.go +++ b/cidranger.go @@ -72,6 +72,7 @@ type Ranger interface { Contains(ip net.IP) (bool, error) ContainingNetworks(ip net.IP) ([]RangerEntry, error) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) + Len() int } // NewPCTrieRanger returns a versionedRanger that supports both IPv4 and IPv6 diff --git a/trie.go b/trie.go index 24bd678..8b2badc 100644 --- a/trie.go +++ b/trie.go @@ -42,6 +42,8 @@ type prefixTrie struct { network rnet.Network entry RangerEntry + + size int // This is only maintained in the root trie. } // newPrefixTree creates a new prefixTrie. @@ -79,12 +81,20 @@ func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie { // Insert inserts a RangerEntry into prefix trie. func (p *prefixTrie) Insert(entry RangerEntry) error { network := entry.Network() - return p.insert(rnet.NewNetwork(network), entry) + sizeIncreased, err := p.insert(rnet.NewNetwork(network), entry) + if sizeIncreased { + p.size++ + } + return err } // Remove removes RangerEntry identified by given network from trie. func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) { - return p.remove(rnet.NewNetwork(network)) + entry, err := p.remove(rnet.NewNetwork(network)) + if entry != nil { + p.size-- + } + return entry, err } // Contains returns boolean indicating whether given ip is contained in any @@ -115,6 +125,11 @@ func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) { return p.coveredNetworks(net) } +// Len returns number of networks in ranger. +func (p *prefixTrie) Len() int { + return p.size +} + // String returns string representation of trie, mainly for visualization and // debugging. func (p *prefixTrie) String() string { @@ -203,22 +218,23 @@ func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error return results, nil } -func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) error { +func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) (bool, error) { if p.network.Equal(network) { + sizeIncreased := p.entry == nil p.entry = entry - return nil + return sizeIncreased, nil } bit, err := p.targetBitFromIP(network.Number) if err != nil { - return err + return false, err } existingChild := p.children[bit] // No existing child, insert new leaf trie. if existingChild == nil { p.appendTrie(bit, newEntryTrie(network, entry)) - return nil + return true, nil } // Check whether it is necessary to insert additional path prefix between current trie and existing child, @@ -229,7 +245,7 @@ func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) error { pathPrefix := newPathprefixTrie(network, p.totalNumberOfBits()-lcb) err := p.insertPrefix(bit, pathPrefix, existingChild) if err != nil { - return err + return false, err } // Update new child existingChild = pathPrefix diff --git a/trie_test.go b/trie_test.go index 56b0ebe..e73508d 100644 --- a/trie_test.go +++ b/trie_test.go @@ -71,6 +71,9 @@ func TestPrefixTrieInsert(t *testing.T) { err := trie.Insert(NewBasicRangerEntry(*network)) assert.NoError(t, err) } + + assert.Equal(t, len(tc.expectedNetworksInDepthOrder), trie.Len(), "trie size should match") + walk := trie.walkDepth() for _, network := range tc.expectedNetworksInDepthOrder { _, ipnet, _ := net.ParseCIDR(network) @@ -198,6 +201,9 @@ func TestPrefixTrieRemove(t *testing.T) { assert.Nil(t, removed) } } + + assert.Equal(t, len(tc.expectedNetworksInDepthOrder), trie.Len(), "trie size should match after revmoval") + walk := trie.walkDepth() for _, network := range tc.expectedNetworksInDepthOrder { _, ipnet, _ := net.ParseCIDR(network) @@ -466,7 +472,7 @@ func TestTrieMemUsage(t *testing.T) { t.Skip("Skipping memory test in `-short` mode") } numIPs := 100000 - runs := 10 + runs := 10 // Avg heap allocation over all runs should not be more than the heap allocation of first run multiplied // by threshold, picking 1% as sane number for detecting memory leak. @@ -476,11 +482,15 @@ func TestTrieMemUsage(t *testing.T) { var baseLineHeap, totalHeapAllocOverRuns uint64 for i := 0; i < runs; i++ { + t.Logf("Executing Run %d of %d", i+1, runs) // Insert networks. for n := 0; n < numIPs; n++ { trie.Insert(NewBasicRangerEntry(GenLeafIPNet(GenIPV4()))) } + t.Logf("Inserted All (%d networks)", trie.Len()) + assert.Less(t, 0, trie.Len(), "Len should > 0") + assert.LessOrEqualf(t, trie.Len(), numIPs, "Len should <= %d", numIPs) // Remove networks. _, all, _ := net.ParseCIDR("0.0.0.0/0") @@ -488,6 +498,8 @@ func TestTrieMemUsage(t *testing.T) { for i := 0; i < len(ll); i++ { trie.Remove(ll[i].Network()) } + t.Logf("Removed All (%d networks)", len(ll)) + assert.Equal(t, 0, trie.Len(), "Len after removal should == 0") // Perform GC runtime.GC() @@ -495,7 +507,7 @@ func TestTrieMemUsage(t *testing.T) { // Get HeapAlloc stats. heapAlloc := GetHeapAllocation() totalHeapAllocOverRuns += heapAlloc - if i ==0 { + if i == 0 { baseLineHeap = heapAlloc } } diff --git a/version.go b/version.go index cfea061..2c3fe2b 100644 --- a/version.go +++ b/version.go @@ -61,6 +61,11 @@ func (v *versionedRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, err return ranger.CoveredNetworks(network) } +// Len returns number of networks in ranger. +func (v *versionedRanger) Len() int { + return v.ipV4Ranger.Len() + v.ipV6Ranger.Len() +} + func (v *versionedRanger) getRangerForIP(ip net.IP) (Ranger, error) { if ip.To4() != nil { return v.ipV4Ranger, nil -- GitLab