Skip to content

Commit 856e371

Browse files
committed
ws_js: Update to match new close code
1 parent db18a31 commit 856e371

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

read.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
6464
// This function is idempotent.
6565
func (c *Conn) CloseRead(ctx context.Context) context.Context {
6666
c.closeReadMu.Lock()
67-
if c.closeReadCtx != nil {
67+
ctx2 := c.closeReadCtx
68+
if ctx2 != nil {
6869
c.closeReadMu.Unlock()
69-
return c.closeReadCtx
70+
return ctx2
7071
}
7172
ctx, cancel := context.WithCancel(ctx)
7273
c.closeReadCtx = ctx

ws_js.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ type Conn struct {
4747
// read limit for a message in bytes.
4848
msgReadLimit xsync.Int64
4949

50-
wg sync.WaitGroup
50+
closeReadMu sync.Mutex
51+
closeReadCtx context.Context
52+
5153
closingMu sync.Mutex
52-
isReadClosed xsync.Int64
5354
closeOnce sync.Once
5455
closed chan struct{}
5556
closeErrOnce sync.Once
@@ -130,7 +131,10 @@ func (c *Conn) closeWithInternal() {
130131
// Read attempts to read a message from the connection.
131132
// The maximum time spent waiting is bounded by the context.
132133
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
133-
if c.isReadClosed.Load() == 1 {
134+
c.closeReadMu.Lock()
135+
closedRead := c.closeReadCtx != nil
136+
c.closeReadMu.Unlock()
137+
if closedRead {
134138
return 0, nil, errors.New("WebSocket connection read closed")
135139
}
136140

@@ -387,14 +391,19 @@ func (w *writer) Close() error {
387391

388392
// CloseRead implements *Conn.CloseRead for wasm.
389393
func (c *Conn) CloseRead(ctx context.Context) context.Context {
390-
c.isReadClosed.Store(1)
391-
394+
c.closeReadMu.Lock()
395+
ctx2 := c.closeReadCtx
396+
if ctx2 != nil {
397+
c.closeReadMu.Unlock()
398+
return ctx2
399+
}
392400
ctx, cancel := context.WithCancel(ctx)
393-
c.wg.Add(1)
401+
c.closeReadCtx = ctx
402+
c.closeReadMu.Unlock()
403+
394404
go func() {
395-
defer c.CloseNow()
396-
defer c.wg.Done()
397405
defer cancel()
406+
defer c.CloseNow()
398407
_, _, err := c.read(ctx)
399408
if err != nil {
400409
c.Close(StatusPolicyViolation, "unexpected data message")

0 commit comments

Comments
 (0)