trie.go 7.85 KB
Newer Older
1
package cidranger
2 3 4 5 6 7

import (
	"fmt"
	"net"
	"strings"

8
	rnet "github.com/yl2chen/cidranger/net"
9 10
)

11
// prefixTrie is a path-compressed (PC) trie implementation of the
12 13 14 15 16 17 18 19 20
// ranger interface inspired by this blog post:
// https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux
//
// CIDR blocks are stored using a prefix tree structure where each node has its
// parent as prefix, and the path from the root node represents current CIDR
// block.
//
// For IPv4, the trie structure guarantees max depth of 32 as IPv4 addresses are
// 32 bits long and each bit represents a prefix tree starting at that bit. This
Yulin Chen's avatar
Yulin Chen committed
21
// property also guarantees constant lookup time in Big-O notation.
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
//
// Path compression compresses a string of node with only 1 child into a single
// node, decrease the amount of lookups necessary during containment tests.
//
// Level compression dictates the amount of direct children of a node by
// allowing it to handle multiple bits in the path.  The heuristic (based on
// children population) to decide when the compression and decompression happens
// is outlined in the prior linked blog, and will be experimented with in more
// depth in this project in the future.
//
// Note: Can not insert both IPv4 and IPv6 network addresses into the same
// prefix trie, use versionedRanger wrapper instead.
//
// TODO: Implement level-compressed component of the LPC trie.
type prefixTrie struct {
	parent   *prefixTrie
	children []*prefixTrie
39

40 41
	numBitsSkipped uint
	numBitsHandled uint
42

43 44
	network rnet.Network
	entry   RangerEntry
45 46
}

47 48
// newPrefixTree creates a new prefixTrie.
func newPrefixTree(version rnet.IPVersion) Ranger {
49
	_, rootNet, _ := net.ParseCIDR("0.0.0.0/0")
50 51 52 53 54
	if version == rnet.IPv6 {
		_, rootNet, _ = net.ParseCIDR("0::0/0")
	}
	return &prefixTrie{
		children:       make([]*prefixTrie, 2, 2),
55 56
		numBitsSkipped: 0,
		numBitsHandled: 1,
57
		network:        rnet.NewNetwork(*rootNet),
58 59 60
	}
}

61 62 63 64 65 66
func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie {
	version := rnet.IPv4
	if len(network.Number) == rnet.IPv6Uint32Count {
		version = rnet.IPv6
	}
	path := newPrefixTree(version).(*prefixTrie)
67
	path.numBitsSkipped = numBitsSkipped
68
	path.network = network.Masked(int(numBitsSkipped))
Yulin Chen's avatar
Yulin Chen committed
69
	return path
70 71
}

72
func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie {
73
	ones, _ := network.IPNet.Mask.Size()
74
	leaf := newPathprefixTrie(network, uint(ones))
75
	leaf.entry = entry
Yulin Chen's avatar
Yulin Chen committed
76
	return leaf
77 78
}

79 80 81 82
// Insert inserts a RangerEntry into prefix trie.
func (p *prefixTrie) Insert(entry RangerEntry) error {
	network := entry.Network()
	return p.insert(rnet.NewNetwork(network), entry)
83 84
}

85 86
// Remove removes RangerEntry identified by given network from trie.
func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) {
87
	return p.remove(rnet.NewNetwork(network))
88 89 90 91
}

// Contains returns boolean indicating whether given ip is contained in any
// of the inserted networks.
92
func (p *prefixTrie) Contains(ip net.IP) (bool, error) {
93 94
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
95
		return false, ErrInvalidNetworkNumberInput
96
	}
97
	return p.contains(nn)
98 99
}

100 101 102
// ContainingNetworks returns the list of RangerEntry(s) the given ip is
// contained in in ascending prefix order.
func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
103 104
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
105
		return nil, ErrInvalidNetworkNumberInput
106
	}
107
	return p.containingNetworks(nn)
108 109 110 111
}

// String returns string representation of trie, mainly for visualization and
// debugging.
112
func (p *prefixTrie) String() string {
113 114 115 116 117 118 119 120 121 122
	children := []string{}
	padding := strings.Repeat("| ", p.level()+1)
	for bits, child := range p.children {
		if child == nil {
			continue
		}
		childStr := fmt.Sprintf("\n%s%d--> %s", padding, bits, child.String())
		children = append(children, childStr)
	}
	return fmt.Sprintf("%s (target_pos:%d:has_entry:%t)%s", p.network,
123
		p.targetBitPosition(), p.hasEntry(), strings.Join(children, ""))
124 125
}

126
func (p *prefixTrie) contains(number rnet.NetworkNumber) (bool, error) {
127
	if !p.network.Contains(number) {
128 129
		return false, nil
	}
130
	if p.hasEntry() {
131 132
		return true, nil
	}
133 134 135
	if p.targetBitPosition() < 0 {
		return false, nil
	}
136
	bit, err := p.targetBitFromIP(number)
137 138 139
	if err != nil {
		return false, err
	}
140
	child := p.children[bit]
141
	if child != nil {
142
		return child.contains(number)
143 144 145 146
	}
	return false, nil
}

147 148
func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntry, error) {
	results := []RangerEntry{}
149
	if !p.network.Contains(number) {
150 151
		return results, nil
	}
152 153
	if p.hasEntry() {
		results = []RangerEntry{p.entry}
154
	}
155 156 157
	if p.targetBitPosition() < 0 {
		return results, nil
	}
158
	bit, err := p.targetBitFromIP(number)
159 160 161
	if err != nil {
		return nil, err
	}
162
	child := p.children[bit]
163
	if child != nil {
164
		ranges, err := child.containingNetworks(number)
165 166 167
		if err != nil {
			return nil, err
		}
168
		if len(ranges) > 0 {
169 170 171 172 173
			if len(results) > 0 {
				results = append(results, ranges...)
			} else {
				results = ranges
			}
174
		}
175 176 177 178
	}
	return results, nil
}

179
func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) error {
180
	if p.network.Equal(network) {
181
		p.entry = entry
182 183
		return nil
	}
184
	bit, err := p.targetBitFromIP(network.Number)
185 186 187
	if err != nil {
		return err
	}
188
	child := p.children[bit]
189
	if child == nil {
190
		return p.insertPrefix(bit, newEntryTrie(network, entry))
191 192
	}

193
	lcb, err := network.LeastCommonBitPosition(child.network)
194 195 196
	if err != nil {
		return err
	}
197
	if int(lcb) > child.targetBitPosition()+1 {
198
		child = newPathprefixTrie(network, p.totalNumberOfBits()-lcb)
199
		err := p.insertPrefix(bit, child)
200 201 202 203
		if err != nil {
			return err
		}
	}
204
	return child.insert(network, entry)
205 206
}

207
func (p *prefixTrie) insertPrefix(bits uint32, prefix *prefixTrie) error {
208 209
	child := p.children[bits]
	if child != nil {
210
		prefixBit, err := prefix.targetBitFromIP(child.network.Number)
211 212 213
		if err != nil {
			return err
		}
214
		prefix.insertPrefix(prefixBit, child)
215 216 217 218 219 220
	}
	p.children[bits] = prefix
	prefix.parent = p
	return nil
}

221 222 223
func (p *prefixTrie) remove(network rnet.Network) (RangerEntry, error) {
	if p.hasEntry() && p.network.Equal(network) {
		entry := p.entry
224
		if p.childrenCount() > 1 {
225
			p.entry = nil
226 227 228 229 230 231
		} else {
			// Has 0 or 1 child.
			parentBits, err := p.parent.targetBitFromIP(network.Number)
			if err != nil {
				return nil, err
			}
232
			var skipChild *prefixTrie
233 234 235 236 237 238 239 240
			for _, child := range p.children {
				if child != nil {
					skipChild = child
					break
				}
			}
			p.parent.children[parentBits] = skipChild
		}
241
		return entry, nil
242 243 244 245 246 247 248 249 250 251
	}
	bit, err := p.targetBitFromIP(network.Number)
	if err != nil {
		return nil, err
	}
	child := p.children[bit]
	if child != nil {
		return child.remove(network)
	}
	return nil, nil
252 253
}

254
func (p *prefixTrie) childrenCount() int {
Yulin Chen's avatar
Yulin Chen committed
255 256 257 258 259 260 261 262 263
	count := 0
	for _, child := range p.children {
		if child != nil {
			count++
		}
	}
	return count
}

264 265 266 267
func (p *prefixTrie) totalNumberOfBits() uint {
	return rnet.BitsPerUint32 * uint(len(p.network.Number))
}

268 269
func (p *prefixTrie) targetBitPosition() int {
	return int(p.totalNumberOfBits()-p.numBitsSkipped) - 1
270 271
}

272
func (p *prefixTrie) targetBitFromIP(n rnet.NetworkNumber) (uint32, error) {
273 274 275
	// This is a safe uint boxing of int since we should never attempt to get
	// target bit at a negative position.
	return n.Bit(uint(p.targetBitPosition()))
276 277
}

278 279 280 281
func (p *prefixTrie) hasEntry() bool {
	return p.entry != nil
}

282
func (p *prefixTrie) level() int {
283 284 285 286 287 288 289
	if p.parent == nil {
		return 0
	}
	return p.parent.level() + 1
}

// walkDepth walks the trie in depth order, for unit testing.
290 291
func (p *prefixTrie) walkDepth() <-chan RangerEntry {
	entries := make(chan RangerEntry)
292
	go func() {
293 294
		if p.hasEntry() {
			entries <- p.entry
295
		}
296
		childEntriesList := []<-chan RangerEntry{}
297 298 299 300
		for _, trie := range p.children {
			if trie == nil {
				continue
			}
301
			childEntriesList = append(childEntriesList, trie.walkDepth())
302
		}
303 304 305
		for _, childEntries := range childEntriesList {
			for entry := range childEntries {
				entries <- entry
306 307
			}
		}
308
		close(entries)
309
	}()
310
	return entries
311
}