Skip to content

Commit 4735f36

Browse files
committed
Minor fixes
1 parent 5f3fa5c commit 4735f36

File tree

8 files changed

+139
-37
lines changed

8 files changed

+139
-37
lines changed

accept.go

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"net/http"
1313
"net/textproto"
1414
"net/url"
15+
"strconv"
1516
"strings"
1617

1718
"nhooyr.io/websocket/internal/errd"
@@ -208,6 +209,7 @@ func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionM
208209

209210
func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
210211
copts := mode.opts()
212+
copts.serverMaxWindowBits = 8
211213

212214
for _, p := range ext.params {
213215
switch p {
@@ -219,7 +221,27 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
219221
continue
220222
}
221223

222-
if strings.HasPrefix(p, "client_max_window_bits") || strings.HasPrefix(p, "server_max_window_bits") {
224+
if strings.HasPrefix(p, "client_max_window_bits") {
225+
continue
226+
227+
// bits, ok := parseExtensionParameter(p, 15)
228+
// if !ok || bits < 8 || bits > 16 {
229+
// err := fmt.Errorf("invalid client_max_window_bits: %q", p)
230+
// http.Error(w, err.Error(), http.StatusBadRequest)
231+
// return nil, err
232+
// }
233+
// copts.clientMaxWindowBits = bits
234+
// continue
235+
}
236+
237+
if false && strings.HasPrefix(p, "server_max_window_bits") {
238+
// We always send back 8 but make sure to validate.
239+
bits, ok := parseExtensionParameter(p, 0)
240+
if !ok || bits < 8 || bits > 16 {
241+
err := fmt.Errorf("invalid server_max_window_bits: %q", p)
242+
http.Error(w, err.Error(), http.StatusBadRequest)
243+
return nil, err
244+
}
223245
continue
224246
}
225247

@@ -233,6 +255,21 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
233255
return copts, nil
234256
}
235257

258+
// parseExtensionParameter parses the value in the extension parameter p.
259+
// It falls back to defaultVal if there is no value.
260+
// If defaultVal == 0, then ok == false if there is no value.
261+
func parseExtensionParameter(p string, defaultVal int) (int, bool) {
262+
ps := strings.Split(p, "=")
263+
if len(ps) == 1 {
264+
if defaultVal > 0 {
265+
return defaultVal, true
266+
}
267+
return 0, false
268+
}
269+
i, e := strconv.Atoi(strings.Trim(ps[1], `"`))
270+
return i, e == nil
271+
}
272+
236273
func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
237274
copts := mode.opts()
238275
// The peer must explicitly request it.

accept_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ func Test_acceptCompression(t *testing.T) {
327327
expCopts: &compressionOptions{
328328
clientNoContextTakeover: true,
329329
serverNoContextTakeover: true,
330+
serverMaxWindowBits: 8,
330331
},
331332
},
332333
{

compress_notjs.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package websocket
44

55
import (
6+
"fmt"
67
"io"
78
"net/http"
89
"sync"
@@ -19,7 +20,10 @@ func (m CompressionMode) opts() *compressionOptions {
1920

2021
type compressionOptions struct {
2122
clientNoContextTakeover bool
23+
clientMaxWindowBits int
24+
2225
serverNoContextTakeover bool
26+
serverMaxWindowBits int
2327
}
2428

2529
func (copts *compressionOptions) setHeader(h http.Header) {
@@ -30,6 +34,12 @@ func (copts *compressionOptions) setHeader(h http.Header) {
3034
if copts.serverNoContextTakeover {
3135
s += "; server_no_context_takeover"
3236
}
37+
if false && copts.serverMaxWindowBits > 0 {
38+
s += fmt.Sprintf("; server_max_window_bits=%v", copts.serverMaxWindowBits)
39+
}
40+
if false && copts.clientMaxWindowBits > 0 {
41+
s += fmt.Sprintf("; client_max_window_bits=%v", copts.clientMaxWindowBits)
42+
}
3343
h.Set("Sec-WebSocket-Extensions", s)
3444
}
3545

@@ -152,9 +162,8 @@ func (sw *slidingWindow) close() {
152162
}
153163

154164
swPoolMu.Lock()
155-
defer swPoolMu.Unlock()
156-
157165
swPool[cap(sw.buf)].Put(sw.buf)
166+
swPoolMu.Unlock()
158167
sw.buf = nil
159168
}
160169

conn_test.go

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,21 @@ func TestConn(t *testing.T) {
114114

115115
for i := 0; i < count; i++ {
116116
go func() {
117-
errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg)
117+
select {
118+
case errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg):
119+
case <-tt.ctx.Done():
120+
return
121+
}
118122
}()
119123
}
120124

121125
for i := 0; i < count; i++ {
122-
err := <-errs
123-
assert.Success(t, err)
126+
select {
127+
case err := <-errs:
128+
assert.Success(t, err)
129+
case <-tt.ctx.Done():
130+
t.Fatal(tt.ctx.Err())
131+
}
124132
}
125133

126134
err := c1.Close(websocket.StatusNormalClosure, "")
@@ -171,8 +179,12 @@ func TestConn(t *testing.T) {
171179
_, err = n1.Read(nil)
172180
assert.Equal(t, "read error", err, io.EOF)
173181

174-
err = <-errs
175-
assert.Success(t, err)
182+
select {
183+
case err := <-errs:
184+
assert.Success(t, err)
185+
case <-tt.ctx.Done():
186+
t.Fatal(tt.ctx.Err())
187+
}
176188

177189
assert.Equal(t, "read msg", []byte("hello"), b)
178190
})
@@ -195,8 +207,12 @@ func TestConn(t *testing.T) {
195207
_, err := ioutil.ReadAll(n1)
196208
assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`)
197209

198-
err = <-errs
199-
assert.Success(t, err)
210+
select {
211+
case err := <-errs:
212+
assert.Success(t, err)
213+
case <-tt.ctx.Done():
214+
t.Fatal(tt.ctx.Err())
215+
}
200216
})
201217

202218
t.Run("wsjson", func(t *testing.T) {
@@ -218,8 +234,12 @@ func TestConn(t *testing.T) {
218234
assert.Success(t, err)
219235
assert.Equal(t, "read msg", exp, act)
220236

221-
err = <-werr
222-
assert.Success(t, err)
237+
select {
238+
case err := <-werr:
239+
assert.Success(t, err)
240+
case <-tt.ctx.Done():
241+
t.Fatal(tt.ctx.Err())
242+
}
223243

224244
err = c1.Close(websocket.StatusNormalClosure, "")
225245
assert.Success(t, err)
@@ -411,14 +431,22 @@ func BenchmarkConn(b *testing.B) {
411431

412432
go func() {
413433
for range writes {
414-
werrs <- c1.Write(bb.ctx, websocket.MessageText, msg)
434+
select {
435+
case werrs <- c1.Write(bb.ctx, websocket.MessageText, msg):
436+
case <-bb.ctx.Done():
437+
return
438+
}
415439
}
416440
}()
417441
b.SetBytes(int64(len(msg)))
418442
b.ReportAllocs()
419443
b.ResetTimer()
420444
for i := 0; i < b.N; i++ {
421-
writes <- struct{}{}
445+
select {
446+
case writes <- struct{}{}:
447+
case <-bb.ctx.Done():
448+
b.Fatal(bb.ctx.Err())
449+
}
422450

423451
typ, r, err := c1.Reader(bb.ctx)
424452
if err != nil {
@@ -445,7 +473,11 @@ func BenchmarkConn(b *testing.B) {
445473
assert.Equal(b, "msg", msg, readBuf)
446474
}
447475

448-
err = <-werrs
476+
select {
477+
case err = <-werrs:
478+
case <-bb.ctx.Done():
479+
b.Fatal(bb.ctx.Err())
480+
}
449481
if err != nil {
450482
b.Fatal(err)
451483
}

dial.go

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"net/url"
1717
"strings"
1818
"sync"
19+
"time"
1920

2021
"nhooyr.io/websocket/internal/errd"
2122
)
@@ -91,6 +92,12 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
9192
if err != nil {
9293
// We read a bit of the body for easier debugging.
9394
r := io.LimitReader(respBody, 1024)
95+
96+
timer := time.AfterFunc(time.Second*3, func() {
97+
respBody.Close()
98+
})
99+
defer timer.Stop()
100+
94101
b, _ := ioutil.ReadAll(r)
95102
respBody.Close()
96103
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
@@ -148,6 +155,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
148155
}
149156
if opts.CompressionMode != CompressionDisabled {
150157
copts := opts.CompressionMode.opts()
158+
copts.clientMaxWindowBits = 8
151159
copts.setHeader(req.Header)
152160
}
153161

@@ -225,15 +233,36 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
225233
}
226234

227235
copts := &compressionOptions{}
236+
copts.clientMaxWindowBits = 8
228237
for _, p := range ext.params {
229238
switch p {
230239
case "client_no_context_takeover":
231240
copts.clientNoContextTakeover = true
241+
continue
232242
case "server_no_context_takeover":
233243
copts.serverNoContextTakeover = true
234-
default:
235-
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
244+
continue
245+
}
246+
247+
if false && strings.HasPrefix(p, "server_max_window_bits") {
248+
bits, ok := parseExtensionParameter(p, 0)
249+
if !ok || bits < 8 || bits > 16 {
250+
return nil, fmt.Errorf("invalid server_max_window_bits: %q", p)
251+
}
252+
copts.serverMaxWindowBits = bits
253+
continue
236254
}
255+
256+
if false && strings.HasPrefix(p, "client_max_window_bits") {
257+
bits, ok := parseExtensionParameter(p, 0)
258+
if !ok || bits < 8 || bits > 16 {
259+
return nil, fmt.Errorf("invalid client_max_window_bits: %q", p)
260+
}
261+
copts.clientMaxWindowBits = 8
262+
continue
263+
}
264+
265+
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
237266
}
238267

239268
return copts, nil

internal/test/assert/assert.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,27 @@ package assert
22

33
import (
44
"fmt"
5+
"reflect"
56
"strings"
67
"testing"
78

8-
"nhooyr.io/websocket/internal/test/cmp"
9+
"github.com/golang/protobuf/proto"
10+
"github.com/google/go-cmp/cmp"
11+
"github.com/google/go-cmp/cmp/cmpopts"
912
)
1013

14+
// Diff returns a human readable diff between v1 and v2
15+
func Diff(v1, v2 interface{}) string {
16+
return cmp.Diff(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool {
17+
return true
18+
}), cmp.Comparer(proto.Equal))
19+
}
20+
1121
// Equal asserts exp == act.
1222
func Equal(t testing.TB, name string, exp, act interface{}) {
1323
t.Helper()
1424

15-
if diff := cmp.Diff(exp, act); diff != "" {
25+
if diff := Diff(exp, act); diff != "" {
1626
t.Fatalf("unexpected %v: %v", name, diff)
1727
}
1828
}

internal/test/cmp/cmp.go

Lines changed: 0 additions & 16 deletions
This file was deleted.

internal/test/wstest/echo.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
"time"
99

1010
"nhooyr.io/websocket"
11-
"nhooyr.io/websocket/internal/test/cmp"
11+
"nhooyr.io/websocket/internal/test/assert"
1212
"nhooyr.io/websocket/internal/test/xrand"
1313
"nhooyr.io/websocket/internal/xsync"
1414
)
@@ -76,7 +76,7 @@ func Echo(ctx context.Context, c *websocket.Conn, max int) error {
7676
}
7777

7878
if !bytes.Equal(msg, act) {
79-
return fmt.Errorf("unexpected msg read: %v", cmp.Diff(msg, act))
79+
return fmt.Errorf("unexpected msg read: %v", assert.Diff(msg, act))
8080
}
8181

8282
return nil

0 commit comments

Comments
 (0)