trie.go 7.34 KB
Newer Older
1
package cidranger
2 3 4 5 6 7

import (
	"fmt"
	"net"
	"strings"

8
	rnet "github.com/yl2chen/cidranger/net"
9 10
)

11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
// prefixTrie is a level-path-compressed (LPC) trie implementation of the
// ranger interface inspired by this blog post:
// https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux
//
// CIDR blocks are stored using a prefix tree structure where each node has its
// parent as prefix, and the path from the root node represents current CIDR
// block.
//
// For IPv4, the trie structure guarantees max depth of 32 as IPv4 addresses are
// 32 bits long and each bit represents a prefix tree starting at that bit. This
// property also gaurantees constant lookup time in Big-O notation.
//
// Path compression compresses a string of node with only 1 child into a single
// node, decrease the amount of lookups necessary during containment tests.
//
// Level compression dictates the amount of direct children of a node by
// allowing it to handle multiple bits in the path.  The heuristic (based on
// children population) to decide when the compression and decompression happens
// is outlined in the prior linked blog, and will be experimented with in more
// depth in this project in the future.
//
// Note: Can not insert both IPv4 and IPv6 network addresses into the same
// prefix trie, use versionedRanger wrapper instead.
//
// TODO: Implement level-compressed component of the LPC trie.
type prefixTrie struct {
	parent   *prefixTrie
	children []*prefixTrie
39

40 41
	numBitsSkipped uint
	numBitsHandled uint
42

43 44
	network  rnet.Network
	hasEntry bool
45 46
}

47 48
// newPrefixTree creates a new prefixTrie.
func newPrefixTree(version rnet.IPVersion) Ranger {
49
	_, rootNet, _ := net.ParseCIDR("0.0.0.0/0")
50 51 52 53 54
	if version == rnet.IPv6 {
		_, rootNet, _ = net.ParseCIDR("0::0/0")
	}
	return &prefixTrie{
		children:       make([]*prefixTrie, 2, 2),
55 56
		numBitsSkipped: 0,
		numBitsHandled: 1,
57
		network:        rnet.NewNetwork(*rootNet),
58 59 60
	}
}

61 62 63 64 65 66
func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie {
	version := rnet.IPv4
	if len(network.Number) == rnet.IPv6Uint32Count {
		version = rnet.IPv6
	}
	path := newPrefixTree(version).(*prefixTrie)
67
	path.numBitsSkipped = numBitsSkipped
68
	path.network = network.Masked(int(numBitsSkipped))
Yulin Chen's avatar
Yulin Chen committed
69
	return path
70 71
}

72
func newEntryTrie(network rnet.Network) *prefixTrie {
73
	ones, _ := network.IPNet.Mask.Size()
74
	leaf := newPathprefixTrie(network, uint(ones))
75
	leaf.hasEntry = true
Yulin Chen's avatar
Yulin Chen committed
76
	return leaf
77 78 79
}

// Insert inserts the given cidr range into prefix trie.
80
func (p *prefixTrie) Insert(network net.IPNet) error {
81
	return p.insert(rnet.NewNetwork(network))
82 83 84
}

// Remove removes network from trie.
85
func (p *prefixTrie) Remove(network net.IPNet) (*net.IPNet, error) {
86
	return p.remove(rnet.NewNetwork(network))
87 88 89 90
}

// Contains returns boolean indicating whether given ip is contained in any
// of the inserted networks.
91
func (p *prefixTrie) Contains(ip net.IP) (bool, error) {
92 93
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
94
		return false, ErrInvalidNetworkNumberInput
95
	}
96
	return p.contains(nn)
97 98 99 100
}

// ContainingNetworks returns the list of networks given ip is a part of in
// ascending prefix order.
101
func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]net.IPNet, error) {
102 103
	nn := rnet.NewNetworkNumber(ip)
	if nn == nil {
104
		return nil, ErrInvalidNetworkNumberInput
105
	}
106
	return p.containingNetworks(nn)
107 108 109 110
}

// String returns string representation of trie, mainly for visualization and
// debugging.
111
func (p *prefixTrie) String() string {
112 113 114 115 116 117 118 119 120 121 122 123 124
	children := []string{}
	padding := strings.Repeat("| ", p.level()+1)
	for bits, child := range p.children {
		if child == nil {
			continue
		}
		childStr := fmt.Sprintf("\n%s%d--> %s", padding, bits, child.String())
		children = append(children, childStr)
	}
	return fmt.Sprintf("%s (target_pos:%d:has_entry:%t)%s", p.network,
		p.targetBitPosition(), p.hasEntry, strings.Join(children, ""))
}

125
func (p *prefixTrie) contains(number rnet.NetworkNumber) (bool, error) {
126
	if !p.network.Contains(number) {
127 128 129 130 131
		return false, nil
	}
	if p.hasEntry {
		return true, nil
	}
132
	bit, err := p.targetBitFromIP(number)
133 134 135
	if err != nil {
		return false, err
	}
136
	child := p.children[bit]
137
	if child != nil {
138
		return child.contains(number)
139 140 141 142
	}
	return false, nil
}

143
func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]net.IPNet, error) {
144
	results := []net.IPNet{}
145
	if !p.network.Contains(number) {
146 147 148
		return results, nil
	}
	if p.hasEntry {
149
		results = []net.IPNet{p.network.IPNet}
150
	}
151
	bit, err := p.targetBitFromIP(number)
152 153 154
	if err != nil {
		return nil, err
	}
155
	child := p.children[bit]
156
	if child != nil {
157
		ranges, err := child.containingNetworks(number)
158 159 160
		if err != nil {
			return nil, err
		}
161 162 163
		if len(ranges) > 0 {
			results = append(results, ranges...)
		}
164 165 166 167
	}
	return results, nil
}

168
func (p *prefixTrie) insert(network rnet.Network) error {
169
	if p.network.Equal(network) {
170 171 172
		p.hasEntry = true
		return nil
	}
173
	bit, err := p.targetBitFromIP(network.Number)
174 175 176
	if err != nil {
		return err
	}
177
	child := p.children[bit]
178
	if child == nil {
Yulin Chen's avatar
Yulin Chen committed
179
		return p.insertPrefix(bit, newEntryTrie(network))
180 181
	}

182
	lcb, err := network.LeastCommonBitPosition(child.network)
183 184 185
	if err != nil {
		return err
	}
186
	if lcb-1 > child.targetBitPosition() {
187
		child = newPathprefixTrie(network, p.totalNumberOfBits()-lcb)
188
		err := p.insertPrefix(bit, child)
189 190 191 192
		if err != nil {
			return err
		}
	}
193
	return child.insert(network)
194 195
}

196
func (p *prefixTrie) insertPrefix(bits uint32, prefix *prefixTrie) error {
197 198
	child := p.children[bits]
	if child != nil {
199
		prefixBit, err := prefix.targetBitFromIP(child.network.Number)
200 201 202
		if err != nil {
			return err
		}
203
		prefix.insertPrefix(prefixBit, child)
204 205 206 207 208 209
	}
	p.children[bits] = prefix
	prefix.parent = p
	return nil
}

210
func (p *prefixTrie) remove(network rnet.Network) (*net.IPNet, error) {
211 212 213 214 215 216 217 218 219
	if p.hasEntry && p.network.Equal(network) {
		if p.childrenCount() > 1 {
			p.hasEntry = false
		} else {
			// Has 0 or 1 child.
			parentBits, err := p.parent.targetBitFromIP(network.Number)
			if err != nil {
				return nil, err
			}
220
			var skipChild *prefixTrie
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
			for _, child := range p.children {
				if child != nil {
					skipChild = child
					break
				}
			}
			p.parent.children[parentBits] = skipChild
		}
		return &network.IPNet, nil
	}
	bit, err := p.targetBitFromIP(network.Number)
	if err != nil {
		return nil, err
	}
	child := p.children[bit]
	if child != nil {
		return child.remove(network)
	}
	return nil, nil
240 241
}

242
func (p *prefixTrie) childrenCount() int {
Yulin Chen's avatar
Yulin Chen committed
243 244 245 246 247 248 249 250 251
	count := 0
	for _, child := range p.children {
		if child != nil {
			count++
		}
	}
	return count
}

252 253 254 255 256 257
func (p *prefixTrie) totalNumberOfBits() uint {
	return rnet.BitsPerUint32 * uint(len(p.network.Number))
}

func (p *prefixTrie) targetBitPosition() uint {
	return p.totalNumberOfBits() - p.numBitsSkipped - 1
258 259
}

260
func (p *prefixTrie) targetBitFromIP(n rnet.NetworkNumber) (uint32, error) {
261
	return n.Bit(p.targetBitPosition())
262 263
}

264
func (p *prefixTrie) level() int {
265 266 267 268 269 270 271
	if p.parent == nil {
		return 0
	}
	return p.parent.level() + 1
}

// walkDepth walks the trie in depth order, for unit testing.
272
func (p *prefixTrie) walkDepth() <-chan net.IPNet {
273 274 275
	networks := make(chan net.IPNet)
	go func() {
		if p.hasEntry {
276
			networks <- p.network.IPNet
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
		}
		subNetworks := []<-chan net.IPNet{}
		for _, trie := range p.children {
			if trie == nil {
				continue
			}
			subNetworks = append(subNetworks, trie.walkDepth())
		}
		for _, subNetwork := range subNetworks {
			for network := range subNetwork {
				networks <- network
			}
		}
		close(networks)
	}()
	return networks
}