diff --git a/jupyter_server_proxy/handlers.py b/jupyter_server_proxy/handlers.py index a5987fc8..9427b5a0 100644 --- a/jupyter_server_proxy/handlers.py +++ b/jupyter_server_proxy/handlers.py @@ -172,6 +172,9 @@ async def http_get(self, host, port, proxy_path=""): "Subclasses of ProxyHandler should implement http_get" ) + async def get(self, *args, **kwargs): + return await self.http_get(*args, **kwargs) + def post(self, host, port, proxy_path=""): raise NotImplementedError( "Subclasses of ProxyHandler should implement this post" @@ -333,6 +336,14 @@ async def proxy(self, host, port, proxied_path): "See https://jupyter-server-proxy.readthedocs.io/en/latest/arbitrary-ports-hosts.html for info.", ) + self._record_activity() + + if ( + self.request.method == "GET" + and self.request.headers.get("Upgrade", "").lower() == "websocket" + ): + return await ensure_async(self.get_websocket(proxied_path)) + # Remove hop-by-hop headers that don't necessarily apply to the request we are making # to the backend. See https://github.com/jupyterhub/jupyter-server-proxy/pull/328 # for more information @@ -351,16 +362,6 @@ async def proxy(self, host, port, proxied_path): if header_to_remove in self.request.headers: del self.request.headers[header_to_remove] - self._record_activity() - - if self.request.headers.get("Upgrade", "").lower() == "websocket": - # We wanna websocket! - # jupyterhub/jupyter-server-proxy@36b3214 - self.log.info( - "we wanna websocket, but we don't define WebSocketProxyHandler" - ) - self.set_status(500) - body = self.request.body if not body: if self.request.method in {"POST", "PUT"}: diff --git a/jupyter_server_proxy/rawsocket.py b/jupyter_server_proxy/rawsocket.py index 89159168..bc222691 100644 --- a/jupyter_server_proxy/rawsocket.py +++ b/jupyter_server_proxy/rawsocket.py @@ -55,6 +55,12 @@ def _create_ws_connection(self, proto: asyncio.BaseProtocol): return loop.create_connection(proto, "localhost", self.port) async def proxy(self, port, path): + if ( + self.request.method == "GET" + and self.request.headers.get("Upgrade", "").lower() == "websocket" + ): + return await super().proxy(port, path) + raise web.HTTPError( 405, "this raw_socket_proxy backend only supports websocket connections" ) diff --git a/jupyter_server_proxy/websocket.py b/jupyter_server_proxy/websocket.py index a43b7795..608edefe 100644 --- a/jupyter_server_proxy/websocket.py +++ b/jupyter_server_proxy/websocket.py @@ -96,8 +96,5 @@ def undisallow(*args2, **kwargs2): setattr(self, method, wrapper(method)) nextparent.__init__(self, *args, **kwargs) - async def get(self, *args, **kwargs): - if self.request.headers.get("Upgrade", "").lower() != "websocket": - return await self.http_get(*args, **kwargs) - else: - await ensure_async(super().get(*args, **kwargs)) + async def get_websocket(self, *args, **kwargs): + await ensure_async(super().get(*args, **kwargs)) diff --git a/tests/resources/jupyter_server_config.py b/tests/resources/jupyter_server_config.py index ac1e0dfe..8bbe541d 100644 --- a/tests/resources/jupyter_server_config.py +++ b/tests/resources/jupyter_server_config.py @@ -18,6 +18,10 @@ def mappathf(path): return p +def mappathf_socket(path): + return path + "socket" + + def translate_ciao(path, host, response, orig_response, port): # Assume that the body has not been modified by any previous rewrite assert response.body == orig_response.body @@ -83,6 +87,10 @@ def my_env(): "X-Custom-Header": "pytest-23456", }, }, + "python-websocket-mappathf_socket": { + "command": [sys.executable, _get_path("websocket.py"), "--port={port}"], + "mappath": mappathf_socket, + }, "python-eventstream": { "command": [sys.executable, _get_path("eventstream.py"), "--port={port}"] }, diff --git a/tests/test_proxies.py b/tests/test_proxies.py index 4573517c..d7a0430f 100644 --- a/tests/test_proxies.py +++ b/tests/test_proxies.py @@ -401,6 +401,19 @@ async def test_server_proxy_websocket_headers(a_server_port_and_token: Tuple[int assert headers["X-Custom-Header"] == "pytest-23456" +async def test_server_proxy_websocket_messages_mappath( + a_server_port_and_token: Tuple[int, str] +) -> None: + PORT, TOKEN = a_server_port_and_token + # Mappath is configured to add "socket" to websocket paths + url = f"ws://{LOCALHOST}:{PORT}/python-websocket-mappathf_socket/echo?token={TOKEN}" + conn = await websocket_connect(url) + expected_msg = "Hello, world!" + await conn.write_message(expected_msg) + msg = await conn.read_message() + assert msg == expected_msg + + @pytest.mark.parametrize( "client_requested,server_received,server_responded,proxy_responded", [