From 0921f16099c727eadbc54141d2b01997ff27d901 Mon Sep 17 00:00:00 2001 From: Alex Carney Date: Fri, 29 Nov 2024 19:25:57 +0000 Subject: [PATCH 1/5] refactor: unify message handler execution This commit removes the separate handling of requests and notifications, opting instead for a generic `_execute_handler` method that can handle both types. The difference between the two can be captured in how the results of the handlers are processed. Request handlers are given a callback that send the result onto the client, while notification handlers are simply checked for errors. This commit also uses a future to report the result of a synchronous handler function. While not necessary, it allows synchronous handlers to treated the same as the other async execution types, simplifying the code overall. --- pygls/protocol/json_rpc.py | 113 ++++++++++++++++++------------ pygls/protocol/language_server.py | 9 ++- 2 files changed, 75 insertions(+), 47 deletions(-) diff --git a/pygls/protocol/json_rpc.py b/pygls/protocol/json_rpc.py index 7e0fb93c..a9e2cfb9 100644 --- a/pygls/protocol/json_rpc.py +++ b/pygls/protocol/json_rpc.py @@ -23,16 +23,11 @@ import logging import sys import traceback +import typing import uuid from concurrent.futures import Future from functools import partial -from typing import ( - TYPE_CHECKING, - Any, - Optional, - Type, - Union, -) +from typing import Any, Callable, Type, Union import attrs from cattrs.errors import ClassValidationError @@ -55,12 +50,14 @@ ) from pygls.feature_manager import FeatureManager, is_thread_function -if TYPE_CHECKING: +if typing.TYPE_CHECKING: from cattrs import Converter from pygls.io_ import AsyncWriter, Writer from pygls.server import JsonRPCServer + MessageHandler = Union[Callable[[Any], Any],] + MessageCallback = Callable[[Future[Any]], None] logger = logging.getLogger(__name__) @@ -119,8 +116,8 @@ def __init__(self, server: JsonRPCServer, converter: Converter): self._shutdown = False # Book keeping for in-flight requests - self._request_futures: dict[str, Future[Any]] = {} - self._result_types: dict[str, Any] = {} + self._request_futures: dict[str | int, Future[Any]] = {} + self._result_types: dict[str | int, Any] = {} self.fm = FeatureManager(server, converter) self.writer: AsyncWriter | Writer | None = None @@ -129,50 +126,58 @@ def __init__(self, server: JsonRPCServer, converter: Converter): def __call__(self): return self - def _execute_notification(self, handler, *params): - """Execute the given notification message handler.""" - if asyncio.iscoroutinefunction(handler): - future = asyncio.ensure_future(handler(*params)) - future.add_done_callback(self._execute_notification_callback) - - elif is_thread_function(handler): - future = self._server.thread_pool.submit(handler, *params) - future.add_done_callback(self._execute_notification_callback) + def _execute_handler( + self, + msg_id: str | int, + handler: MessageHandler, + params: Any, + callback: MessageCallback, + ): + """Execute the given message handler. - else: - handler(*params) + Parameters + ---------- + msg_id + The id of the message being handled - def _execute_notification_callback(self, future): - """Callback used for async/threaded notification message handler.""" - if future.exception(): - try: - raise future.exception() - except Exception: - error = JsonRpcInternalError.of(sys.exc_info()) - logger.exception('Exception occurred in notification: "%s"', error) + handler + The request handler to call - # Revisit. Client does not support response with msg_id = None - # https://stackoverflow.com/questions/31091376/json-rpc-2-0-allow-notifications-to-have-an-error-response - # self._send_response(None, error=error) + params + The parameters object to pass to the handler - def _execute_request(self, msg_id, handler, params): - """Execute the given request message handler.""" + callback + An optional callback function to call upon completion of the handler + """ if asyncio.iscoroutinefunction(handler): future = asyncio.ensure_future(handler(params)) self._request_futures[msg_id] = future - future.add_done_callback(partial(self._execute_request_callback, msg_id)) + future.add_done_callback(callback) elif is_thread_function(handler): future = self._server.thread_pool.submit(handler, params) self._request_futures[msg_id] = future - - future.add_done_callback(partial(self._execute_request_callback, msg_id)) + future.add_done_callback(callback) else: - self._send_response(msg_id, handler(params)) + # While a future is not necessary for a synchronous function, it allows us to use a single + # pattern across all handler types + future: Future[Any] = Future() + future.add_done_callback(callback) + + try: + result = handler(params) + future.set_result(result) + except Exception as exc: + future.set_exception(exc) + + def _send_handler_result(self, future: Future[Any], *, msg_id: str | int): + """Callback function that sends the result of the given future to the client. + + Used to respond to request messages. + """ + self._request_futures.pop(msg_id, None) - def _execute_request_callback(self, msg_id, future): - """Callback used for async/threaded request message handler.""" try: if not future.cancelled(): self._send_response(msg_id, result=future.result()) @@ -183,12 +188,23 @@ def _execute_request_callback(self, msg_id, future): f'Request with id "{msg_id}" is canceled' ).to_response_error(), ) - self._request_futures.pop(msg_id, None) except Exception: error = JsonRpcInternalError.of(sys.exc_info()) logger.exception('Exception occurred for message "%s": %s', msg_id, error) self._send_response(msg_id, error=error.to_response_error()) + def _check_handler_result(self, future: Future[Any]): + """Check the result of the future to see if an error occurred. + + Used when handling notification messages + """ + if future.exception(): + try: + raise future.exception() + except Exception: + error = JsonRpcInternalError.of(sys.exc_info()) + self._server._report_server_error(error, FeatureNotificationError) + def _get_handler(self, feature_name): """Returns builtin or used defined feature by name if exists.""" @@ -220,7 +236,9 @@ def _handle_notification(self, method_name, params): try: handler = self._get_handler(method_name) - self._execute_notification(handler, params) + self._execute_handler( + str(uuid.uuid4()), handler, params, self._check_handler_result + ) except JsonRpcMethodNotFound: logger.warning("Ignoring notification for unknown method %r", method_name) except Exception as error: @@ -241,7 +259,12 @@ def _handle_request(self, msg_id, method_name, params): if method_name == WORKSPACE_EXECUTE_COMMAND: handler(params, msg_id) else: - self._execute_request(msg_id, handler, params) + self._execute_handler( + msg_id, + handler, + params, + callback=partial(self._send_handler_result, msg_id=msg_id), + ) except JsonRpcMethodNotFound as error: logger.warning( @@ -430,11 +453,11 @@ def set_writer( self.writer = writer self._include_headers = include_headers - def get_message_type(self, method: str) -> Optional[Type]: + def get_message_type(self, method: str) -> Type[Any] | None: """Return the type definition of the message associated with the given method.""" return None - def get_result_type(self, method: str) -> Optional[Type]: + def get_result_type(self, method: str) -> Type[Any] | None: """Return the type definition of the result associated with the given method.""" return None diff --git a/pygls/protocol/language_server.py b/pygls/protocol/language_server.py index 1353d6d0..5a7bb76a 100644 --- a/pygls/protocol/language_server.py +++ b/pygls/protocol/language_server.py @@ -22,7 +22,7 @@ import logging import sys import typing -from functools import lru_cache +from functools import lru_cache, partial from itertools import zip_longest from typing import ( Callable, @@ -261,7 +261,12 @@ def lsp_workspace__execute_command( ) -> None: """Executes commands with passed arguments and returns a value.""" cmd_handler = self.fm.commands[params.command] - self._execute_request(msg_id, cmd_handler, params.arguments) + self._execute_handler( + msg_id, + cmd_handler, + params.arguments, + partial(self._send_handler_result, msg_id=msg_id), + ) @lsp_method(types.WINDOW_WORK_DONE_PROGRESS_CANCEL) def lsp_work_done_progress_cancel( From b4580b860a0663102505d31e5dff219a52280037 Mon Sep 17 00:00:00 2001 From: Alex Carney Date: Fri, 29 Nov 2024 21:40:57 +0000 Subject: [PATCH 2/5] chore: add missing version number --- .vscode/launch.json | 1 + 1 file changed, 1 insertion(+) diff --git a/.vscode/launch.json b/.vscode/launch.json index 02ecd063..3a498f44 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,4 +1,5 @@ { + "version": "0.2.0", "configurations": [ { "name": "pygls: Debug Server", From 2d9513f009544795a13aa0f27647196382996a44 Mon Sep 17 00:00:00 2001 From: Alex Carney Date: Sat, 30 Nov 2024 14:25:26 +0000 Subject: [PATCH 3/5] refactor: re-implement pygls' builtin handlers using generators The underlying cause of #433 is that pygls' current implementation of builtin feature handlers cannot guarantee that an async user handler will finish executing before pygls responds with the answer generated from the builtin handler. This commit adds support for another execution model, generators. A generator handler can yield to another sub-handler method like so ``` yield handler_func, args, kwargs ``` The `JsonRPCProtocol` class with then schedule the execution of `handler_func(*args, **kwargs)` as if it were a normal handler function (meaning `handler_func could be async, threaded, sync or a generator itself!) The result of the sub-handler is then sent back into the generator handler allowing the top-level handler to continue and even make use of the result! This gives pygls' built-in handlers much greater control over exactly when a user handler is called, allowing us to fix #433 and opens up a lot other exciting possibilities! This also removes the need for the `LSPMeta` metaclass, so it and the corresponding module have been deleted. --- pygls/protocol/__init__.py | 6 +- pygls/protocol/json_rpc.py | 113 ++++++++++++++++++++++------ pygls/protocol/language_server.py | 118 ++++++++++++++++++++---------- pygls/protocol/lsp_meta.py | 51 ------------- tests/test_protocol.py | 31 -------- 5 files changed, 173 insertions(+), 146 deletions(-) delete mode 100644 pygls/protocol/lsp_meta.py diff --git a/pygls/protocol/__init__.py b/pygls/protocol/__init__.py index 1a30b485..6c61e877 100644 --- a/pygls/protocol/__init__.py +++ b/pygls/protocol/__init__.py @@ -1,7 +1,6 @@ import json -from typing import Any - from collections import namedtuple +from typing import Any from lsprotocol import converters @@ -12,7 +11,6 @@ JsonRPCResponseMessage, ) from pygls.protocol.language_server import LanguageServerProtocol, lsp_method -from pygls.protocol.lsp_meta import LSPMeta, call_user_feature def _dict_to_object(d: Any): @@ -68,8 +66,6 @@ def default_converter(): "JsonRPCRequestMessage", "JsonRPCResponseMessage", "JsonRPCNotification", - "LSPMeta", - "call_user_feature", "_dict_to_object", "_params_field_structure_hook", "_result_field_structure_hook", diff --git a/pygls/protocol/json_rpc.py b/pygls/protocol/json_rpc.py index a9e2cfb9..b4f68d59 100644 --- a/pygls/protocol/json_rpc.py +++ b/pygls/protocol/json_rpc.py @@ -34,7 +34,6 @@ from lsprotocol.types import ( CANCEL_REQUEST, EXIT, - WORKSPACE_EXECUTE_COMMAND, ResponseError, ResponseErrorMessage, ) @@ -51,6 +50,8 @@ from pygls.feature_manager import FeatureManager, is_thread_function if typing.TYPE_CHECKING: + from collections.abc import Generator + from cattrs import Converter from pygls.io_ import AsyncWriter, Writer @@ -130,8 +131,9 @@ def _execute_handler( self, msg_id: str | int, handler: MessageHandler, - params: Any, callback: MessageCallback, + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, ): """Execute the given message handler. @@ -143,22 +145,41 @@ def _execute_handler( handler The request handler to call - params - The parameters object to pass to the handler - callback An optional callback function to call upon completion of the handler + + args + Positional arguments to pass to the handler + + kwargs + Keyword arguments to pass to the handler """ + args = args or tuple() + kwargs = kwargs or {} + if asyncio.iscoroutinefunction(handler): - future = asyncio.ensure_future(handler(params)) + future = asyncio.ensure_future(handler(*args, **kwargs)) self._request_futures[msg_id] = future future.add_done_callback(callback) elif is_thread_function(handler): - future = self._server.thread_pool.submit(handler, params) + future = self._server.thread_pool.submit(handler, *args, **kwargs) self._request_futures[msg_id] = future future.add_done_callback(callback) + + elif inspect.isgeneratorfunction(handler): + future: Future[Any] = Future() + self._request_futures[msg_id] = future + future.add_done_callback(callback) + + try: + self._run_generator( + future=None, gen=handler(*args, **kwargs), result_future=future + ) + except Exception as exc: + future.set_exception(exc) + else: # While a future is not necessary for a synchronous function, it allows us to use a single # pattern across all handler types @@ -166,11 +187,61 @@ def _execute_handler( future.add_done_callback(callback) try: - result = handler(params) + result = handler(*args, **kwargs) future.set_result(result) except Exception as exc: future.set_exception(exc) + def _run_generator( + self, + future: Future[Any] | None, + *, + gen: Generator[Any, Any, Any], + result_future: Future[Any], + ): + """Run the next portion of the given generator. + + Generator handlers are designed to ``yield`` to other handlers that are executed + separately before their results are sent back into the generator allowing + execution to continue. + + Generator handlers are primarily used in the implementation of pygls' builtin + feature handlers. + + Parameters + ---------- + future + The future that contains the result of the previously executed handler, if any + + gen + The generator to run + + result_future + The future to send the final result to once the generator stops. + """ + + if result_future.cancelled(): + return + + try: + value = future.result() if future is not None else None + handler, args, kwargs = gen.send(value) + + self._execute_handler( + str(uuid.uuid4()), + handler, + args=args, + kwargs=kwargs, + callback=partial( + self._run_generator, gen=gen, result_future=result_future + ), + ) + except StopIteration as result: + result_future.set_result(result.value) + + except Exception as exc: + result_future.set_exception(exc) + def _send_handler_result(self, future: Future[Any], *, msg_id: str | int): """Callback function that sends the result of the given future to the client. @@ -192,6 +263,7 @@ def _send_handler_result(self, future: Future[Any], *, msg_id: str | int): error = JsonRpcInternalError.of(sys.exc_info()) logger.exception('Exception occurred for message "%s": %s', msg_id, error) self._send_response(msg_id, error=error.to_response_error()) + self._server._report_server_error(error, FeatureRequestError) def _check_handler_result(self, future: Future[Any]): """Check the result of the future to see if an error occurred. @@ -237,7 +309,10 @@ def _handle_notification(self, method_name, params): try: handler = self._get_handler(method_name) self._execute_handler( - str(uuid.uuid4()), handler, params, self._check_handler_result + msg_id=str(uuid.uuid4()), + handler=handler, + args=(params,), + callback=self._check_handler_result, ) except JsonRpcMethodNotFound: logger.warning("Ignoring notification for unknown method %r", method_name) @@ -255,16 +330,12 @@ def _handle_request(self, msg_id, method_name, params): try: handler = self._get_handler(method_name) - # workspace/executeCommand is a special case - if method_name == WORKSPACE_EXECUTE_COMMAND: - handler(params, msg_id) - else: - self._execute_handler( - msg_id, - handler, - params, - callback=partial(self._send_handler_result, msg_id=msg_id), - ) + self._execute_handler( + msg_id=msg_id, + handler=handler, + args=(params,), + callback=partial(self._send_handler_result, msg_id=msg_id), + ) except JsonRpcMethodNotFound as error: logger.warning( @@ -369,10 +440,10 @@ def handle_message(self, message): if hasattr(message, "method"): if hasattr(message, "id"): - logger.debug("Request message received.") + logger.debug("Request %r received", message.method) self._handle_request(message.id, message.method, message.params) else: - logger.debug("Notification message received.") + logger.debug("Notification %r received", message.method) self._handle_notification(message.method, message.params) else: if hasattr(message, "error"): diff --git a/pygls/protocol/language_server.py b/pygls/protocol/language_server.py index 5a7bb76a..9a879924 100644 --- a/pygls/protocol/language_server.py +++ b/pygls/protocol/language_server.py @@ -22,29 +22,25 @@ import logging import sys import typing -from functools import lru_cache, partial +from functools import lru_cache from itertools import zip_longest -from typing import ( - Callable, - Optional, - Type, - TypeVar, -) from lsprotocol import types from pygls.capabilities import ServerCapabilitiesBuilder from pygls.protocol.json_rpc import JsonRPCProtocol -from pygls.protocol.lsp_meta import LSPMeta from pygls.uris import from_fs_path from pygls.workspace import Workspace if typing.TYPE_CHECKING: + from collections.abc import Generator + from typing import Any, Callable, Optional, Type, TypeVar + from cattrs import Converter from pygls.lsp.server import LanguageServer -F = TypeVar("F", bound=Callable) + F = TypeVar("F", bound=Callable) logger = logging.getLogger(__name__) @@ -57,7 +53,7 @@ def decorator(f: F) -> F: return decorator -class LanguageServerProtocol(JsonRPCProtocol, metaclass=LSPMeta): +class LanguageServerProtocol(JsonRPCProtocol): """A class that represents language server protocol. It contains implementations for generic LSP features. @@ -105,17 +101,22 @@ def workspace(self) -> Workspace: return self._workspace @lru_cache() - def get_message_type(self, method: str) -> Optional[Type]: + def get_message_type(self, method: str) -> Type[Any] | None: """Return LSP type definitions, as provided by `lsprotocol`""" return types.METHOD_TO_TYPES.get(method, (None,))[0] @lru_cache() - def get_result_type(self, method: str) -> Optional[Type]: + def get_result_type(self, method: str) -> Type[Any] | None: return types.METHOD_TO_TYPES.get(method, (None, None))[1] @lsp_method(types.EXIT) def lsp_exit(self, *args) -> None: """Stops the server process.""" + + # Ensure that the user handler is called first + if (user_handler := self.fm.features.get(types.EXIT)) is not None: + yield user_handler, args, None + returncode = 0 if self._shutdown else 1 if self.writer is None: sys.exit(returncode) @@ -176,13 +177,19 @@ def lsp_initialize(self, params: types.InitializeParams) -> types.InitializeResu ) @lsp_method(types.INITIALIZED) - def lsp_initialized(self, *args) -> None: + def lsp_initialized(self, *args): """Notification received when client and server are connected.""" - pass + + if (user_handler := self.fm.features.get(types.INITIALIZED)) is not None: + yield user_handler, args, None @lsp_method(types.SHUTDOWN) def lsp_shutdown(self, *args) -> None: """Request from client which asks server to shutdown.""" + + if (user_handler := self.fm.features.get(types.SHUTDOWN)) is not None: + yield user_handler, args, None + for future in self._request_futures.values(): future.cancel() @@ -190,59 +197,86 @@ def lsp_shutdown(self, *args) -> None: return None @lsp_method(types.TEXT_DOCUMENT_DID_CHANGE) - def lsp_text_document__did_change( - self, params: types.DidChangeTextDocumentParams - ) -> None: + def lsp_text_document__did_change(self, params: types.DidChangeTextDocumentParams): """Updates document's content. (Incremental(from server capabilities); not configurable for now) """ for change in params.content_changes: self.workspace.update_text_document(params.text_document, change) + if ( + user_handler := self.fm.features.get(types.TEXT_DOCUMENT_DID_CHANGE) + ) is not None: + yield user_handler, (params,), None + @lsp_method(types.TEXT_DOCUMENT_DID_CLOSE) - def lsp_text_document__did_close( - self, params: types.DidCloseTextDocumentParams - ) -> None: + def lsp_text_document__did_close(self, params: types.DidCloseTextDocumentParams): """Removes document from workspace.""" self.workspace.remove_text_document(params.text_document.uri) + if ( + user_handler := self.fm.features.get(types.TEXT_DOCUMENT_DID_CLOSE) + ) is not None: + yield user_handler, (params,), None + @lsp_method(types.TEXT_DOCUMENT_DID_OPEN) - def lsp_text_document__did_open( - self, params: types.DidOpenTextDocumentParams - ) -> None: + def lsp_text_document__did_open(self, params: types.DidOpenTextDocumentParams): """Puts document to the workspace.""" self.workspace.put_text_document(params.text_document) + if ( + user_handler := self.fm.features.get(types.TEXT_DOCUMENT_DID_OPEN) + ) is not None: + yield user_handler, (params,), None + @lsp_method(types.NOTEBOOK_DOCUMENT_DID_OPEN) def lsp_notebook_document__did_open( self, params: types.DidOpenNotebookDocumentParams - ) -> None: + ): """Put a notebook document into the workspace""" self.workspace.put_notebook_document(params) + if ( + user_handler := self.fm.features.get(types.NOTEBOOK_DOCUMENT_DID_OPEN) + ) is not None: + yield user_handler, (params,), None + @lsp_method(types.NOTEBOOK_DOCUMENT_DID_CHANGE) def lsp_notebook_document__did_change( self, params: types.DidChangeNotebookDocumentParams - ) -> None: + ): """Update a notebook's contents""" self.workspace.update_notebook_document(params) + if ( + user_handler := self.fm.features.get(types.NOTEBOOK_DOCUMENT_DID_CHANGE) + ) is not None: + yield user_handler, (params,), None + @lsp_method(types.NOTEBOOK_DOCUMENT_DID_CLOSE) def lsp_notebook_document__did_close( self, params: types.DidCloseNotebookDocumentParams - ) -> None: + ): """Remove a notebook document from the workspace.""" self.workspace.remove_notebook_document(params) + if ( + user_handler := self.fm.features.get(types.NOTEBOOK_DOCUMENT_DID_CLOSE) + ) is not None: + yield user_handler, (params,), None + @lsp_method(types.SET_TRACE) def lsp_set_trace(self, params: types.SetTraceParams) -> None: """Changes server trace value.""" self.trace = params.value + if (user_handler := self.fm.features.get(types.SET_TRACE)) is not None: + yield user_handler, (params,), None + @lsp_method(types.WORKSPACE_DID_CHANGE_WORKSPACE_FOLDERS) def lsp_workspace__did_change_workspace_folders( self, params: types.DidChangeWorkspaceFoldersParams - ) -> None: + ): """Adds/Removes folders from the workspace.""" logger.info("Workspace folders changed: %s", params) @@ -255,23 +289,26 @@ def lsp_workspace__did_change_workspace_folders( if f_remove: self.workspace.remove_folder(f_remove.uri) + if ( + user_handler := self.fm.features.get( + types.WORKSPACE_DID_CHANGE_WORKSPACE_FOLDERS + ) + ) is not None: + yield user_handler, (params,), None + @lsp_method(types.WORKSPACE_EXECUTE_COMMAND) def lsp_workspace__execute_command( - self, params: types.ExecuteCommandParams, msg_id: str - ) -> None: + self, params: types.ExecuteCommandParams + ) -> Generator[Any, Any, Any]: """Executes commands with passed arguments and returns a value.""" cmd_handler = self.fm.commands[params.command] - self._execute_handler( - msg_id, - cmd_handler, - params.arguments, - partial(self._send_handler_result, msg_id=msg_id), - ) + + # Call the user's command implementation + result = yield cmd_handler, (params.arguments,), None + return result @lsp_method(types.WINDOW_WORK_DONE_PROGRESS_CANCEL) - def lsp_work_done_progress_cancel( - self, params: types.WorkDoneProgressCancelParams - ) -> None: + def lsp_work_done_progress_cancel(self, params: types.WorkDoneProgressCancelParams): """Received a progress cancellation from client.""" future = self.progress.tokens.get(params.token) if future is None: @@ -280,3 +317,8 @@ def lsp_work_done_progress_cancel( ) else: future.cancel() + + if ( + user_handler := self.fm.features.get(types.WINDOW_WORK_DONE_PROGRESS_CANCEL) + ) is not None: + yield user_handler, (params,), None diff --git a/pygls/protocol/lsp_meta.py b/pygls/protocol/lsp_meta.py deleted file mode 100644 index 0dc52db0..00000000 --- a/pygls/protocol/lsp_meta.py +++ /dev/null @@ -1,51 +0,0 @@ -import functools -import logging -from pygls.constants import ATTR_FEATURE_TYPE -from pygls.feature_manager import assign_help_attrs - - -logger = logging.getLogger(__name__) - - -def call_user_feature(base_func, method_name): - """Wraps generic LSP features and calls user registered feature - immediately after it. - """ - - @functools.wraps(base_func) - def decorator(self, *args, **kwargs): - ret_val = base_func(self, *args, **kwargs) - - try: - user_func = self.fm.features[method_name] - self._execute_notification(user_func, *args, **kwargs) - except KeyError: - pass - except Exception: - logger.exception( - 'Failed to handle user defined notification "%s": %s', method_name, args - ) - - return ret_val - - return decorator - - -class LSPMeta(type): - """Wraps LSP built-in features (`lsp_` naming convention). - - Built-in features cannot be overridden but user defined features with - the same LSP name will be called after them. - """ - - def __new__(mcs, cls_name, cls_bases, cls): - for attr_name, attr_val in cls.items(): - if callable(attr_val) and hasattr(attr_val, "method_name"): - method_name = attr_val.method_name - wrapped = call_user_feature(attr_val, method_name) - assign_help_attrs(wrapped, method_name, ATTR_FEATURE_TYPE) - cls[attr_name] = wrapped - - logger.debug('Added decorator for lsp method: "%s"', attr_name) - - return super().__new__(mcs, cls_name, cls_bases, cls) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 9a526370..73d6047f 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -16,22 +16,17 @@ ############################################################################ import io import json -from pathlib import Path from typing import Optional -from unittest.mock import Mock import attrs import pytest from lsprotocol.types import ( PROGRESS, TEXT_DOCUMENT_COMPLETION, - ClientCapabilities, CompletionItem, CompletionItemKind, CompletionParams, CompletionResponse, - InitializeParams, - InitializeResult, Position, ProgressParams, ShutdownResponse, @@ -521,29 +516,3 @@ def test_serialize_request_message(method, params, expected): actual = json.loads(buffer.getvalue()) assert actual == expected - - -def test_initialize_should_return_server_capabilities(client_server): - _, server = client_server - params = InitializeParams( - process_id=1234, - root_uri=Path(__file__).parent.as_uri(), - capabilities=ClientCapabilities(), - ) - - server_capabilities = server.protocol.lsp_initialize(params) - - assert isinstance(server_capabilities, InitializeResult) - - -def test_ignore_unknown_notification(client_server): - _, server = client_server - - fn = server.protocol._execute_notification - server.protocol._execute_notification = Mock() - - server.protocol._handle_notification("random/notification", None) - assert not server.protocol._execute_notification.called - - # Remove mock - server.protocol._execute_notification = fn From e4862c18238c5afc34e90b9f347f670d510ed1c9 Mon Sep 17 00:00:00 2001 From: Alex Carney Date: Sat, 30 Nov 2024 19:36:16 +0000 Subject: [PATCH 4/5] fix: don't cancel the future handling the shutdown request This commit makes use of a `ContextVar` to keep track of the current request's id, allowing handlers to reference it. Most importantly so that the shutdown request handler does not cancel its own future! --- pygls/protocol/json_rpc.py | 43 +++++++++++++++++++++++-------- pygls/protocol/language_server.py | 8 ++++-- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/pygls/protocol/json_rpc.py b/pygls/protocol/json_rpc.py index b4f68d59..15daaad8 100644 --- a/pygls/protocol/json_rpc.py +++ b/pygls/protocol/json_rpc.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio +import contextvars import enum import inspect import json @@ -62,6 +63,10 @@ logger = logging.getLogger(__name__) +# cattrs needs access to this type definition so we cannot include it in the +# TYPE_CHECKING block above +MsgId = Union[str | int] + @attrs.define class JsonRPCNotification: @@ -80,7 +85,7 @@ class JsonRPCRequestMessage: Used as a fallback for unknown types. """ - id: Union[int, str] + id: MsgId method: str jsonrpc: str params: Any @@ -92,7 +97,7 @@ class JsonRPCResponseMessage: Used as a fallback for unknown types. """ - id: Union[int, str] + id: MsgId jsonrpc: str result: Any @@ -117,8 +122,11 @@ def __init__(self, server: JsonRPCServer, converter: Converter): self._shutdown = False # Book keeping for in-flight requests - self._request_futures: dict[str | int, Future[Any]] = {} - self._result_types: dict[str | int, Any] = {} + self._ctx_msg_id: contextvars.ContextVar[MsgId | None] = contextvars.ContextVar( + "msg_id", default=None + ) + self._request_futures: dict[MsgId, Future[Any]] = {} + self._result_types: dict[MsgId, Any] = {} self.fm = FeatureManager(server, converter) self.writer: AsyncWriter | Writer | None = None @@ -127,9 +135,15 @@ def __init__(self, server: JsonRPCServer, converter: Converter): def __call__(self): return self + @property + def msg_id(self) -> MsgId | None: + """Returns the id of the current context (if it exists).""" + ctx = contextvars.copy_context() + return ctx.get(self._ctx_msg_id) + def _execute_handler( self, - msg_id: str | int, + msg_id: MsgId, handler: MessageHandler, callback: MessageCallback, args: tuple[Any, ...] | None = None, @@ -300,7 +314,7 @@ def _handle_cancel_notification(self, msg_id): if future.cancel(): logger.info('Cancelled request with id "%s"', msg_id) - def _handle_notification(self, method_name, params): + def _handle_notification(self, method_name: str, params: Any): """Handles a notification from the client.""" if method_name == CANCEL_REQUEST: self._handle_cancel_notification(params.id) @@ -325,11 +339,13 @@ def _handle_notification(self, method_name, params): ) self._server._report_server_error(error, FeatureNotificationError) - def _handle_request(self, msg_id, method_name, params): + def _handle_request(self, msg_id: MsgId, method_name: str, params: Any): """Handles a request from the client.""" try: handler = self._get_handler(method_name) + # Set the request id within the current context. + self._ctx_msg_id.set(msg_id) self._execute_handler( msg_id=msg_id, handler=handler, @@ -438,20 +454,25 @@ def handle_message(self, message): logger.warning("Server shutting down. No more requests!") return + # Run each handler within its own context. + ctx = contextvars.copy_context() + if hasattr(message, "method"): if hasattr(message, "id"): logger.debug("Request %r received", message.method) - self._handle_request(message.id, message.method, message.params) + ctx.run( + self._handle_request, message.id, message.method, message.params + ) else: logger.debug("Notification %r received", message.method) - self._handle_notification(message.method, message.params) + ctx.run(self._handle_notification, message.method, message.params) else: if hasattr(message, "error"): logger.debug("Error message received.") - self._handle_response(message.id, None, message.error) + ctx.run(self._handle_response, message.id, None, message.error) else: logger.debug("Response message received.") - self._handle_response(message.id, message.result) + ctx.run(self._handle_response, message.id, message.result) def _send_data(self, data): """Sends data to the client.""" diff --git a/pygls/protocol/language_server.py b/pygls/protocol/language_server.py index 9a879924..709631d9 100644 --- a/pygls/protocol/language_server.py +++ b/pygls/protocol/language_server.py @@ -190,8 +190,12 @@ def lsp_shutdown(self, *args) -> None: if (user_handler := self.fm.features.get(types.SHUTDOWN)) is not None: yield user_handler, args, None - for future in self._request_futures.values(): - future.cancel() + # Don't cancel the future for this request! + current_id = self.msg_id + + for msg_id, future in self._request_futures.items(): + if msg_id != current_id and not future.done(): + future.cancel() self._shutdown = True return None From 1035549ced4f1fe8150bf6a6151b937c6ec61e31 Mon Sep 17 00:00:00 2001 From: Alex Carney Date: Sat, 30 Nov 2024 20:16:04 +0000 Subject: [PATCH 5/5] feat: yield to the user's initialize before calculating capabilities This should enable the dynamic registration of features during initialization, as discussed in #381 --- pygls/capabilities.py | 51 ++++++++++++++++--------------- pygls/protocol/language_server.py | 41 ++++++++++++++++--------- tests/test_feature_manager.py | 6 ++-- 3 files changed, 55 insertions(+), 43 deletions(-) diff --git a/pygls/capabilities.py b/pygls/capabilities.py index fcc9cf02..9557370d 100644 --- a/pygls/capabilities.py +++ b/pygls/capabilities.py @@ -14,13 +14,12 @@ # See the License for the specific language governing permissions and # # limitations under the License. # ############################################################################ -from functools import reduce -from typing import Any, Dict, List, Optional, Set, Union, TypeVar import logging +from functools import reduce +from typing import Any, Dict, List, Optional, Set, TypeVar, Union from lsprotocol import types - logger = logging.getLogger(__name__) T = TypeVar("T") @@ -62,6 +61,7 @@ def __init__( commands: List[str], text_document_sync_kind: types.TextDocumentSyncKind, notebook_document_sync: Optional[types.NotebookDocumentSyncOptions] = None, + position_encoding: types.PositionEncodingKind = types.PositionEncodingKind.Utf16, ): self.client_capabilities = client_capabilities self.features = features @@ -71,12 +71,35 @@ def __init__( self.notebook_document_sync = notebook_document_sync self.server_cap = types.ServerCapabilities() + self.server_cap.position_encoding = position_encoding def _provider_options(self, feature: str, default: T) -> Optional[Union[T, Any]]: if feature in self.features: return self.feature_options.get(feature, default) return None + @classmethod + def choose_position_encoding( + cls, client_capabilities: types.ClientCapabilities + ) -> types.PositionEncodingKind: + server_encoding = types.PositionEncodingKind.Utf16 + + if (general := client_capabilities.general) is None: + return server_encoding + + if (encodings := general.position_encodings) is None: + return server_encoding + + # We match client preference where this an overlap between its and our supported encodings. + for client_encoding in encodings: + if client_encoding in _SUPPORTED_ENCODINGS: + server_encoding = client_encoding + return server_encoding + + logger.warning(f"Unknown `PositionEncoding`s: {encodings}") + + return server_encoding + def _with_text_document_sync(self): open_close = ( types.TEXT_DOCUMENT_DID_OPEN in self.features @@ -415,27 +438,6 @@ def _with_inline_value_provider(self): self.server_cap.inline_value_provider = value return self - def _with_position_encodings(self): - self.server_cap.position_encoding = types.PositionEncodingKind.Utf16 - - general = self.client_capabilities.general - if general is None: - return self - - encodings = general.position_encodings - if encodings is None: - return self - - # We match client preference where this an overlap between its and our supported encodings. - for encoding in encodings: - if encoding in _SUPPORTED_ENCODINGS: - self.server_cap.position_encoding = encoding - return self - - logger.warning(f"Unknown `PositionEncoding`s: {encodings}") - - return self - def _build(self): return self.server_cap @@ -474,6 +476,5 @@ def build(self): ._with_workspace_capabilities() ._with_diagnostic_provider() ._with_inline_value_provider() - ._with_position_encodings() ._build() ) diff --git a/pygls/protocol/language_server.py b/pygls/protocol/language_server.py index 709631d9..9f5b431d 100644 --- a/pygls/protocol/language_server.py +++ b/pygls/protocol/language_server.py @@ -130,7 +130,9 @@ def lsp_exit(self, *args) -> None: sys.exit(returncode) @lsp_method(types.INITIALIZE) - def lsp_initialize(self, params: types.InitializeParams) -> types.InitializeResult: + def lsp_initialize( + self, params: types.InitializeParams + ) -> Generator[Any, Any, types.InitializeResult]: """Method that initializes language server. It will compute and return server capabilities based on registered features. @@ -142,19 +144,9 @@ def lsp_initialize(self, params: types.InitializeParams) -> types.InitializeResu text_document_sync_kind = self._server._text_document_sync_kind notebook_document_sync = self._server._notebook_document_sync - # Initialize server capabilities self.client_capabilities = params.capabilities - self.server_capabilities = ServerCapabilitiesBuilder( - self.client_capabilities, - set({**self.fm.features, **self.fm.builtin_features}.keys()), - self.fm.feature_options, - list(self.fm.commands.keys()), - text_document_sync_kind, - notebook_document_sync, - ).build() - logger.debug( - "Server capabilities: %s", - json.dumps(self.server_capabilities, default=self._serialize_message), + position_encoding = ServerCapabilitiesBuilder.choose_position_encoding( + self.client_capabilities ) root_path = params.root_path @@ -162,13 +154,32 @@ def lsp_initialize(self, params: types.InitializeParams) -> types.InitializeResu if root_path is not None and root_uri is None: root_uri = from_fs_path(root_path) - # Initialize the workspace + # Initialize the workspace before yielding to the user's initialize handler workspace_folders = params.workspace_folders or [] self._workspace = Workspace( root_uri, text_document_sync_kind, workspace_folders, - self.server_capabilities.position_encoding, + position_encoding, + ) + + if (user_handler := self.fm.features.get(types.INITIALIZE)) is not None: + yield user_handler, (params,), None + + # Now that the user has had the opportunity to setup additional features, calculate + # the server's capabilities + self.server_capabilities = ServerCapabilitiesBuilder( + self.client_capabilities, + set({**self.fm.features, **self.fm.builtin_features}.keys()), + self.fm.feature_options, + list(self.fm.commands.keys()), + text_document_sync_kind, + notebook_document_sync, + position_encoding, + ).build() + logger.debug( + "Server capabilities: %s", + json.dumps(self.server_capabilities, default=self._serialize_message), ) return types.InitializeResult( diff --git a/tests/test_feature_manager.py b/tests/test_feature_manager.py index 119af551..5ceadc2a 100644 --- a/tests/test_feature_manager.py +++ b/tests/test_feature_manager.py @@ -18,6 +18,8 @@ from typing import Any import pytest +from lsprotocol import types as lsp + from pygls.capabilities import ServerCapabilitiesBuilder from pygls.exceptions import ( CommandAlreadyRegisteredError, @@ -29,7 +31,6 @@ has_ls_param_or_annotation, wrap_with_server, ) -from lsprotocol import types as lsp class Temp: @@ -704,13 +705,13 @@ def _(): [], None, None, + ServerCapabilitiesBuilder.choose_position_encoding(capabilities), ).build() assert expected == actual def test_register_prepare_rename_no_client_support(feature_manager: FeatureManager): - @feature_manager.feature(lsp.TEXT_DOCUMENT_RENAME) def _(): pass @@ -734,7 +735,6 @@ def _(): def test_register_prepare_rename_with_client_support(feature_manager: FeatureManager): - @feature_manager.feature(lsp.TEXT_DOCUMENT_RENAME) def _(): pass