trie.go 10.4 KB
Newer Older
1
package cidranger
2 3 4 5 6 7

import (
	"fmt"
	"net"
	"strings"

Steven Allen's avatar
Steven Allen committed
8
	rnet "github.com/libp2p/go-cidranger/net"
9 10
)

11
// prefixTrie is a path-compressed (PC) trie implementation of the
12 13 14 15 16 17 18 19 20
// ranger interface inspired by this blog post:
// https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux
//
// CIDR blocks are stored using a prefix tree structure where each node has its
// parent as prefix, and the path from the root node represents current CIDR
// block.
//
// For IPv4, the trie structure guarantees max depth of 32 as IPv4 addresses are
// 32 bits long and each bit represents a prefix tree starting at that bit. This
Yulin Chen's avatar
Yulin Chen committed
21
// property also guarantees constant lookup time in Big-O notation.
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
//
// Path compression compresses a string of node with only 1 child into a single
// node, decrease the amount of lookups necessary during containment tests.
//
// Level compression dictates the amount of direct children of a node by
// allowing it to handle multiple bits in the path.  The heuristic (based on
// children population) to decide when the compression and decompression happens
// is outlined in the prior linked blog, and will be experimented with in more
// depth in this project in the future.
//
// Note: Can not insert both IPv4 and IPv6 network addresses into the same
// prefix trie, use versionedRanger wrapper instead.
//
// TODO: Implement level-compressed component of the LPC trie.
type prefixTrie struct {
	parent   *prefixTrie
Steven Allen's avatar
Steven Allen committed
38
	children [2]*prefixTrie
39

40 41
	numBitsSkipped uint
	numBitsHandled uint
42

43 44
	network rnet.Network
	entry   RangerEntry
Yulin Chen's avatar
Yulin Chen committed
45 46

	size int // This is only maintained in the root trie.
47 48
}

Steven Allen's avatar
Steven Allen committed
49 50 51 52 53 54 55 56 57 58 59 60 61
var ip4ZeroCIDR, ip6ZeroCIDR net.IPNet

func init() {
	_, v4, _ := net.ParseCIDR("0.0.0.0/0")
	_, v6, _ := net.ParseCIDR("0::0/0")
	ip4ZeroCIDR = *v4
	ip6ZeroCIDR = *v6
}

func newRanger(version rnet.IPVersion) Ranger {
	return newPrefixTree(version)
}

62
// newPrefixTree creates a new prefixTrie.
Steven Allen's avatar
Steven Allen committed
63 64
func newPrefixTree(version rnet.IPVersion) *prefixTrie {
	rootNet := ip4ZeroCIDR
65
	if version == rnet.IPv6 {
Steven Allen's avatar
Steven Allen committed
66
		rootNet = ip6ZeroCIDR
67 68
	}
	return &prefixTrie{
69 70
		numBitsSkipped: 0,
		numBitsHandled: 1,
Steven Allen's avatar
Steven Allen committed
71
		network:        rnet.NewNetwork(rootNet),
72 73 74
	}
}

75 76 77 78 79
func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie {
	version := rnet.IPv4
	if len(network.Number) == rnet.IPv6Uint32Count {
		version = rnet.IPv6
	}
Steven Allen's avatar
Steven Allen committed
80
	path := newPrefixTree(version)
81
	path.numBitsSkipped = numBitsSkipped
82
	path.network = network.Masked(int(numBitsSkipped))
Yulin Chen's avatar
Yulin Chen committed
83
	return path
84 85
}

86
func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie {
87
	leaf := newPathprefixTrie(network, uint(network.Mask))
88
	leaf.entry = entry
Yulin Chen's avatar
Yulin Chen committed
89
	return leaf
90 91
}

92 93 94
// Insert inserts a RangerEntry into prefix trie.
func (p *prefixTrie) Insert(entry RangerEntry) error {
	network := entry.Network()
Yulin Chen's avatar
Yulin Chen committed
95 96 97 98 99
	sizeIncreased, err := p.insert(rnet.NewNetwork(network), entry)
	if sizeIncreased {
		p.size++
	}
	return err
100 101
}

102 103
// Remove removes RangerEntry identified by given network from trie.
func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) {
Yulin Chen's avatar
Yulin Chen committed
104 105 106 107 108
	entry, err := p.remove(rnet.NewNetwork(network))
	if entry != nil {
		p.size--
	}
	return entry, err
109 110 111 112
}

// Contains returns boolean indicating whether given ip is contained in any
// of the inserted networks.
113
func (p *prefixTrie) Contains(ip net.IP) (bool, error) {
114 115
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
116
		return false, ErrInvalidNetworkNumberInput
117
	}
118
	return p.contains(nn)
119 120
}

121 122 123
// ContainingNetworks returns the list of RangerEntry(s) the given ip is
// contained in in ascending prefix order.
func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
124 125
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
126
		return nil, ErrInvalidNetworkNumberInput
127
	}
128
	return p.containingNetworks(nn)
129 130
}

Rob Adams's avatar
Rob Adams committed
131 132 133 134 135 136 137 138
// CoveredNetworks returns the list of RangerEntry(s) the given ipnet
// covers.  That is, the networks that are completely subsumed by the
// specified network.
func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
	net := rnet.NewNetwork(network)
	return p.coveredNetworks(net)
}

Yulin Chen's avatar
Yulin Chen committed
139 140 141 142 143
// Len returns number of networks in ranger.
func (p *prefixTrie) Len() int {
	return p.size
}

144 145
// String returns string representation of trie, mainly for visualization and
// debugging.
146
func (p *prefixTrie) String() string {
147 148 149 150 151 152 153 154 155 156
	children := []string{}
	padding := strings.Repeat("| ", p.level()+1)
	for bits, child := range p.children {
		if child == nil {
			continue
		}
		childStr := fmt.Sprintf("\n%s%d--> %s", padding, bits, child.String())
		children = append(children, childStr)
	}
	return fmt.Sprintf("%s (target_pos:%d:has_entry:%t)%s", p.network,
157
		p.targetBitPosition(), p.hasEntry(), strings.Join(children, ""))
158 159
}

160
func (p *prefixTrie) contains(number rnet.NetworkNumber) (bool, error) {
161
	if !p.network.Contains(number) {
162 163
		return false, nil
	}
164
	if p.hasEntry() {
165 166
		return true, nil
	}
167 168 169
	if p.targetBitPosition() < 0 {
		return false, nil
	}
170
	bit, err := p.targetBitFromIP(number)
171 172 173
	if err != nil {
		return false, err
	}
174
	child := p.children[bit]
175
	if child != nil {
176
		return child.contains(number)
177 178 179 180
	}
	return false, nil
}

181 182
func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntry, error) {
	results := []RangerEntry{}
183
	if !p.network.Contains(number) {
184 185
		return results, nil
	}
186 187
	if p.hasEntry() {
		results = []RangerEntry{p.entry}
188
	}
189 190 191
	if p.targetBitPosition() < 0 {
		return results, nil
	}
192
	bit, err := p.targetBitFromIP(number)
193 194 195
	if err != nil {
		return nil, err
	}
196
	child := p.children[bit]
197
	if child != nil {
198
		ranges, err := child.containingNetworks(number)
199 200 201
		if err != nil {
			return nil, err
		}
202
		if len(ranges) > 0 {
203 204 205 206 207
			if len(results) > 0 {
				results = append(results, ranges...)
			} else {
				results = ranges
			}
208
		}
209 210 211 212
	}
	return results, nil
}

Rob Adams's avatar
Rob Adams committed
213 214
func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) {
	var results []RangerEntry
215 216 217 218 219 220 221 222 223 224
	if network.Covers(p.network) {
		for entry := range p.walkDepth() {
			results = append(results, entry)
		}
	} else if p.targetBitPosition() >= 0 {
		bit, err := p.targetBitFromIP(network.Number)
		if err != nil {
			return results, err
		}
		child := p.children[bit]
Rob Adams's avatar
Rob Adams committed
225
		if child != nil {
226
			return child.coveredNetworks(network)
Rob Adams's avatar
Rob Adams committed
227 228 229 230 231
		}
	}
	return results, nil
}

Yulin Chen's avatar
Yulin Chen committed
232
func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) (bool, error) {
233
	if p.network.Equal(network) {
Yulin Chen's avatar
Yulin Chen committed
234
		sizeIncreased := p.entry == nil
235
		p.entry = entry
Yulin Chen's avatar
Yulin Chen committed
236
		return sizeIncreased, nil
237
	}
238

239
	bit, err := p.targetBitFromIP(network.Number)
240
	if err != nil {
Yulin Chen's avatar
Yulin Chen committed
241
		return false, err
242
	}
243
	existingChild := p.children[bit]
244

245 246 247
	// No existing child, insert new leaf trie.
	if existingChild == nil {
		p.appendTrie(bit, newEntryTrie(network, entry))
Yulin Chen's avatar
Yulin Chen committed
248
		return true, nil
249
	}
250 251 252 253 254 255 256 257

	// Check whether it is necessary to insert additional path prefix between current trie and existing child,
	// in the case that inserted network diverges on its path to existing child.
	lcb, err := network.LeastCommonBitPosition(existingChild.network)
	divergingBitPos := int(lcb) - 1
	if divergingBitPos > existingChild.targetBitPosition() {
		pathPrefix := newPathprefixTrie(network, p.totalNumberOfBits()-lcb)
		err := p.insertPrefix(bit, pathPrefix, existingChild)
258
		if err != nil {
Yulin Chen's avatar
Yulin Chen committed
259
			return false, err
260
		}
261 262
		// Update new child
		existingChild = pathPrefix
263
	}
264
	return existingChild.insert(network, entry)
265 266
}

267 268
func (p *prefixTrie) appendTrie(bit uint32, prefix *prefixTrie) {
	p.children[bit] = prefix
269
	prefix.parent = p
270 271 272 273 274 275 276 277 278 279 280 281 282 283
}

func (p *prefixTrie) insertPrefix(bit uint32, pathPrefix, child *prefixTrie) error {
	// Set parent/child relationship between current trie and inserted pathPrefix
	p.children[bit] = pathPrefix
	pathPrefix.parent = p

	// Set parent/child relationship between inserted pathPrefix and original child
	pathPrefixBit, err := pathPrefix.targetBitFromIP(child.network.Number)
	if err != nil {
		return err
	}
	pathPrefix.children[pathPrefixBit] = child
	child.parent = pathPrefix
284 285 286
	return nil
}

287 288 289
func (p *prefixTrie) remove(network rnet.Network) (RangerEntry, error) {
	if p.hasEntry() && p.network.Equal(network) {
		entry := p.entry
290 291 292 293 294
		p.entry = nil

		err := p.compressPathIfPossible()
		if err != nil {
			return nil, err
295
		}
296
		return entry, nil
297 298 299 300 301 302 303 304 305 306
	}
	bit, err := p.targetBitFromIP(network.Number)
	if err != nil {
		return nil, err
	}
	child := p.children[bit]
	if child != nil {
		return child.remove(network)
	}
	return nil, nil
307 308
}

309 310 311 312 313
func (p *prefixTrie) qualifiesForPathCompression() bool {
	// Current prefix trie can be path compressed if it meets all following.
	//		1. records no CIDR entry
	//		2. has single or no child
	//		3. is not root trie
314
	return !p.hasEntry() && p.childrenCount() <= 1 && p.parent != nil
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
}

func (p *prefixTrie) compressPathIfPossible() error {
	if !p.qualifiesForPathCompression() {
		// Does not qualify to be compressed
		return nil
	}

	// Find lone child.
	var loneChild *prefixTrie
	for _, child := range p.children {
		if child != nil {
			loneChild = child
			break
		}
	}

	// Find root of currnt single child lineage.
	parent := p.parent
	for ; parent.qualifiesForPathCompression(); parent = parent.parent {
	}
	parentBit, err := parent.targetBitFromIP(p.network.Number)
	if err != nil {
		return err
	}
	parent.children[parentBit] = loneChild

	// Attempts to furthur apply path compression at current lineage parent, in case current lineage
	// compressed into parent.
	return parent.compressPathIfPossible()
}

347
func (p *prefixTrie) childrenCount() int {
Yulin Chen's avatar
Yulin Chen committed
348 349 350 351 352 353 354 355 356
	count := 0
	for _, child := range p.children {
		if child != nil {
			count++
		}
	}
	return count
}

357 358 359 360
func (p *prefixTrie) totalNumberOfBits() uint {
	return rnet.BitsPerUint32 * uint(len(p.network.Number))
}

361 362
func (p *prefixTrie) targetBitPosition() int {
	return int(p.totalNumberOfBits()-p.numBitsSkipped) - 1
363 364
}

365
func (p *prefixTrie) targetBitFromIP(n rnet.NetworkNumber) (uint32, error) {
366 367 368
	// This is a safe uint boxing of int since we should never attempt to get
	// target bit at a negative position.
	return n.Bit(uint(p.targetBitPosition()))
369 370
}

371 372 373 374
func (p *prefixTrie) hasEntry() bool {
	return p.entry != nil
}

375
func (p *prefixTrie) level() int {
376 377 378 379 380 381 382
	if p.parent == nil {
		return 0
	}
	return p.parent.level() + 1
}

// walkDepth walks the trie in depth order, for unit testing.
383 384
func (p *prefixTrie) walkDepth() <-chan RangerEntry {
	entries := make(chan RangerEntry)
385
	go func() {
386 387
		if p.hasEntry() {
			entries <- p.entry
388
		}
389
		childEntriesList := []<-chan RangerEntry{}
390 391 392 393
		for _, trie := range p.children {
			if trie == nil {
				continue
			}
394
			childEntriesList = append(childEntriesList, trie.walkDepth())
395
		}
396 397 398
		for _, childEntries := range childEntriesList {
			for entry := range childEntries {
				entries <- entry
399 400
			}
		}
401
		close(entries)
402
	}()
403
	return entries
404
}