Skip to content

Commit a97030c

Browse files
committed
Enable buffer reuse again for clients
1 parent 579ce18 commit a97030c

File tree

1 file changed

+64
-38
lines changed

1 file changed

+64
-38
lines changed

websocket.go

Lines changed: 64 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ type Conn struct {
3838
writeDataLock chan struct{}
3939
writeFrameLock chan struct{}
4040

41-
readDataLock chan struct{}
42-
readData chan header
43-
readDone chan struct{}
44-
readLoopDone chan struct{}
41+
readMsgLock chan struct{}
42+
readMsg chan header
43+
readMsgDone chan struct{}
44+
readFrameLock chan struct{}
4545

4646
setReadTimeout chan context.Context
4747
setWriteTimeout chan context.Context
@@ -90,17 +90,15 @@ func (c *Conn) close(err error) {
9090

9191
close(c.closed)
9292

93+
// This ensures every goroutine that interacts
94+
// with the conn closes before it can interact with the connection
95+
c.readFrameLock <- struct{}{}
96+
c.writeFrameLock <- struct{}{}
97+
9398
// See comment in dial.go
9499
if c.client {
95-
go func() {
96-
<-c.readLoopDone
97-
// TODO this does not work if reader errors out so skip for now.
98-
// c.readDataLock <- struct{}{}
99-
// c.writeFrameLock <- struct{}{}
100-
//
101-
// returnBufioReader(c.br)
102-
// returnBufioWriter(c.bw)
103-
}()
100+
returnBufioReader(c.br)
101+
returnBufioWriter(c.bw)
104102
}
105103
})
106104
}
@@ -119,10 +117,10 @@ func (c *Conn) init() {
119117
c.writeDataLock = make(chan struct{}, 1)
120118
c.writeFrameLock = make(chan struct{}, 1)
121119

122-
c.readData = make(chan header)
123-
c.readDone = make(chan struct{})
124-
c.readDataLock = make(chan struct{}, 1)
125-
c.readLoopDone = make(chan struct{})
120+
c.readMsg = make(chan header)
121+
c.readMsgDone = make(chan struct{})
122+
c.readMsgLock = make(chan struct{}, 1)
123+
c.readFrameLock = make(chan struct{}, 1)
126124

127125
c.setReadTimeout = make(chan context.Context)
128126
c.setWriteTimeout = make(chan context.Context)
@@ -141,8 +139,8 @@ func (c *Conn) init() {
141139

142140
// We never mask inside here because our mask key is always 0,0,0,0.
143141
// See comment on secWebSocketKey.
144-
func (c *Conn) writeFrame(ctx context.Context, h header, p []byte) error {
145-
err := c.acquireLock(ctx, c.writeFrameLock)
142+
func (c *Conn) writeFrame(ctx context.Context, h header, p []byte) (err error) {
143+
err = c.acquireLock(ctx, c.writeFrameLock)
146144
if err != nil {
147145
return err
148146
}
@@ -164,27 +162,33 @@ func (c *Conn) writeFrame(ctx context.Context, h header, p []byte) error {
164162
}
165163
}()
166164

165+
defer func() {
166+
if err != nil {
167+
// We need to always release the lock first before closing the connection to ensure
168+
// the lock can be acquired inside close.
169+
c.releaseLock(c.writeFrameLock)
170+
c.close(err)
171+
}
172+
}()
173+
167174
h.masked = c.client
168175
h.payloadLength = int64(len(p))
169176

170177
b2 := marshalHeader(h)
171178
_, err = c.bw.Write(b2)
172179
if err != nil {
173-
c.close(xerrors.Errorf("failed to write to connection: %w", err))
174-
return c.closeErr
180+
return xerrors.Errorf("failed to write to connection: %w", err)
175181
}
176182
_, err = c.bw.Write(p)
177183
if err != nil {
178-
c.close(xerrors.Errorf("failed to write to connection: %w", err))
179-
return c.closeErr
184+
return xerrors.Errorf("failed to write to connection: %w", err)
180185

181186
}
182187

183188
if h.fin {
184189
err := c.bw.Flush()
185190
if err != nil {
186-
c.close(xerrors.Errorf("failed to write to connection: %w", err))
187-
return c.closeErr
191+
return xerrors.Errorf("failed to write to connection: %w", err)
188192
}
189193
}
190194

@@ -279,9 +283,9 @@ func (c *Conn) handleControl(h header) {
279283

280284
func (c *Conn) readTillData() (header, error) {
281285
for {
282-
h, err := readHeader(c.br)
286+
h, err := c.readHeader()
283287
if err != nil {
284-
return header{}, xerrors.Errorf("failed to read header: %w", err)
288+
return header{}, err
285289
}
286290

287291
if h.rsv1 || h.rsv2 || h.rsv3 {
@@ -312,9 +316,22 @@ func (c *Conn) readTillData() (header, error) {
312316
}
313317
}
314318

315-
func (c *Conn) readLoop() {
316-
defer close(c.readLoopDone)
319+
func (c *Conn) readHeader() (header, error) {
320+
err := c.acquireLock(context.Background(), c.readFrameLock)
321+
if err != nil {
322+
return header{}, err
323+
}
324+
defer c.releaseLock(c.readFrameLock)
317325

326+
h, err := readHeader(c.br)
327+
if err != nil {
328+
return header{}, xerrors.Errorf("failed to read header: %w", err)
329+
}
330+
331+
return h, nil
332+
}
333+
334+
func (c *Conn) readLoop() {
318335
for {
319336
h, err := c.readTillData()
320337
if err != nil {
@@ -325,13 +342,13 @@ func (c *Conn) readLoop() {
325342
select {
326343
case <-c.closed:
327344
return
328-
case c.readData <- h:
345+
case c.readMsg <- h:
329346
}
330347

331348
select {
332349
case <-c.closed:
333350
return
334-
case <-c.readDone:
351+
case <-c.readMsgDone:
335352
}
336353
}
337354
}
@@ -374,7 +391,7 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
374391
// Definitely worth seeing what popular browsers do later.
375392
p, err := ce.bytes()
376393
if err != nil {
377-
fmt.Fprintf(os.Stderr, "failed to marshal close frame: %v\n", err)
394+
fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err)
378395
ce = CloseError{
379396
Code: StatusInternalError,
380397
}
@@ -415,7 +432,11 @@ func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error {
415432
}
416433

417434
func (c *Conn) releaseLock(lock chan struct{}) {
418-
<-lock
435+
// Allow multiple releases.
436+
select {
437+
case <-lock:
438+
default:
439+
}
419440
}
420441

421442
func (c *Conn) writeMessage(ctx context.Context, opcode opcode, p []byte) error {
@@ -572,7 +593,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
572593
}
573594

574595
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
575-
err = c.acquireLock(ctx, c.readDataLock)
596+
err = c.acquireLock(ctx, c.readMsgLock)
576597
if err != nil {
577598
return 0, nil, err
578599
}
@@ -582,7 +603,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
582603
return 0, nil, c.closeErr
583604
case <-ctx.Done():
584605
return 0, nil, ctx.Err()
585-
case h := <-c.readData:
606+
case h := <-c.readMsg:
586607
if h.opcode == opContinuation {
587608
ce := CloseError{
588609
Code: StatusProtocolError,
@@ -631,7 +652,7 @@ func (r *messageReader) read(p []byte) (int, error) {
631652
select {
632653
case <-r.c.closed:
633654
return 0, r.c.closeErr
634-
case h := <-r.c.readData:
655+
case h := <-r.c.readMsg:
635656
if h.opcode != opContinuation {
636657
ce := CloseError{
637658
Code: StatusProtocolError,
@@ -654,7 +675,12 @@ func (r *messageReader) read(p []byte) (int, error) {
654675
case r.c.setReadTimeout <- r.ctx:
655676
}
656677

678+
err := r.c.acquireLock(r.ctx, r.c.readFrameLock)
679+
if err != nil {
680+
return 0, err
681+
}
657682
n, err := io.ReadFull(r.c.br, p)
683+
r.c.releaseLock(r.c.readFrameLock)
658684

659685
select {
660686
case <-r.c.closed:
@@ -676,11 +702,11 @@ func (r *messageReader) read(p []byte) (int, error) {
676702
select {
677703
case <-r.c.closed:
678704
return n, r.c.closeErr
679-
case r.c.readDone <- struct{}{}:
705+
case r.c.readMsgDone <- struct{}{}:
680706
}
681707
if r.h.fin {
682708
r.eofed = true
683-
r.c.releaseLock(r.c.readDataLock)
709+
r.c.releaseLock(r.c.readMsgLock)
684710
return n, io.EOF
685711
}
686712
r.maskPos = 0

0 commit comments

Comments
 (0)