trie.go 8.72 KB
Newer Older
1
package cidranger
2 3 4 5 6 7

import (
	"fmt"
	"net"
	"strings"

8
	rnet "github.com/yl2chen/cidranger/net"
9 10
)

11
// prefixTrie is a path-compressed (PC) trie implementation of the
12 13 14 15 16 17 18 19 20
// ranger interface inspired by this blog post:
// https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux
//
// CIDR blocks are stored using a prefix tree structure where each node has its
// parent as prefix, and the path from the root node represents current CIDR
// block.
//
// For IPv4, the trie structure guarantees max depth of 32 as IPv4 addresses are
// 32 bits long and each bit represents a prefix tree starting at that bit. This
Yulin Chen's avatar
Yulin Chen committed
21
// property also guarantees constant lookup time in Big-O notation.
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
//
// Path compression compresses a string of node with only 1 child into a single
// node, decrease the amount of lookups necessary during containment tests.
//
// Level compression dictates the amount of direct children of a node by
// allowing it to handle multiple bits in the path.  The heuristic (based on
// children population) to decide when the compression and decompression happens
// is outlined in the prior linked blog, and will be experimented with in more
// depth in this project in the future.
//
// Note: Can not insert both IPv4 and IPv6 network addresses into the same
// prefix trie, use versionedRanger wrapper instead.
//
// TODO: Implement level-compressed component of the LPC trie.
type prefixTrie struct {
	parent   *prefixTrie
	children []*prefixTrie
39

40 41
	numBitsSkipped uint
	numBitsHandled uint
42

43 44
	network rnet.Network
	entry   RangerEntry
45 46
}

47 48
// newPrefixTree creates a new prefixTrie.
func newPrefixTree(version rnet.IPVersion) Ranger {
49
	_, rootNet, _ := net.ParseCIDR("0.0.0.0/0")
50 51 52 53 54
	if version == rnet.IPv6 {
		_, rootNet, _ = net.ParseCIDR("0::0/0")
	}
	return &prefixTrie{
		children:       make([]*prefixTrie, 2, 2),
55 56
		numBitsSkipped: 0,
		numBitsHandled: 1,
57
		network:        rnet.NewNetwork(*rootNet),
58 59 60
	}
}

61 62 63 64 65 66
func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie {
	version := rnet.IPv4
	if len(network.Number) == rnet.IPv6Uint32Count {
		version = rnet.IPv6
	}
	path := newPrefixTree(version).(*prefixTrie)
67
	path.numBitsSkipped = numBitsSkipped
68
	path.network = network.Masked(int(numBitsSkipped))
Yulin Chen's avatar
Yulin Chen committed
69
	return path
70 71
}

72
func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie {
73
	ones, _ := network.IPNet.Mask.Size()
74
	leaf := newPathprefixTrie(network, uint(ones))
75
	leaf.entry = entry
Yulin Chen's avatar
Yulin Chen committed
76
	return leaf
77 78
}

79 80 81 82
// Insert inserts a RangerEntry into prefix trie.
func (p *prefixTrie) Insert(entry RangerEntry) error {
	network := entry.Network()
	return p.insert(rnet.NewNetwork(network), entry)
83 84
}

85 86
// Remove removes RangerEntry identified by given network from trie.
func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) {
87
	return p.remove(rnet.NewNetwork(network))
88 89 90 91
}

// Contains returns boolean indicating whether given ip is contained in any
// of the inserted networks.
92
func (p *prefixTrie) Contains(ip net.IP) (bool, error) {
93 94
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
95
		return false, ErrInvalidNetworkNumberInput
96
	}
97
	return p.contains(nn)
98 99
}

100 101 102
// ContainingNetworks returns the list of RangerEntry(s) the given ip is
// contained in in ascending prefix order.
func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
103 104
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
105
		return nil, ErrInvalidNetworkNumberInput
106
	}
107
	return p.containingNetworks(nn)
108 109
}

Rob Adams's avatar
Rob Adams committed
110 111 112 113 114 115 116 117
// CoveredNetworks returns the list of RangerEntry(s) the given ipnet
// covers.  That is, the networks that are completely subsumed by the
// specified network.
func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
	net := rnet.NewNetwork(network)
	return p.coveredNetworks(net)
}

118 119
// String returns string representation of trie, mainly for visualization and
// debugging.
120
func (p *prefixTrie) String() string {
121 122 123 124 125 126 127 128 129 130
	children := []string{}
	padding := strings.Repeat("| ", p.level()+1)
	for bits, child := range p.children {
		if child == nil {
			continue
		}
		childStr := fmt.Sprintf("\n%s%d--> %s", padding, bits, child.String())
		children = append(children, childStr)
	}
	return fmt.Sprintf("%s (target_pos:%d:has_entry:%t)%s", p.network,
131
		p.targetBitPosition(), p.hasEntry(), strings.Join(children, ""))
132 133
}

134
func (p *prefixTrie) contains(number rnet.NetworkNumber) (bool, error) {
135
	if !p.network.Contains(number) {
136 137
		return false, nil
	}
138
	if p.hasEntry() {
139 140
		return true, nil
	}
141 142 143
	if p.targetBitPosition() < 0 {
		return false, nil
	}
144
	bit, err := p.targetBitFromIP(number)
145 146 147
	if err != nil {
		return false, err
	}
148
	child := p.children[bit]
149
	if child != nil {
150
		return child.contains(number)
151 152 153 154
	}
	return false, nil
}

155 156
func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntry, error) {
	results := []RangerEntry{}
157
	if !p.network.Contains(number) {
158 159
		return results, nil
	}
160 161
	if p.hasEntry() {
		results = []RangerEntry{p.entry}
162
	}
163 164 165
	if p.targetBitPosition() < 0 {
		return results, nil
	}
166
	bit, err := p.targetBitFromIP(number)
167 168 169
	if err != nil {
		return nil, err
	}
170
	child := p.children[bit]
171
	if child != nil {
172
		ranges, err := child.containingNetworks(number)
173 174 175
		if err != nil {
			return nil, err
		}
176
		if len(ranges) > 0 {
177 178 179 180 181
			if len(results) > 0 {
				results = append(results, ranges...)
			} else {
				results = ranges
			}
182
		}
183 184 185 186
	}
	return results, nil
}

Rob Adams's avatar
Rob Adams committed
187 188
func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) {
	var results []RangerEntry
189
	if p.hasEntry() && network.Covers(p.network) {
Rob Adams's avatar
Rob Adams committed
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
		results = []RangerEntry{p.entry}
	}
	if p.targetBitPosition() < 0 {
		return results, nil
	}

	masked := network.Masked(int(p.numBitsSkipped))
	if !masked.Equal(p.network) {
		return results, nil
	}
	for _, child := range p.children {
		if child != nil {
			ranges, err := child.coveredNetworks(network)
			if err != nil {
				return nil, err
			}
			results = append(results, ranges...)
		}
	}
	return results, nil
}

212
func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) error {
213
	if p.network.Equal(network) {
214
		p.entry = entry
215 216
		return nil
	}
217
	bit, err := p.targetBitFromIP(network.Number)
218 219 220
	if err != nil {
		return err
	}
221
	child := p.children[bit]
222
	if child == nil {
223
		return p.insertPrefix(bit, newEntryTrie(network, entry))
224 225
	}

226
	lcb, err := network.LeastCommonBitPosition(child.network)
227 228 229
	if err != nil {
		return err
	}
230
	if int(lcb) > child.targetBitPosition()+1 {
231
		child = newPathprefixTrie(network, p.totalNumberOfBits()-lcb)
232
		err := p.insertPrefix(bit, child)
233 234 235 236
		if err != nil {
			return err
		}
	}
237
	return child.insert(network, entry)
238 239
}

240
func (p *prefixTrie) insertPrefix(bits uint32, prefix *prefixTrie) error {
241 242
	child := p.children[bits]
	if child != nil {
243
		prefixBit, err := prefix.targetBitFromIP(child.network.Number)
244 245 246
		if err != nil {
			return err
		}
247
		prefix.insertPrefix(prefixBit, child)
248 249 250 251 252 253
	}
	p.children[bits] = prefix
	prefix.parent = p
	return nil
}

254 255 256
func (p *prefixTrie) remove(network rnet.Network) (RangerEntry, error) {
	if p.hasEntry() && p.network.Equal(network) {
		entry := p.entry
257
		if p.childrenCount() > 1 {
258
			p.entry = nil
259 260 261 262 263 264
		} else {
			// Has 0 or 1 child.
			parentBits, err := p.parent.targetBitFromIP(network.Number)
			if err != nil {
				return nil, err
			}
265
			var skipChild *prefixTrie
266 267 268 269 270 271 272 273
			for _, child := range p.children {
				if child != nil {
					skipChild = child
					break
				}
			}
			p.parent.children[parentBits] = skipChild
		}
274
		return entry, nil
275 276 277 278 279 280 281 282 283 284
	}
	bit, err := p.targetBitFromIP(network.Number)
	if err != nil {
		return nil, err
	}
	child := p.children[bit]
	if child != nil {
		return child.remove(network)
	}
	return nil, nil
285 286
}

287
func (p *prefixTrie) childrenCount() int {
Yulin Chen's avatar
Yulin Chen committed
288 289 290 291 292 293 294 295 296
	count := 0
	for _, child := range p.children {
		if child != nil {
			count++
		}
	}
	return count
}

297 298 299 300
func (p *prefixTrie) totalNumberOfBits() uint {
	return rnet.BitsPerUint32 * uint(len(p.network.Number))
}

301 302
func (p *prefixTrie) targetBitPosition() int {
	return int(p.totalNumberOfBits()-p.numBitsSkipped) - 1
303 304
}

305
func (p *prefixTrie) targetBitFromIP(n rnet.NetworkNumber) (uint32, error) {
306 307 308
	// This is a safe uint boxing of int since we should never attempt to get
	// target bit at a negative position.
	return n.Bit(uint(p.targetBitPosition()))
309 310
}

311 312 313 314
func (p *prefixTrie) hasEntry() bool {
	return p.entry != nil
}

315
func (p *prefixTrie) level() int {
316 317 318 319 320 321 322
	if p.parent == nil {
		return 0
	}
	return p.parent.level() + 1
}

// walkDepth walks the trie in depth order, for unit testing.
323 324
func (p *prefixTrie) walkDepth() <-chan RangerEntry {
	entries := make(chan RangerEntry)
325
	go func() {
326 327
		if p.hasEntry() {
			entries <- p.entry
328
		}
329
		childEntriesList := []<-chan RangerEntry{}
330 331 332 333
		for _, trie := range p.children {
			if trie == nil {
				continue
			}
334
			childEntriesList = append(childEntriesList, trie.walkDepth())
335
		}
336 337 338
		for _, childEntries := range childEntriesList {
			for entry := range childEntries {
				entries <- entry
339 340
			}
		}
341
		close(entries)
342
	}()
343
	return entries
344
}