Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Use websockets for json communication" #4548

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 37 additions & 119 deletions codalab/bin/ws_server.py
Original file line number Diff line number Diff line change
@@ -1,157 +1,75 @@
# Main entry point to the CodaLab Websocket Server.
# The Websocket Server handles communication between the REST server and workers.
# Main entry point for CodaLab cl-ws-server.
import argparse
import asyncio
from collections import defaultdict
import logging
import os
import random
import re
import time
from typing import Any, Dict
import websockets
import threading

from codalab.lib.codalab_manager import CodaLabManager


class TimedLock:
"""A lock that gets automatically released after timeout_seconds.
"""

def __init__(self, timeout_seconds: float = 60):
self._lock = threading.Lock()
self._time_since_locked: float
self._timeout: float = timeout_seconds

def acquire(self, blocking=True, timeout=-1):
acquired = self._lock.acquire(blocking, timeout)
if acquired:
self._time_since_locked = time.time()
return acquired

def locked(self):
return self._lock.locked()

def release(self):
self._lock.release()
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
logging.basicConfig(format='%(asctime)s %(message)s %(pathname)s %(lineno)d')

def timeout(self):
return time.time() - self._time_since_locked > self._timeout
worker_to_ws: Dict[str, Any] = {}

def release_if_timeout(self):
if self.locked() and self.timeout():
self.release()

async def rest_server_handler(websocket):
"""Handles routes of the form: /main. This route is called by the rest-server
whenever a worker needs to be pinged (to ask it to check in). The body of the
message is the worker id to ping. This function sends a message to the worker
with that worker id through an appropriate websocket.
"""
# Got a message from the rest server.
worker_id = await websocket.recv()
logger.warning(f"Got a message from the rest server, to ping worker: {worker_id}.")

worker_to_ws: Dict[str, Dict[str, Any]] = defaultdict(
dict
) # Maps worker ID to socket ID to websocket
worker_to_lock: Dict[str, Dict[str, TimedLock]] = defaultdict(
dict
) # Maps worker ID to socket ID to lock
ACK = b'a'
logger = logging.getLogger(__name__)
manager = CodaLabManager()
bundle_model = manager.model()
worker_model = manager.worker_model()
server_secret = os.getenv("CODALAB_SERVER_SECRET")
try:
worker_ws = worker_to_ws[worker_id]
await worker_ws.send(worker_id)
except KeyError:
logger.error(f"Websocket not found for worker: {worker_id}")


async def send_to_worker_handler(server_websocket, worker_id):
"""Handles routes of the form: /send_to_worker/{worker_id}. This route is called by
the rest-server or bundle-manager when either wants to send a message/stream to the worker.
"""
# Authenticate server.
received_secret = await server_websocket.recv()
if received_secret != server_secret:
logger.warning("Server unable to authenticate.")
await server_websocket.close(1008, "Server unable to authenticate.")
return

# Check if any websockets available
if worker_id not in worker_to_ws or len(worker_to_ws[worker_id]) == 0:
logger.warning(f"No websockets currently available for worker {worker_id}")
await server_websocket.close(
1011, f"No websockets currently available for worker {worker_id}"
)
return

# Send message from server to worker.
for socket_id, worker_websocket in random.sample(
worker_to_ws[worker_id].items(), len(worker_to_ws[worker_id])
):
if worker_to_lock[worker_id][socket_id].acquire(blocking=False):
data = await server_websocket.recv()
await worker_websocket.send(data)
await server_websocket.send(ACK)
worker_to_lock[worker_id][socket_id].release()
return

logger.warning(f"All websockets for worker {worker_id} are currently busy.")
await server_websocket.close(1011, f"All websockets for worker {worker_id} are currently busy.")


async def worker_connection_handler(websocket: Any, worker_id: str, socket_id: str) -> None:
"""Handles routes of the form: /worker_connect/{worker_id}/{socket_id}.
This route is called when a worker first connects to the ws-server, creating
a connection that can be used to ask the worker to check-in later.
async def worker_handler(websocket, worker_id):
"""Handles routes of the form: /worker/{id}. This route is called when
a worker first connects to the ws-server, creating a connection that can
be used to ask the worker to check-in later.
"""
# Authenticate worker.
access_token = await websocket.recv()
user_id = worker_model.get_user_id_for_worker(worker_id=worker_id)
authenticated = bundle_model.access_token_exists_for_user(
'codalab_worker_client', user_id, access_token # TODO: Avoid hard-coding this if possible.
)
logger.error(f"AUTHENTICATED: {authenticated}")
if not authenticated:
logger.warning(f"Thread {socket_id} for worker {worker_id} unable to authenticate.")
await websocket.close(
1008, f"Thread {socket_id} for worker {worker_id} unable to authenticate."
)
return

# Establish a connection with worker and keep it alive.
worker_to_ws[worker_id][socket_id] = websocket
worker_to_lock[worker_id][socket_id] = TimedLock()
logger.warning(f"Worker {worker_id} connected; has {len(worker_to_ws[worker_id])} connections")
# runs on worker connect
worker_to_ws[worker_id] = websocket
logger.warning(f"Connected to worker {worker_id}!")

while True:
try:
await asyncio.wait_for(websocket.recv(), timeout=60)
worker_to_lock[worker_id][
socket_id
].release_if_timeout() # Failsafe in case not released
except asyncio.futures.TimeoutError:
pass
except websockets.exceptions.ConnectionClosed:
logger.warning(f"Socket connection closed with worker {worker_id}.")
logger.error(f"Socket connection closed with worker {worker_id}.")
break
del worker_to_ws[worker_id][socket_id]
del worker_to_lock[worker_id][socket_id]
logger.warning(f"Worker {worker_id} now has {len(worker_to_ws[worker_id])} connections")


ROUTES = (
(r'^.*/main$', rest_server_handler),
(r'^.*/worker/(.+)$', worker_handler),
)


async def ws_handler(websocket, *args):
"""Handler for websocket connections. Routes websockets to the appropriate
route handler defined in ROUTES."""
ROUTES = (
(r'^.*/send_to_worker/(.+)$', send_to_worker_handler),
(r'^.*/worker_connect/(.+)/(.+)$', worker_connection_handler),
)
logger.info(f"websocket handler, path: {websocket.path}.")
logger.warning(f"websocket handler, path: {websocket.path}.")
for (pattern, handler) in ROUTES:
match = re.match(pattern, websocket.path)
if match:
return await handler(websocket, *match.groups())
return await websocket.close(1011, f"Path {websocket.path} is not valid.")
assert False


async def async_main():
"""Main function that runs the websocket server."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--port', help='Port to run the server on.', type=int, required=False, default=2901
)
parser.add_argument('--port', help='Port to run the server on.', type=int, required=True)
args = parser.parse_args()
logging.debug(f"Running ws-server on 0.0.0.0:{args.port}")
async with websockets.serve(ws_handler, "0.0.0.0", args.port):
Expand Down
9 changes: 1 addition & 8 deletions codalab/lib/codalab_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,6 @@ def ws_server(self):
ws_port = self.config['ws-server']['ws_port']
return f"ws://ws-server:{ws_port}"

@property # type: ignore
@cached
def server_secret(self):
return os.getenv("CODALAB_SERVER_SECRET")

@property # type: ignore
@cached
def worker_socket_dir(self):
Expand Down Expand Up @@ -385,9 +380,7 @@ def model(self):

@cached
def worker_model(self):
return WorkerModel(
self.model().engine, self.worker_socket_dir, self.ws_server, self.server_secret
)
return WorkerModel(self.model().engine, self.worker_socket_dir, self.ws_server)

@cached
def upload_manager(self):
Expand Down
14 changes: 9 additions & 5 deletions codalab/lib/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _get_target_info_within_bundle(self, target, depth):
read_args = {'type': 'get_target_info', 'depth': depth}
self._send_read_message(worker, response_socket_id, target, read_args)
with closing(self._worker_model.start_listening(response_socket_id)) as sock:
result = self._worker_model.recv_json_message_with_unix_socket(sock, 60)
result = self._worker_model.get_json_message(sock, 60)
if result is None: # dead workers are a fact of life now
logging.info('Unable to reach worker, bundle state {}'.format(bundle_state))
raise NotFoundError(
Expand Down Expand Up @@ -365,7 +365,9 @@ def _send_read_message(self, worker, response_socket_id, target, read_args):
'path': target.subpath,
'read_args': read_args,
}
if not self._worker_model.send_json_message(message, worker['worker_id']):
if not self._worker_model.send_json_message(
worker['socket_id'], worker['worker_id'], message, 60
): # dead workers are a fact of life now
logging.info('Unable to reach worker')

def _send_netcat_message(self, worker, response_socket_id, uuid, port, message):
Expand All @@ -376,19 +378,21 @@ def _send_netcat_message(self, worker, response_socket_id, uuid, port, message):
'port': port,
'message': message,
}
if not self._worker_model.send_json_message(message, worker['worker_id']):
if not self._worker_model.send_json_message(
worker['socket_id'], worker['worker_id'], message, 60
): # dead workers are a fact of life now
logging.info('Unable to reach worker')

def _get_read_response_stream(self, response_socket_id):
with closing(self._worker_model.start_listening(response_socket_id)) as sock:
header_message = self._worker_model.recv_json_message_with_unix_socket(sock, 60)
header_message = self._worker_model.get_json_message(sock, 60)
precondition(header_message is not None, 'Unable to reach worker')
if 'error_code' in header_message:
raise http_error_to_exception(
header_message['error_code'], header_message['error_message']
)

fileobj = self._worker_model.recv_stream(sock, 60)
fileobj = self._worker_model.get_stream(sock, 60)
precondition(fileobj is not None, 'Unable to reach worker')
return fileobj

Expand Down
21 changes: 1 addition & 20 deletions codalab/model/bundle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2881,7 +2881,7 @@ def get_oauth2_token(self, access_token=None, refresh_token=None):

return OAuth2Token(self, **row)

def find_oauth2_token(self, client_id, user_id, expires_after=datetime.datetime.utcnow()):
def find_oauth2_token(self, client_id, user_id, expires_after):
with self.engine.begin() as connection:
row = connection.execute(
select([oauth2_token])
Expand All @@ -2900,25 +2900,6 @@ def find_oauth2_token(self, client_id, user_id, expires_after=datetime.datetime.

return OAuth2Token(self, **row)

def access_token_exists_for_user(self, client_id: str, user_id: str, access_token: str) -> bool:
"""Check that the provided access_token exists in the database for the provided user_id.
"""
with self.engine.begin() as connection:
row = connection.execute(
select([oauth2_token])
.where(
and_(
oauth2_token.c.client_id == client_id,
oauth2_token.c.user_id == user_id,
oauth2_token.c.access_token == access_token,
oauth2_token.c.expires > datetime.datetime.utcnow(),
)
)
.limit(1)
).fetchone()

return row is not None

def save_oauth2_token(self, token):
with self.engine.begin() as connection:
result = connection.execute(oauth2_token.insert().values(token.columns))
Expand Down
Loading