Skip to content

Commit 28129c1

Browse files
committed
Add headers to websocket handler object
1 parent 56af8ae commit 28129c1

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

tests/test_query_parse.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from websocket_server import WebSocketHandler
2+
3+
4+
def test_websocket_handler_query_parse():
5+
case1 = WebSocketHandler.parse_query("GET /?a=hello HTTP/1.1")
6+
case2 = WebSocketHandler.parse_query("GET / HTTP/1.1")
7+
case3 = WebSocketHandler.parse_query("GET /?a=hello&b=world HTTP/1.1")
8+
case4 = WebSocketHandler.parse_query("GET /?a=hello&a=world HTTP/1.1")
9+
assert case1 == {'a': ['hello']}
10+
assert case2 == {}
11+
assert case3 == {'a': ['hello'], 'b': ['world']}
12+
assert case4 == {'a': ['hello', 'world']}

websocket_server/websocket_server.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import errno
1212
import threading
1313
from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler
14+
from urllib.parse import urlparse, parse_qs
1415

1516
from websocket_server.thread import WebsocketServerThread
1617

@@ -261,6 +262,8 @@ class WebSocketHandler(StreamRequestHandler):
261262

262263
def __init__(self, socket, addr, server):
263264
self.server = server
265+
self.headers = {}
266+
self.query_params = {}
264267
assert not hasattr(self, "_send_lock"), "_send_lock already exists"
265268
self._send_lock = threading.Lock()
266269
if server.key and server.cert:
@@ -412,6 +415,16 @@ def send_text(self, message, opcode=OPCODE_TEXT):
412415
with self._send_lock:
413416
self.request.send(header + payload)
414417

418+
@staticmethod
419+
def parse_query(http_get):
420+
"""
421+
Parses the query parameters from the first line.
422+
Example: "GET /?q=hello HTTP/1.1" will be parsed to {'q': ['hello']}
423+
"""
424+
query = http_get.split(" ")[1] # example: http_get = "GET /?q=hello HTTP/1.1"
425+
parsed_url = urlparse(query)
426+
return parse_qs(parsed_url.query)
427+
415428
def read_http_headers(self):
416429
headers = {}
417430
# first line should be HTTP GET
@@ -424,6 +437,8 @@ def read_http_headers(self):
424437
break
425438
head, value = header.split(':', 1)
426439
headers[head.lower().strip()] = value.strip()
440+
self.headers = headers
441+
self.query_params = WebSocketHandler.parse_query(http_get)
427442
return headers
428443

429444
def handshake(self):

0 commit comments

Comments
 (0)