diff --git a/pool.go b/pool.go index 8771e023a02108c64be6d9a7cd57360d5a8c4e83..df8fa922771ba06db653239f7a62cd196c93e914 100644 --- a/pool.go +++ b/pool.go @@ -72,8 +72,9 @@ func (p *Pool) getPool(idx uint32) *sync.Pool { // If Get would otherwise return nil and p.New is non-nil, Get returns the // result of calling p.New. func (p *Pool) Get(length uint32) interface{} { - idx := largerPowerOfTwo(length) + idx := nextPowerOfTwo(length) sp := p.getPool(idx) + // fmt.Printf("Get(%d) idx(%d)\n", length, idx) val := sp.Get() if val == nil && p.New != nil { val = p.New(0x1 << idx) @@ -83,27 +84,42 @@ func (p *Pool) Get(length uint32) interface{} { // Put adds x to the pool. func (p *Pool) Put(length uint32, val interface{}) { - idx := smallerPowerOfTwo(length) + idx := prevPowerOfTwo(length) + // fmt.Printf("Put(%d, -) idx(%d)\n", length, idx) sp := p.getPool(idx) sp.Put(val) } -func largerPowerOfTwo(num uint32) uint32 { - for p := uint32(0); p < 32; p++ { - if (0x1 << p) >= num { - return p - } +func nextPowerOfTwo(v uint32) uint32 { + // fmt.Printf("nextPowerOfTwo(%d) ", v) + v-- + v |= v >> 1 + v |= v >> 2 + v |= v >> 4 + v |= v >> 8 + v |= v >> 16 + v++ + + // fmt.Printf("-> %d", v) + + i := uint32(0) + for i = 0; v > 1; i++ { + v = v >> 1 } - panic("unreachable") + // fmt.Printf("-> %d\n", i) + return i } -func smallerPowerOfTwo(num uint32) uint32 { - for p := uint32(1); p < 32; p++ { - if (0x1 << p) > num { - return p - 1 - } +func prevPowerOfTwo(num uint32) uint32 { + next := nextPowerOfTwo(num) + // fmt.Printf("prevPowerOfTwo(%d) next: %d", num, next) + switch { + case num == (1 << next): // num is a power of 2 + case next == 0: + default: + next = next - 1 // smaller } - - panic("unreachable") + // fmt.Printf(" = %d\n", next) + return next } diff --git a/pool_test.go b/pool_test.go index f669b3a4099deb21e0a4bb3cfffe54b8c21377a8..1681103220b2ced901047ef07b37d78b587dc53d 100644 --- a/pool_test.go +++ b/pool_test.go @@ -9,6 +9,7 @@ package mpool import ( "fmt" + "math/rand" "runtime" "runtime/debug" "sync/atomic" @@ -62,7 +63,7 @@ func TestPoolNew(t *testing.T) { s := [32]int{} p := Pool{ New: func(length int) interface{} { - idx := largerPowerOfTwo(uint32(length)) + idx := nextPowerOfTwo(uint32(length)) s[idx]++ return s[idx] }, @@ -148,10 +149,63 @@ func TestPoolStress(t *testing.T) { }() } for i := 0; i < P; i++ { + // fmt.Printf("%d/%d\n", i, P) <-done } } +func TestPoolStressByteSlicePool(t *testing.T) { + const P = 10 + chs := 10 + maxSize := uint32(1 << 16) + N := int(1e4) + if testing.Short() { + N /= 100 + } + p := ByteSlicePool + done := make(chan bool) + errs := make(chan error) + for i := 0; i < P; i++ { + go func() { + ch := make(chan []byte, chs+1) + + for i := 0; i < chs; i++ { + j := rand.Uint32() % maxSize + ch <- p.Get(j).([]byte) + } + + for j := 0; j < N; j++ { + r := uint32(0) + for i := 0; i < chs; i++ { + v := <-ch + p.Put(uint32(cap(v)), v) + r = rand.Uint32() % maxSize + v = p.Get(r).([]byte) + if uint32(len(v)) < r { + errs <- fmt.Errorf("expect len(v) >= %d, got %d", j, len(v)) + } + ch <- v + } + + if r%1000 == 0 { + runtime.GC() + } + } + done <- true + }() + } + + for i := 0; i < P; { + select { + case <-done: + i++ + // fmt.Printf("%d/%d\n", i, P) + case err := <-errs: + t.Error(err) + } + } +} + func BenchmarkPool(b *testing.B) { var p Pool b.RunParallel(func(pb *testing.PB) {