Skip to content

Commit 0a6805b

Browse files
committed
pass metadata along with stream creation command
1 parent 6aa95ef commit 0a6805b

File tree

3 files changed

+44
-10
lines changed

3 files changed

+44
-10
lines changed

session.go

+15-4
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
102102
}
103103

104104
// OpenStream is used to create a new stream
105-
func (s *Session) OpenStream() (*Stream, error) {
105+
func (s *Session) OpenStream(metadata ...byte) (*Stream, error) {
106106
if s.IsClosed() {
107107
return nil, errors.WithStack(io.ErrClosedPipe)
108108
}
@@ -123,9 +123,11 @@ func (s *Session) OpenStream() (*Stream, error) {
123123
}
124124
s.nextStreamIDLock.Unlock()
125125

126-
stream := newStream(sid, s.config.MaxFrameSize, s)
126+
stream := newStream(sid, metadata, s.config.MaxFrameSize, s)
127127

128-
if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil {
128+
frame := newFrame(cmdSYN, sid)
129+
frame.data = metadata
130+
if _, err := s.writeFrame(frame); err != nil {
129131
return nil, errors.WithStack(err)
130132
}
131133

@@ -307,7 +309,16 @@ func (s *Session) recvLoop() {
307309
case cmdSYN:
308310
s.streamLock.Lock()
309311
if _, ok := s.streams[sid]; !ok {
310-
stream := newStream(sid, s.config.MaxFrameSize, s)
312+
var newbuf []byte
313+
if hdr.Length() > 0 {
314+
newbuf = defaultAllocator.Get(int(hdr.Length()))
315+
if _, err := io.ReadFull(s.conn, newbuf); err != nil {
316+
s.notifyReadError(errors.WithStack(err))
317+
s.streamLock.Unlock()
318+
return
319+
}
320+
}
321+
stream := newStream(sid, append([]byte(nil), newbuf...), s.config.MaxFrameSize, s)
311322
s.streams[sid] = stream
312323
select {
313324
case s.chAccepts <- stream:

session_test.go

+19-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package smux
22

33
import (
4+
"bytes"
45
crand "crypto/rand"
56
"encoding/binary"
67
"fmt"
@@ -25,7 +26,7 @@ func init() {
2526
// setupServer starts new server listening on a random localhost port and
2627
// returns address of the server, function to stop the server, new client
2728
// connection to this server or an error.
28-
func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) {
29+
func setupServer(tb testing.TB, metadata ...byte) (addr string, stopfunc func(), client net.Conn, err error) {
2930
ln, err := net.Listen("tcp", "localhost:0")
3031
if err != nil {
3132
return "", nil, nil, err
@@ -35,7 +36,7 @@ func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn,
3536
if err != nil {
3637
return
3738
}
38-
go handleConnection(conn)
39+
go handleConnection(tb, conn, metadata...)
3940
}()
4041
addr = ln.Addr().String()
4142
conn, err := net.Dial("tcp", addr)
@@ -46,10 +47,13 @@ func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn,
4647
return ln.Addr().String(), func() { ln.Close() }, conn, nil
4748
}
4849

49-
func handleConnection(conn net.Conn) {
50+
func handleConnection(tb testing.TB, conn net.Conn, metadata ...byte) {
5051
session, _ := Server(conn, nil)
5152
for {
5253
if stream, err := session.AcceptStream(); err == nil {
54+
if !bytes.Equal(metadata, stream.Metadata()) {
55+
tb.Fatal("metadata mismatch")
56+
}
5357
go func(s io.ReadWriteCloser) {
5458
buf := make([]byte, 65536)
5559
for {
@@ -66,6 +70,18 @@ func handleConnection(conn net.Conn) {
6670
}
6771
}
6872

73+
func TestMetadata(t *testing.T) {
74+
metadata := []byte("hello, world")
75+
_, stop, cli, err := setupServer(t, metadata...)
76+
if err != nil {
77+
t.Fatal(err)
78+
}
79+
defer stop()
80+
session, _ := Client(cli, nil)
81+
session.OpenStream(metadata...)
82+
session.Close()
83+
}
84+
6985
func TestEcho(t *testing.T) {
7086
_, stop, cli, err := setupServer(t)
7187
if err != nil {

stream.go

+10-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ import (
1212

1313
// Stream implements net.Conn
1414
type Stream struct {
15-
id uint32
16-
sess *Session
15+
id uint32
16+
metadata []byte
17+
sess *Session
1718

1819
buffers [][]byte
1920
heads [][]byte // slice heads kept for recycle
@@ -38,9 +39,10 @@ type Stream struct {
3839
}
3940

4041
// newStream initiates a Stream struct
41-
func newStream(id uint32, frameSize int, sess *Session) *Stream {
42+
func newStream(id uint32, metadata []byte, frameSize int, sess *Session) *Stream {
4243
s := new(Stream)
4344
s.id = id
45+
s.metadata = metadata
4446
s.chReadEvent = make(chan struct{}, 1)
4547
s.frameSize = frameSize
4648
s.sess = sess
@@ -54,6 +56,11 @@ func (s *Stream) ID() uint32 {
5456
return s.id
5557
}
5658

59+
// Metadata returns stream metadata which was provided when opening stream.
60+
func (s *Stream) Metadata() []byte {
61+
return s.metadata
62+
}
63+
5764
// Read implements net.Conn
5865
func (s *Stream) Read(b []byte) (n int, err error) {
5966
if len(b) == 0 {

0 commit comments

Comments
 (0)