Commit 1cc48d47 authored by Juan Benet's avatar Juan Benet

Merge pull request #26 from whyrusleeping/master

Fix interface issues and implement more things on sub-datastores
parents 7d6acaf7 0f0f9bfc
......@@ -21,6 +21,11 @@
"ImportPath": "github.com/dustin/randbo",
"Rev": "7f1b564ca7242d22bcc6e2128beb90d9fa38b9f0"
},
{
"ImportPath": "github.com/fzzy/radix/redis",
"Comment": "v0.5.1",
"Rev": "27a863cdffdb0998d13e1e11992b18489aeeaa25"
},
{
"ImportPath": "github.com/hashicorp/golang-lru",
"Rev": "4dfff096c4973178c8f35cf6dd1a732a0a139370"
......
package redis
import (
"bufio"
"errors"
"net"
"strings"
"time"
"github.com/jbenet/go-datastore/Godeps/_workspace/src/github.com/fzzy/radix/redis/resp"
)
const (
bufSize int = 4096
)
//* Common errors
var LoadingError error = errors.New("server is busy loading dataset in memory")
var PipelineQueueEmptyError error = errors.New("pipeline queue empty")
//* Client
// Client describes a Redis client.
type Client struct {
// The connection the client talks to redis over. Don't touch this unless
// you know what you're doing.
Conn net.Conn
timeout time.Duration
reader *bufio.Reader
pending []*request
completed []*Reply
}
// request describes a client's request to the redis server
type request struct {
cmd string
args []interface{}
}
// Dial connects to the given Redis server with the given timeout, which will be
// used as the read/write timeout when communicating with redis
func DialTimeout(network, addr string, timeout time.Duration) (*Client, error) {
// establish a connection
conn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
c := new(Client)
c.Conn = conn
c.timeout = timeout
c.reader = bufio.NewReaderSize(conn, bufSize)
return c, nil
}
// Dial connects to the given Redis server.
func Dial(network, addr string) (*Client, error) {
return DialTimeout(network, addr, time.Duration(0))
}
//* Public methods
// Close closes the connection.
func (c *Client) Close() error {
return c.Conn.Close()
}
// Cmd calls the given Redis command.
func (c *Client) Cmd(cmd string, args ...interface{}) *Reply {
err := c.writeRequest(&request{cmd, args})
if err != nil {
return &Reply{Type: ErrorReply, Err: err}
}
return c.ReadReply()
}
// Append adds the given call to the pipeline queue.
// Use GetReply() to read the reply.
func (c *Client) Append(cmd string, args ...interface{}) {
c.pending = append(c.pending, &request{cmd, args})
}
// GetReply returns the reply for the next request in the pipeline queue.
// Error reply with PipelineQueueEmptyError is returned,
// if the pipeline queue is empty.
func (c *Client) GetReply() *Reply {
if len(c.completed) > 0 {
r := c.completed[0]
c.completed = c.completed[1:]
return r
}
c.completed = nil
if len(c.pending) == 0 {
return &Reply{Type: ErrorReply, Err: PipelineQueueEmptyError}
}
nreqs := len(c.pending)
err := c.writeRequest(c.pending...)
c.pending = nil
if err != nil {
return &Reply{Type: ErrorReply, Err: err}
}
r := c.ReadReply()
c.completed = make([]*Reply, nreqs-1)
for i := 0; i < nreqs-1; i++ {
c.completed[i] = c.ReadReply()
}
return r
}
//* Private methods
func (c *Client) setReadTimeout() {
if c.timeout != 0 {
c.Conn.SetReadDeadline(time.Now().Add(c.timeout))
}
}
func (c *Client) setWriteTimeout() {
if c.timeout != 0 {
c.Conn.SetWriteDeadline(time.Now().Add(c.timeout))
}
}
// This will read a redis reply off of the connection without sending anything
// first (useful after you've sent a SUSBSCRIBE command). This will block until
// a reply is received or the timeout is reached. On timeout an ErrorReply will
// be returned, you can check if it's a timeout like so:
//
// r := conn.ReadReply()
// if r.Err != nil {
// if t, ok := r.Err.(*net.OpError); ok && t.Timeout() {
// // Is timeout
// } else {
// // Not timeout
// }
// }
//
// Note: this is a more low-level function, you really shouldn't have to
// actually use it unless you're writing your own pub/sub code
func (c *Client) ReadReply() *Reply {
c.setReadTimeout()
return c.parse()
}
func (c *Client) writeRequest(requests ...*request) error {
c.setWriteTimeout()
for i := range requests {
req := make([]interface{}, 0, len(requests[i].args)+1)
req = append(req, requests[i].cmd)
req = append(req, requests[i].args...)
err := resp.WriteArbitraryAsFlattenedStrings(c.Conn, req)
if err != nil {
c.Close()
return err
}
}
return nil
}
func (c *Client) parse() *Reply {
m, err := resp.ReadMessage(c.reader)
if err != nil {
if t, ok := err.(*net.OpError); !ok || !t.Timeout() {
// close connection except timeout
c.Close()
}
return &Reply{Type: ErrorReply, Err: err}
}
r, err := messageToReply(m)
if err != nil {
return &Reply{Type: ErrorReply, Err: err}
}
return r
}
// The error return parameter is for bubbling up parse errors and the like, if
// the error is sent by redis itself as an Err message type, then it will be
// sent back as an actual Reply (wrapped in a CmdError)
func messageToReply(m *resp.Message) (*Reply, error) {
r := &Reply{}
switch m.Type {
case resp.Err:
errMsg, err := m.Err()
if err != nil {
return nil, err
}
if strings.HasPrefix(errMsg.Error(), "LOADING") {
err = LoadingError
} else {
err = &CmdError{errMsg}
}
r.Type = ErrorReply
r.Err = err
case resp.SimpleStr:
status, err := m.Bytes()
if err != nil {
return nil, err
}
r.Type = StatusReply
r.buf = status
case resp.Int:
i, err := m.Int()
if err != nil {
return nil, err
}
r.Type = IntegerReply
r.int = i
case resp.BulkStr:
b, err := m.Bytes()
if err != nil {
return nil, err
}
r.Type = BulkReply
r.buf = b
case resp.Nil:
r.Type = NilReply
case resp.Array:
ms, err := m.Array()
if err != nil {
return nil, err
}
r.Type = MultiReply
r.Elems = make([]*Reply, len(ms))
for i := range ms {
r.Elems[i], err = messageToReply(ms[i])
if err != nil {
return nil, err
}
}
}
return r, nil
}
package redis
import (
"bufio"
"bytes"
"github.com/stretchr/testify/assert"
. "testing"
"time"
)
func dial(t *T) *Client {
client, err := DialTimeout("tcp", "127.0.0.1:6379", 10*time.Second)
assert.Nil(t, err)
return client
}
func TestCmd(t *T) {
c := dial(t)
v, _ := c.Cmd("echo", "Hello, World!").Str()
assert.Equal(t, "Hello, World!", v)
// Test that a bad command properly returns a *CmdError
err := c.Cmd("non-existant-cmd").Err
assert.NotEqual(t, "", err.(*CmdError).Error())
// Test that application level errors propagate correctly
c.Cmd("sadd", "foo", "bar")
_, err = c.Cmd("get", "foo").Str()
assert.NotEqual(t, "", err.(*CmdError).Error())
}
func TestPipeline(t *T) {
c := dial(t)
c.Append("echo", "foo")
c.Append("echo", "bar")
c.Append("echo", "zot")
v, _ := c.GetReply().Str()
assert.Equal(t, "foo", v)
v, _ = c.GetReply().Str()
assert.Equal(t, "bar", v)
v, _ = c.GetReply().Str()
assert.Equal(t, "zot", v)
r := c.GetReply()
assert.Equal(t, ErrorReply, r.Type)
assert.Equal(t, PipelineQueueEmptyError, r.Err)
}
func TestParse(t *T) {
c := dial(t)
parseString := func(b string) *Reply {
c.reader = bufio.NewReader(bytes.NewBufferString(b))
return c.parse()
}
// missing \n trailing
r := parseString("foo")
assert.Equal(t, ErrorReply, r.Type)
assert.NotNil(t, r.Err)
// error reply
r = parseString("-ERR unknown command 'foobar'\r\n")
assert.Equal(t, ErrorReply, r.Type)
assert.Equal(t, "ERR unknown command 'foobar'", r.Err.Error())
// LOADING error
r = parseString("-LOADING Redis is loading the dataset in memory\r\n")
assert.Equal(t, ErrorReply, r.Type)
assert.Equal(t, LoadingError, r.Err)
// status reply
r = parseString("+OK\r\n")
assert.Equal(t, StatusReply, r.Type)
assert.Equal(t, []byte("OK"), r.buf)
// integer reply
r = parseString(":1337\r\n")
assert.Equal(t, IntegerReply, r.Type)
assert.Equal(t, int64(1337), r.int)
// null bulk reply
r = parseString("$-1\r\n")
assert.Equal(t, NilReply, r.Type)
// bulk reply
r = parseString("$6\r\nfoobar\r\n")
assert.Equal(t, BulkReply, r.Type)
assert.Equal(t, []byte("foobar"), r.buf)
// null multi bulk reply
r = parseString("*-1\r\n")
assert.Equal(t, NilReply, r.Type)
// multi bulk reply
r = parseString("*5\r\n:0\r\n:1\r\n:2\r\n:3\r\n$6\r\nfoobar\r\n")
assert.Equal(t, MultiReply, r.Type)
assert.Equal(t, 5, len(r.Elems))
for i := 0; i < 4; i++ {
assert.Equal(t, int64(i), r.Elems[i].int)
}
assert.Equal(t, []byte("foobar"), r.Elems[4].buf)
}
// A simple client for connecting and interacting with redis.
//
// To import inside your package do:
//
// import "github.com/fzzy/radix/redis"
//
// Connecting
//
// Use either Dial or DialTimeout:
//
// client, err := redis.Dial("tcp", "localhost:6379")
// if err != nil {
// // handle err
// }
//
// Make sure to call Close on the client if you want to clean it up before the
// end of the program.
//
// Cmd and Reply
//
// The Cmd method returns a Reply, which has methods for converting to various
// types. Each of these methods returns an error which can either be a
// connection error (e.g. timeout), an application error (e.g. key is wrong
// type), or a conversion error (e.g. cannot convert to integer). You can also
// directly check the error using the Err field:
//
// foo, err := client.Cmd("GET", "foo").Str()
// if err != nil {
// // handle err
// }
//
// // Checking Err field directly
//
// err = client.Cmd("PING").Err
// if err != nil {
// // handle err
// }
//
// Multi Replies
//
// The elements to Multi replies can be accessed as strings using List or
// ListBytes, or you can use the Elems field for more low-level access:
//
// r := client.Cmd("MGET", "foo", "bar", "baz")
//
// // This:
// for _, elemStr := range r.List() {
// fmt.Println(elemStr)
// }
//
// // is equivalent to this:
// for i := range r.Elems {
// elemStr, _ := r.Elems[i].Str()
// fmt.Println(elemStr)
// }
//
// Pipelining
//
// Pipelining is when the client sends a bunch of commands to the server at
// once, and only once all the commands have been sent does it start reading the
// replies off the socket. This is supported using the Append and GetReply
// methods. Append will simply append the command to a buffer without sending
// it, the first time GetReply is called it will send all the commands in the
// buffer and return the Reply for the first command that was sent. Subsequent
// calls to GetReply return Replys for subsequent commands:
//
// client.Append("GET", "foo")
// client.Append("SET", "bar", "foo")
// client.Append("DEL", "baz")
//
// // Read GET foo reply
// foo, err := client.GetReply().Str()
// if err != nil {
// // handle err
// }
//
// // Read SET bar foo reply
// if err := client.GetReply().Err; err != nil {
// // handle err
// }
//
// // Read DEL baz reply
// if err := client.GetReply().Err; err != nil {
// // handle err
// }
//
package redis
package redis
import (
"errors"
"strconv"
"strings"
)
// A CmdError implements the error interface and is what is returned when the
// server returns an error on the application level (e.g. key doesn't exist or
// is the wrong type), as opposed to a connection/transport error.
//
// You can test if a reply is a CmdError like so:
//
// r := conn.Cmd("GET", "key-which-isnt-a-string")
// if r.Err != nil {
// if cerr, ok := r.Err.(*redis.CmdError); ok {
// // Is CmdError
// } else {
// // Is other error
// }
// }
type CmdError struct {
Err error
}
func (cerr *CmdError) Error() string {
return cerr.Err.Error()
}
// Returns true if error returned was due to the redis server being read only
func (cerr *CmdError) Readonly() bool {
return strings.HasPrefix(cerr.Err.Error(), "READONLY")
}
//* Reply
/*
ReplyType describes type of a reply.
Possible values are:
StatusReply -- status reply
ErrorReply -- error reply
IntegerReply -- integer reply
NilReply -- nil reply
BulkReply -- bulk reply
MultiReply -- multi bulk reply
*/
type ReplyType uint8
const (
StatusReply ReplyType = iota
ErrorReply
IntegerReply
NilReply
BulkReply
MultiReply
)
// Reply holds a Redis reply.
type Reply struct {
Type ReplyType // Reply type
Elems []*Reply // Sub-replies
Err error // Reply error
buf []byte
int int64
}
// Bytes returns the reply value as a byte string or
// an error, if the reply type is not StatusReply or BulkReply.
func (r *Reply) Bytes() ([]byte, error) {
if r.Type == ErrorReply {
return nil, r.Err
}
if !(r.Type == StatusReply || r.Type == BulkReply) {
return nil, errors.New("string value is not available for this reply type")
}
return r.buf, nil
}
// Str is a convenience method for calling Reply.Bytes() and converting it to string
func (r *Reply) Str() (string, error) {
b, err := r.Bytes()
if err != nil {
return "", err
}
return string(b), nil
}
// Int64 returns the reply value as a int64 or an error,
// if the reply type is not IntegerReply or the reply type
// BulkReply could not be parsed to an int64.
func (r *Reply) Int64() (int64, error) {
if r.Type == ErrorReply {
return 0, r.Err
}
if r.Type != IntegerReply {
s, err := r.Str()
if err == nil {
i64, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0, errors.New("failed to parse integer value from string value")
} else {
return i64, nil
}
}
return 0, errors.New("integer value is not available for this reply type")
}
return r.int, nil
}
// Int is a convenience method for calling Reply.Int64() and converting it to int.
func (r *Reply) Int() (int, error) {
i64, err := r.Int64()
if err != nil {
return 0, err
}
return int(i64), nil
}
// Bool returns false, if the reply value equals to 0 or "0", otherwise true; or
// an error, if the reply type is not IntegerReply or BulkReply.
func (r *Reply) Bool() (bool, error) {
if r.Type == ErrorReply {
return false, r.Err
}
i, err := r.Int()
if err == nil {
if i == 0 {
return false, nil
}
return true, nil
}
s, err := r.Str()
if err == nil {
if s == "0" {
return false, nil
}
return true, nil
}
return false, errors.New("boolean value is not available for this reply type")
}
// List returns a multi bulk reply as a slice of strings or an error.
// The reply type must be MultiReply and its elements' types must all be either BulkReply or NilReply.
// Nil elements are returned as empty strings.
// Useful for list commands.
func (r *Reply) List() ([]string, error) {
// Doing all this in two places instead of just calling ListBytes() so we don't have
// to iterate twice
if r.Type == ErrorReply {
return nil, r.Err
}
if r.Type != MultiReply {
return nil, errors.New("reply type is not MultiReply")
}
strings := make([]string, len(r.Elems))
for i, v := range r.Elems {
if v.Type == BulkReply {
strings[i] = string(v.buf)
} else if v.Type == NilReply {
strings[i] = ""
} else {
return nil, errors.New("element type is not BulkReply or NilReply")
}
}
return strings, nil
}
// ListBytes returns a multi bulk reply as a slice of bytes slices or an error.
// The reply type must be MultiReply and its elements' types must all be either BulkReply or NilReply.
// Nil elements are returned as nil.
// Useful for list commands.
func (r *Reply) ListBytes() ([][]byte, error) {
if r.Type == ErrorReply {
return nil, r.Err
}
if r.Type != MultiReply {
return nil, errors.New("reply type is not MultiReply")
}
bufs := make([][]byte, len(r.Elems))
for i, v := range r.Elems {
if v.Type == BulkReply {
bufs[i] = v.buf
} else if v.Type == NilReply {
bufs[i] = nil
} else {
return nil, errors.New("element type is not BulkReply or NilReply")
}
}
return bufs, nil
}
// Hash returns a multi bulk reply as a map[string]string or an error.
// The reply type must be MultiReply,
// it must have an even number of elements,
// they must be in a "key value key value..." order and
// values must all be either BulkReply or NilReply.
// Nil values are returned as empty strings.
// Useful for hash commands.
func (r *Reply) Hash() (map[string]string, error) {
if r.Type == ErrorReply {
return nil, r.Err
}
rmap := map[string]string{}
if r.Type != MultiReply {
return nil, errors.New("reply type is not MultiReply")
}
if len(r.Elems)%2 != 0 {
return nil, errors.New("reply has odd number of elements")
}
for i := 0; i < len(r.Elems)/2; i++ {
var val string
key, err := r.Elems[i*2].Str()
if err != nil {
return nil, errors.New("key element has no string reply")
}
v := r.Elems[i*2+1]
if v.Type == BulkReply {
val = string(v.buf)
rmap[key] = val
} else if v.Type == NilReply {
} else {
return nil, errors.New("value element type is not BulkReply or NilReply")
}
}
return rmap, nil
}
// String returns a string representation of the reply and its sub-replies.
// This method is for debugging.
// Use method Reply.Str() for reading string reply.
func (r *Reply) String() string {
switch r.Type {
case ErrorReply:
return r.Err.Error()
case StatusReply:
fallthrough
case BulkReply:
return string(r.buf)
case IntegerReply:
return strconv.FormatInt(r.int, 10)
case NilReply:
return "<nil>"
case MultiReply:
s := "[ "
for _, e := range r.Elems {
s = s + e.String() + " "
}
return s + "]"
}
// This should never execute
return ""
}
package redis
import (
"github.com/stretchr/testify/assert"
. "testing"
)
func TestStr(t *T) {
r := &Reply{Type: ErrorReply, Err: LoadingError}
_, err := r.Str()
assert.Equal(t, LoadingError, err)
r = &Reply{Type: IntegerReply}
_, err = r.Str()
assert.NotNil(t, err)
r = &Reply{Type: StatusReply, buf: []byte("foo")}
b, err := r.Str()
assert.Nil(t, err)
assert.Equal(t, "foo", b)
r = &Reply{Type: BulkReply, buf: []byte("foo")}
b, err = r.Str()
assert.Nil(t, err)
assert.Equal(t, "foo", b)
}
func TestBytes(t *T) {
r := &Reply{Type: BulkReply, buf: []byte("foo")}
b, err := r.Bytes()
assert.Nil(t, err)
assert.Equal(t, []byte("foo"), b)
}
func TestInt64(t *T) {
r := &Reply{Type: ErrorReply, Err: LoadingError}
_, err := r.Int64()
assert.Equal(t, LoadingError, err)
r = &Reply{Type: IntegerReply, int: 5}
b, err := r.Int64()
assert.Nil(t, err)
assert.Equal(t, int64(5), b)
r = &Reply{Type: BulkReply, buf: []byte("5")}
b, err = r.Int64()
assert.Nil(t, err)
assert.Equal(t, int64(5), b)
r = &Reply{Type: BulkReply, buf: []byte("foo")}
_, err = r.Int64()
assert.NotNil(t, err)
}
func TestInt(t *T) {
r := &Reply{Type: IntegerReply, int: 5}
b, err := r.Int()
assert.Nil(t, err)
assert.Equal(t, 5, b)
}
func TestBool(t *T) {
r := &Reply{Type: IntegerReply, int: 0}
b, err := r.Bool()
assert.Nil(t, err)
assert.Equal(t, false, b)
r = &Reply{Type: StatusReply, buf: []byte("0")}
b, err = r.Bool()
assert.Nil(t, err)
assert.Equal(t, false, b)
r = &Reply{Type: IntegerReply, int: 2}
b, err = r.Bool()
assert.Nil(t, err)
assert.Equal(t, true, b)
r = &Reply{Type: NilReply}
_, err = r.Bool()
assert.NotNil(t, err)
}
func TestList(t *T) {
r := &Reply{Type: MultiReply}
r.Elems = make([]*Reply, 3)
r.Elems[0] = &Reply{Type: BulkReply, buf: []byte("0")}
r.Elems[1] = &Reply{Type: NilReply}
r.Elems[2] = &Reply{Type: BulkReply, buf: []byte("2")}
l, err := r.List()
assert.Nil(t, err)
assert.Equal(t, 3, len(l))
assert.Equal(t, "0", l[0])
assert.Equal(t, "", l[1])
assert.Equal(t, "2", l[2])
}
func TestBytesList(t *T) {
r := &Reply{Type: MultiReply}
r.Elems = make([]*Reply, 3)
r.Elems[0] = &Reply{Type: BulkReply, buf: []byte("0")}
r.Elems[1] = &Reply{Type: NilReply}
r.Elems[2] = &Reply{Type: BulkReply, buf: []byte("2")}
l, err := r.ListBytes()
assert.Nil(t, err)
assert.Equal(t, 3, len(l))
assert.Equal(t, []byte("0"), l[0])
assert.Nil(t, l[1])
assert.Equal(t, []byte("2"), l[2])
}
func TestHash(t *T) {
r := &Reply{Type: MultiReply}
r.Elems = make([]*Reply, 6)
r.Elems[0] = &Reply{Type: BulkReply, buf: []byte("a")}
r.Elems[1] = &Reply{Type: BulkReply, buf: []byte("0")}
r.Elems[2] = &Reply{Type: BulkReply, buf: []byte("b")}
r.Elems[3] = &Reply{Type: NilReply}
r.Elems[4] = &Reply{Type: BulkReply, buf: []byte("c")}
r.Elems[5] = &Reply{Type: BulkReply, buf: []byte("2")}
h, err := r.Hash()
assert.Nil(t, err)
assert.Equal(t, "0", h["a"])
assert.Equal(t, "", h["b"])
assert.Equal(t, "2", h["c"])
}
// This package provides an easy to use interface for creating and parsing
// messages encoded in the REdis Serialization Protocol (RESP). You can check
// out more details about the protocol here: http://redis.io/topics/protocol
package resp
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"reflect"
"strconv"
)
var (
delim = []byte{'\r', '\n'}
delimEnd = delim[len(delim)-1]
)
type Type int
const (
SimpleStr Type = iota
Err
Int
BulkStr
Array
Nil
)
const (
simpleStrPrefix = '+'
errPrefix = '-'
intPrefix = ':'
bulkStrPrefix = '$'
arrayPrefix = '*'
)
// Parse errors
var (
badType = errors.New("wrong type")
parseErr = errors.New("parse error")
)
type Message struct {
Type
val interface{}
raw []byte
}
// NewMessagePParses the given raw message and returns a Message struct
// representing it
func NewMessage(b []byte) (*Message, error) {
return ReadMessage(bytes.NewReader(b))
}
// Can be used when writing to a resp stream to write a simple-string-style
// stream (e.g. +OK\r\n) instead of the default bulk-string-style strings.
//
// foo := NewSimpleString("foo")
// bar := NewSimpleString("bar")
// baz := NewSimpleString("baz")
// resp.WriteArbitrary(w, foo)
// resp.WriteArbitrary(w, []interface{}{bar, baz})
//
func NewSimpleString(s string) *Message {
b := append(make([]byte, 0, len(s) + 3), '+')
b = append(b, []byte(s)...)
b = append(b, '\r', '\n')
return &Message{
Type: SimpleStr,
val: s,
raw: b,
}
}
// ReadMessage attempts to read a message object from the given io.Reader, parse
// it, and return a Message struct representing it
func ReadMessage(reader io.Reader) (*Message, error) {
r := bufio.NewReader(reader)
return bufioReadMessage(r)
}
func bufioReadMessage(r *bufio.Reader) (*Message, error) {
b, err := r.Peek(1)
if err != nil {
return nil, err
}
switch b[0] {
case simpleStrPrefix:
return readSimpleStr(r)
case errPrefix:
return readError(r)
case intPrefix:
return readInt(r)
case bulkStrPrefix:
return readBulkStr(r)
case arrayPrefix:
return readArray(r)
default:
return nil, badType
}
}
func readSimpleStr(r *bufio.Reader) (*Message, error) {
b, err := r.ReadBytes(delimEnd)
if err != nil {
return nil, err
}
return &Message{Type: SimpleStr, val: b[1 : len(b)-2], raw: b}, nil
}
func readError(r *bufio.Reader) (*Message, error) {
b, err := r.ReadBytes(delimEnd)
if err != nil {
return nil, err
}
return &Message{Type: Err, val: b[1 : len(b)-2], raw: b}, nil
}
func readInt(r *bufio.Reader) (*Message, error) {
b, err := r.ReadBytes(delimEnd)
if err != nil {
return nil, err
}
i, err := strconv.ParseInt(string(b[1:len(b)-2]), 10, 64)
if err != nil {
return nil, parseErr
}
return &Message{Type: Int, val: i, raw: b}, nil
}
func readBulkStr(r *bufio.Reader) (*Message, error) {
b, err := r.ReadBytes(delimEnd)
if err != nil {
return nil, err
}
size, err := strconv.ParseInt(string(b[1:len(b)-2]), 10, 64)
if err != nil {
return nil, parseErr
}
if size < 0 {
return &Message{Type: Nil, raw: b}, nil
}
total := make([]byte, size)
b2 := total
var n int
for len(b2) > 0 {
n, err = r.Read(b2)
if err != nil {
return nil, err
}
b2 = b2[n:]
}
// There's a hanging \r\n there, gotta read past it
trail := make([]byte, 2)
for i := 0; i < 2; i++ {
if c, err := r.ReadByte(); err != nil {
return nil, err
} else {
trail[i] = c
}
}
blens := len(b) + len(total)
raw := make([]byte, 0, blens+2)
raw = append(raw, b...)
raw = append(raw, total...)
raw = append(raw, trail...)
return &Message{Type: BulkStr, val: total, raw: raw}, nil
}
func readArray(r *bufio.Reader) (*Message, error) {
b, err := r.ReadBytes(delimEnd)
if err != nil {
return nil, err
}
size, err := strconv.ParseInt(string(b[1:len(b)-2]), 10, 64)
if err != nil {
return nil, parseErr
}
if size < 0 {
return &Message{Type: Nil, raw: b}, nil
}
arr := make([]*Message, size)
for i := range arr {
m, err := bufioReadMessage(r)
if err != nil {
return nil, err
}
arr[i] = m
b = append(b, m.raw...)
}
return &Message{Type: Array, val: arr, raw: b}, nil
}
// Bytes returns a byte slice representing the value of the Message. Only valid
// for a Message of type SimpleStr, Err, and BulkStr. Others will return an
// error
func (m *Message) Bytes() ([]byte, error) {
if b, ok := m.val.([]byte); ok {
return b, nil
}
return nil, badType
}
// Str is a Convenience method around Bytes which converts the output to a
// string
func (m *Message) Str() (string, error) {
b, err := m.Bytes()
if err != nil {
return "", err
}
return string(b), nil
}
// Int returns an int64 representing the value of the Message. Only valid for
// Int messages
func (m *Message) Int() (int64, error) {
if i, ok := m.val.(int64); ok {
return i, nil
}
return 0, badType
}
// Err returns an error representing the value of the Message. Only valid for
// Err messages
func (m *Message) Err() (error, error) {
if m.Type != Err {
return nil, badType
}
s, err := m.Str()
if err != nil {
return nil, err
}
return errors.New(s), nil
}
// Array returns the Message slice encompassed by this Messsage, assuming the
// Message is of type Array
func (m *Message) Array() ([]*Message, error) {
if a, ok := m.val.([]*Message); ok {
return a, nil
}
return nil, badType
}
// WriteMessage takes in the given Message and writes its encoded form to the
// given io.Writer
func WriteMessage(w io.Writer, m *Message) error {
_, err := w.Write(m.raw)
return err
}
// WriteArbitrary takes in any primitive golang value, or Message, and writes
// its encoded form to the given io.Writer, inferring types where appropriate.
func WriteArbitrary(w io.Writer, m interface{}) error {
b := format(m, false)
_, err := w.Write(b)
return err
}
// WriteArbitraryAsString is similar to WriteArbitraryAsFlattenedString except
// that it won't flatten any embedded arrays.
func WriteArbitraryAsString(w io.Writer, m interface{}) error {
b := format(m, true)
_, err := w.Write(b)
return err
}
// WriteArbitraryAsFlattenedStrings is similar to WriteArbitrary except that it
// will encode all types except Array as a BulkStr, converting the argument into
// a string first as necessary. It will also flatten any embedded arrays into a
// single long array. This is useful because commands to a redis server must be
// given as an array of bulk strings. If the argument isn't already in a slice
// or map it will be wrapped so that it is written as an Array of size one.
//
// Note that if a Message type is found it will *not* be encoded to a BulkStr,
// but will simply be passed through as whatever type it already represents.
func WriteArbitraryAsFlattenedStrings(w io.Writer, m interface{}) error {
fm := flatten(m)
return WriteArbitraryAsString(w, fm)
}
func format(m interface{}, forceString bool) []byte {
switch mt := m.(type) {
case []byte:
return formatStr(mt)
case string:
return formatStr([]byte(mt))
case bool:
if mt {
return formatStr([]byte("1"))
} else {
return formatStr([]byte("0"))
}
case nil:
if forceString {
return formatStr([]byte{})
} else {
return formatNil()
}
case int:
return formatInt(int64(mt), forceString)
case int8:
return formatInt(int64(mt), forceString)
case int16:
return formatInt(int64(mt), forceString)
case int32:
return formatInt(int64(mt), forceString)
case int64:
return formatInt(mt, forceString)
case uint:
return formatInt(int64(mt), forceString)
case uint8:
return formatInt(int64(mt), forceString)
case uint16:
return formatInt(int64(mt), forceString)
case uint32:
return formatInt(int64(mt), forceString)
case uint64:
return formatInt(int64(mt), forceString)
case float32:
ft := strconv.FormatFloat(float64(mt), 'f', -1, 32)
return formatStr([]byte(ft))
case float64:
ft := strconv.FormatFloat(mt, 'f', -1, 64)
return formatStr([]byte(ft))
case error:
if forceString {
return formatStr([]byte(mt.Error()))
} else {
return formatErr(mt)
}
// We duplicate the below code here a bit, since this is the common case and
// it'd be better to not get the reflect package involved here
case []interface{}:
l := len(mt)
b := make([]byte, 0, l*1024)
b = append(b, '*')
b = append(b, []byte(strconv.Itoa(l))...)
b = append(b, []byte("\r\n")...)
for i := 0; i < l; i++ {
b = append(b, format(mt[i], forceString)...)
}
return b
case *Message:
return mt.raw
default:
// Fallback to reflect-based.
switch reflect.TypeOf(m).Kind() {
case reflect.Slice:
rm := reflect.ValueOf(mt)
l := rm.Len()
b := make([]byte, 0, l*1024)
b = append(b, '*')
b = append(b, []byte(strconv.Itoa(l))...)
b = append(b, []byte("\r\n")...)
for i := 0; i < l; i++ {
vv := rm.Index(i).Interface()
b = append(b, format(vv, forceString)...)
}
return b
case reflect.Map:
rm := reflect.ValueOf(mt)
l := rm.Len() * 2
b := make([]byte, 0, l*1024)
b = append(b, '*')
b = append(b, []byte(strconv.Itoa(l))...)
b = append(b, []byte("\r\n")...)
keys := rm.MapKeys()
for _, k := range keys {
kv := k.Interface()
vv := rm.MapIndex(k).Interface()
b = append(b, format(kv, forceString)...)
b = append(b, format(vv, forceString)...)
}
return b
default:
return formatStr([]byte(fmt.Sprint(m)))
}
}
}
var typeOfBytes = reflect.TypeOf([]byte(nil))
func flatten(m interface{}) []interface{} {
t := reflect.TypeOf(m)
// If it's a byte-slice we don't want to flatten
if t == typeOfBytes {
return []interface{}{m}
}
switch t.Kind() {
case reflect.Slice:
rm := reflect.ValueOf(m)
l := rm.Len()
ret := make([]interface{}, 0, l)
for i := 0; i < l; i++ {
ret = append(ret, flatten(rm.Index(i).Interface())...)
}
return ret
case reflect.Map:
rm := reflect.ValueOf(m)
l := rm.Len() * 2
keys := rm.MapKeys()
ret := make([]interface{}, 0, l)
for _, k := range keys {
kv := k.Interface()
vv := rm.MapIndex(k).Interface()
ret = append(ret, flatten(kv)...)
ret = append(ret, flatten(vv)...)
}
return ret
default:
return []interface{}{m}
}
}
func formatStr(b []byte) []byte {
l := strconv.Itoa(len(b))
bs := make([]byte, 0, len(l)+len(b)+5)
bs = append(bs, bulkStrPrefix)
bs = append(bs, []byte(l)...)
bs = append(bs, delim...)
bs = append(bs, b...)
bs = append(bs, delim...)
return bs
}
func formatErr(ierr error) []byte {
ierrstr := []byte(ierr.Error())
bs := make([]byte, 0, len(ierrstr)+3)
bs = append(bs, errPrefix)
bs = append(bs, ierrstr...)
bs = append(bs, delim...)
return bs
}
func formatInt(i int64, forceString bool) []byte {
istr := strconv.FormatInt(i, 10)
if forceString {
return formatStr([]byte(istr))
}
bs := make([]byte, 0, len(istr)+3)
bs = append(bs, intPrefix)
bs = append(bs, istr...)
bs = append(bs, delim...)
return bs
}
var nilFormatted = []byte("$-1\r\n")
func formatNil() []byte {
return nilFormatted
}
package resp
import (
"bytes"
"errors"
"github.com/stretchr/testify/assert"
. "testing"
)
func TestRead(t *T) {
var m *Message
var err error
_, err = NewMessage(nil)
assert.NotNil(t, err)
_, err = NewMessage([]byte{})
assert.NotNil(t, err)
// Simple string
m, _ = NewMessage([]byte("+ohey\r\n"))
assert.Equal(t, SimpleStr, m.Type)
assert.Equal(t, []byte("ohey"), m.val)
// Empty simple string
m, _ = NewMessage([]byte("+\r\n"))
assert.Equal(t, SimpleStr, m.Type)
assert.Equal(t, []byte(""), m.val.([]byte))
// Error
m, _ = NewMessage([]byte("-ohey\r\n"))
assert.Equal(t, Err, m.Type)
assert.Equal(t, []byte("ohey"), m.val.([]byte))
// Empty error
m, _ = NewMessage([]byte("-\r\n"))
assert.Equal(t, Err, m.Type)
assert.Equal(t, []byte(""), m.val.([]byte))
// Int
m, _ = NewMessage([]byte(":1024\r\n"))
assert.Equal(t, Int, m.Type)
assert.Equal(t, int64(1024), m.val.(int64))
// Bulk string
m, _ = NewMessage([]byte("$3\r\nfoo\r\n"))
assert.Equal(t, BulkStr, m.Type)
assert.Equal(t, []byte("foo"), m.val.([]byte))
// Empty bulk string
m, _ = NewMessage([]byte("$0\r\n\r\n"))
assert.Equal(t, BulkStr, m.Type)
assert.Equal(t, []byte(""), m.val.([]byte))
// Nil bulk string
m, _ = NewMessage([]byte("$-1\r\n"))
assert.Equal(t, Nil, m.Type)
// Array
m, _ = NewMessage([]byte("*2\r\n+foo\r\n+bar\r\n"))
assert.Equal(t, Array, m.Type)
assert.Equal(t, 2, len(m.val.([]*Message)))
assert.Equal(t, SimpleStr, m.val.([]*Message)[0].Type)
assert.Equal(t, []byte("foo"), m.val.([]*Message)[0].val.([]byte))
assert.Equal(t, SimpleStr, m.val.([]*Message)[1].Type)
assert.Equal(t, []byte("bar"), m.val.([]*Message)[1].val.([]byte))
// Empty array
m, _ = NewMessage([]byte("*0\r\n"))
assert.Equal(t, Array, m.Type)
assert.Equal(t, 0, len(m.val.([]*Message)))
// Nil Array
m, _ = NewMessage([]byte("*-1\r\n"))
assert.Equal(t, Nil, m.Type)
// Embedded Array
m, _ = NewMessage([]byte("*3\r\n+foo\r\n+bar\r\n*2\r\n+foo\r\n+bar\r\n"))
assert.Equal(t, Array, m.Type)
assert.Equal(t, 3, len(m.val.([]*Message)))
assert.Equal(t, SimpleStr, m.val.([]*Message)[0].Type)
assert.Equal(t, []byte("foo"), m.val.([]*Message)[0].val.([]byte))
assert.Equal(t, SimpleStr, m.val.([]*Message)[1].Type)
assert.Equal(t, []byte("bar"), m.val.([]*Message)[1].val.([]byte))
m = m.val.([]*Message)[2]
assert.Equal(t, 2, len(m.val.([]*Message)))
assert.Equal(t, SimpleStr, m.val.([]*Message)[0].Type)
assert.Equal(t, []byte("foo"), m.val.([]*Message)[0].val.([]byte))
assert.Equal(t, SimpleStr, m.val.([]*Message)[1].Type)
assert.Equal(t, []byte("bar"), m.val.([]*Message)[1].val.([]byte))
// Test that two bulks in a row read correctly
m, _ = NewMessage([]byte("*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n"))
assert.Equal(t, Array, m.Type)
assert.Equal(t, 2, len(m.val.([]*Message)))
assert.Equal(t, BulkStr, m.val.([]*Message)[0].Type)
assert.Equal(t, []byte("foo"), m.val.([]*Message)[0].val.([]byte))
assert.Equal(t, BulkStr, m.val.([]*Message)[1].Type)
assert.Equal(t, []byte("bar"), m.val.([]*Message)[1].val.([]byte))
}
type arbitraryTest struct {
val interface{}
expect []byte
}
var nilMessage, _ = NewMessage([]byte("$-1\r\n"))
var arbitraryTests = []arbitraryTest{
{[]byte("OHAI"), []byte("$4\r\nOHAI\r\n")},
{"OHAI", []byte("$4\r\nOHAI\r\n")},
{true, []byte("$1\r\n1\r\n")},
{false, []byte("$1\r\n0\r\n")},
{nil, []byte("$-1\r\n")},
{80, []byte(":80\r\n")},
{int64(-80), []byte(":-80\r\n")},
{uint64(80), []byte(":80\r\n")},
{float32(0.1234), []byte("$6\r\n0.1234\r\n")},
{float64(0.1234), []byte("$6\r\n0.1234\r\n")},
{errors.New("hi"), []byte("-hi\r\n")},
{nilMessage, []byte("$-1\r\n")},
{[]int{1, 2, 3}, []byte("*3\r\n:1\r\n:2\r\n:3\r\n")},
{map[int]int{1: 2}, []byte("*2\r\n:1\r\n:2\r\n")},
{NewSimpleString("OK"), []byte("+OK\r\n")},
}
var arbitraryAsStringTests = []arbitraryTest{
{[]byte("OHAI"), []byte("$4\r\nOHAI\r\n")},
{"OHAI", []byte("$4\r\nOHAI\r\n")},
{true, []byte("$1\r\n1\r\n")},
{false, []byte("$1\r\n0\r\n")},
{nil, []byte("$0\r\n\r\n")},
{80, []byte("$2\r\n80\r\n")},
{int64(-80), []byte("$3\r\n-80\r\n")},
{uint64(80), []byte("$2\r\n80\r\n")},
{float32(0.1234), []byte("$6\r\n0.1234\r\n")},
{float64(0.1234), []byte("$6\r\n0.1234\r\n")},
{errors.New("hi"), []byte("$2\r\nhi\r\n")},
{nilMessage, []byte("$-1\r\n")},
{[]int{1, 2, 3}, []byte("*3\r\n$1\r\n1\r\n$1\r\n2\r\n$1\r\n3\r\n")},
{map[int]int{1: 2}, []byte("*2\r\n$1\r\n1\r\n$1\r\n2\r\n")},
{NewSimpleString("OK"), []byte("+OK\r\n")},
}
var arbitraryAsFlattenedStringsTests = []arbitraryTest{
{
[]interface{}{"wat", map[string]interface{}{
"foo": 1,
}},
[]byte("*3\r\n$3\r\nwat\r\n$3\r\nfoo\r\n$1\r\n1\r\n"),
},
}
func TestWriteArbitrary(t *T) {
var err error
buf := bytes.NewBuffer([]byte{})
for _, test := range arbitraryTests {
t.Logf("Checking test %v", test)
buf.Reset()
err = WriteArbitrary(buf, test.val)
assert.Nil(t, err)
assert.Equal(t, test.expect, buf.Bytes())
}
}
func TestWriteArbitraryAsString(t *T) {
var err error
buf := bytes.NewBuffer([]byte{})
for _, test := range arbitraryAsStringTests {
t.Logf("Checking test %v", test)
buf.Reset()
err = WriteArbitraryAsString(buf, test.val)
assert.Nil(t, err)
assert.Equal(t, test.expect, buf.Bytes())
}
}
func TestWriteArbitraryAsFlattenedStrings(t *T) {
var err error
buf := bytes.NewBuffer([]byte{})
for _, test := range arbitraryAsFlattenedStringsTests {
t.Logf("Checking test %v", test)
buf.Reset()
err = WriteArbitraryAsFlattenedStrings(buf, test.val)
assert.Nil(t, err)
assert.Equal(t, test.expect, buf.Bytes())
}
}
func TestMessageWrite(t *T) {
var err error
var m *Message
buf := bytes.NewBuffer([]byte{})
for _, test := range arbitraryTests {
t.Logf("Checking test; %v", test)
buf.Reset()
m, err = NewMessage(test.expect)
assert.Nil(t, err)
err = WriteMessage(buf, m)
assert.Nil(t, err)
assert.Equal(t, test.expect, buf.Bytes())
}
}
package datastore
import (
"io"
"log"
dsq "github.com/jbenet/go-datastore/query"
......@@ -67,6 +68,10 @@ func (d *MapDatastore) Batch() (Batch, error) {
return NewBasicBatch(d), nil
}
func (d *MapDatastore) Close() error {
return nil
}
// NullDatastore stores nothing, but conforms to the API.
// Useful to test with.
type NullDatastore struct {
......@@ -106,6 +111,10 @@ func (d *NullDatastore) Batch() (Batch, error) {
return NewBasicBatch(d), nil
}
func (d *NullDatastore) Close() error {
return nil
}
// LogDatastore logs all accesses through the datastore.
type LogDatastore struct {
Name string
......@@ -165,9 +174,16 @@ func (d *LogDatastore) Query(q dsq.Query) (dsq.Results, error) {
func (d *LogDatastore) Batch() (Batch, error) {
log.Printf("%s: Batch\n", d.Name)
bds, ok := d.child.(BatchingDatastore)
if !ok {
return nil, ErrBatchUnsupported
if bds, ok := d.child.(Batching); ok {
return bds.Batch()
}
return bds.Batch()
return nil, ErrBatchUnsupported
}
func (d *LogDatastore) Close() error {
log.Printf("%s: Close\n", d.Name)
if cds, ok := d.child.(io.Closer); ok {
return cds.Close()
}
return nil
}
package coalesce
import (
"io"
"sync"
ds "github.com/jbenet/go-datastore"
......@@ -124,3 +125,16 @@ func (d *datastore) Query(q dsq.Query) (dsq.Results, error) {
// query not coalesced yet.
return d.child.Query(q)
}
func (d *datastore) Close() error {
d.reqmu.Lock()
defer d.reqmu.Unlock()
for _, s := range d.req {
<-s.done
}
if c, ok := d.child.(io.Closer); ok {
return c.Close()
}
return nil
}
......@@ -69,7 +69,7 @@ type Datastore interface {
Query(q query.Query) (query.Results, error)
}
type BatchingDatastore interface {
type Batching interface {
Datastore
Batch() (Batch, error)
......
......@@ -323,6 +323,10 @@ func (fs *Datastore) enumerateKeys(fi os.FileInfo, res []query.Entry) ([]query.E
return res, nil
}
func (fs *Datastore) Close() error {
return nil
}
type flatfsBatch struct {
puts map[datastore.Key]interface{}
deletes map[datastore.Key]struct{}
......
......@@ -149,3 +149,11 @@ func isFile(path string) bool {
return !finfo.IsDir()
}
func (d *Datastore) Close() error {
return nil
}
func (d *Datastore) Batch() (ds.Batch, error) {
return ds.NewBasicBatch(d), nil
}
......@@ -16,8 +16,6 @@ type KeyTransform interface {
type Datastore interface {
ds.Shim
KeyTransform
Batch() (ds.Batch, error)
}
// Wrap wraps a given datastore with a KeyTransform function.
......
package keytransform
import (
"io"
ds "github.com/jbenet/go-datastore"
dsq "github.com/jbenet/go-datastore/query"
)
......@@ -74,8 +76,15 @@ func (d *ktds) Query(q dsq.Query) (dsq.Results, error) {
return dsq.DerivedResults(qr, ch), nil
}
func (d *ktds) Close() error {
if c, ok := d.child.(io.Closer); ok {
return c.Close()
}
return nil
}
func (d *ktds) Batch() (ds.Batch, error) {
bds, ok := d.child.(ds.BatchingDatastore)
bds, ok := d.child.(ds.Batching)
if !ok {
return nil, ds.ErrBatchUnsupported
}
......
package leveldb
import (
"io"
ds "github.com/jbenet/go-datastore"
"github.com/jbenet/go-datastore/Godeps/_workspace/src/github.com/jbenet/goprocess"
"github.com/jbenet/go-datastore/Godeps/_workspace/src/github.com/syndtr/goleveldb/leveldb"
......@@ -11,18 +9,13 @@ import (
dsq "github.com/jbenet/go-datastore/query"
)
type Datastore interface {
ds.ThreadSafeDatastore
io.Closer
}
type datastore struct {
DB *leveldb.DB
}
type Options opt.Options
func NewDatastore(path string, opts *Options) (Datastore, error) {
func NewDatastore(path string, opts *Options) (*datastore, error) {
var nopts opt.Options
if opts != nil {
nopts = opt.Options(*opts)
......@@ -148,6 +141,11 @@ func (d *datastore) runQuery(worker goprocess.Process, qrb *dsq.ResultBuilder) {
}
}
func (d *datastore) Batch() (ds.Batch, error) {
// TODO: implement batch on leveldb
return nil, ds.ErrBatchUnsupported
}
// LevelDB needs to be closed.
func (d *datastore) Close() (err error) {
return d.DB.Close()
......
......@@ -25,7 +25,7 @@ var testcases = map[string]string{
//
// d, close := newDS(t)
// defer close()
func newDS(t *testing.T) (Datastore, func()) {
func newDS(t *testing.T) (*datastore, func()) {
path, err := ioutil.TempDir("/tmp", "testing_leveldb_")
if err != nil {
t.Fatal(err)
......@@ -41,7 +41,7 @@ func newDS(t *testing.T) (Datastore, func()) {
}
}
func addTestCases(t *testing.T, d Datastore, testcases map[string]string) {
func addTestCases(t *testing.T, d *datastore, testcases map[string]string) {
for k, v := range testcases {
dsk := ds.NewKey(k)
if err := d.Put(dsk, []byte(v)); err != nil {
......
......@@ -54,3 +54,11 @@ func (d *Datastore) Delete(key ds.Key) (err error) {
func (d *Datastore) Query(q dsq.Query) (dsq.Results, error) {
return nil, errors.New("KeyList not implemented.")
}
func (d *Datastore) Close() error {
return nil
}
func (d *Datastore) Batch() (ds.Batch, error) {
return nil, ds.ErrBatchUnsupported
}
......@@ -3,6 +3,7 @@
package measure
import (
"io"
"time"
"github.com/jbenet/go-datastore"
......@@ -18,17 +19,12 @@ const (
maxSize = int64(1 << 32)
)
type DatastoreCloser interface {
datastore.Datastore
Close() error
}
// New wraps the datastore, providing metrics on the operations. The
// metrics are registered with names starting with prefix and a dot.
//
// If prefix is not unique, New will panic. Call Close to release the
// prefix.
func New(prefix string, ds datastore.Datastore) DatastoreCloser {
func New(prefix string, ds datastore.Datastore) *measure {
m := &measure{
backend: ds,
......@@ -84,7 +80,6 @@ type measure struct {
}
var _ datastore.Datastore = (*measure)(nil)
var _ DatastoreCloser = (*measure)(nil)
func recordLatency(h *metrics.Histogram, start time.Time) {
elapsed := time.Now().Sub(start) / time.Microsecond
......@@ -159,7 +154,7 @@ type measuredBatch struct {
}
func (m *measure) Batch() (datastore.Batch, error) {
bds, ok := m.backend.(datastore.BatchingDatastore)
bds, ok := m.backend.(datastore.Batching)
if !ok {
return nil, datastore.ErrBatchUnsupported
}
......@@ -245,5 +240,9 @@ func (m *measure) Close() error {
m.queryNum.Remove()
m.queryErr.Remove()
m.queryLatency.Remove()
if c, ok := m.backend.(io.Closer); ok {
return c.Close()
}
return nil
}
......@@ -4,6 +4,7 @@ package mount
import (
"errors"
"io"
"strings"
"github.com/jbenet/go-datastore"
......@@ -115,6 +116,18 @@ func (d *Datastore) Query(q query.Query) (query.Results, error) {
return r, nil
}
func (d *Datastore) Close() error {
for _, d := range d.mounts {
if c, ok := d.Datastore.(io.Closer); ok {
err := c.Close()
if err != nil {
return err
}
}
}
return nil
}
type mountBatch struct {
mounts map[string]datastore.Batch
......@@ -132,7 +145,7 @@ func (mt *mountBatch) lookupBatch(key datastore.Key) (datastore.Batch, datastore
child, loc, rest := mt.d.lookup(key)
t, ok := mt.mounts[loc.String()]
if !ok {
bds, ok := child.(datastore.BatchingDatastore)
bds, ok := child.(datastore.Batching)
if !ok {
return nil, datastore.NewKey(""), datastore.ErrBatchUnsupported
}
......
......@@ -2,6 +2,7 @@ package sync
import (
"fmt"
"io"
"os"
ds "github.com/jbenet/go-datastore"
......@@ -67,6 +68,26 @@ func (d *datastore) Query(q dsq.Query) (dsq.Results, error) {
return r, nil
}
func (d *datastore) Close() error {
if c, ok := d.child.(io.Closer); ok {
err := c.Close()
if err != nil {
fmt.Fprintf(os.Stdout, "panic datastore: %s", err)
panic("panic datastore: Close failed")
}
}
return nil
}
func (d *datastore) Batch() (ds.Batch, error) {
b, err := d.child.(ds.Batching).Batch()
if err != nil {
return nil, err
}
return &panicBatch{b}, nil
}
type panicBatch struct {
t ds.Batch
}
......
......@@ -6,9 +6,8 @@ import (
"sync"
"time"
"github.com/fzzy/radix/redis"
datastore "github.com/jbenet/go-datastore"
"github.com/jbenet/go-datastore/Godeps/_workspace/src/github.com/fzzy/radix/redis"
query "github.com/jbenet/go-datastore/query"
)
......@@ -17,14 +16,14 @@ var _ datastore.ThreadSafeDatastore = &Datastore{}
var ErrInvalidType = errors.New("redis datastore: invalid type error. this datastore only supports []byte values")
func NewExpiringDatastore(client *redis.Client, ttl time.Duration) (datastore.ThreadSafeDatastore, error) {
func NewExpiringDatastore(client *redis.Client, ttl time.Duration) (*Datastore, error) {
return &Datastore{
client: client,
ttl: ttl,
}, nil
}
func NewDatastore(client *redis.Client) (datastore.ThreadSafeDatastore, error) {
func NewDatastore(client *redis.Client) (*Datastore, error) {
return &Datastore{
client: client,
}, nil
......@@ -83,3 +82,11 @@ func (ds *Datastore) Query(q query.Query) (query.Results, error) {
}
func (ds *Datastore) IsThreadSafe() {}
func (ds *Datastore) Batch() (datastore.Batch, error) {
return nil, datastore.ErrBatchUnsupported
}
func (ds *Datastore) Close() error {
return ds.client.Close()
}
......@@ -6,8 +6,8 @@ import (
"testing"
"time"
"github.com/fzzy/radix/redis"
datastore "github.com/jbenet/go-datastore"
"github.com/jbenet/go-datastore/Godeps/_workspace/src/github.com/fzzy/radix/redis"
dstest "github.com/jbenet/go-datastore/test"
)
......
package sync
import (
"io"
"sync"
ds "github.com/jbenet/go-datastore"
......@@ -67,7 +68,7 @@ func (d *MutexDatastore) Query(q dsq.Query) (dsq.Results, error) {
func (d *MutexDatastore) Batch() (ds.Batch, error) {
d.RLock()
defer d.RUnlock()
bds, ok := d.child.(ds.BatchingDatastore)
bds, ok := d.child.(ds.Batching)
if !ok {
return nil, ds.ErrBatchUnsupported
}
......@@ -81,6 +82,15 @@ func (d *MutexDatastore) Batch() (ds.Batch, error) {
}, nil
}
func (d *MutexDatastore) Close() error {
d.RWMutex.Lock()
defer d.RWMutex.Unlock()
if c, ok := d.child.(io.Closer); ok {
return c.Close()
}
return nil
}
type syncBatch struct {
lk sync.Mutex
batch ds.Batch
......
......@@ -9,7 +9,7 @@ import (
rand "github.com/jbenet/go-datastore/Godeps/_workspace/src/github.com/dustin/randbo"
)
func RunBatchTest(t *testing.T, ds dstore.BatchingDatastore) {
func RunBatchTest(t *testing.T, ds dstore.Batching) {
batch, err := ds.Batch()
if err != nil {
t.Fatal(err)
......@@ -58,7 +58,7 @@ func RunBatchTest(t *testing.T, ds dstore.BatchingDatastore) {
}
}
func RunBatchDeleteTest(t *testing.T, ds dstore.BatchingDatastore) {
func RunBatchDeleteTest(t *testing.T, ds dstore.Batching) {
r := rand.New()
var keys []dstore.Key
for i := 0; i < 20; i++ {
......
......@@ -13,7 +13,7 @@ type tiered []ds.Datastore
// New returns a tiered datastore. Puts and Deletes will write-through to
// all datastores, Has and Get will try each datastore sequentially, and
// Query will always try the last one (most complete) first.
func New(dses ...ds.Datastore) ds.Datastore {
func New(dses ...ds.Datastore) tiered {
return tiered(dses)
}
......
......@@ -49,19 +49,19 @@ func TestTiered(t *testing.T) {
td := New(d1, d2, d3, d4)
td.Put(ds.NewKey("foo"), "bar")
testHas(t, []ds.Datastore{td}, ds.NewKey("foo"), "bar")
testHas(t, td.(tiered), ds.NewKey("foo"), "bar") // all children
testHas(t, td, ds.NewKey("foo"), "bar") // all children
// remove it from, say, caches.
d1.Delete(ds.NewKey("foo"))
d2.Delete(ds.NewKey("foo"))
testHas(t, []ds.Datastore{td}, ds.NewKey("foo"), "bar")
testHas(t, td.(tiered)[2:], ds.NewKey("foo"), "bar")
testNotHas(t, td.(tiered)[:2], ds.NewKey("foo"))
testHas(t, td[2:], ds.NewKey("foo"), "bar")
testNotHas(t, td[:2], ds.NewKey("foo"))
// write it again.
td.Put(ds.NewKey("foo"), "bar2")
testHas(t, []ds.Datastore{td}, ds.NewKey("foo"), "bar2")
testHas(t, td.(tiered), ds.NewKey("foo"), "bar2")
testHas(t, td, ds.NewKey("foo"), "bar2")
}
func TestQueryCallsLast(t *testing.T) {
......
package timecache
import (
"io"
"sync"
"time"
......@@ -24,13 +25,13 @@ type datastore struct {
ttls map[ds.Key]time.Time
}
func WithTTL(ttl time.Duration) ds.Datastore {
func WithTTL(ttl time.Duration) *datastore {
return WithCache(ds.NewMapDatastore(), ttl)
}
// WithCache wraps a given datastore as a timecache.
// Get + Has requests are considered expired after a TTL.
func WithCache(d ds.Datastore, ttl time.Duration) ds.Datastore {
func WithCache(d ds.Datastore, ttl time.Duration) *datastore {
return &datastore{cache: d, ttl: ttl, ttls: make(map[ds.Key]time.Time)}
}
......@@ -94,3 +95,10 @@ func (d *datastore) Delete(key ds.Key) (err error) {
func (d *datastore) Query(q dsq.Query) (dsq.Results, error) {
return d.cache.Query(q)
}
func (d *datastore) Close() error {
if c, ok := d.cache.(io.Closer); ok {
return c.Close()
}
return nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment