Commit a8f03e4e authored by Marten Seemann's avatar Marten Seemann

add a context to NewStream, remove the NewStreamTimeout

parent 67680fbd
package multiplex package multiplex
import ( import (
"context"
"io" "io"
"math/rand" "math/rand"
"net" "net"
...@@ -64,7 +65,7 @@ func testSmallPackets(b *testing.B, n1, n2 net.Conn) { ...@@ -64,7 +65,7 @@ func testSmallPackets(b *testing.B, n1, n2 net.Conn) {
streamPairs := make([][]*Stream, 0) streamPairs := make([][]*Stream, 0)
for i := 0; i < mp; i++ { for i := 0; i < mp; i++ {
sa, err := mpa.NewStream() sa, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
...@@ -190,7 +191,7 @@ func benchmarkPackets(b *testing.B, msgs [][]byte) { ...@@ -190,7 +191,7 @@ func benchmarkPackets(b *testing.B, msgs [][]byte) {
func benchmarkPacketsWithConn(b *testing.B, parallelism int, msgs [][]byte, mpa, mpb *Multiplex) { func benchmarkPacketsWithConn(b *testing.B, parallelism int, msgs [][]byte, mpa, mpb *Multiplex) {
streamPairs := make([][]*Stream, 0) streamPairs := make([][]*Stream, 0)
for i := 0; i < parallelism*runtime.GOMAXPROCS(0); i++ { for i := 0; i < parallelism*runtime.GOMAXPROCS(0); i++ {
sa, err := mpa.NewStream() sa, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
......
...@@ -39,7 +39,6 @@ var errTimeout = timeout{} ...@@ -39,7 +39,6 @@ var errTimeout = timeout{}
var errStreamClosed = errors.New("stream closed") var errStreamClosed = errors.New("stream closed")
var ( var (
NewStreamTimeout = time.Minute
ResetStreamTimeout = 2 * time.Minute ResetStreamTimeout = 2 * time.Minute
WriteCoalesceDelay = 100 * time.Microsecond WriteCoalesceDelay = 100 * time.Microsecond
...@@ -291,12 +290,12 @@ func (mp *Multiplex) nextChanID() uint64 { ...@@ -291,12 +290,12 @@ func (mp *Multiplex) nextChanID() uint64 {
} }
// NewStream creates a new stream. // NewStream creates a new stream.
func (mp *Multiplex) NewStream() (*Stream, error) { func (mp *Multiplex) NewStream(ctx context.Context) (*Stream, error) {
return mp.NewNamedStream("") return mp.NewNamedStream(ctx, "")
} }
// NewNamedStream creates a new named stream. // NewNamedStream creates a new named stream.
func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) { func (mp *Multiplex) NewNamedStream(ctx context.Context, name string) (*Stream, error) {
mp.chLock.Lock() mp.chLock.Lock()
// We could call IsClosed but this is faster (given that we already have // We could call IsClosed but this is faster (given that we already have
...@@ -319,11 +318,11 @@ func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) { ...@@ -319,11 +318,11 @@ func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) {
mp.channels[s.id] = s mp.channels[s.id] = s
mp.chLock.Unlock() mp.chLock.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), NewStreamTimeout)
defer cancel()
err := mp.sendMsg(ctx.Done(), nil, header, []byte(name)) err := mp.sendMsg(ctx.Done(), nil, header, []byte(name))
if err != nil { if err != nil {
if err == errTimeout {
return nil, ctx.Err()
}
return nil, err return nil, err
} }
......
package multiplex package multiplex
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -28,7 +29,7 @@ func TestSlowReader(t *testing.T) { ...@@ -28,7 +29,7 @@ func TestSlowReader(t *testing.T) {
mes := []byte("Hello world") mes := []byte("Hello world")
sa, err := mpa.NewStream() sa, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -85,7 +86,7 @@ func TestBasicStreams(t *testing.T) { ...@@ -85,7 +86,7 @@ func TestBasicStreams(t *testing.T) {
} }
}() }()
s, err := mpa.NewStream() s, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -105,6 +106,32 @@ func TestBasicStreams(t *testing.T) { ...@@ -105,6 +106,32 @@ func TestBasicStreams(t *testing.T) {
mpb.Close() mpb.Close()
} }
func TestOpenStreamDeadline(t *testing.T) {
a, _ := net.Pipe()
mp := NewMultiplex(a, false)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
var counter int
var deadlineExceeded bool
for i := 0; i < 1000; i++ {
if _, err := mp.NewStream(ctx); err != nil {
if err != context.DeadlineExceeded {
t.Fatalf("expected the error to be a deadline error, got %s", err.Error())
}
deadlineExceeded = true
break
}
counter++
}
if counter == 0 {
t.Fatal("expected at least some streams to open successfully")
}
if !deadlineExceeded {
t.Fatal("expected a deadline error to occur at some point")
}
}
func TestWriteAfterClose(t *testing.T) { func TestWriteAfterClose(t *testing.T) {
a, b := net.Pipe() a, b := net.Pipe()
...@@ -134,7 +161,7 @@ func TestWriteAfterClose(t *testing.T) { ...@@ -134,7 +161,7 @@ func TestWriteAfterClose(t *testing.T) {
close(done) close(done)
}() }()
s, err := mpa.NewStream() s, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -178,7 +205,7 @@ func TestEcho(t *testing.T) { ...@@ -178,7 +205,7 @@ func TestEcho(t *testing.T) {
io.Copy(s, s) io.Copy(s, s)
}() }()
s, err := mpa.NewStream() s, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -214,7 +241,7 @@ func TestFullClose(t *testing.T) { ...@@ -214,7 +241,7 @@ func TestFullClose(t *testing.T) {
mes := make([]byte, 40960) mes := make([]byte, 40960)
rand.Read(mes) rand.Read(mes)
s, err := mpa.NewStream() s, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -281,7 +308,7 @@ func TestHalfClose(t *testing.T) { ...@@ -281,7 +308,7 @@ func TestHalfClose(t *testing.T) {
} }
}() }()
s, err := mpa.NewStream() s, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -339,7 +366,7 @@ func TestClosing(t *testing.T) { ...@@ -339,7 +366,7 @@ func TestClosing(t *testing.T) {
mpa := NewMultiplex(a, false) mpa := NewMultiplex(a, false)
mpb := NewMultiplex(b, true) mpb := NewMultiplex(b, true)
_, err := mpb.NewStream() _, err := mpb.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -373,7 +400,7 @@ func TestReset(t *testing.T) { ...@@ -373,7 +400,7 @@ func TestReset(t *testing.T) {
defer mpa.Close() defer mpa.Close()
defer mpb.Close() defer mpb.Close()
sa, err := mpa.NewStream() sa, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -425,7 +452,7 @@ func TestCancelRead(t *testing.T) { ...@@ -425,7 +452,7 @@ func TestCancelRead(t *testing.T) {
defer mpa.Close() defer mpa.Close()
defer mpb.Close() defer mpb.Close()
sa, err := mpa.NewStream() sa, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -486,7 +513,7 @@ func TestCancelWrite(t *testing.T) { ...@@ -486,7 +513,7 @@ func TestCancelWrite(t *testing.T) {
defer mpa.Close() defer mpa.Close()
defer mpb.Close() defer mpb.Close()
sa, err := mpa.NewStream() sa, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -560,7 +587,7 @@ func TestCancelReadAfterWrite(t *testing.T) { ...@@ -560,7 +587,7 @@ func TestCancelReadAfterWrite(t *testing.T) {
defer mpa.Close() defer mpa.Close()
defer mpb.Close() defer mpb.Close()
sa, err := mpa.NewStream() sa, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -603,7 +630,7 @@ func TestResetAfterEOF(t *testing.T) { ...@@ -603,7 +630,7 @@ func TestResetAfterEOF(t *testing.T) {
defer mpa.Close() defer mpa.Close()
defer mpb.Close() defer mpb.Close()
sa, err := mpa.NewStream() sa, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -632,7 +659,7 @@ func TestOpenAfterClose(t *testing.T) { ...@@ -632,7 +659,7 @@ func TestOpenAfterClose(t *testing.T) {
mpa := NewMultiplex(a, false) mpa := NewMultiplex(a, false)
mpb := NewMultiplex(b, true) mpb := NewMultiplex(b, true)
sa, err := mpa.NewStream() sa, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -646,12 +673,12 @@ func TestOpenAfterClose(t *testing.T) { ...@@ -646,12 +673,12 @@ func TestOpenAfterClose(t *testing.T) {
mpa.Close() mpa.Close()
s, err := mpa.NewStream() s, err := mpa.NewStream(context.Background())
if err == nil || s != nil { if err == nil || s != nil {
t.Fatal("opened a stream on a closed connection") t.Fatal("opened a stream on a closed connection")
} }
s, err = mpa.NewStream() s, err = mpa.NewStream(context.Background())
if err == nil || s != nil { if err == nil || s != nil {
t.Fatal("opened a stream on a closed connection") t.Fatal("opened a stream on a closed connection")
} }
...@@ -668,7 +695,7 @@ func TestDeadline(t *testing.T) { ...@@ -668,7 +695,7 @@ func TestDeadline(t *testing.T) {
defer mpa.Close() defer mpa.Close()
defer mpb.Close() defer mpb.Close()
sa, err := mpa.NewStream() sa, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -694,7 +721,7 @@ func TestReadAfterClose(t *testing.T) { ...@@ -694,7 +721,7 @@ func TestReadAfterClose(t *testing.T) {
defer mpa.Close() defer mpa.Close()
defer mpb.Close() defer mpb.Close()
sa, err := mpa.NewStream() sa, err := mpa.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -735,7 +762,7 @@ func TestFuzzCloseStream(t *testing.T) { ...@@ -735,7 +762,7 @@ func TestFuzzCloseStream(t *testing.T) {
streams := make([]*Stream, 100) streams := make([]*Stream, 100)
for i := range streams { for i := range streams {
var err error var err error
streams[i], err = mpb.NewStream() streams[i], err = mpb.NewStream(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment