trie.go 9.96 KB
Newer Older
1
package cidranger
2 3 4 5 6 7

import (
	"fmt"
	"net"
	"strings"

8
	rnet "github.com/yl2chen/cidranger/net"
9 10
)

11
// prefixTrie is a path-compressed (PC) trie implementation of the
12 13 14 15 16 17 18 19 20
// ranger interface inspired by this blog post:
// https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux
//
// CIDR blocks are stored using a prefix tree structure where each node has its
// parent as prefix, and the path from the root node represents current CIDR
// block.
//
// For IPv4, the trie structure guarantees max depth of 32 as IPv4 addresses are
// 32 bits long and each bit represents a prefix tree starting at that bit. This
Yulin Chen's avatar
Yulin Chen committed
21
// property also guarantees constant lookup time in Big-O notation.
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
//
// Path compression compresses a string of node with only 1 child into a single
// node, decrease the amount of lookups necessary during containment tests.
//
// Level compression dictates the amount of direct children of a node by
// allowing it to handle multiple bits in the path.  The heuristic (based on
// children population) to decide when the compression and decompression happens
// is outlined in the prior linked blog, and will be experimented with in more
// depth in this project in the future.
//
// Note: Can not insert both IPv4 and IPv6 network addresses into the same
// prefix trie, use versionedRanger wrapper instead.
//
// TODO: Implement level-compressed component of the LPC trie.
type prefixTrie struct {
	parent   *prefixTrie
	children []*prefixTrie
39

40 41
	numBitsSkipped uint
	numBitsHandled uint
42

43 44
	network rnet.Network
	entry   RangerEntry
45 46
}

47 48
// newPrefixTree creates a new prefixTrie.
func newPrefixTree(version rnet.IPVersion) Ranger {
49
	_, rootNet, _ := net.ParseCIDR("0.0.0.0/0")
50 51 52 53 54
	if version == rnet.IPv6 {
		_, rootNet, _ = net.ParseCIDR("0::0/0")
	}
	return &prefixTrie{
		children:       make([]*prefixTrie, 2, 2),
55 56
		numBitsSkipped: 0,
		numBitsHandled: 1,
57
		network:        rnet.NewNetwork(*rootNet),
58 59 60
	}
}

61 62 63 64 65 66
func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie {
	version := rnet.IPv4
	if len(network.Number) == rnet.IPv6Uint32Count {
		version = rnet.IPv6
	}
	path := newPrefixTree(version).(*prefixTrie)
67
	path.numBitsSkipped = numBitsSkipped
68
	path.network = network.Masked(int(numBitsSkipped))
Yulin Chen's avatar
Yulin Chen committed
69
	return path
70 71
}

72
func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie {
73
	ones, _ := network.IPNet.Mask.Size()
74
	leaf := newPathprefixTrie(network, uint(ones))
75
	leaf.entry = entry
Yulin Chen's avatar
Yulin Chen committed
76
	return leaf
77 78
}

79 80 81 82
// Insert inserts a RangerEntry into prefix trie.
func (p *prefixTrie) Insert(entry RangerEntry) error {
	network := entry.Network()
	return p.insert(rnet.NewNetwork(network), entry)
83 84
}

85 86
// Remove removes RangerEntry identified by given network from trie.
func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) {
87
	return p.remove(rnet.NewNetwork(network))
88 89 90 91
}

// Contains returns boolean indicating whether given ip is contained in any
// of the inserted networks.
92
func (p *prefixTrie) Contains(ip net.IP) (bool, error) {
93 94
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
95
		return false, ErrInvalidNetworkNumberInput
96
	}
97
	return p.contains(nn)
98 99
}

100 101 102
// ContainingNetworks returns the list of RangerEntry(s) the given ip is
// contained in in ascending prefix order.
func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
103 104
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
105
		return nil, ErrInvalidNetworkNumberInput
106
	}
107
	return p.containingNetworks(nn)
108 109
}

Rob Adams's avatar
Rob Adams committed
110 111 112 113 114 115 116 117
// CoveredNetworks returns the list of RangerEntry(s) the given ipnet
// covers.  That is, the networks that are completely subsumed by the
// specified network.
func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
	net := rnet.NewNetwork(network)
	return p.coveredNetworks(net)
}

118 119
// String returns string representation of trie, mainly for visualization and
// debugging.
120
func (p *prefixTrie) String() string {
121 122 123 124 125 126 127 128 129 130
	children := []string{}
	padding := strings.Repeat("| ", p.level()+1)
	for bits, child := range p.children {
		if child == nil {
			continue
		}
		childStr := fmt.Sprintf("\n%s%d--> %s", padding, bits, child.String())
		children = append(children, childStr)
	}
	return fmt.Sprintf("%s (target_pos:%d:has_entry:%t)%s", p.network,
131
		p.targetBitPosition(), p.hasEntry(), strings.Join(children, ""))
132 133
}

134
func (p *prefixTrie) contains(number rnet.NetworkNumber) (bool, error) {
135
	if !p.network.Contains(number) {
136 137
		return false, nil
	}
138
	if p.hasEntry() {
139 140
		return true, nil
	}
141 142 143
	if p.targetBitPosition() < 0 {
		return false, nil
	}
144
	bit, err := p.targetBitFromIP(number)
145 146 147
	if err != nil {
		return false, err
	}
148
	child := p.children[bit]
149
	if child != nil {
150
		return child.contains(number)
151 152 153 154
	}
	return false, nil
}

155 156
func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntry, error) {
	results := []RangerEntry{}
157
	if !p.network.Contains(number) {
158 159
		return results, nil
	}
160 161
	if p.hasEntry() {
		results = []RangerEntry{p.entry}
162
	}
163 164 165
	if p.targetBitPosition() < 0 {
		return results, nil
	}
166
	bit, err := p.targetBitFromIP(number)
167 168 169
	if err != nil {
		return nil, err
	}
170
	child := p.children[bit]
171
	if child != nil {
172
		ranges, err := child.containingNetworks(number)
173 174 175
		if err != nil {
			return nil, err
		}
176
		if len(ranges) > 0 {
177 178 179 180 181
			if len(results) > 0 {
				results = append(results, ranges...)
			} else {
				results = ranges
			}
182
		}
183 184 185 186
	}
	return results, nil
}

Rob Adams's avatar
Rob Adams committed
187 188
func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) {
	var results []RangerEntry
189 190 191 192 193 194 195 196 197 198
	if network.Covers(p.network) {
		for entry := range p.walkDepth() {
			results = append(results, entry)
		}
	} else if p.targetBitPosition() >= 0 {
		bit, err := p.targetBitFromIP(network.Number)
		if err != nil {
			return results, err
		}
		child := p.children[bit]
Rob Adams's avatar
Rob Adams committed
199
		if child != nil {
200
			return child.coveredNetworks(network)
Rob Adams's avatar
Rob Adams committed
201 202 203 204 205
		}
	}
	return results, nil
}

206
func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) error {
207
	if p.network.Equal(network) {
208
		p.entry = entry
209 210
		return nil
	}
211

212
	bit, err := p.targetBitFromIP(network.Number)
213 214 215
	if err != nil {
		return err
	}
216
	existingChild := p.children[bit]
217

218 219 220 221
	// No existing child, insert new leaf trie.
	if existingChild == nil {
		p.appendTrie(bit, newEntryTrie(network, entry))
		return nil
222
	}
223 224 225 226 227 228 229 230

	// Check whether it is necessary to insert additional path prefix between current trie and existing child,
	// in the case that inserted network diverges on its path to existing child.
	lcb, err := network.LeastCommonBitPosition(existingChild.network)
	divergingBitPos := int(lcb) - 1
	if divergingBitPos > existingChild.targetBitPosition() {
		pathPrefix := newPathprefixTrie(network, p.totalNumberOfBits()-lcb)
		err := p.insertPrefix(bit, pathPrefix, existingChild)
231 232 233
		if err != nil {
			return err
		}
234 235
		// Update new child
		existingChild = pathPrefix
236
	}
237
	return existingChild.insert(network, entry)
238 239
}

240 241
func (p *prefixTrie) appendTrie(bit uint32, prefix *prefixTrie) {
	p.children[bit] = prefix
242
	prefix.parent = p
243 244 245 246 247 248 249 250 251 252 253 254 255 256
}

func (p *prefixTrie) insertPrefix(bit uint32, pathPrefix, child *prefixTrie) error {
	// Set parent/child relationship between current trie and inserted pathPrefix
	p.children[bit] = pathPrefix
	pathPrefix.parent = p

	// Set parent/child relationship between inserted pathPrefix and original child
	pathPrefixBit, err := pathPrefix.targetBitFromIP(child.network.Number)
	if err != nil {
		return err
	}
	pathPrefix.children[pathPrefixBit] = child
	child.parent = pathPrefix
257 258 259
	return nil
}

260 261 262
func (p *prefixTrie) remove(network rnet.Network) (RangerEntry, error) {
	if p.hasEntry() && p.network.Equal(network) {
		entry := p.entry
263 264 265 266 267
		p.entry = nil

		err := p.compressPathIfPossible()
		if err != nil {
			return nil, err
268
		}
269
		return entry, nil
270 271 272 273 274 275 276 277 278 279
	}
	bit, err := p.targetBitFromIP(network.Number)
	if err != nil {
		return nil, err
	}
	child := p.children[bit]
	if child != nil {
		return child.remove(network)
	}
	return nil, nil
280 281
}

282 283 284 285 286
func (p *prefixTrie) qualifiesForPathCompression() bool {
	// Current prefix trie can be path compressed if it meets all following.
	//		1. records no CIDR entry
	//		2. has single or no child
	//		3. is not root trie
287
	return !p.hasEntry() && p.childrenCount() <= 1 && p.parent != nil
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
}

func (p *prefixTrie) compressPathIfPossible() error {
	if !p.qualifiesForPathCompression() {
		// Does not qualify to be compressed
		return nil
	}

	// Find lone child.
	var loneChild *prefixTrie
	for _, child := range p.children {
		if child != nil {
			loneChild = child
			break
		}
	}

	// Find root of currnt single child lineage.
	parent := p.parent
	for ; parent.qualifiesForPathCompression(); parent = parent.parent {
	}
	parentBit, err := parent.targetBitFromIP(p.network.Number)
	if err != nil {
		return err
	}
	parent.children[parentBit] = loneChild

	// Attempts to furthur apply path compression at current lineage parent, in case current lineage
	// compressed into parent.
	return parent.compressPathIfPossible()
}

320
func (p *prefixTrie) childrenCount() int {
Yulin Chen's avatar
Yulin Chen committed
321 322 323 324 325 326 327 328 329
	count := 0
	for _, child := range p.children {
		if child != nil {
			count++
		}
	}
	return count
}

330 331 332 333
func (p *prefixTrie) totalNumberOfBits() uint {
	return rnet.BitsPerUint32 * uint(len(p.network.Number))
}

334 335
func (p *prefixTrie) targetBitPosition() int {
	return int(p.totalNumberOfBits()-p.numBitsSkipped) - 1
336 337
}

338
func (p *prefixTrie) targetBitFromIP(n rnet.NetworkNumber) (uint32, error) {
339 340 341
	// This is a safe uint boxing of int since we should never attempt to get
	// target bit at a negative position.
	return n.Bit(uint(p.targetBitPosition()))
342 343
}

344 345 346 347
func (p *prefixTrie) hasEntry() bool {
	return p.entry != nil
}

348
func (p *prefixTrie) level() int {
349 350 351 352 353 354 355
	if p.parent == nil {
		return 0
	}
	return p.parent.level() + 1
}

// walkDepth walks the trie in depth order, for unit testing.
356 357
func (p *prefixTrie) walkDepth() <-chan RangerEntry {
	entries := make(chan RangerEntry)
358
	go func() {
359 360
		if p.hasEntry() {
			entries <- p.entry
361
		}
362
		childEntriesList := []<-chan RangerEntry{}
363 364 365 366
		for _, trie := range p.children {
			if trie == nil {
				continue
			}
367
			childEntriesList = append(childEntriesList, trie.walkDepth())
368
		}
369 370 371
		for _, childEntries := range childEntriesList {
			for entry := range childEntries {
				entries <- entry
372 373
			}
		}
374
		close(entries)
375
	}()
376
	return entries
377
}