trie.go 10.3 KB
Newer Older
1
package cidranger
2 3 4 5 6 7

import (
	"fmt"
	"net"
	"strings"

Steven Allen's avatar
Steven Allen committed
8
	rnet "github.com/libp2p/go-cidranger/net"
9 10
)

11
// prefixTrie is a path-compressed (PC) trie implementation of the
12 13 14 15 16 17 18 19 20
// ranger interface inspired by this blog post:
// https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux
//
// CIDR blocks are stored using a prefix tree structure where each node has its
// parent as prefix, and the path from the root node represents current CIDR
// block.
//
// For IPv4, the trie structure guarantees max depth of 32 as IPv4 addresses are
// 32 bits long and each bit represents a prefix tree starting at that bit. This
Yulin Chen's avatar
Yulin Chen committed
21
// property also guarantees constant lookup time in Big-O notation.
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
//
// Path compression compresses a string of node with only 1 child into a single
// node, decrease the amount of lookups necessary during containment tests.
//
// Level compression dictates the amount of direct children of a node by
// allowing it to handle multiple bits in the path.  The heuristic (based on
// children population) to decide when the compression and decompression happens
// is outlined in the prior linked blog, and will be experimented with in more
// depth in this project in the future.
//
// Note: Can not insert both IPv4 and IPv6 network addresses into the same
// prefix trie, use versionedRanger wrapper instead.
//
// TODO: Implement level-compressed component of the LPC trie.
type prefixTrie struct {
	parent   *prefixTrie
	children []*prefixTrie
39

40 41
	numBitsSkipped uint
	numBitsHandled uint
42

43 44
	network rnet.Network
	entry   RangerEntry
Yulin Chen's avatar
Yulin Chen committed
45 46

	size int // This is only maintained in the root trie.
47 48
}

49 50
// newPrefixTree creates a new prefixTrie.
func newPrefixTree(version rnet.IPVersion) Ranger {
51
	_, rootNet, _ := net.ParseCIDR("0.0.0.0/0")
52 53 54 55 56
	if version == rnet.IPv6 {
		_, rootNet, _ = net.ParseCIDR("0::0/0")
	}
	return &prefixTrie{
		children:       make([]*prefixTrie, 2, 2),
57 58
		numBitsSkipped: 0,
		numBitsHandled: 1,
59
		network:        rnet.NewNetwork(*rootNet),
60 61 62
	}
}

63 64 65 66 67 68
func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie {
	version := rnet.IPv4
	if len(network.Number) == rnet.IPv6Uint32Count {
		version = rnet.IPv6
	}
	path := newPrefixTree(version).(*prefixTrie)
69
	path.numBitsSkipped = numBitsSkipped
70
	path.network = network.Masked(int(numBitsSkipped))
Yulin Chen's avatar
Yulin Chen committed
71
	return path
72 73
}

74
func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie {
75
	ones, _ := network.IPNet.Mask.Size()
76
	leaf := newPathprefixTrie(network, uint(ones))
77
	leaf.entry = entry
Yulin Chen's avatar
Yulin Chen committed
78
	return leaf
79 80
}

81 82 83
// Insert inserts a RangerEntry into prefix trie.
func (p *prefixTrie) Insert(entry RangerEntry) error {
	network := entry.Network()
Yulin Chen's avatar
Yulin Chen committed
84 85 86 87 88
	sizeIncreased, err := p.insert(rnet.NewNetwork(network), entry)
	if sizeIncreased {
		p.size++
	}
	return err
89 90
}

91 92
// Remove removes RangerEntry identified by given network from trie.
func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) {
Yulin Chen's avatar
Yulin Chen committed
93 94 95 96 97
	entry, err := p.remove(rnet.NewNetwork(network))
	if entry != nil {
		p.size--
	}
	return entry, err
98 99 100 101
}

// Contains returns boolean indicating whether given ip is contained in any
// of the inserted networks.
102
func (p *prefixTrie) Contains(ip net.IP) (bool, error) {
103 104
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
105
		return false, ErrInvalidNetworkNumberInput
106
	}
107
	return p.contains(nn)
108 109
}

110 111 112
// ContainingNetworks returns the list of RangerEntry(s) the given ip is
// contained in in ascending prefix order.
func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
113 114
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
115
		return nil, ErrInvalidNetworkNumberInput
116
	}
117
	return p.containingNetworks(nn)
118 119
}

Rob Adams's avatar
Rob Adams committed
120 121 122 123 124 125 126 127
// CoveredNetworks returns the list of RangerEntry(s) the given ipnet
// covers.  That is, the networks that are completely subsumed by the
// specified network.
func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
	net := rnet.NewNetwork(network)
	return p.coveredNetworks(net)
}

Yulin Chen's avatar
Yulin Chen committed
128 129 130 131 132
// Len returns number of networks in ranger.
func (p *prefixTrie) Len() int {
	return p.size
}

133 134
// String returns string representation of trie, mainly for visualization and
// debugging.
135
func (p *prefixTrie) String() string {
136 137 138 139 140 141 142 143 144 145
	children := []string{}
	padding := strings.Repeat("| ", p.level()+1)
	for bits, child := range p.children {
		if child == nil {
			continue
		}
		childStr := fmt.Sprintf("\n%s%d--> %s", padding, bits, child.String())
		children = append(children, childStr)
	}
	return fmt.Sprintf("%s (target_pos:%d:has_entry:%t)%s", p.network,
146
		p.targetBitPosition(), p.hasEntry(), strings.Join(children, ""))
147 148
}

149
func (p *prefixTrie) contains(number rnet.NetworkNumber) (bool, error) {
150
	if !p.network.Contains(number) {
151 152
		return false, nil
	}
153
	if p.hasEntry() {
154 155
		return true, nil
	}
156 157 158
	if p.targetBitPosition() < 0 {
		return false, nil
	}
159
	bit, err := p.targetBitFromIP(number)
160 161 162
	if err != nil {
		return false, err
	}
163
	child := p.children[bit]
164
	if child != nil {
165
		return child.contains(number)
166 167 168 169
	}
	return false, nil
}

170 171
func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntry, error) {
	results := []RangerEntry{}
172
	if !p.network.Contains(number) {
173 174
		return results, nil
	}
175 176
	if p.hasEntry() {
		results = []RangerEntry{p.entry}
177
	}
178 179 180
	if p.targetBitPosition() < 0 {
		return results, nil
	}
181
	bit, err := p.targetBitFromIP(number)
182 183 184
	if err != nil {
		return nil, err
	}
185
	child := p.children[bit]
186
	if child != nil {
187
		ranges, err := child.containingNetworks(number)
188 189 190
		if err != nil {
			return nil, err
		}
191
		if len(ranges) > 0 {
192 193 194 195 196
			if len(results) > 0 {
				results = append(results, ranges...)
			} else {
				results = ranges
			}
197
		}
198 199 200 201
	}
	return results, nil
}

Rob Adams's avatar
Rob Adams committed
202 203
func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) {
	var results []RangerEntry
204 205 206 207 208 209 210 211 212 213
	if network.Covers(p.network) {
		for entry := range p.walkDepth() {
			results = append(results, entry)
		}
	} else if p.targetBitPosition() >= 0 {
		bit, err := p.targetBitFromIP(network.Number)
		if err != nil {
			return results, err
		}
		child := p.children[bit]
Rob Adams's avatar
Rob Adams committed
214
		if child != nil {
215
			return child.coveredNetworks(network)
Rob Adams's avatar
Rob Adams committed
216 217 218 219 220
		}
	}
	return results, nil
}

Yulin Chen's avatar
Yulin Chen committed
221
func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) (bool, error) {
222
	if p.network.Equal(network) {
Yulin Chen's avatar
Yulin Chen committed
223
		sizeIncreased := p.entry == nil
224
		p.entry = entry
Yulin Chen's avatar
Yulin Chen committed
225
		return sizeIncreased, nil
226
	}
227

228
	bit, err := p.targetBitFromIP(network.Number)
229
	if err != nil {
Yulin Chen's avatar
Yulin Chen committed
230
		return false, err
231
	}
232
	existingChild := p.children[bit]
233

234 235 236
	// No existing child, insert new leaf trie.
	if existingChild == nil {
		p.appendTrie(bit, newEntryTrie(network, entry))
Yulin Chen's avatar
Yulin Chen committed
237
		return true, nil
238
	}
239 240 241 242 243 244 245 246

	// Check whether it is necessary to insert additional path prefix between current trie and existing child,
	// in the case that inserted network diverges on its path to existing child.
	lcb, err := network.LeastCommonBitPosition(existingChild.network)
	divergingBitPos := int(lcb) - 1
	if divergingBitPos > existingChild.targetBitPosition() {
		pathPrefix := newPathprefixTrie(network, p.totalNumberOfBits()-lcb)
		err := p.insertPrefix(bit, pathPrefix, existingChild)
247
		if err != nil {
Yulin Chen's avatar
Yulin Chen committed
248
			return false, err
249
		}
250 251
		// Update new child
		existingChild = pathPrefix
252
	}
253
	return existingChild.insert(network, entry)
254 255
}

256 257
func (p *prefixTrie) appendTrie(bit uint32, prefix *prefixTrie) {
	p.children[bit] = prefix
258
	prefix.parent = p
259 260 261 262 263 264 265 266 267 268 269 270 271 272
}

func (p *prefixTrie) insertPrefix(bit uint32, pathPrefix, child *prefixTrie) error {
	// Set parent/child relationship between current trie and inserted pathPrefix
	p.children[bit] = pathPrefix
	pathPrefix.parent = p

	// Set parent/child relationship between inserted pathPrefix and original child
	pathPrefixBit, err := pathPrefix.targetBitFromIP(child.network.Number)
	if err != nil {
		return err
	}
	pathPrefix.children[pathPrefixBit] = child
	child.parent = pathPrefix
273 274 275
	return nil
}

276 277 278
func (p *prefixTrie) remove(network rnet.Network) (RangerEntry, error) {
	if p.hasEntry() && p.network.Equal(network) {
		entry := p.entry
279 280 281 282 283
		p.entry = nil

		err := p.compressPathIfPossible()
		if err != nil {
			return nil, err
284
		}
285
		return entry, nil
286 287 288 289 290 291 292 293 294 295
	}
	bit, err := p.targetBitFromIP(network.Number)
	if err != nil {
		return nil, err
	}
	child := p.children[bit]
	if child != nil {
		return child.remove(network)
	}
	return nil, nil
296 297
}

298 299 300 301 302
func (p *prefixTrie) qualifiesForPathCompression() bool {
	// Current prefix trie can be path compressed if it meets all following.
	//		1. records no CIDR entry
	//		2. has single or no child
	//		3. is not root trie
303
	return !p.hasEntry() && p.childrenCount() <= 1 && p.parent != nil
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
}

func (p *prefixTrie) compressPathIfPossible() error {
	if !p.qualifiesForPathCompression() {
		// Does not qualify to be compressed
		return nil
	}

	// Find lone child.
	var loneChild *prefixTrie
	for _, child := range p.children {
		if child != nil {
			loneChild = child
			break
		}
	}

	// Find root of currnt single child lineage.
	parent := p.parent
	for ; parent.qualifiesForPathCompression(); parent = parent.parent {
	}
	parentBit, err := parent.targetBitFromIP(p.network.Number)
	if err != nil {
		return err
	}
	parent.children[parentBit] = loneChild

	// Attempts to furthur apply path compression at current lineage parent, in case current lineage
	// compressed into parent.
	return parent.compressPathIfPossible()
}

336
func (p *prefixTrie) childrenCount() int {
Yulin Chen's avatar
Yulin Chen committed
337 338 339 340 341 342 343 344 345
	count := 0
	for _, child := range p.children {
		if child != nil {
			count++
		}
	}
	return count
}

346 347 348 349
func (p *prefixTrie) totalNumberOfBits() uint {
	return rnet.BitsPerUint32 * uint(len(p.network.Number))
}

350 351
func (p *prefixTrie) targetBitPosition() int {
	return int(p.totalNumberOfBits()-p.numBitsSkipped) - 1
352 353
}

354
func (p *prefixTrie) targetBitFromIP(n rnet.NetworkNumber) (uint32, error) {
355 356 357
	// This is a safe uint boxing of int since we should never attempt to get
	// target bit at a negative position.
	return n.Bit(uint(p.targetBitPosition()))
358 359
}

360 361 362 363
func (p *prefixTrie) hasEntry() bool {
	return p.entry != nil
}

364
func (p *prefixTrie) level() int {
365 366 367 368 369 370 371
	if p.parent == nil {
		return 0
	}
	return p.parent.level() + 1
}

// walkDepth walks the trie in depth order, for unit testing.
372 373
func (p *prefixTrie) walkDepth() <-chan RangerEntry {
	entries := make(chan RangerEntry)
374
	go func() {
375 376
		if p.hasEntry() {
			entries <- p.entry
377
		}
378
		childEntriesList := []<-chan RangerEntry{}
379 380 381 382
		for _, trie := range p.children {
			if trie == nil {
				continue
			}
383
			childEntriesList = append(childEntriesList, trie.walkDepth())
384
		}
385 386 387
		for _, childEntries := range childEntriesList {
			for entry := range childEntries {
				entries <- entry
388 389
			}
		}
390
		close(entries)
391
	}()
392
	return entries
393
}