From 516e999b8afbb8c937b9dba447baea5a37c9afa4 Mon Sep 17 00:00:00 2001 From: sada-sigsci Date: Fri, 14 Jun 2019 13:34:03 -0700 Subject: [PATCH 1/8] Added a hook called when a websocket message is received or sent --- forward/fwd.go | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/forward/fwd.go b/forward/fwd.go index b8500489..95b86a18 100644 --- a/forward/fwd.go +++ b/forward/fwd.go @@ -43,6 +43,9 @@ type ReqRewriter interface { Rewrite(r *http.Request) } +// WsHook websocket message hook called when message is received or sent +type WsHook func(req *http.Request, messageType int, reader io.Reader) (io.Reader, error) + type optSetter func(f *Forwarder) error // PassHostHeader specifies if a client's Host header field should be delegated @@ -137,6 +140,22 @@ func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn)) } } +// WebsocketMessageReceivedHook defines a hook called when websocket message is received +func WebsocketMessageReceivedHook(hook WsHook) optSetter { + return func(f *Forwarder) error { + f.httpForwarder.websocketMessageReceivedHook = hook + return nil + } +} + +// WebsocketMessageSentHook defines a hook called when websocket message is sent +func WebsocketMessageSentHook(hook WsHook) optSetter { + return func(f *Forwarder) error { + f.httpForwarder.websocketMessageSentHook = hook + return nil + } +} + // ResponseModifier defines a response modifier for the HTTP forwarder func ResponseModifier(responseModifier func(*http.Response) error) optSetter { return func(f *Forwarder) error { @@ -201,6 +220,8 @@ type httpForwarder struct { bufferPool httputil.BufferPool websocketConnectionClosedHook func(req *http.Request, conn net.Conn) + websocketMessageReceivedHook WsHook + websocketMessageSentHook WsHook } const defaultFlushInterval = time.Duration(100) * time.Millisecond @@ -403,7 +424,7 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, errClient := make(chan error, 1) errBackend := make(chan error, 1) - replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) { + replicateWebsocketConn := func(dst, src *websocket.Conn, websocketMessageHook WsHook, errc chan error) { forward := func(messageType int, reader io.Reader) error { writer, err := dst.NextWriter(messageType) @@ -448,6 +469,12 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, } break } + if websocketMessageHook != nil { + if reader, err = websocketMessageHook(req, msgType, reader); err != nil { + errc <- err + break + } + } err = forward(msgType, reader) if err != nil { errc <- err @@ -456,8 +483,8 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, } } - go replicateWebsocketConn(underlyingConn, targetConn, errClient) - go replicateWebsocketConn(targetConn, underlyingConn, errBackend) + go replicateWebsocketConn(underlyingConn, targetConn, f.websocketMessageSentHook, errClient) + go replicateWebsocketConn(targetConn, underlyingConn, f.websocketMessageReceivedHook, errBackend) var message string select { From 4173a4d1ca4efac37f00ccd10345499bfdeb2c1e Mon Sep 17 00:00:00 2001 From: nickg-sigsci Date: Wed, 8 Apr 2020 21:19:02 -0700 Subject: [PATCH 2/8] Update go.mod fix module declaration --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 5219aacc..9582890a 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/vulcand/oxy +module github.com/signalsciences/oxy go 1.12 From e4364cd09c23bb52cd5cefd0facfe456a7629340 Mon Sep 17 00:00:00 2001 From: Brian Rectanus Date: Fri, 19 Jun 2020 12:50:16 -0400 Subject: [PATCH 3/8] Revert "Update go.mod" This reverts commit 4173a4d1ca4efac37f00ccd10345499bfdeb2c1e. --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 9582890a..5219aacc 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/signalsciences/oxy +module github.com/vulcand/oxy go 1.12 From e81853a74f82add4ac860bbebec11616d68a61b8 Mon Sep 17 00:00:00 2001 From: Clifton Kaznocha Date: Thu, 28 May 2020 13:31:25 -0700 Subject: [PATCH 4/8] Do not use global websocket.DefaultDialer This change makes it so that each forward gets its own dialer rather then all sharing the global `websocket.DefaultDialer`. It fixes the flaky `TestWebSocketNumGoRoutine` test and allows `WebsocketTLSClientConfig` to set a different TLS config than the one used in the http `RoundTripper`, the TLS config in the http `RoundTripper` will still be used as a fallback if one wasn't set by the user. Adds the new `optSetter` `WebsocketNetDialContext` to set a custom DialContet for WebSocket use. - `go test -run=TestWebSocketNumGoRoutine -count=100 ./forward` now passes. Removed the skip directive. - Closes https://github.com/vulcand/oxy/issues/199 - Closes https://github.com/vulcand/oxy/issues/125 --- forward/fwd.go | 35 +++++++++++++++++++++++++---------- forward/fwd_websocket_test.go | 1 - 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/forward/fwd.go b/forward/fwd.go index 95b86a18..7407321e 100644 --- a/forward/fwd.go +++ b/forward/fwd.go @@ -5,6 +5,7 @@ package forward import ( "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -73,10 +74,21 @@ func Rewriter(r ReqRewriter) optSetter { } } -// WebsocketTLSClientConfig define the websocker client TLS configuration +// WebsocketTLSClientConfig define the websocket client TLS configuration func WebsocketTLSClientConfig(tcc *tls.Config) optSetter { return func(f *Forwarder) error { - f.httpForwarder.tlsClientConfig = tcc + f.websocketDialer.TLSClientConfig = tcc.Clone() + // WebSocket is only in http/1.1 + f.websocketDialer.TLSClientConfig.NextProtos = []string{"http/1.1"} + + return nil + } +} + +// WebsocketNetDialContext define the websocket client DialContext function +func WebsocketNetDialContext(dialContext func(ctx context.Context, network string, addr string) (net.Conn, error)) optSetter { + return func(f *Forwarder) error { + f.websocketDialer.NetDialContext = dialContext return nil } } @@ -222,6 +234,7 @@ type httpForwarder struct { websocketConnectionClosedHook func(req *http.Request, conn net.Conn) websocketMessageReceivedHook WsHook websocketMessageSentHook WsHook + websocketDialer *websocket.Dialer } const defaultFlushInterval = time.Duration(100) * time.Millisecond @@ -241,6 +254,12 @@ func New(setters ...optSetter) (*Forwarder, error) { httpForwarder: &httpForwarder{log: &internalLogger{Logger: log.StandardLogger()}}, handlerContext: &handlerContext{}, } + + f.websocketDialer = &websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, + } + for _, s := range setters { if err := s(f); err != nil { return nil, err @@ -272,6 +291,9 @@ func New(setters ...optSetter) (*Forwarder, error) { if f.tlsClientConfig == nil { if ht, ok := f.httpForwarder.roundTripper.(*http.Transport); ok { f.tlsClientConfig = ht.TLSClientConfig + if f.websocketDialer.TLSClientConfig == nil && ht.TLSClientConfig != nil { + _ = WebsocketTLSClientConfig(ht.TLSClientConfig)(f) + } } } @@ -358,14 +380,7 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, outReq := f.copyWebSocketRequest(req) - dialer := websocket.DefaultDialer - - if outReq.URL.Scheme == "wss" && f.tlsClientConfig != nil { - dialer.TLSClientConfig = f.tlsClientConfig.Clone() - // WebSocket is only in http/1.1 - dialer.TLSClientConfig.NextProtos = []string{"http/1.1"} - } - targetConn, resp, err := dialer.DialContext(outReq.Context(), outReq.URL.String(), outReq.Header) + targetConn, resp, err := f.websocketDialer.DialContext(outReq.Context(), outReq.URL.String(), outReq.Header) if err != nil { if resp == nil { ctx.errHandler.ServeHTTP(w, req, err) diff --git a/forward/fwd_websocket_test.go b/forward/fwd_websocket_test.go index 35ed6b40..cb1d5461 100644 --- a/forward/fwd_websocket_test.go +++ b/forward/fwd_websocket_test.go @@ -277,7 +277,6 @@ func TestWebSocketPassHost(t *testing.T) { } func TestWebSocketNumGoRoutine(t *testing.T) { - t.Skip("Flaky on goroutine") f, err := New() require.NoError(t, err) From e788601d0fdf6fc231ccfddb8f2b4f2f07bdf725 Mon Sep 17 00:00:00 2001 From: Clifton Kaznocha Date: Thu, 28 May 2020 16:46:58 -0700 Subject: [PATCH 5/8] Revert un-skipping test --- forward/fwd_websocket_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/forward/fwd_websocket_test.go b/forward/fwd_websocket_test.go index cb1d5461..35ed6b40 100644 --- a/forward/fwd_websocket_test.go +++ b/forward/fwd_websocket_test.go @@ -277,6 +277,7 @@ func TestWebSocketPassHost(t *testing.T) { } func TestWebSocketNumGoRoutine(t *testing.T) { + t.Skip("Flaky on goroutine") f, err := New() require.NoError(t, err) From e5b151cf03782f97a2a5caad6af283d6f7517a2e Mon Sep 17 00:00:00 2001 From: nickg-sigsci Date: Wed, 8 Apr 2020 21:19:02 -0700 Subject: [PATCH 6/8] Update go.mod fix module declaration --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 601e5bec..d903ecd2 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/vulcand/oxy +module github.com/signalsciences/oxy go 1.13 From f2e2bb7b054a0554acec1697ba7ce6a8b70f7e11 Mon Sep 17 00:00:00 2001 From: Brian Rectanus Date: Fri, 19 Jun 2020 12:50:16 -0400 Subject: [PATCH 7/8] Revert "Update go.mod" This reverts commit 4173a4d1ca4efac37f00ccd10345499bfdeb2c1e. --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index d903ecd2..601e5bec 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/signalsciences/oxy +module github.com/vulcand/oxy go 1.13 From 2fbb821b571f9deb2d6ca8c5a0dec4abf62d441f Mon Sep 17 00:00:00 2001 From: sada-sigsci Date: Fri, 14 Jun 2019 13:34:03 -0700 Subject: [PATCH 8/8] Added a hook called when a websocket message is received or sent --- forward/fwd.go | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/forward/fwd.go b/forward/fwd.go index 8f1cb1ce..c05e551b 100644 --- a/forward/fwd.go +++ b/forward/fwd.go @@ -42,6 +42,9 @@ type ReqRewriter interface { Rewrite(r *http.Request) } +// WsHook websocket message hook called when message is received or sent +type WsHook func(req *http.Request, messageType int, reader io.Reader) (io.Reader, error) + type optSetter func(f *Forwarder) error // PassHostHeader specifies if a client's Host header field should be delegated. @@ -136,7 +139,23 @@ func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn)) } } -// ResponseModifier defines a response modifier for the HTTP forwarder. +// WebsocketMessageReceivedHook defines a hook called when websocket message is received. +func WebsocketMessageReceivedHook(hook WsHook) optSetter { + return func(f *Forwarder) error { + f.httpForwarder.websocketMessageReceivedHook = hook + return nil + } +} + +// WebsocketMessageSentHook defines a hook called when websocket message is sent. +func WebsocketMessageSentHook(hook WsHook) optSetter { + return func(f *Forwarder) error { + f.httpForwarder.websocketMessageSentHook = hook + return nil + } +} + +// ResponseModifier defines a response modifier for the HTTP forwarder func ResponseModifier(responseModifier func(*http.Response) error) optSetter { return func(f *Forwarder) error { f.httpForwarder.modifyResponse = responseModifier @@ -180,6 +199,8 @@ type httpForwarder struct { bufferPool httputil.BufferPool websocketConnectionClosedHook func(req *http.Request, conn net.Conn) + websocketMessageReceivedHook WsHook + websocketMessageSentHook WsHook } const defaultFlushInterval = 100 * clock.Millisecond @@ -383,7 +404,8 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, errClient := make(chan error, 1) errBackend := make(chan error, 1) - replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) { + replicateWebsocketConn := func(dst, src *websocket.Conn, websocketMessageHook WsHook, errc chan error) { + forward := func(messageType int, reader io.Reader) error { writer, err := dst.NextWriter(messageType) if err != nil { @@ -424,6 +446,12 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, } break } + if websocketMessageHook != nil { + if reader, err = websocketMessageHook(req, msgType, reader); err != nil { + errc <- err + break + } + } err = forward(msgType, reader) if err != nil { errc <- err @@ -432,8 +460,8 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, } } - go replicateWebsocketConn(underlyingConn, targetConn, errClient) - go replicateWebsocketConn(targetConn, underlyingConn, errBackend) + go replicateWebsocketConn(underlyingConn, targetConn, f.websocketMessageSentHook, errClient) + go replicateWebsocketConn(targetConn, underlyingConn, f.websocketMessageReceivedHook, errBackend) var message string select {