trie.go 9.34 KB
Newer Older
1
package cidranger
2 3 4 5 6 7

import (
	"fmt"
	"net"
	"strings"

8
	rnet "github.com/yl2chen/cidranger/net"
9 10
)

11
// prefixTrie is a path-compressed (PC) trie implementation of the
12 13 14 15 16 17 18 19 20
// ranger interface inspired by this blog post:
// https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux
//
// CIDR blocks are stored using a prefix tree structure where each node has its
// parent as prefix, and the path from the root node represents current CIDR
// block.
//
// For IPv4, the trie structure guarantees max depth of 32 as IPv4 addresses are
// 32 bits long and each bit represents a prefix tree starting at that bit. This
Yulin Chen's avatar
Yulin Chen committed
21
// property also guarantees constant lookup time in Big-O notation.
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
//
// Path compression compresses a string of node with only 1 child into a single
// node, decrease the amount of lookups necessary during containment tests.
//
// Level compression dictates the amount of direct children of a node by
// allowing it to handle multiple bits in the path.  The heuristic (based on
// children population) to decide when the compression and decompression happens
// is outlined in the prior linked blog, and will be experimented with in more
// depth in this project in the future.
//
// Note: Can not insert both IPv4 and IPv6 network addresses into the same
// prefix trie, use versionedRanger wrapper instead.
//
// TODO: Implement level-compressed component of the LPC trie.
type prefixTrie struct {
	parent   *prefixTrie
	children []*prefixTrie
39

40 41
	numBitsSkipped uint
	numBitsHandled uint
42

43 44
	network rnet.Network
	entry   RangerEntry
45 46
}

47 48
// newPrefixTree creates a new prefixTrie.
func newPrefixTree(version rnet.IPVersion) Ranger {
49
	_, rootNet, _ := net.ParseCIDR("0.0.0.0/0")
50 51 52 53 54
	if version == rnet.IPv6 {
		_, rootNet, _ = net.ParseCIDR("0::0/0")
	}
	return &prefixTrie{
		children:       make([]*prefixTrie, 2, 2),
55 56
		numBitsSkipped: 0,
		numBitsHandled: 1,
57
		network:        rnet.NewNetwork(*rootNet),
58 59 60
	}
}

61 62 63 64 65 66
func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie {
	version := rnet.IPv4
	if len(network.Number) == rnet.IPv6Uint32Count {
		version = rnet.IPv6
	}
	path := newPrefixTree(version).(*prefixTrie)
67
	path.numBitsSkipped = numBitsSkipped
68
	path.network = network.Masked(int(numBitsSkipped))
Yulin Chen's avatar
Yulin Chen committed
69
	return path
70 71
}

72
func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie {
73
	ones, _ := network.IPNet.Mask.Size()
74
	leaf := newPathprefixTrie(network, uint(ones))
75
	leaf.entry = entry
Yulin Chen's avatar
Yulin Chen committed
76
	return leaf
77 78
}

79 80 81 82
// Insert inserts a RangerEntry into prefix trie.
func (p *prefixTrie) Insert(entry RangerEntry) error {
	network := entry.Network()
	return p.insert(rnet.NewNetwork(network), entry)
83 84
}

85 86
// Remove removes RangerEntry identified by given network from trie.
func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) {
87
	return p.remove(rnet.NewNetwork(network))
88 89 90 91
}

// Contains returns boolean indicating whether given ip is contained in any
// of the inserted networks.
92
func (p *prefixTrie) Contains(ip net.IP) (bool, error) {
93 94
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
95
		return false, ErrInvalidNetworkNumberInput
96
	}
97
	return p.contains(nn)
98 99
}

100 101 102
// ContainingNetworks returns the list of RangerEntry(s) the given ip is
// contained in in ascending prefix order.
func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
103 104
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
105
		return nil, ErrInvalidNetworkNumberInput
106
	}
107
	return p.containingNetworks(nn)
108 109
}

Rob Adams's avatar
Rob Adams committed
110 111 112 113 114 115 116 117
// CoveredNetworks returns the list of RangerEntry(s) the given ipnet
// covers.  That is, the networks that are completely subsumed by the
// specified network.
func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
	net := rnet.NewNetwork(network)
	return p.coveredNetworks(net)
}

118 119
// String returns string representation of trie, mainly for visualization and
// debugging.
120
func (p *prefixTrie) String() string {
121 122 123 124 125 126 127 128 129 130
	children := []string{}
	padding := strings.Repeat("| ", p.level()+1)
	for bits, child := range p.children {
		if child == nil {
			continue
		}
		childStr := fmt.Sprintf("\n%s%d--> %s", padding, bits, child.String())
		children = append(children, childStr)
	}
	return fmt.Sprintf("%s (target_pos:%d:has_entry:%t)%s", p.network,
131
		p.targetBitPosition(), p.hasEntry(), strings.Join(children, ""))
132 133
}

134
func (p *prefixTrie) contains(number rnet.NetworkNumber) (bool, error) {
135
	if !p.network.Contains(number) {
136 137
		return false, nil
	}
138
	if p.hasEntry() {
139 140
		return true, nil
	}
141 142 143
	if p.targetBitPosition() < 0 {
		return false, nil
	}
144
	bit, err := p.targetBitFromIP(number)
145 146 147
	if err != nil {
		return false, err
	}
148
	child := p.children[bit]
149
	if child != nil {
150
		return child.contains(number)
151 152 153 154
	}
	return false, nil
}

155 156
func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntry, error) {
	results := []RangerEntry{}
157
	if !p.network.Contains(number) {
158 159
		return results, nil
	}
160 161
	if p.hasEntry() {
		results = []RangerEntry{p.entry}
162
	}
163 164 165
	if p.targetBitPosition() < 0 {
		return results, nil
	}
166
	bit, err := p.targetBitFromIP(number)
167 168 169
	if err != nil {
		return nil, err
	}
170
	child := p.children[bit]
171
	if child != nil {
172
		ranges, err := child.containingNetworks(number)
173 174 175
		if err != nil {
			return nil, err
		}
176
		if len(ranges) > 0 {
177 178 179 180 181
			if len(results) > 0 {
				results = append(results, ranges...)
			} else {
				results = ranges
			}
182
		}
183 184 185 186
	}
	return results, nil
}

Rob Adams's avatar
Rob Adams committed
187 188
func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) {
	var results []RangerEntry
189 190 191 192 193 194 195 196 197 198
	if network.Covers(p.network) {
		for entry := range p.walkDepth() {
			results = append(results, entry)
		}
	} else if p.targetBitPosition() >= 0 {
		bit, err := p.targetBitFromIP(network.Number)
		if err != nil {
			return results, err
		}
		child := p.children[bit]
Rob Adams's avatar
Rob Adams committed
199
		if child != nil {
200
			return child.coveredNetworks(network)
Rob Adams's avatar
Rob Adams committed
201 202 203 204 205
		}
	}
	return results, nil
}

206
func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) error {
207
	if p.network.Equal(network) {
208
		p.entry = entry
209 210
		return nil
	}
211
	bit, err := p.targetBitFromIP(network.Number)
212 213 214
	if err != nil {
		return err
	}
215
	child := p.children[bit]
216
	if child == nil {
217
		return p.insertPrefix(bit, newEntryTrie(network, entry))
218 219
	}

220
	lcb, err := network.LeastCommonBitPosition(child.network)
221 222 223
	if err != nil {
		return err
	}
224
	if int(lcb) > child.targetBitPosition()+1 {
225
		child = newPathprefixTrie(network, p.totalNumberOfBits()-lcb)
226
		err := p.insertPrefix(bit, child)
227 228 229 230
		if err != nil {
			return err
		}
	}
231
	return child.insert(network, entry)
232 233
}

234
func (p *prefixTrie) insertPrefix(bits uint32, prefix *prefixTrie) error {
235 236
	child := p.children[bits]
	if child != nil {
237
		prefixBit, err := prefix.targetBitFromIP(child.network.Number)
238 239 240
		if err != nil {
			return err
		}
241
		prefix.insertPrefix(prefixBit, child)
242 243 244 245 246 247
	}
	p.children[bits] = prefix
	prefix.parent = p
	return nil
}

248 249 250
func (p *prefixTrie) remove(network rnet.Network) (RangerEntry, error) {
	if p.hasEntry() && p.network.Equal(network) {
		entry := p.entry
251 252 253 254 255
		p.entry = nil

		err := p.compressPathIfPossible()
		if err != nil {
			return nil, err
256
		}
257
		return entry, nil
258 259 260 261 262 263 264 265 266 267
	}
	bit, err := p.targetBitFromIP(network.Number)
	if err != nil {
		return nil, err
	}
	child := p.children[bit]
	if child != nil {
		return child.remove(network)
	}
	return nil, nil
268 269
}

270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
func (p *prefixTrie) qualifiesForPathCompression() bool {
	// Current prefix trie can be path compressed if it meets all following.
	//		1. records no CIDR entry
	//		2. has single or no child
	//		3. is not root trie
	return !p.hasEntry() && p.childrenCount() <= 1 && p.parent != nil;
}

func (p *prefixTrie) compressPathIfPossible() error {
	if !p.qualifiesForPathCompression() {
		// Does not qualify to be compressed
		return nil
	}

	// Find lone child.
	var loneChild *prefixTrie
	for _, child := range p.children {
		if child != nil {
			loneChild = child
			break
		}
	}

	// Find root of currnt single child lineage.
	parent := p.parent
	for ; parent.qualifiesForPathCompression(); parent = parent.parent {
	}
	parentBit, err := parent.targetBitFromIP(p.network.Number)
	if err != nil {
		return err
	}
	parent.children[parentBit] = loneChild

	// Attempts to furthur apply path compression at current lineage parent, in case current lineage
	// compressed into parent.
	return parent.compressPathIfPossible()
}

308
func (p *prefixTrie) childrenCount() int {
Yulin Chen's avatar
Yulin Chen committed
309 310 311 312 313 314 315 316 317
	count := 0
	for _, child := range p.children {
		if child != nil {
			count++
		}
	}
	return count
}

318 319 320 321
func (p *prefixTrie) totalNumberOfBits() uint {
	return rnet.BitsPerUint32 * uint(len(p.network.Number))
}

322 323
func (p *prefixTrie) targetBitPosition() int {
	return int(p.totalNumberOfBits()-p.numBitsSkipped) - 1
324 325
}

326
func (p *prefixTrie) targetBitFromIP(n rnet.NetworkNumber) (uint32, error) {
327 328 329
	// This is a safe uint boxing of int since we should never attempt to get
	// target bit at a negative position.
	return n.Bit(uint(p.targetBitPosition()))
330 331
}

332 333 334 335
func (p *prefixTrie) hasEntry() bool {
	return p.entry != nil
}

336
func (p *prefixTrie) level() int {
337 338 339 340 341 342 343
	if p.parent == nil {
		return 0
	}
	return p.parent.level() + 1
}

// walkDepth walks the trie in depth order, for unit testing.
344 345
func (p *prefixTrie) walkDepth() <-chan RangerEntry {
	entries := make(chan RangerEntry)
346
	go func() {
347 348
		if p.hasEntry() {
			entries <- p.entry
349
		}
350
		childEntriesList := []<-chan RangerEntry{}
351 352 353 354
		for _, trie := range p.children {
			if trie == nil {
				continue
			}
355
			childEntriesList = append(childEntriesList, trie.walkDepth())
356
		}
357 358 359
		for _, childEntries := range childEntriesList {
			for entry := range childEntries {
				entries <- entry
360 361
			}
		}
362
		close(entries)
363
	}()
364
	return entries
365
}