trie.go 10.5 KB
Newer Older
1
package cidranger
2 3 4 5 6 7

import (
	"fmt"
	"net"
	"strings"

Steven Allen's avatar
Steven Allen committed
8
	rnet "github.com/libp2p/go-cidranger/net"
9 10
)

11
// prefixTrie is a path-compressed (PC) trie implementation of the
12 13 14 15 16 17 18 19 20
// ranger interface inspired by this blog post:
// https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux
//
// CIDR blocks are stored using a prefix tree structure where each node has its
// parent as prefix, and the path from the root node represents current CIDR
// block.
//
// For IPv4, the trie structure guarantees max depth of 32 as IPv4 addresses are
// 32 bits long and each bit represents a prefix tree starting at that bit. This
Yulin Chen's avatar
Yulin Chen committed
21
// property also guarantees constant lookup time in Big-O notation.
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
//
// Path compression compresses a string of node with only 1 child into a single
// node, decrease the amount of lookups necessary during containment tests.
//
// Level compression dictates the amount of direct children of a node by
// allowing it to handle multiple bits in the path.  The heuristic (based on
// children population) to decide when the compression and decompression happens
// is outlined in the prior linked blog, and will be experimented with in more
// depth in this project in the future.
//
// Note: Can not insert both IPv4 and IPv6 network addresses into the same
// prefix trie, use versionedRanger wrapper instead.
//
// TODO: Implement level-compressed component of the LPC trie.
type prefixTrie struct {
	parent   *prefixTrie
Steven Allen's avatar
Steven Allen committed
38
	children [2]*prefixTrie
39

40 41
	numBitsSkipped uint
	numBitsHandled uint
42

43 44
	network rnet.Network
	entry   RangerEntry
Yulin Chen's avatar
Yulin Chen committed
45 46

	size int // This is only maintained in the root trie.
47 48
}

Steven Allen's avatar
Steven Allen committed
49 50 51 52 53 54 55 56 57 58 59 60 61
var ip4ZeroCIDR, ip6ZeroCIDR net.IPNet

func init() {
	_, v4, _ := net.ParseCIDR("0.0.0.0/0")
	_, v6, _ := net.ParseCIDR("0::0/0")
	ip4ZeroCIDR = *v4
	ip6ZeroCIDR = *v6
}

func newRanger(version rnet.IPVersion) Ranger {
	return newPrefixTree(version)
}

62
// newPrefixTree creates a new prefixTrie.
Steven Allen's avatar
Steven Allen committed
63 64
func newPrefixTree(version rnet.IPVersion) *prefixTrie {
	rootNet := ip4ZeroCIDR
65
	if version == rnet.IPv6 {
Steven Allen's avatar
Steven Allen committed
66
		rootNet = ip6ZeroCIDR
67 68
	}
	return &prefixTrie{
69 70
		numBitsSkipped: 0,
		numBitsHandled: 1,
Steven Allen's avatar
Steven Allen committed
71
		network:        rnet.NewNetwork(rootNet),
72 73 74
	}
}

75 76 77 78 79
func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie {
	version := rnet.IPv4
	if len(network.Number) == rnet.IPv6Uint32Count {
		version = rnet.IPv6
	}
Steven Allen's avatar
Steven Allen committed
80
	path := newPrefixTree(version)
81
	path.numBitsSkipped = numBitsSkipped
82
	path.network = network.Masked(int(numBitsSkipped))
Yulin Chen's avatar
Yulin Chen committed
83
	return path
84 85
}

86
func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie {
87
	ones, _ := network.IPNet.Mask.Size()
88
	leaf := newPathprefixTrie(network, uint(ones))
89
	leaf.entry = entry
Yulin Chen's avatar
Yulin Chen committed
90
	return leaf
91 92
}

93 94 95
// Insert inserts a RangerEntry into prefix trie.
func (p *prefixTrie) Insert(entry RangerEntry) error {
	network := entry.Network()
Yulin Chen's avatar
Yulin Chen committed
96 97 98 99 100
	sizeIncreased, err := p.insert(rnet.NewNetwork(network), entry)
	if sizeIncreased {
		p.size++
	}
	return err
101 102
}

103 104
// Remove removes RangerEntry identified by given network from trie.
func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) {
Yulin Chen's avatar
Yulin Chen committed
105 106 107 108 109
	entry, err := p.remove(rnet.NewNetwork(network))
	if entry != nil {
		p.size--
	}
	return entry, err
110 111 112 113
}

// Contains returns boolean indicating whether given ip is contained in any
// of the inserted networks.
114
func (p *prefixTrie) Contains(ip net.IP) (bool, error) {
115 116
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
117
		return false, ErrInvalidNetworkNumberInput
118
	}
119
	return p.contains(nn)
120 121
}

122 123 124
// ContainingNetworks returns the list of RangerEntry(s) the given ip is
// contained in in ascending prefix order.
func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
125 126
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
127
		return nil, ErrInvalidNetworkNumberInput
128
	}
129
	return p.containingNetworks(nn)
130 131
}

Rob Adams's avatar
Rob Adams committed
132 133 134 135 136 137 138 139
// CoveredNetworks returns the list of RangerEntry(s) the given ipnet
// covers.  That is, the networks that are completely subsumed by the
// specified network.
func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
	net := rnet.NewNetwork(network)
	return p.coveredNetworks(net)
}

Yulin Chen's avatar
Yulin Chen committed
140 141 142 143 144
// Len returns number of networks in ranger.
func (p *prefixTrie) Len() int {
	return p.size
}

145 146
// String returns string representation of trie, mainly for visualization and
// debugging.
147
func (p *prefixTrie) String() string {
148 149 150 151 152 153 154 155 156 157
	children := []string{}
	padding := strings.Repeat("| ", p.level()+1)
	for bits, child := range p.children {
		if child == nil {
			continue
		}
		childStr := fmt.Sprintf("\n%s%d--> %s", padding, bits, child.String())
		children = append(children, childStr)
	}
	return fmt.Sprintf("%s (target_pos:%d:has_entry:%t)%s", p.network,
158
		p.targetBitPosition(), p.hasEntry(), strings.Join(children, ""))
159 160
}

161
func (p *prefixTrie) contains(number rnet.NetworkNumber) (bool, error) {
162
	if !p.network.Contains(number) {
163 164
		return false, nil
	}
165
	if p.hasEntry() {
166 167
		return true, nil
	}
168 169 170
	if p.targetBitPosition() < 0 {
		return false, nil
	}
171
	bit, err := p.targetBitFromIP(number)
172 173 174
	if err != nil {
		return false, err
	}
175
	child := p.children[bit]
176
	if child != nil {
177
		return child.contains(number)
178 179 180 181
	}
	return false, nil
}

182 183
func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntry, error) {
	results := []RangerEntry{}
184
	if !p.network.Contains(number) {
185 186
		return results, nil
	}
187 188
	if p.hasEntry() {
		results = []RangerEntry{p.entry}
189
	}
190 191 192
	if p.targetBitPosition() < 0 {
		return results, nil
	}
193
	bit, err := p.targetBitFromIP(number)
194 195 196
	if err != nil {
		return nil, err
	}
197
	child := p.children[bit]
198
	if child != nil {
199
		ranges, err := child.containingNetworks(number)
200 201 202
		if err != nil {
			return nil, err
		}
203
		if len(ranges) > 0 {
204 205 206 207 208
			if len(results) > 0 {
				results = append(results, ranges...)
			} else {
				results = ranges
			}
209
		}
210 211 212 213
	}
	return results, nil
}

Rob Adams's avatar
Rob Adams committed
214 215
func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) {
	var results []RangerEntry
216 217 218 219 220 221 222 223 224 225
	if network.Covers(p.network) {
		for entry := range p.walkDepth() {
			results = append(results, entry)
		}
	} else if p.targetBitPosition() >= 0 {
		bit, err := p.targetBitFromIP(network.Number)
		if err != nil {
			return results, err
		}
		child := p.children[bit]
Rob Adams's avatar
Rob Adams committed
226
		if child != nil {
227
			return child.coveredNetworks(network)
Rob Adams's avatar
Rob Adams committed
228 229 230 231 232
		}
	}
	return results, nil
}

Yulin Chen's avatar
Yulin Chen committed
233
func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) (bool, error) {
234
	if p.network.Equal(network) {
Yulin Chen's avatar
Yulin Chen committed
235
		sizeIncreased := p.entry == nil
236
		p.entry = entry
Yulin Chen's avatar
Yulin Chen committed
237
		return sizeIncreased, nil
238
	}
239

240
	bit, err := p.targetBitFromIP(network.Number)
241
	if err != nil {
Yulin Chen's avatar
Yulin Chen committed
242
		return false, err
243
	}
244
	existingChild := p.children[bit]
245

246 247 248
	// No existing child, insert new leaf trie.
	if existingChild == nil {
		p.appendTrie(bit, newEntryTrie(network, entry))
Yulin Chen's avatar
Yulin Chen committed
249
		return true, nil
250
	}
251 252 253 254 255 256 257 258

	// Check whether it is necessary to insert additional path prefix between current trie and existing child,
	// in the case that inserted network diverges on its path to existing child.
	lcb, err := network.LeastCommonBitPosition(existingChild.network)
	divergingBitPos := int(lcb) - 1
	if divergingBitPos > existingChild.targetBitPosition() {
		pathPrefix := newPathprefixTrie(network, p.totalNumberOfBits()-lcb)
		err := p.insertPrefix(bit, pathPrefix, existingChild)
259
		if err != nil {
Yulin Chen's avatar
Yulin Chen committed
260
			return false, err
261
		}
262 263
		// Update new child
		existingChild = pathPrefix
264
	}
265
	return existingChild.insert(network, entry)
266 267
}

268 269
func (p *prefixTrie) appendTrie(bit uint32, prefix *prefixTrie) {
	p.children[bit] = prefix
270
	prefix.parent = p
271 272 273 274 275 276 277 278 279 280 281 282 283 284
}

func (p *prefixTrie) insertPrefix(bit uint32, pathPrefix, child *prefixTrie) error {
	// Set parent/child relationship between current trie and inserted pathPrefix
	p.children[bit] = pathPrefix
	pathPrefix.parent = p

	// Set parent/child relationship between inserted pathPrefix and original child
	pathPrefixBit, err := pathPrefix.targetBitFromIP(child.network.Number)
	if err != nil {
		return err
	}
	pathPrefix.children[pathPrefixBit] = child
	child.parent = pathPrefix
285 286 287
	return nil
}

288 289 290
func (p *prefixTrie) remove(network rnet.Network) (RangerEntry, error) {
	if p.hasEntry() && p.network.Equal(network) {
		entry := p.entry
291 292 293 294 295
		p.entry = nil

		err := p.compressPathIfPossible()
		if err != nil {
			return nil, err
296
		}
297
		return entry, nil
298 299 300 301 302 303 304 305 306 307
	}
	bit, err := p.targetBitFromIP(network.Number)
	if err != nil {
		return nil, err
	}
	child := p.children[bit]
	if child != nil {
		return child.remove(network)
	}
	return nil, nil
308 309
}

310 311 312 313 314
func (p *prefixTrie) qualifiesForPathCompression() bool {
	// Current prefix trie can be path compressed if it meets all following.
	//		1. records no CIDR entry
	//		2. has single or no child
	//		3. is not root trie
315
	return !p.hasEntry() && p.childrenCount() <= 1 && p.parent != nil
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
}

func (p *prefixTrie) compressPathIfPossible() error {
	if !p.qualifiesForPathCompression() {
		// Does not qualify to be compressed
		return nil
	}

	// Find lone child.
	var loneChild *prefixTrie
	for _, child := range p.children {
		if child != nil {
			loneChild = child
			break
		}
	}

	// Find root of currnt single child lineage.
	parent := p.parent
	for ; parent.qualifiesForPathCompression(); parent = parent.parent {
	}
	parentBit, err := parent.targetBitFromIP(p.network.Number)
	if err != nil {
		return err
	}
	parent.children[parentBit] = loneChild

	// Attempts to furthur apply path compression at current lineage parent, in case current lineage
	// compressed into parent.
	return parent.compressPathIfPossible()
}

348
func (p *prefixTrie) childrenCount() int {
Yulin Chen's avatar
Yulin Chen committed
349 350 351 352 353 354 355 356 357
	count := 0
	for _, child := range p.children {
		if child != nil {
			count++
		}
	}
	return count
}

358 359 360 361
func (p *prefixTrie) totalNumberOfBits() uint {
	return rnet.BitsPerUint32 * uint(len(p.network.Number))
}

362 363
func (p *prefixTrie) targetBitPosition() int {
	return int(p.totalNumberOfBits()-p.numBitsSkipped) - 1
364 365
}

366
func (p *prefixTrie) targetBitFromIP(n rnet.NetworkNumber) (uint32, error) {
367 368 369
	// This is a safe uint boxing of int since we should never attempt to get
	// target bit at a negative position.
	return n.Bit(uint(p.targetBitPosition()))
370 371
}

372 373 374 375
func (p *prefixTrie) hasEntry() bool {
	return p.entry != nil
}

376
func (p *prefixTrie) level() int {
377 378 379 380 381 382 383
	if p.parent == nil {
		return 0
	}
	return p.parent.level() + 1
}

// walkDepth walks the trie in depth order, for unit testing.
384 385
func (p *prefixTrie) walkDepth() <-chan RangerEntry {
	entries := make(chan RangerEntry)
386
	go func() {
387 388
		if p.hasEntry() {
			entries <- p.entry
389
		}
390
		childEntriesList := []<-chan RangerEntry{}
391 392 393 394
		for _, trie := range p.children {
			if trie == nil {
				continue
			}
395
			childEntriesList = append(childEntriesList, trie.walkDepth())
396
		}
397 398 399
		for _, childEntries := range childEntriesList {
			for entry := range childEntries {
				entries <- entry
400 401
			}
		}
402
		close(entries)
403
	}()
404
	return entries
405
}