@@ -38,10 +38,10 @@ type Conn struct {
38
38
writeDataLock chan struct {}
39
39
writeFrameLock chan struct {}
40
40
41
- readDataLock chan struct {}
42
- readData chan header
43
- readDone chan struct {}
44
- readLoopDone chan struct {}
41
+ readMsgLock chan struct {}
42
+ readMsg chan header
43
+ readMsgDone chan struct {}
44
+ readFrameLock chan struct {}
45
45
46
46
setReadTimeout chan context.Context
47
47
setWriteTimeout chan context.Context
@@ -90,17 +90,15 @@ func (c *Conn) close(err error) {
90
90
91
91
close (c .closed )
92
92
93
+ // This ensures every goroutine that interacts
94
+ // with the conn closes before it can interact with the connection
95
+ c .readFrameLock <- struct {}{}
96
+ c .writeFrameLock <- struct {}{}
97
+
93
98
// See comment in dial.go
94
99
if c .client {
95
- go func () {
96
- <- c .readLoopDone
97
- // TODO this does not work if reader errors out so skip for now.
98
- // c.readDataLock <- struct{}{}
99
- // c.writeFrameLock <- struct{}{}
100
- //
101
- // returnBufioReader(c.br)
102
- // returnBufioWriter(c.bw)
103
- }()
100
+ returnBufioReader (c .br )
101
+ returnBufioWriter (c .bw )
104
102
}
105
103
})
106
104
}
@@ -119,10 +117,10 @@ func (c *Conn) init() {
119
117
c .writeDataLock = make (chan struct {}, 1 )
120
118
c .writeFrameLock = make (chan struct {}, 1 )
121
119
122
- c .readData = make (chan header )
123
- c .readDone = make (chan struct {})
124
- c .readDataLock = make (chan struct {}, 1 )
125
- c .readLoopDone = make (chan struct {})
120
+ c .readMsg = make (chan header )
121
+ c .readMsgDone = make (chan struct {})
122
+ c .readMsgLock = make (chan struct {}, 1 )
123
+ c .readFrameLock = make (chan struct {}, 1 )
126
124
127
125
c .setReadTimeout = make (chan context.Context )
128
126
c .setWriteTimeout = make (chan context.Context )
@@ -141,8 +139,8 @@ func (c *Conn) init() {
141
139
142
140
// We never mask inside here because our mask key is always 0,0,0,0.
143
141
// See comment on secWebSocketKey.
144
- func (c * Conn ) writeFrame (ctx context.Context , h header , p []byte ) error {
145
- err : = c .acquireLock (ctx , c .writeFrameLock )
142
+ func (c * Conn ) writeFrame (ctx context.Context , h header , p []byte ) ( err error ) {
143
+ err = c .acquireLock (ctx , c .writeFrameLock )
146
144
if err != nil {
147
145
return err
148
146
}
@@ -164,27 +162,33 @@ func (c *Conn) writeFrame(ctx context.Context, h header, p []byte) error {
164
162
}
165
163
}()
166
164
165
+ defer func () {
166
+ if err != nil {
167
+ // We need to always release the lock first before closing the connection to ensure
168
+ // the lock can be acquired inside close.
169
+ c .releaseLock (c .writeFrameLock )
170
+ c .close (err )
171
+ }
172
+ }()
173
+
167
174
h .masked = c .client
168
175
h .payloadLength = int64 (len (p ))
169
176
170
177
b2 := marshalHeader (h )
171
178
_ , err = c .bw .Write (b2 )
172
179
if err != nil {
173
- c .close (xerrors .Errorf ("failed to write to connection: %w" , err ))
174
- return c .closeErr
180
+ return xerrors .Errorf ("failed to write to connection: %w" , err )
175
181
}
176
182
_ , err = c .bw .Write (p )
177
183
if err != nil {
178
- c .close (xerrors .Errorf ("failed to write to connection: %w" , err ))
179
- return c .closeErr
184
+ return xerrors .Errorf ("failed to write to connection: %w" , err )
180
185
181
186
}
182
187
183
188
if h .fin {
184
189
err := c .bw .Flush ()
185
190
if err != nil {
186
- c .close (xerrors .Errorf ("failed to write to connection: %w" , err ))
187
- return c .closeErr
191
+ return xerrors .Errorf ("failed to write to connection: %w" , err )
188
192
}
189
193
}
190
194
@@ -279,9 +283,9 @@ func (c *Conn) handleControl(h header) {
279
283
280
284
func (c * Conn ) readTillData () (header , error ) {
281
285
for {
282
- h , err := readHeader ( c . br )
286
+ h , err := c . readHeader ( )
283
287
if err != nil {
284
- return header {}, xerrors . Errorf ( "failed to read header: %w" , err )
288
+ return header {}, err
285
289
}
286
290
287
291
if h .rsv1 || h .rsv2 || h .rsv3 {
@@ -312,9 +316,22 @@ func (c *Conn) readTillData() (header, error) {
312
316
}
313
317
}
314
318
315
- func (c * Conn ) readLoop () {
316
- defer close (c .readLoopDone )
319
+ func (c * Conn ) readHeader () (header , error ) {
320
+ err := c .acquireLock (context .Background (), c .readFrameLock )
321
+ if err != nil {
322
+ return header {}, err
323
+ }
324
+ defer c .releaseLock (c .readFrameLock )
317
325
326
+ h , err := readHeader (c .br )
327
+ if err != nil {
328
+ return header {}, xerrors .Errorf ("failed to read header: %w" , err )
329
+ }
330
+
331
+ return h , nil
332
+ }
333
+
334
+ func (c * Conn ) readLoop () {
318
335
for {
319
336
h , err := c .readTillData ()
320
337
if err != nil {
@@ -325,13 +342,13 @@ func (c *Conn) readLoop() {
325
342
select {
326
343
case <- c .closed :
327
344
return
328
- case c .readData <- h :
345
+ case c .readMsg <- h :
329
346
}
330
347
331
348
select {
332
349
case <- c .closed :
333
350
return
334
- case <- c .readDone :
351
+ case <- c .readMsgDone :
335
352
}
336
353
}
337
354
}
@@ -374,7 +391,7 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
374
391
// Definitely worth seeing what popular browsers do later.
375
392
p , err := ce .bytes ()
376
393
if err != nil {
377
- fmt .Fprintf (os .Stderr , "failed to marshal close frame: %v\n " , err )
394
+ fmt .Fprintf (os .Stderr , "websocket: failed to marshal close frame: %v\n " , err )
378
395
ce = CloseError {
379
396
Code : StatusInternalError ,
380
397
}
@@ -415,7 +432,11 @@ func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error {
415
432
}
416
433
417
434
func (c * Conn ) releaseLock (lock chan struct {}) {
418
- <- lock
435
+ // Allow multiple releases.
436
+ select {
437
+ case <- lock :
438
+ default :
439
+ }
419
440
}
420
441
421
442
func (c * Conn ) writeMessage (ctx context.Context , opcode opcode , p []byte ) error {
@@ -572,7 +593,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
572
593
}
573
594
574
595
func (c * Conn ) reader (ctx context.Context ) (_ MessageType , _ io.Reader , err error ) {
575
- err = c .acquireLock (ctx , c .readDataLock )
596
+ err = c .acquireLock (ctx , c .readMsgLock )
576
597
if err != nil {
577
598
return 0 , nil , err
578
599
}
@@ -582,7 +603,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
582
603
return 0 , nil , c .closeErr
583
604
case <- ctx .Done ():
584
605
return 0 , nil , ctx .Err ()
585
- case h := <- c .readData :
606
+ case h := <- c .readMsg :
586
607
if h .opcode == opContinuation {
587
608
ce := CloseError {
588
609
Code : StatusProtocolError ,
@@ -631,7 +652,7 @@ func (r *messageReader) read(p []byte) (int, error) {
631
652
select {
632
653
case <- r .c .closed :
633
654
return 0 , r .c .closeErr
634
- case h := <- r .c .readData :
655
+ case h := <- r .c .readMsg :
635
656
if h .opcode != opContinuation {
636
657
ce := CloseError {
637
658
Code : StatusProtocolError ,
@@ -654,7 +675,12 @@ func (r *messageReader) read(p []byte) (int, error) {
654
675
case r .c .setReadTimeout <- r .ctx :
655
676
}
656
677
678
+ err := r .c .acquireLock (r .ctx , r .c .readFrameLock )
679
+ if err != nil {
680
+ return 0 , err
681
+ }
657
682
n , err := io .ReadFull (r .c .br , p )
683
+ r .c .releaseLock (r .c .readFrameLock )
658
684
659
685
select {
660
686
case <- r .c .closed :
@@ -676,11 +702,11 @@ func (r *messageReader) read(p []byte) (int, error) {
676
702
select {
677
703
case <- r .c .closed :
678
704
return n , r .c .closeErr
679
- case r .c .readDone <- struct {}{}:
705
+ case r .c .readMsgDone <- struct {}{}:
680
706
}
681
707
if r .h .fin {
682
708
r .eofed = true
683
- r .c .releaseLock (r .c .readDataLock )
709
+ r .c .releaseLock (r .c .readMsgLock )
684
710
return n , io .EOF
685
711
}
686
712
r .maskPos = 0
0 commit comments