Unverified Commit 1c6c8201 authored by Yulin Chen's avatar Yulin Chen Committed by GitHub

Add Len() to Ranger (#24)

* Add Len() to Ranger
parent 31e96aaf
...@@ -108,6 +108,11 @@ func (b *bruteRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) ...@@ -108,6 +108,11 @@ func (b *bruteRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error)
return results, nil return results, nil
} }
// Len returns number of networks in ranger.
func (b *bruteRanger) Len() int {
return len(b.ipV4Entries) + len(b.ipV6Entries)
}
func (b *bruteRanger) getEntriesByVersion(ip net.IP) (map[string]RangerEntry, error) { func (b *bruteRanger) getEntriesByVersion(ip net.IP) (map[string]RangerEntry, error) {
if ip.To4() != nil { if ip.To4() != nil {
return b.ipV4Entries, nil return b.ipV4Entries, nil
......
...@@ -72,6 +72,7 @@ type Ranger interface { ...@@ -72,6 +72,7 @@ type Ranger interface {
Contains(ip net.IP) (bool, error) Contains(ip net.IP) (bool, error)
ContainingNetworks(ip net.IP) ([]RangerEntry, error) ContainingNetworks(ip net.IP) ([]RangerEntry, error)
CoveredNetworks(network net.IPNet) ([]RangerEntry, error) CoveredNetworks(network net.IPNet) ([]RangerEntry, error)
Len() int
} }
// NewPCTrieRanger returns a versionedRanger that supports both IPv4 and IPv6 // NewPCTrieRanger returns a versionedRanger that supports both IPv4 and IPv6
......
...@@ -42,6 +42,8 @@ type prefixTrie struct { ...@@ -42,6 +42,8 @@ type prefixTrie struct {
network rnet.Network network rnet.Network
entry RangerEntry entry RangerEntry
size int // This is only maintained in the root trie.
} }
// newPrefixTree creates a new prefixTrie. // newPrefixTree creates a new prefixTrie.
...@@ -79,12 +81,20 @@ func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie { ...@@ -79,12 +81,20 @@ func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie {
// Insert inserts a RangerEntry into prefix trie. // Insert inserts a RangerEntry into prefix trie.
func (p *prefixTrie) Insert(entry RangerEntry) error { func (p *prefixTrie) Insert(entry RangerEntry) error {
network := entry.Network() network := entry.Network()
return p.insert(rnet.NewNetwork(network), entry) sizeIncreased, err := p.insert(rnet.NewNetwork(network), entry)
if sizeIncreased {
p.size++
}
return err
} }
// Remove removes RangerEntry identified by given network from trie. // Remove removes RangerEntry identified by given network from trie.
func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) { func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) {
return p.remove(rnet.NewNetwork(network)) entry, err := p.remove(rnet.NewNetwork(network))
if entry != nil {
p.size--
}
return entry, err
} }
// Contains returns boolean indicating whether given ip is contained in any // Contains returns boolean indicating whether given ip is contained in any
...@@ -115,6 +125,11 @@ func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) { ...@@ -115,6 +125,11 @@ func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
return p.coveredNetworks(net) return p.coveredNetworks(net)
} }
// Len returns number of networks in ranger.
func (p *prefixTrie) Len() int {
return p.size
}
// String returns string representation of trie, mainly for visualization and // String returns string representation of trie, mainly for visualization and
// debugging. // debugging.
func (p *prefixTrie) String() string { func (p *prefixTrie) String() string {
...@@ -203,22 +218,23 @@ func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error ...@@ -203,22 +218,23 @@ func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error
return results, nil return results, nil
} }
func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) error { func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) (bool, error) {
if p.network.Equal(network) { if p.network.Equal(network) {
sizeIncreased := p.entry == nil
p.entry = entry p.entry = entry
return nil return sizeIncreased, nil
} }
bit, err := p.targetBitFromIP(network.Number) bit, err := p.targetBitFromIP(network.Number)
if err != nil { if err != nil {
return err return false, err
} }
existingChild := p.children[bit] existingChild := p.children[bit]
// No existing child, insert new leaf trie. // No existing child, insert new leaf trie.
if existingChild == nil { if existingChild == nil {
p.appendTrie(bit, newEntryTrie(network, entry)) p.appendTrie(bit, newEntryTrie(network, entry))
return nil return true, nil
} }
// Check whether it is necessary to insert additional path prefix between current trie and existing child, // Check whether it is necessary to insert additional path prefix between current trie and existing child,
...@@ -229,7 +245,7 @@ func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) error { ...@@ -229,7 +245,7 @@ func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) error {
pathPrefix := newPathprefixTrie(network, p.totalNumberOfBits()-lcb) pathPrefix := newPathprefixTrie(network, p.totalNumberOfBits()-lcb)
err := p.insertPrefix(bit, pathPrefix, existingChild) err := p.insertPrefix(bit, pathPrefix, existingChild)
if err != nil { if err != nil {
return err return false, err
} }
// Update new child // Update new child
existingChild = pathPrefix existingChild = pathPrefix
......
...@@ -71,6 +71,9 @@ func TestPrefixTrieInsert(t *testing.T) { ...@@ -71,6 +71,9 @@ func TestPrefixTrieInsert(t *testing.T) {
err := trie.Insert(NewBasicRangerEntry(*network)) err := trie.Insert(NewBasicRangerEntry(*network))
assert.NoError(t, err) assert.NoError(t, err)
} }
assert.Equal(t, len(tc.expectedNetworksInDepthOrder), trie.Len(), "trie size should match")
walk := trie.walkDepth() walk := trie.walkDepth()
for _, network := range tc.expectedNetworksInDepthOrder { for _, network := range tc.expectedNetworksInDepthOrder {
_, ipnet, _ := net.ParseCIDR(network) _, ipnet, _ := net.ParseCIDR(network)
...@@ -198,6 +201,9 @@ func TestPrefixTrieRemove(t *testing.T) { ...@@ -198,6 +201,9 @@ func TestPrefixTrieRemove(t *testing.T) {
assert.Nil(t, removed) assert.Nil(t, removed)
} }
} }
assert.Equal(t, len(tc.expectedNetworksInDepthOrder), trie.Len(), "trie size should match after revmoval")
walk := trie.walkDepth() walk := trie.walkDepth()
for _, network := range tc.expectedNetworksInDepthOrder { for _, network := range tc.expectedNetworksInDepthOrder {
_, ipnet, _ := net.ParseCIDR(network) _, ipnet, _ := net.ParseCIDR(network)
...@@ -476,11 +482,15 @@ func TestTrieMemUsage(t *testing.T) { ...@@ -476,11 +482,15 @@ func TestTrieMemUsage(t *testing.T) {
var baseLineHeap, totalHeapAllocOverRuns uint64 var baseLineHeap, totalHeapAllocOverRuns uint64
for i := 0; i < runs; i++ { for i := 0; i < runs; i++ {
t.Logf("Executing Run %d of %d", i+1, runs)
// Insert networks. // Insert networks.
for n := 0; n < numIPs; n++ { for n := 0; n < numIPs; n++ {
trie.Insert(NewBasicRangerEntry(GenLeafIPNet(GenIPV4()))) trie.Insert(NewBasicRangerEntry(GenLeafIPNet(GenIPV4())))
} }
t.Logf("Inserted All (%d networks)", trie.Len())
assert.Less(t, 0, trie.Len(), "Len should > 0")
assert.LessOrEqualf(t, trie.Len(), numIPs, "Len should <= %d", numIPs)
// Remove networks. // Remove networks.
_, all, _ := net.ParseCIDR("0.0.0.0/0") _, all, _ := net.ParseCIDR("0.0.0.0/0")
...@@ -488,6 +498,8 @@ func TestTrieMemUsage(t *testing.T) { ...@@ -488,6 +498,8 @@ func TestTrieMemUsage(t *testing.T) {
for i := 0; i < len(ll); i++ { for i := 0; i < len(ll); i++ {
trie.Remove(ll[i].Network()) trie.Remove(ll[i].Network())
} }
t.Logf("Removed All (%d networks)", len(ll))
assert.Equal(t, 0, trie.Len(), "Len after removal should == 0")
// Perform GC // Perform GC
runtime.GC() runtime.GC()
...@@ -495,7 +507,7 @@ func TestTrieMemUsage(t *testing.T) { ...@@ -495,7 +507,7 @@ func TestTrieMemUsage(t *testing.T) {
// Get HeapAlloc stats. // Get HeapAlloc stats.
heapAlloc := GetHeapAllocation() heapAlloc := GetHeapAllocation()
totalHeapAllocOverRuns += heapAlloc totalHeapAllocOverRuns += heapAlloc
if i ==0 { if i == 0 {
baseLineHeap = heapAlloc baseLineHeap = heapAlloc
} }
} }
......
...@@ -61,6 +61,11 @@ func (v *versionedRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, err ...@@ -61,6 +61,11 @@ func (v *versionedRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, err
return ranger.CoveredNetworks(network) return ranger.CoveredNetworks(network)
} }
// Len returns number of networks in ranger.
func (v *versionedRanger) Len() int {
return v.ipV4Ranger.Len() + v.ipV6Ranger.Len()
}
func (v *versionedRanger) getRangerForIP(ip net.IP) (Ranger, error) { func (v *versionedRanger) getRangerForIP(ip net.IP) (Ranger, error) {
if ip.To4() != nil { if ip.To4() != nil {
return v.ipV4Ranger, nil return v.ipV4Ranger, nil
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment