Skip to content

Commit 5dca20e

Browse files
authored
Merge pull request #156 from nhooyr/improve-rand
Improve usage of math/rand versus crypto/rand
2 parents 4f91d7a + e476358 commit 5dca20e

File tree

9 files changed

+47
-38
lines changed

9 files changed

+47
-38
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,18 @@ jobs:
77
container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f
88
steps:
99
- uses: actions/checkout@v1
10-
with:
11-
fetch-depth: 1
1210
- run: ./ci/fmt.sh
1311
lint:
1412
runs-on: ubuntu-latest
1513
container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f
1614
steps:
1715
- uses: actions/checkout@v1
18-
with:
19-
fetch-depth: 1
2016
- run: ./ci/lint.sh
2117
test:
2218
runs-on: ubuntu-latest
2319
container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f
2420
steps:
2521
- uses: actions/checkout@v1
26-
with:
27-
fetch-depth: 1
2822
- run: ./ci/test.sh
2923
env:
3024
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
@@ -33,6 +27,4 @@ jobs:
3327
container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f
3428
steps:
3529
- uses: actions/checkout@v1
36-
with:
37-
fetch-depth: 1
3830
- run: ./ci/wasm.sh

assert_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@ import (
66
"math/rand"
77
"reflect"
88
"strings"
9+
"time"
910

1011
"github.com/google/go-cmp/cmp"
1112

1213
"nhooyr.io/websocket"
1314
"nhooyr.io/websocket/wsjson"
1415
)
1516

17+
func init() {
18+
rand.Seed(time.Now().UnixNano())
19+
}
20+
1621
// https://github.com/google/go-cmp/issues/40#issuecomment-328615283
1722
func cmpDiff(exp, act interface{}) string {
1823
return cmp.Diff(exp, act, deepAllowUnexported(exp, act))

ci/wasm.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ go install github.com/agnivade/wasmbrowsertest
2525
export WS_ECHO_SERVER_URL
2626
GOOS=js GOARCH=wasm go test -exec=wasmbrowsertest ./...
2727

28-
kill "$wsjstestPID"
28+
kill "$wsjstestPID" || true
2929
if ! wait "$wsjstestPID"; then
3030
echo "--- wsjstest exited unsuccessfully"
3131
echo "output:"

conn.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@ package websocket
55
import (
66
"bufio"
77
"context"
8-
cryptorand "crypto/rand"
8+
"crypto/rand"
99
"errors"
1010
"fmt"
1111
"io"
1212
"io/ioutil"
1313
"log"
14-
"math/rand"
1514
"runtime"
1615
"strconv"
1716
"sync"
@@ -82,6 +81,7 @@ type Conn struct {
8281
setReadTimeout chan context.Context
8382
setWriteTimeout chan context.Context
8483

84+
pingCounter *atomicInt64
8585
activePingsMu sync.Mutex
8686
activePings map[string]chan<- struct{}
8787
}
@@ -100,6 +100,7 @@ func (c *Conn) init() {
100100
c.setReadTimeout = make(chan context.Context)
101101
c.setWriteTimeout = make(chan context.Context)
102102

103+
c.pingCounter = &atomicInt64{}
103104
c.activePings = make(map[string]chan<- struct{})
104105

105106
c.writeHeaderBuf = makeWriteHeaderBuf()
@@ -669,7 +670,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
669670
c.writeHeader.payloadLength = int64(len(p))
670671

671672
if c.client {
672-
_, err := io.ReadFull(cryptorand.Reader, c.writeHeader.maskKey[:])
673+
_, err := io.ReadFull(rand.Reader, c.writeHeader.maskKey[:])
673674
if err != nil {
674675
return 0, fmt.Errorf("failed to generate masking key: %w", err)
675676
}
@@ -839,10 +840,6 @@ func (c *Conn) writeClose(p []byte, cerr error) error {
839840
return nil
840841
}
841842

842-
func init() {
843-
rand.Seed(time.Now().UnixNano())
844-
}
845-
846843
// Ping sends a ping to the peer and waits for a pong.
847844
// Use this to measure latency or ensure the peer is responsive.
848845
// Ping must be called concurrently with Reader as it does
@@ -851,10 +848,9 @@ func init() {
851848
//
852849
// TCP Keepalives should suffice for most use cases.
853850
func (c *Conn) Ping(ctx context.Context) error {
854-
id := rand.Uint64()
855-
p := strconv.FormatUint(id, 10)
851+
p := c.pingCounter.Increment(1)
856852

857-
err := c.ping(ctx, p)
853+
err := c.ping(ctx, strconv.FormatInt(p, 10))
858854
if err != nil {
859855
return fmt.Errorf("failed to ping: %w", err)
860856
}

conn_common.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -211,21 +211,22 @@ func (c *Conn) setCloseErr(err error) {
211211

212212
// See https://github.com/nhooyr/websocket/issues/153
213213
type atomicInt64 struct {
214-
v atomic.Value
214+
v int64
215215
}
216216

217217
func (v *atomicInt64) Load() int64 {
218-
i, ok := v.v.Load().(int64)
219-
if !ok {
220-
return 0
221-
}
222-
return i
218+
return atomic.LoadInt64(&v.v)
223219
}
224220

225221
func (v *atomicInt64) Store(i int64) {
226-
v.v.Store(i)
222+
atomic.StoreInt64(&v.v, i)
227223
}
228224

229225
func (v *atomicInt64) String() string {
230-
return fmt.Sprint(v.v.Load())
226+
return fmt.Sprint(v.Load())
227+
}
228+
229+
// Increment increments the value and returns the new value.
230+
func (v *atomicInt64) Increment(delta int64) int64 {
231+
return atomic.AddInt64(&v.v, delta)
231232
}

conn_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ import (
3737
"nhooyr.io/websocket/wspb"
3838
)
3939

40+
func init() {
41+
rand.Seed(time.Now().UnixNano())
42+
}
43+
4044
func TestHandshake(t *testing.T) {
4145
t.Parallel()
4246

@@ -911,10 +915,6 @@ func TestConn(t *testing.T) {
911915
}
912916
}
913917

914-
func init() {
915-
rand.Seed(time.Now().UnixNano())
916-
}
917-
918918
func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) {
919919
var conns int64
920920
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

frame_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@ import (
1010
"strconv"
1111
"strings"
1212
"testing"
13+
"time"
1314

1415
"github.com/google/go-cmp/cmp"
1516
)
1617

18+
func init() {
19+
rand.Seed(time.Now().UnixNano())
20+
}
21+
1722
func randBool() bool {
1823
return rand.Intn(1) == 0
1924
}

handshake.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ import (
66
"bufio"
77
"bytes"
88
"context"
9+
"crypto/rand"
910
"crypto/sha1"
1011
"encoding/base64"
1112
"errors"
1213
"fmt"
1314
"io"
1415
"io/ioutil"
15-
"math/rand"
1616
"net/http"
1717
"net/textproto"
1818
"net/url"
@@ -299,7 +299,11 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
299299
req.Header.Set("Connection", "Upgrade")
300300
req.Header.Set("Upgrade", "websocket")
301301
req.Header.Set("Sec-WebSocket-Version", "13")
302-
req.Header.Set("Sec-WebSocket-Key", makeSecWebSocketKey())
302+
secWebSocketKey, err := makeSecWebSocketKey()
303+
if err != nil {
304+
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
305+
}
306+
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
303307
if len(opts.Subprotocols) > 0 {
304308
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
305309
}
@@ -403,8 +407,11 @@ func returnBufioWriter(bw *bufio.Writer) {
403407
bufioWriterPool.Put(bw)
404408
}
405409

406-
func makeSecWebSocketKey() string {
410+
func makeSecWebSocketKey() (string, error) {
407411
b := make([]byte, 16)
408-
rand.Read(b)
409-
return base64.StdEncoding.EncodeToString(b)
412+
_, err := io.ReadFull(rand.Reader, b)
413+
if err != nil {
414+
return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
415+
}
416+
return base64.StdEncoding.EncodeToString(b), nil
410417
}

handshake_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,14 +367,17 @@ func Test_verifyServerHandshake(t *testing.T) {
367367
resp := w.Result()
368368

369369
r := httptest.NewRequest("GET", "/", nil)
370-
key := makeSecWebSocketKey()
370+
key, err := makeSecWebSocketKey()
371+
if err != nil {
372+
t.Fatal(err)
373+
}
371374
r.Header.Set("Sec-WebSocket-Key", key)
372375

373376
if resp.Header.Get("Sec-WebSocket-Accept") == "" {
374377
resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
375378
}
376379

377-
err := verifyServerResponse(r, resp)
380+
err = verifyServerResponse(r, resp)
378381
if (err == nil) != tc.success {
379382
t.Fatalf("unexpected error: %+v", err)
380383
}

0 commit comments

Comments
 (0)