Skip to content

Commit ecbf6db

Browse files
authored
Merge pull request #159 from nhooyr/close
Add websocket.CloseStatus
2 parents 96e5af1 + 0919bdb commit ecbf6db

File tree

9 files changed

+152
-88
lines changed

9 files changed

+152
-88
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ go get nhooyr.io/websocket
3333

3434
For a production quality example that shows off the full API, see the [echo example on the godoc](https://godoc.org/nhooyr.io/websocket#example-package--Echo). On github, the example is at [example_echo_test.go](./example_echo_test.go).
3535

36-
Use the [errors.As](https://golang.org/pkg/errors/#As) function [new in Go 1.13](https://golang.org/doc/go1.13#error_wrapping) to check for [websocket.CloseError](https://godoc.org/nhooyr.io/websocket#CloseError). See the [CloseError godoc example](https://godoc.org/nhooyr.io/websocket#example-CloseError).
36+
Use the [errors.As](https://golang.org/pkg/errors/#As) function [new in Go 1.13](https://golang.org/doc/go1.13#error_wrapping) to check for [websocket.CloseError](https://godoc.org/nhooyr.io/websocket#CloseError).
37+
There is also [websocket.CloseStatus](https://godoc.org/nhooyr.io/websocket#CloseStatus) to quickly grab the close status code out of a [websocket.CloseError](https://godoc.org/nhooyr.io/websocket#CloseError).
38+
See the [CloseError godoc example](https://godoc.org/nhooyr.io/websocket#example-CloseError).
3739

3840
### Server
3941

assert_test.go

Lines changed: 6 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,75 +2,19 @@ package websocket_test
22

33
import (
44
"context"
5-
"fmt"
65
"math/rand"
7-
"reflect"
86
"strings"
97
"time"
108

11-
"github.com/google/go-cmp/cmp"
12-
139
"nhooyr.io/websocket"
10+
"nhooyr.io/websocket/internal/assert"
1411
"nhooyr.io/websocket/wsjson"
1512
)
1613

1714
func init() {
1815
rand.Seed(time.Now().UnixNano())
1916
}
2017

21-
// https://github.com/google/go-cmp/issues/40#issuecomment-328615283
22-
func cmpDiff(exp, act interface{}) string {
23-
return cmp.Diff(exp, act, deepAllowUnexported(exp, act))
24-
}
25-
26-
func deepAllowUnexported(vs ...interface{}) cmp.Option {
27-
m := make(map[reflect.Type]struct{})
28-
for _, v := range vs {
29-
structTypes(reflect.ValueOf(v), m)
30-
}
31-
var typs []interface{}
32-
for t := range m {
33-
typs = append(typs, reflect.New(t).Elem().Interface())
34-
}
35-
return cmp.AllowUnexported(typs...)
36-
}
37-
38-
func structTypes(v reflect.Value, m map[reflect.Type]struct{}) {
39-
if !v.IsValid() {
40-
return
41-
}
42-
switch v.Kind() {
43-
case reflect.Ptr:
44-
if !v.IsNil() {
45-
structTypes(v.Elem(), m)
46-
}
47-
case reflect.Interface:
48-
if !v.IsNil() {
49-
structTypes(v.Elem(), m)
50-
}
51-
case reflect.Slice, reflect.Array:
52-
for i := 0; i < v.Len(); i++ {
53-
structTypes(v.Index(i), m)
54-
}
55-
case reflect.Map:
56-
for _, k := range v.MapKeys() {
57-
structTypes(v.MapIndex(k), m)
58-
}
59-
case reflect.Struct:
60-
m[v.Type()] = struct{}{}
61-
for i := 0; i < v.NumField(); i++ {
62-
structTypes(v.Field(i), m)
63-
}
64-
}
65-
}
66-
67-
func assertEqualf(exp, act interface{}, f string, v ...interface{}) error {
68-
if diff := cmpDiff(exp, act); diff != "" {
69-
return fmt.Errorf(f+": %v", append(v, diff)...)
70-
}
71-
return nil
72-
}
73-
7418
func assertJSONEcho(ctx context.Context, c *websocket.Conn, n int) error {
7519
exp := randString(n)
7620
err := wsjson.Write(ctx, c, exp)
@@ -84,7 +28,7 @@ func assertJSONEcho(ctx context.Context, c *websocket.Conn, n int) error {
8428
return err
8529
}
8630

87-
return assertEqualf(exp, act, "unexpected JSON")
31+
return assert.Equalf(exp, act, "unexpected JSON")
8832
}
8933

9034
func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) error {
@@ -94,7 +38,7 @@ func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) err
9438
return err
9539
}
9640

97-
return assertEqualf(exp, act, "unexpected JSON")
41+
return assert.Equalf(exp, act, "unexpected JSON")
9842
}
9943

10044
func randBytes(n int) []byte {
@@ -126,13 +70,13 @@ func assertEcho(ctx context.Context, c *websocket.Conn, typ websocket.MessageTyp
12670
if err != nil {
12771
return err
12872
}
129-
err = assertEqualf(typ, typ2, "unexpected data type")
73+
err = assert.Equalf(typ, typ2, "unexpected data type")
13074
if err != nil {
13175
return err
13276
}
133-
return assertEqualf(p, p2, "unexpected payload")
77+
return assert.Equalf(p, p2, "unexpected payload")
13478
}
13579

13680
func assertSubprotocol(c *websocket.Conn, exp string) error {
137-
return assertEqualf(exp, c.Subprotocol(), "unexpected subprotocol")
81+
return assert.Equalf(exp, c.Subprotocol(), "unexpected subprotocol")
13882
}

conn_test.go

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232
"go.uber.org/multierr"
3333

3434
"nhooyr.io/websocket"
35+
"nhooyr.io/websocket/internal/assert"
3536
"nhooyr.io/websocket/internal/wsecho"
3637
"nhooyr.io/websocket/wsjson"
3738
"nhooyr.io/websocket/wspb"
@@ -127,7 +128,7 @@ func TestHandshake(t *testing.T) {
127128
if err != nil {
128129
return fmt.Errorf("request is missing mycookie: %w", err)
129130
}
130-
err = assertEqualf("myvalue", cookie.Value, "unexpected cookie value")
131+
err = assert.Equalf("myvalue", cookie.Value, "unexpected cookie value")
131132
if err != nil {
132133
return err
133134
}
@@ -219,7 +220,7 @@ func TestConn(t *testing.T) {
219220
}
220221
for h, exp := range headers {
221222
value := resp.Header.Get(h)
222-
err := assertEqualf(exp, value, "unexpected value for header %v", h)
223+
err := assert.Equalf(exp, value, "unexpected value for header %v", h)
223224
if err != nil {
224225
return err
225226
}
@@ -276,11 +277,11 @@ func TestConn(t *testing.T) {
276277
time.Sleep(1)
277278
nc.SetWriteDeadline(time.Now().Add(time.Second * 15))
278279

279-
err := assertEqualf(websocket.Addr{}, nc.LocalAddr(), "net conn local address is not equal to websocket.Addr")
280+
err := assert.Equalf(websocket.Addr{}, nc.LocalAddr(), "net conn local address is not equal to websocket.Addr")
280281
if err != nil {
281282
return err
282283
}
283-
err = assertEqualf(websocket.Addr{}, nc.RemoteAddr(), "net conn remote address is not equal to websocket.Addr")
284+
err = assert.Equalf(websocket.Addr{}, nc.RemoteAddr(), "net conn remote address is not equal to websocket.Addr")
284285
if err != nil {
285286
return err
286287
}
@@ -310,13 +311,13 @@ func TestConn(t *testing.T) {
310311

311312
// Ensure the close frame is converted to an EOF and multiple read's after all return EOF.
312313
err2 := assertNetConnRead(nc, "hello")
313-
err := assertEqualf(io.EOF, err2, "unexpected error")
314+
err := assert.Equalf(io.EOF, err2, "unexpected error")
314315
if err != nil {
315316
return err
316317
}
317318

318319
err2 = assertNetConnRead(nc, "hello")
319-
return assertEqualf(io.EOF, err2, "unexpected error")
320+
return assert.Equalf(io.EOF, err2, "unexpected error")
320321
},
321322
},
322323
{
@@ -585,8 +586,8 @@ func TestConn(t *testing.T) {
585586
return err
586587
}
587588
_, _, err = c.Read(ctx)
588-
cerr := &websocket.CloseError{}
589-
if !errors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError {
589+
var cerr websocket.CloseError
590+
if !errors.As(err, &cerr) || cerr.Code != websocket.StatusProtocolError {
590591
return fmt.Errorf("expected close error with StatusProtocolError: %+v", err)
591592
}
592593
return nil
@@ -772,15 +773,15 @@ func TestConn(t *testing.T) {
772773
if err != nil {
773774
return err
774775
}
775-
err = assertEqualf("hi", v, "unexpected JSON")
776+
err = assert.Equalf("hi", v, "unexpected JSON")
776777
if err != nil {
777778
return err
778779
}
779780
_, b, err := c.Read(ctx)
780781
if err != nil {
781782
return err
782783
}
783-
return assertEqualf("hi", string(b), "unexpected JSON")
784+
return assert.Equalf("hi", string(b), "unexpected JSON")
784785
},
785786
client: func(ctx context.Context, c *websocket.Conn) error {
786787
err := wsjson.Write(ctx, c, "hi")
@@ -1079,11 +1080,11 @@ func TestAutobahn(t *testing.T) {
10791080
if err != nil {
10801081
return err
10811082
}
1082-
err = assertEqualf(typ, actTyp, "unexpected message type")
1083+
err = assert.Equalf(typ, actTyp, "unexpected message type")
10831084
if err != nil {
10841085
return err
10851086
}
1086-
return assertEqualf(p, p2, "unexpected message")
1087+
return assert.Equalf(p, p2, "unexpected message")
10871088
})
10881089
}
10891090
}
@@ -1859,7 +1860,7 @@ func assertCloseStatus(err error, code websocket.StatusCode) error {
18591860
if !errors.As(err, &cerr) {
18601861
return fmt.Errorf("no websocket close error in error chain: %+v", err)
18611862
}
1862-
return assertEqualf(code, cerr.Code, "unexpected status code")
1863+
return assert.Equalf(code, cerr.Code, "unexpected status code")
18631864
}
18641865

18651866
func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{}) error {
@@ -1871,7 +1872,7 @@ func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{})
18711872
return err
18721873
}
18731874

1874-
return assertEqualf(exp, act, "unexpected protobuf")
1875+
return assert.Equalf(exp, act, "unexpected protobuf")
18751876
}
18761877

18771878
func assertNetConnRead(r io.Reader, exp string) error {
@@ -1880,7 +1881,7 @@ func assertNetConnRead(r io.Reader, exp string) error {
18801881
if err != nil {
18811882
return err
18821883
}
1883-
return assertEqualf(exp, string(act), "unexpected net conn read")
1884+
return assert.Equalf(exp, string(act), "unexpected net conn read")
18841885
}
18851886

18861887
func assertErrorContains(err error, exp string) error {
@@ -1902,27 +1903,27 @@ func assertReadFrame(ctx context.Context, c *websocket.Conn, opcode websocket.Op
19021903
if err != nil {
19031904
return err
19041905
}
1905-
err = assertEqualf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP)
1906+
err = assert.Equalf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP)
19061907
if err != nil {
19071908
return err
19081909
}
1909-
return assertEqualf(p, actP, "unexpected frame %v payload", opcode)
1910+
return assert.Equalf(p, actP, "unexpected frame %v payload", opcode)
19101911
}
19111912

19121913
func assertReadCloseFrame(ctx context.Context, c *websocket.Conn, code websocket.StatusCode) error {
19131914
actOpcode, actP, err := c.ReadFrame(ctx)
19141915
if err != nil {
19151916
return err
19161917
}
1917-
err = assertEqualf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP)
1918+
err = assert.Equalf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP)
19181919
if err != nil {
19191920
return err
19201921
}
19211922
ce, err := websocket.ParseClosePayload(actP)
19221923
if err != nil {
19231924
return fmt.Errorf("failed to parse close frame payload: %w", err)
19241925
}
1925-
return assertEqualf(ce.Code, code, "unexpected frame close frame code with payload %q", actP)
1926+
return assert.Equalf(ce.Code, code, "unexpected frame close frame code with payload %q", actP)
19261927
}
19271928

19281929
func assertCloseHandshake(ctx context.Context, c *websocket.Conn, code websocket.StatusCode, reason string) error {
@@ -1960,11 +1961,11 @@ func assertReadMessage(ctx context.Context, c *websocket.Conn, typ websocket.Mes
19601961
if err != nil {
19611962
return err
19621963
}
1963-
err = assertEqualf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP)
1964+
err = assert.Equalf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP)
19641965
if err != nil {
19651966
return err
19661967
}
1967-
return assertEqualf(p, actP, "unexpected frame %v payload", actTyp)
1968+
return assert.Equalf(p, actP, "unexpected frame %v payload", actTyp)
19681969
}
19691970

19701971
func BenchmarkConn(b *testing.B) {

doc.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// comparison with existing implementations.
1717
//
1818
// Use the errors.As function new in Go 1.13 to check for websocket.CloseError.
19+
// Or use the CloseStatus function to grab the StatusCode out of a websocket.CloseError
1920
// See the CloseError example.
2021
//
2122
// Wasm

example_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ package websocket_test
44

55
import (
66
"context"
7-
"errors"
87
"log"
98
"net/http"
109
"time"
@@ -76,8 +75,7 @@ func ExampleCloseError() {
7675
defer c.Close(websocket.StatusInternalError, "the sky is falling")
7776

7877
_, _, err = c.Reader(ctx)
79-
var cerr websocket.CloseError
80-
if !errors.As(err, &cerr) || cerr.Code != websocket.StatusNormalClosure {
78+
if websocket.CloseStatus(err) != websocket.StatusNormalClosure {
8179
log.Fatalf("expected to be disconnected with StatusNormalClosure but got: %+v", err)
8280
return
8381
}

frame.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package websocket
22

33
import (
44
"encoding/binary"
5+
"errors"
56
"fmt"
67
"io"
78
"math"
@@ -252,6 +253,17 @@ func (ce CloseError) Error() string {
252253
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
253254
}
254255

256+
// CloseStatus is a convenience wrapper around errors.As to grab
257+
// the status code from a *CloseError. If the passed error is nil
258+
// or not a *CloseError, the returned StatusCode will be -1.
259+
func CloseStatus(err error) StatusCode {
260+
var ce CloseError
261+
if errors.As(err, &ce) {
262+
return ce.Code
263+
}
264+
return -1
265+
}
266+
255267
func parseClosePayload(p []byte) (CloseError, error) {
256268
if len(p) == 0 {
257269
return CloseError{

frame_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"time"
1414

1515
"github.com/google/go-cmp/cmp"
16+
17+
"nhooyr.io/websocket/internal/assert"
1618
)
1719

1820
func init() {
@@ -376,3 +378,43 @@ func BenchmarkXOR(b *testing.B) {
376378
})
377379
}
378380
}
381+
382+
func TestCloseStatus(t *testing.T) {
383+
t.Parallel()
384+
385+
testCases := []struct {
386+
name string
387+
in error
388+
exp StatusCode
389+
}{
390+
{
391+
name: "nil",
392+
in: nil,
393+
exp: -1,
394+
},
395+
{
396+
name: "io.EOF",
397+
in: io.EOF,
398+
exp: -1,
399+
},
400+
{
401+
name: "StatusInternalError",
402+
in: CloseError{
403+
Code: StatusInternalError,
404+
},
405+
exp: StatusInternalError,
406+
},
407+
}
408+
409+
for _, tc := range testCases {
410+
tc := tc
411+
t.Run(tc.name, func(t *testing.T) {
412+
t.Parallel()
413+
414+
err := assert.Equalf(tc.exp, CloseStatus(tc.in), "unexpected close status")
415+
if err != nil {
416+
t.Fatal(err)
417+
}
418+
})
419+
}
420+
}

0 commit comments

Comments
 (0)