Commit 96a5b4b6 authored by Petar Maymounkov's avatar Petar Maymounkov

Refactor.

parent d706c85b
package xor package key
import "bytes" import "bytes"
// TrieKey is a vector of bits backed by a Go byte slice in big endian byte order and big-endian bit order. // Key is a vector of bits backed by a Go byte slice in big endian byte order and big-endian bit order.
type TrieKey []byte type Key []byte
func (bs TrieKey) BitAt(offset int) byte { func (bs Key) BitAt(offset int) byte {
if bs[offset/8]&(1<<(offset%8)) == 0 { if bs[offset/8]&(1<<(offset%8)) == 0 {
return 0 return 0
} else { } else {
...@@ -13,10 +13,10 @@ func (bs TrieKey) BitAt(offset int) byte { ...@@ -13,10 +13,10 @@ func (bs TrieKey) BitAt(offset int) byte {
} }
} }
func (bs TrieKey) BitLen() int { func (bs Key) BitLen() int {
return 8 * len(bs) return 8 * len(bs)
} }
func TrieKeyEqual(x, y TrieKey) bool { func Equal(x, y Key) bool {
return bytes.Equal(x, y) return bytes.Equal(x, y)
} }
package xor package trie
import (
"github.com/libp2p/go-libp2p-xor/key"
)
// Add adds the key q to trie, returning a new trie. // Add adds the key q to trie, returning a new trie.
// Add is immutable/non-destructive: The original trie remains unchanged. // Add is immutable/non-destructive: The original trie remains unchanged.
func Add(trie *XorTrie, q TrieKey) *XorTrie { func Add(trie *XorTrie, q key.Key) *XorTrie {
return add(0, trie, q) return add(0, trie, q)
} }
func add(depth int, trie *XorTrie, q TrieKey) *XorTrie { func add(depth int, trie *XorTrie, q key.Key) *XorTrie {
dir := q.BitAt(depth) dir := q.BitAt(depth)
if !trie.isLeaf() { if !trie.isLeaf() {
s := &XorTrie{} s := &XorTrie{}
...@@ -17,7 +21,7 @@ func add(depth int, trie *XorTrie, q TrieKey) *XorTrie { ...@@ -17,7 +21,7 @@ func add(depth int, trie *XorTrie, q TrieKey) *XorTrie {
if trie.key == nil { if trie.key == nil {
return &XorTrie{key: q} return &XorTrie{key: q}
} else { } else {
if TrieKeyEqual(trie.key, q) { if key.Equal(trie.key, q) {
return trie return trie
} else { } else {
s := &XorTrie{} s := &XorTrie{}
......
package xor package trie
import "testing" import (
"testing"
"github.com/libp2p/go-libp2p-xor/key"
)
// Verify mutable and immutable add do the same thing. // Verify mutable and immutable add do the same thing.
func TestMutableAndImmutableAddSame(t *testing.T) { func TestMutableAndImmutableAddSame(t *testing.T) {
for _, s := range testAddSameSamples { for _, s := range testAddSameSamples {
mut := NewXorTrie() mut := New()
immut := NewXorTrie() immut := New()
for _, k := range s.Keys { for _, k := range s.Keys {
mut.Add(k) mut.Add(k)
immut = Add(immut, k) immut = Add(immut, k)
} }
if !XorTrieEqual(mut, immut) { if !Equal(mut, immut) {
t.Errorf("mutable trie %v differs from immutable trie %v", mut, immut) t.Errorf("mutable trie %v differs from immutable trie %v", mut, immut)
} }
} }
} }
type testAddSameSample struct { type testAddSameSample struct {
Keys []TrieKey Keys []key.Key
} }
var testAddSameSamples = []*testAddSameSample{ var testAddSameSamples = []*testAddSameSample{
{Keys: []TrieKey{{1, 3, 5, 7, 11, 13}}}, {Keys: []key.Key{{1, 3, 5, 7, 11, 13}}},
} }
package xor package trie
func XorTrieEqual(p, q *XorTrie) bool { import (
"github.com/libp2p/go-libp2p-xor/key"
)
func Equal(p, q *XorTrie) bool {
switch { switch {
case p.isLeaf() && q.isLeaf(): case p.isLeaf() && q.isLeaf():
return TrieKeyEqual(p.key, q.key) return key.Equal(p.key, q.key)
case !p.isLeaf() && !q.isLeaf(): case !p.isLeaf() && !q.isLeaf():
return XorTrieEqual(p.branch[0], q.branch[0]) && XorTrieEqual(p.branch[1], q.branch[1]) return Equal(p.branch[0], q.branch[0]) && Equal(p.branch[1], q.branch[1])
} }
return false return false
} }
package xor package trie
import (
"github.com/libp2p/go-libp2p-xor/key"
)
// Intersect computes the intersection of the keys in p and q. // Intersect computes the intersection of the keys in p and q.
// p and q must be non-nil. The returned trie is never nil. // p and q must be non-nil. The returned trie is never nil.
...@@ -12,7 +16,7 @@ func intersect(depth int, p, q *XorTrie) *XorTrie { ...@@ -12,7 +16,7 @@ func intersect(depth int, p, q *XorTrie) *XorTrie {
if p.isEmpty() || q.isEmpty() { if p.isEmpty() || q.isEmpty() {
return &XorTrie{} // empty set return &XorTrie{} // empty set
} else { } else {
if TrieKeyEqual(p.key, q.key) { if key.Equal(p.key, q.key) {
return &XorTrie{key: p.key} // singleton return &XorTrie{key: p.key} // singleton
} else { } else {
return &XorTrie{} // empty set return &XorTrie{} // empty set
......
package xor package trie
import "testing" import (
"testing"
"github.com/libp2p/go-libp2p-xor/key"
)
func TestIntersectRandom(t *testing.T) { func TestIntersectRandom(t *testing.T) {
for _, s := range testIntersectSamples { for _, s := range testIntersectSamples {
...@@ -9,7 +13,7 @@ func TestIntersectRandom(t *testing.T) { ...@@ -9,7 +13,7 @@ func TestIntersectRandom(t *testing.T) {
} }
func testIntersect(t *testing.T, sample *testIntersectSample) { func testIntersect(t *testing.T, sample *testIntersectSample) {
left, right, expected := NewXorTrie(), NewXorTrie(), NewXorTrie() left, right, expected := New(), New(), New()
for _, l := range sample.LeftKeys { for _, l := range sample.LeftKeys {
left.Add(l) left.Add(l)
} }
...@@ -20,17 +24,17 @@ func testIntersect(t *testing.T, sample *testIntersectSample) { ...@@ -20,17 +24,17 @@ func testIntersect(t *testing.T, sample *testIntersectSample) {
expected.Add(s) expected.Add(s)
} }
got := Intersect(left, right) got := Intersect(left, right)
if !XorTrieEqual(expected, got) { if !Equal(expected, got) {
t.Errorf("intersection of %v and %v: expected %v, got %v", t.Errorf("intersection of %v and %v: expected %v, got %v",
sample.LeftKeys, sample.RightKeys, expected, got) sample.LeftKeys, sample.RightKeys, expected, got)
} }
} }
func setIntersect(left, right []TrieKey) []TrieKey { func setIntersect(left, right []key.Key) []key.Key {
intersection := []TrieKey{} intersection := []key.Key{}
for _, l := range left { for _, l := range left {
for _, r := range right { for _, r := range right {
if TrieKeyEqual(l, r) { if key.Equal(l, r) {
intersection = append(intersection, r) intersection = append(intersection, r)
} }
} }
...@@ -39,21 +43,21 @@ func setIntersect(left, right []TrieKey) []TrieKey { ...@@ -39,21 +43,21 @@ func setIntersect(left, right []TrieKey) []TrieKey {
} }
type testIntersectSample struct { type testIntersectSample struct {
LeftKeys []TrieKey LeftKeys []key.Key
RightKeys []TrieKey RightKeys []key.Key
} }
var testIntersectSamples = []*testIntersectSample{ var testIntersectSamples = []*testIntersectSample{
{ {
LeftKeys: []TrieKey{{1, 2, 3}}, LeftKeys: []key.Key{{1, 2, 3}},
RightKeys: []TrieKey{{1, 3, 5}}, RightKeys: []key.Key{{1, 3, 5}},
}, },
{ {
LeftKeys: []TrieKey{{1, 2, 3, 4, 5, 6}}, LeftKeys: []key.Key{{1, 2, 3, 4, 5, 6}},
RightKeys: []TrieKey{{3, 5, 7}}, RightKeys: []key.Key{{3, 5, 7}},
}, },
{ {
LeftKeys: []TrieKey{{23, 3, 7, 13, 17}}, LeftKeys: []key.Key{{23, 3, 7, 13, 17}},
RightKeys: []TrieKey{{2, 11, 17, 19, 23}}, RightKeys: []key.Key{{2, 11, 17, 19, 23}},
}, },
} }
package xor package trie
import (
"github.com/libp2p/go-libp2p-xor/key"
)
// XorTrie is a trie for equal-length bit vectors, which stores values only in the leaves. // XorTrie is a trie for equal-length bit vectors, which stores values only in the leaves.
// XorTrie node invariants: // XorTrie node invariants:
...@@ -6,10 +10,10 @@ package xor ...@@ -6,10 +10,10 @@ package xor
// (2) If both branches are leaves, then they are both non-empty (have keys). // (2) If both branches are leaves, then they are both non-empty (have keys).
type XorTrie struct { type XorTrie struct {
branch [2]*XorTrie branch [2]*XorTrie
key TrieKey key key.Key
} }
func NewXorTrie() *XorTrie { func New() *XorTrie {
return &XorTrie{} return &XorTrie{}
} }
...@@ -32,29 +36,29 @@ func max(x, y int) int { ...@@ -32,29 +36,29 @@ func max(x, y int) int {
return y return y
} }
func (trie *XorTrie) Find(q TrieKey) (reachedDepth int, found bool) { func (trie *XorTrie) Find(q key.Key) (reachedDepth int, found bool) {
return trie.find(0, q) return trie.find(0, q)
} }
func (trie *XorTrie) find(depth int, q TrieKey) (reachedDepth int, found bool) { func (trie *XorTrie) find(depth int, q key.Key) (reachedDepth int, found bool) {
if qb := trie.branch[q.BitAt(depth)]; qb != nil { if qb := trie.branch[q.BitAt(depth)]; qb != nil {
return qb.find(depth+1, q) return qb.find(depth+1, q)
} else { } else {
if trie.key == nil { if trie.key == nil {
return depth, false return depth, false
} else { } else {
return depth, TrieKeyEqual(trie.key, q) return depth, key.Equal(trie.key, q)
} }
} }
} }
// Add adds the key q to the trie. Add mutates the trie. // Add adds the key q to the trie. Add mutates the trie.
// TODO: Also implement an immutable version of Add. // TODO: Also implement an immutable version of Add.
func (trie *XorTrie) Add(q TrieKey) (insertedDepth int, insertedOK bool) { func (trie *XorTrie) Add(q key.Key) (insertedDepth int, insertedOK bool) {
return trie.add(0, q) return trie.add(0, q)
} }
func (trie *XorTrie) add(depth int, q TrieKey) (insertedDepth int, insertedOK bool) { func (trie *XorTrie) add(depth int, q key.Key) (insertedDepth int, insertedOK bool) {
if qb := trie.branch[q.BitAt(depth)]; qb != nil { if qb := trie.branch[q.BitAt(depth)]; qb != nil {
return qb.add(depth+1, q) return qb.add(depth+1, q)
} else { } else {
...@@ -62,7 +66,7 @@ func (trie *XorTrie) add(depth int, q TrieKey) (insertedDepth int, insertedOK bo ...@@ -62,7 +66,7 @@ func (trie *XorTrie) add(depth int, q TrieKey) (insertedDepth int, insertedOK bo
trie.key = q trie.key = q
return depth, true return depth, true
} else { } else {
if TrieKeyEqual(trie.key, q) { if key.Equal(trie.key, q) {
return depth, false return depth, false
} else { } else {
p := trie.key p := trie.key
...@@ -78,11 +82,11 @@ func (trie *XorTrie) add(depth int, q TrieKey) (insertedDepth int, insertedOK bo ...@@ -78,11 +82,11 @@ func (trie *XorTrie) add(depth int, q TrieKey) (insertedDepth int, insertedOK bo
// Remove removes the key q from the trie. Remove mutates the trie. // Remove removes the key q from the trie. Remove mutates the trie.
// TODO: Also implement an immutable version of Add. // TODO: Also implement an immutable version of Add.
func (trie *XorTrie) Remove(q TrieKey) (removedDepth int, removed bool) { func (trie *XorTrie) Remove(q key.Key) (removedDepth int, removed bool) {
return trie.remove(0, q) return trie.remove(0, q)
} }
func (trie *XorTrie) remove(depth int, q TrieKey) (reachedDepth int, removed bool) { func (trie *XorTrie) remove(depth int, q key.Key) (reachedDepth int, removed bool) {
if qb := trie.branch[q.BitAt(depth)]; qb != nil { if qb := trie.branch[q.BitAt(depth)]; qb != nil {
if d, ok := qb.remove(depth+1, q); ok { if d, ok := qb.remove(depth+1, q); ok {
trie.shrink() trie.shrink()
...@@ -91,7 +95,7 @@ func (trie *XorTrie) remove(depth int, q TrieKey) (reachedDepth int, removed boo ...@@ -91,7 +95,7 @@ func (trie *XorTrie) remove(depth int, q TrieKey) (reachedDepth int, removed boo
return d, false return d, false
} }
} else { } else {
if trie.key != nil && TrieKeyEqual(q, trie.key) { if trie.key != nil && key.Equal(q, trie.key) {
trie.key = nil trie.key = nil
return depth, true return depth, true
} else { } else {
......
package xor package trie
import "testing" import (
"testing"
"github.com/libp2p/go-libp2p-xor/key"
)
func TestInsertRemove(t *testing.T) { func TestInsertRemove(t *testing.T) {
r := NewXorTrie() r := New()
testSeq(r, t) testSeq(r, t)
testSeq(r, t) testSeq(r, t)
} }
func testSeq(r *XorTrie, t *testing.T) { func testSeq(r *XorTrie, t *testing.T) {
for _, s := range testInsertSeq { for _, s := range testInsertSeq {
depth, _ := r.Add(TrieKey(s.key)) depth, _ := r.Add(key.Key(s.key))
if depth != s.insertedDepth { if depth != s.insertedDepth {
t.Errorf("inserting expected %d, got %d", s.insertedDepth, depth) t.Errorf("inserting expected %d, got %d", s.insertedDepth, depth)
} }
} }
for _, s := range testRemoveSeq { for _, s := range testRemoveSeq {
depth, _ := r.Remove(TrieKey(s.key)) depth, _ := r.Remove(key.Key(s.key))
if depth != s.reachedDepth { if depth != s.reachedDepth {
t.Errorf("removing expected %d, got %d", s.reachedDepth, depth) t.Errorf("removing expected %d, got %d", s.reachedDepth, depth)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment