Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for batch requests to JSON RPC server #4688

Merged
merged 5 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 105 additions & 42 deletions pyk/src/pyk/rpc/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import json
import logging
from dataclasses import dataclass
from functools import partial
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import TYPE_CHECKING, Any, Final
from typing import TYPE_CHECKING, Any, Final, NamedTuple

from typing_extensions import Protocol

Expand Down Expand Up @@ -71,37 +72,80 @@ class JsonRpcMethod(Protocol):
def __call__(self, **kwargs: Any) -> Any: ...


class JsonRpcRequestHandler(BaseHTTPRequestHandler):
methods: dict[str, JsonRpcMethod]
class JsonRpcRequest(NamedTuple):
id: str | int
method: str
params: Any

def __init__(self, methods: dict[str, JsonRpcMethod], *args: Any, **kwargs: Any) -> None:
self.methods = methods
super().__init__(*args, **kwargs)

def send_json_error(self, code: int, message: str, id: Any = None) -> None:
error_dict = {
class JsonRpcBatchRequest(NamedTuple):
requests: tuple[JsonRpcRequest]


class JsonRpcResult:

def encode(self) -> bytes:
raise NotImplementedError('Subclasses must implement this method')
RaoulSchaffranek marked this conversation as resolved.
Show resolved Hide resolved


@dataclass(frozen=True)
class JsonRpcError(JsonRpcResult):

code: int
message: str
id: str | int | None

def to_json(self) -> dict[str, Any]:
return {
'jsonrpc': JsonRpcServer.JSONRPC_VERSION,
'error': {
'code': code,
'message': message,
'code': self.code,
'message': self.message,
},
'id': id,
'id': self.id,
}
error_bytes = json.dumps(error_dict).encode('ascii')
self.set_response()
self.wfile.write(error_bytes)

def send_json_success(self, result: Any, id: Any) -> None:
response_dict = {
def encode(self) -> bytes:
return json.dumps(self.to_json()).encode('ascii')


@dataclass(frozen=True)
class JsonRpcSuccess(JsonRpcResult):
payload: Any
id: Any

def to_json(self) -> dict[str, Any]:
return {
'jsonrpc': JsonRpcServer.JSONRPC_VERSION,
'result': result,
'id': id,
'result': self.payload,
'id': self.id,
}
response_bytes = json.dumps(response_dict).encode('ascii')
self.set_response()

def encode(self) -> bytes:
return json.dumps(self.to_json()).encode('ascii')


@dataclass(frozen=True)
class JsonRpcBatchResult(JsonRpcResult):
results: tuple[JsonRpcError | JsonRpcSuccess, ...]

def encode(self) -> bytes:
return json.dumps([result.to_json() for result in self.results]).encode('ascii')


class JsonRpcRequestHandler(BaseHTTPRequestHandler):
methods: dict[str, JsonRpcMethod]

def __init__(self, methods: dict[str, JsonRpcMethod], *args: Any, **kwargs: Any) -> None:
self.methods = methods
super().__init__(*args, **kwargs)

def _send_response(self, response: JsonRpcResult) -> None:
self.send_response_headers()
response_bytes = response.encode()
self.wfile.write(response_bytes)

def set_response(self) -> None:
def send_response_headers(self) -> None:
self.send_response(200)
self.send_header('Content-type', 'text/html')
self.end_headers()
Expand All @@ -113,44 +157,63 @@ def do_POST(self) -> None: # noqa: N802
content = self.rfile.read(int(content_len))
_LOGGER.debug(f'Received bytes: {content.decode()}')

request: dict
request: dict[str, Any] | list[dict[str, Any]]
try:
request = json.loads(content)
_LOGGER.info(f'Received request: {request}')
except json.JSONDecodeError:
_LOGGER.warning(f'Invalid JSON: {content.decode()}')
self.send_json_error(-32700, 'Invalid JSON')
json_error = JsonRpcError(-32700, 'Invalid JSON', None)
self._send_response(json_error)
return

required_fields = ['jsonrpc', 'method', 'id']
for field in required_fields:
if field not in request:
_LOGGER.warning(f'Missing required field "{field}": {request}')
self.send_json_error(-32600, f'Invalid request: missing field "{field}"', request.get('id', None))
return
response: JsonRpcResult
if isinstance(request, list):
response = self._batch_request(request)
else:
response = self._single_request(request)

jsonrpc_version = request['jsonrpc']
if jsonrpc_version != JsonRpcServer.JSONRPC_VERSION:
_LOGGER.warning(f'Bad JSON-RPC version: {jsonrpc_version}')
self.send_json_error(-32600, f'Invalid request: bad version: "{jsonrpc_version}"', request['id'])
return
self._send_response(response)

method_name = request['method']
if method_name not in self.methods:
_LOGGER.warning(f'Method not found: {method_name}')
self.send_json_error(-32601, f'Method "{method_name}" not found.', request['id'])
return
def _batch_request(self, requests: list[dict[str, Any]]) -> JsonRpcBatchResult:
return JsonRpcBatchResult(tuple(self._single_request(request) for request in requests))

def _single_request(self, request: dict[str, Any]) -> JsonRpcError | JsonRpcSuccess:
validation_result = self._validate_request(request)
if isinstance(validation_result, JsonRpcError):
return validation_result

id, method_name, params = validation_result
method = self.methods[method_name]
params = request.get('params', None)
_LOGGER.info(f'Executing method {method_name}')
result: Any
if type(params) is dict:
result = method(**params)
elif type(params) is list:
result = method(*params)
elif params is None:
result = method()
else:
self.send_json_error(-32602, 'Unrecognized method parameter format.')
return JsonRpcError(-32602, 'Unrecognized method parameter format.', id)
_LOGGER.debug(f'Got response {result}')
self.send_json_success(result, request['id'])
return JsonRpcSuccess(result, id)

def _validate_request(self, request_dict: Any) -> JsonRpcRequest | JsonRpcError:
required_fields = ['jsonrpc', 'method', 'id']
for field in required_fields:
if field not in request_dict:
return JsonRpcError(-32600, f'Invalid request: missing field "{field}"', request_dict.get('id', None))

jsonrpc_version = request_dict['jsonrpc']
if jsonrpc_version != JsonRpcServer.JSONRPC_VERSION:
return JsonRpcError(
-32600, f'Invalid request: bad version: "{jsonrpc_version}"', request_dict.get('id', None)
)

method_name = request_dict['method']
if method_name not in self.methods.keys():
return JsonRpcError(-32601, f'Method "{method_name}" not found.', request_dict.get('id', None))

return JsonRpcRequest(
method=request_dict['method'], params=request_dict.get('params', None), id=request_dict.get('id', None)
)
126 changes: 125 additions & 1 deletion pyk/src/tests/integration/test_json_rpc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import json
from http.client import HTTPConnection
from threading import Thread
from time import sleep
from typing import TYPE_CHECKING

from pyk.cterm import CTerm
from pyk.kast.inner import KApply, KSequence, KSort, KToken
Expand All @@ -11,6 +14,9 @@
from pyk.rpc.rpc import JsonRpcServer, ServeRpcOptions
from pyk.testing import KRunTest

if TYPE_CHECKING:
from typing import Any


class StatefulKJsonRpcServer(JsonRpcServer):
krun: KRun
Expand Down Expand Up @@ -67,7 +73,7 @@ def exec_add(self) -> int:
return int(k_cell.token)


class TestJsonRPCServer(KRunTest):
class TestJsonKRPCServer(KRunTest):
KOMPILE_DEFINITION = """
module JSON-RPC-EXAMPLE-SYNTAX
imports INT-SYNTAX
Expand Down Expand Up @@ -133,3 +139,121 @@ def wait_until_ready() -> None:

server.shutdown()
thread.join()


class StatefulJsonRpcServer(JsonRpcServer):

x: int = 42
y: int = 43

def __init__(self, options: ServeRpcOptions) -> None:
super().__init__(options)

self.register_method('get_x', self.exec_get_x)
self.register_method('get_y', self.exec_get_y)
self.register_method('set_x', self.exec_set_x)
self.register_method('set_y', self.exec_set_y)
self.register_method('add', self.exec_add)

def exec_get_x(self) -> int:
return self.x

def exec_get_y(self) -> int:
return self.y

def exec_set_x(self, n: int) -> None:
self.x = n

def exec_set_y(self, n: int) -> None:
self.y = n

def exec_add(self) -> int:
return self.x + self.y


class TestJsonRPCServer(KRunTest):

def test_json_rpc_server(self) -> None:
server = StatefulJsonRpcServer(ServeRpcOptions({'port': 0}))

def run_server() -> None:
server.serve()

def wait_until_server_is_up() -> None:
while True:
try:
server.port()
return
except ValueError:
sleep(0.1)

thread = Thread(target=run_server)
thread.start()

wait_until_server_is_up()

http_client = HTTPConnection('localhost', server.port())
rpc_client = SimpleClient(http_client)

def wait_until_ready() -> None:
while True:
try:
rpc_client.request('get_x', [])
except ConnectionRefusedError:
sleep(0.1)
continue
break

wait_until_ready()

rpc_client.request('set_x', [123])
res = rpc_client.request('get_x')
assert res == 123

rpc_client.request('set_y', [456])
res = rpc_client.request('get_y')
assert res == 456

res = rpc_client.request('add', [])
assert res == (123 + 456)

res = rpc_client.batch_request(('set_x', [1]), ('set_y', [2]), ('add', []))
assert len(res) == 3
assert res[2]['result'] == 1 + 2

server.shutdown()
thread.join()


class SimpleClient:

client: HTTPConnection
_request_id: int = 0

def __init__(self, client: HTTPConnection) -> None:
self.client = client

def request_id(self) -> int:
self._request_id += 1
return self._request_id

def request(self, method: str, params: Any = None) -> Any:
body = json.dumps({'jsonrpc': '2.0', 'method': method, 'params': params, 'id': self.request_id()})

self.client.request('POST', '/', body)
response = self.client.getresponse()
result = json.loads(response.read())
return result['result']

def batch_request(self, *requests: tuple[str, Any]) -> list[Any]:
body = json.dumps(
[
{'jsonrpc': '2.0', 'method': method, 'params': params, 'id': self.request_id()}
for method, params in requests
]
)

self.client.request('POST', '/', body)
response = self.client.getresponse()
result = json.loads(response.read())
return result
Loading