Skip to content

Commit

Permalink
lsp: Migrate to pygls v2
Browse files Browse the repository at this point in the history
  • Loading branch information
alcarney committed Aug 24, 2024
1 parent b875c03 commit e32fa7f
Show file tree
Hide file tree
Showing 14 changed files with 74 additions and 53 deletions.
2 changes: 1 addition & 1 deletion lib/esbonio/esbonio/server/_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ async def update_workspace_configuration(self):
)

try:
results = await self.server.get_configuration_async(params)
results = await self.server.workspace_configuration_async(params)
except Exception:
self.logger.error("Unable to get workspace configuration", exc_info=True)
return
Expand Down
4 changes: 3 additions & 1 deletion lib/esbonio/esbonio/server/features/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def emit(self, record: logging.LogRecord) -> None:
return

log = self.format(record).strip()
self.server.show_message_log(log)
self.server.window_log_message(
types.LogMessageParams(message=log, type=types.MessageType.Log)
)


@attrs.define
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def show_preview_uri(self) -> Optional[Uri]:
self.logger.info("Preview available at: %s", uri.as_string(encode=False))

if self.supports_show_document:
result = await self.server.show_document_async(
result = await self.server.window_show_document_async(
types.ShowDocumentParams(
uri=uri.as_string(encode=False), external=True, take_focus=False
)
Expand Down
23 changes: 10 additions & 13 deletions lib/esbonio/esbonio/server/features/preview_manager/webview.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from lsprotocol import types
from pygls.protocol import JsonRPCProtocol
from pygls.protocol import default_converter
from pygls.server import Server
from pygls.server import JsonRPCServer
from pygls.server import WebSocketTransportAdapter
from websockets.server import serve

Expand All @@ -21,20 +21,20 @@
from .config import PreviewConfig


class WebviewServer(Server):
class WebviewServer(JsonRPCServer):
"""The webview server controlls the webpage hosting the preview.
Used to implement automatic reloads and features like sync scrolling.
"""

lsp: JsonRPCProtocol
protocol: JsonRPCProtocol

def __init__(self, logger: logging.Logger, config: PreviewConfig, *args, **kwargs):
super().__init__(JsonRPCProtocol, default_converter, *args, **kwargs)

self.config = config
self.logger = logger.getChild("WebviewServer")
self.lsp._send_only_body = True
self.protocol._send_only_body = True

self._connected = False
self._ws_server: WebSocketServer | None = None
Expand Down Expand Up @@ -74,13 +74,10 @@ def connected(self) -> bool:
"""Indicates when we have an active connection to the client."""
return self._connected

def feature(self, feature_name: str, options=None):
return self.lsp.fm.feature(feature_name, options)

def reload(self):
"""Reload the current view."""
if self.connected:
self.lsp.notify("view/reload", {})
self.protocol.notify("view/reload", {})

def scroll(self, uri: str, line: int):
"""Called by the editor to scroll the current webview."""
Expand All @@ -93,7 +90,7 @@ def scroll(self, uri: str, line: int):

self._current_uri = uri
self._editor_in_control = asyncio.create_task(self.cooldown("editor"))
self.lsp.notify("view/scroll", {"uri": uri, "line": line})
self.protocol.notify("view/scroll", {"uri": uri, "line": line})

async def cooldown(self, name: str):
"""Create a cooldown."""
Expand Down Expand Up @@ -128,13 +125,13 @@ async def connection(websocket):
loop = asyncio.get_running_loop()
transport = WebSocketTransportAdapter(websocket, loop)

self.lsp.connection_made(transport) # type: ignore[arg-type]
self.protocol.connection_made(transport) # type: ignore[arg-type]
self._connected = True
self.logger.debug("Connected")

async for message in websocket:
self.lsp._procedure_handler(
json.loads(message, object_hook=self.lsp._deserialize_message)
self.protocol._procedure_handler(
json.loads(message, object_hook=self.protocol._deserialize_message)
)

self.logger.debug("Connection lost")
Expand Down Expand Up @@ -168,7 +165,7 @@ def on_scroll(ls: WebviewServer, params):

server._view_in_control = asyncio.create_task(server.cooldown("view"))

esbonio.lsp.show_document(
esbonio.window_show_document(
types.ShowDocumentParams(
uri=params.uri,
external=False,
Expand Down
28 changes: 17 additions & 11 deletions lib/esbonio/esbonio/server/features/sphinx_manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ async def trigger_build(self, uri: Uri):
known_src_uris = await project.get_src_uris()

for src_uri in known_src_uris:
doc = self.server.workspace.get_document(str(src_uri))
doc = self.server.workspace.get_text_document(str(src_uri))
doc_version = doc.version or 0
saved_version = getattr(doc, "saved_version", 0)

Expand All @@ -197,7 +197,9 @@ async def trigger_build(self, uri: Uri):
try:
result = await client.build(content_overrides=content_overrides)
except Exception as exc:
self.server.show_message(f"{exc}", lsp.MessageType.Error)
self.server.window_show_message(
lsp.ShowMessageParams(message=f"{exc}", type=lsp.MessageType.Error)
)
return
finally:
self.stop_progress(client)
Expand Down Expand Up @@ -268,7 +270,7 @@ async def _create_or_replace_client(

# If there was a previous client, stop it.
if (previous_client := self.clients.pop(event.scope, None)) is not None:
self.server.lsp.notify(
self.server.protocol.notify(
"sphinx/clientDestroyed",
ClientDestroyedNotification(id=previous_client.id),
)
Expand All @@ -282,7 +284,7 @@ async def _create_or_replace_client(
self.clients[event.scope] = client = self.client_factory(self, resolved)
client.add_listener("state-change", partial(self._on_state_change, event.scope))

self.server.lsp.notify(
self.server.protocol.notify(
"sphinx/clientCreated",
ClientCreatedNotification(id=client.id, scope=event.scope, config=resolved),
)
Expand All @@ -303,7 +305,7 @@ def _on_state_change(
if old_state == ClientState.Starting and new_state == ClientState.Running:
if (sphinx_info := client.sphinx_info) is not None:
self.project_manager.register_project(scope, client.db)
self.server.lsp.notify(
self.server.protocol.notify(
"sphinx/appCreated",
AppCreatedNotification(id=client.id, application=sphinx_info),
)
Expand All @@ -318,8 +320,10 @@ def _on_state_change(
traceback.format_exception(type(exc), exc, exc.__traceback__)
)

self.server.lsp.show_message(error, lsp.MessageType.Error)
self.server.lsp.notify(
self.server.window_show_message(
lsp.ShowMessageParams(message=error, type=lsp.MessageType.Error)
)
self.server.protocol.notify(
"sphinx/clientErrored",
ClientErroredNotification(id=client.id, error=error, detail=detail),
)
Expand All @@ -331,13 +335,13 @@ async def start_progress(self, client: SphinxClient):
self.logger.debug("Starting progress: '%s'", token)

try:
await self.server.progress.create_async(token)
await self.server.work_done_progress.create_async(token)
except Exception as exc:
self.logger.debug("Unable to create progress token: %s", exc)
return

self._progress_tokens[client.id] = token
self.server.progress.begin(
self.server.work_done_progress.begin(
token,
lsp.WorkDoneProgressBegin(title="sphinx-build", cancellable=False),
)
Expand All @@ -346,7 +350,9 @@ def stop_progress(self, client: SphinxClient):
if (token := self._progress_tokens.pop(client.id, None)) is None:
return

self.server.progress.end(token, lsp.WorkDoneProgressEnd(message="Finished"))
self.server.work_done_progress.end(
token, lsp.WorkDoneProgressEnd(message="Finished")
)

def report_progress(self, client: SphinxClient, progress: types.ProgressParams):
"""Report progress done for the given client."""
Expand All @@ -357,7 +363,7 @@ def report_progress(self, client: SphinxClient, progress: types.ProgressParams):
if (token := self._progress_tokens.get(client.id, None)) is None:
return

self.server.progress.report(
self.server.work_done_progress.report(
token,
lsp.WorkDoneProgressReport(
message=progress.message,
Expand Down
22 changes: 12 additions & 10 deletions lib/esbonio/esbonio/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import cattrs
from lsprotocol import types
from pygls.capabilities import get_capability
from pygls.server import LanguageServer
from pygls.lsp.server import LanguageServer
from pygls.workspace import TextDocument
from pygls.workspace import Workspace

Expand All @@ -35,19 +35,19 @@
class EsbonioWorkspace(Workspace):
"""A modified version of pygls' workspace that ensures uris are always resolved."""

def get_document(self, doc_uri: str) -> TextDocument:
def get_text_document(self, doc_uri: str) -> TextDocument:
uri = str(Uri.parse(doc_uri).resolve())
return super().get_text_document(uri)

def put_document(self, text_document: types.TextDocumentItem):
def put_text_document(self, text_document: types.TextDocumentItem):
text_document.uri = str(Uri.parse(text_document.uri).resolve())
return super().put_text_document(text_document)

def remove_document(self, doc_uri: str):
def remove_text_document(self, doc_uri: str):
doc_uri = str(Uri.parse(doc_uri).resolve())
return super().remove_text_document(doc_uri)

def update_document(
def update_text_document(
self,
text_doc: types.VersionedTextDocumentIdentifier,
change: types.TextDocumentContentChangeEvent,
Expand Down Expand Up @@ -99,7 +99,7 @@ def ready(self) -> asyncio.Future:
@property
def converter(self) -> cattrs.Converter:
"""The cattrs converter instance we should use."""
return self.lsp._converter
return self.protocol._converter

def _finish_task(self, task: asyncio.Task[Any]):
"""Cleanup a finished task."""
Expand Down Expand Up @@ -133,7 +133,7 @@ def initialize(self, params: types.InitializeParams):
self.logger.info("Language client: %s %s", client.name, client.version)

# TODO: Propose patch to pygls for providing custom Workspace implementations.
self.lsp._workspace = EsbonioWorkspace(
self.protocol._workspace = EsbonioWorkspace(
self.workspace.root_uri,
self.workspace._sync_kind,
list(self.workspace.folders.values()),
Expand Down Expand Up @@ -307,7 +307,9 @@ def sync_diagnostics(self) -> None:

for uri, diag_list in diagnostics.items():
self.logger.debug("Publishing %d diagnostics for: %s", len(diag_list), uri)
self.publish_diagnostics(str(uri), diag_list.data)
self.text_document_publish_diagnostics(
types.PublishDiagnosticsParams(uri=str(uri), diagnostics=diag_list.data)
)

async def _register_did_change_watched_files_handler(self):
"""Register the server's handler for ``workspace/didChangeWatchedFiles``."""
Expand All @@ -326,7 +328,7 @@ async def _register_did_change_watched_files_handler(self):
return

try:
await self.register_capability_async(
await self.client_register_capability_async(
types.RegistrationParams(
registrations=[
types.Registration(
Expand Down Expand Up @@ -375,7 +377,7 @@ async def _register_did_change_configuration_handler(self):
return

try:
await self.register_capability_async(
await self.client_register_capability_async(
types.RegistrationParams(
registrations=[
types.Registration(
Expand Down
6 changes: 2 additions & 4 deletions lib/esbonio/esbonio/server/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def on_document_save(
ls: EsbonioLanguageServer, params: types.DidSaveTextDocumentParams
):
# Record the version number of the document
doc = ls.workspace.get_document(params.text_document.uri)
doc = ls.workspace.get_text_document(params.text_document.uri)
doc.saved_version = doc.version or 0

await call_features(ls, "document_save", params)
Expand Down Expand Up @@ -135,9 +135,7 @@ async def on_workspace_diagnostic(
)
)

# Typing issues should be fixed in a future version of lsprotocol
# see: https://github.com/microsoft/lsprotocol/pull/285
return types.WorkspaceDiagnosticReport(items=reports) # type: ignore[arg-type]
return types.WorkspaceDiagnosticReport(items=reports)

@server.feature(types.TEXT_DOCUMENT_DOCUMENT_SYMBOL)
async def on_document_symbol(
Expand Down
11 changes: 11 additions & 0 deletions lib/esbonio/esbonio/sphinx_agent/handlers/webview.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


STATIC_DIR = (pathlib.Path(__file__).parent.parent / "static").resolve()
ALLOWED_MODULES = {"docutils.nodes", "sphinx.addnodes"}


def has_source(node):
Expand All @@ -31,6 +32,16 @@ def has_source(node):
if isinstance(node, addnodes.toctree) and version_info[0] < 7:
return False

# It's not only limited to `toctreenodes`!
#
# The identical error is thrown when using esbonio with the `ablog` extension
# See: https://github.com/swyddfa/esbonio/issues/874
#
# I think for now, the safest approach is to only handle nodes defined by Sphinx or
# docutils.
if node.__module__ not in ALLOWED_MODULES:
return False

return (node.line or 0) > 0 and node.source is not None


Expand Down
7 changes: 6 additions & 1 deletion lib/esbonio/hatch.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[metadata]
allow-direct-references = true

[version]
path = "esbonio/server/server.py"
validate-bump = false
Expand All @@ -10,7 +13,9 @@ packages = ["esbonio"]

[envs.hatch-test]
default-args = ["tests/server"]
extra-dependencies = ["pytest-lsp>=0.3.1,<1"]
extra-dependencies = [
"pytest-lsp @ git+file:///home/alex/Projects/swyddfa/lsp-devtools/pygls-v2/#subdirectory=lib/pytest-lsp",
]
matrix-name-format = "{variable}{value}"

[[envs.hatch-test.matrix]]
Expand Down
2 changes: 1 addition & 1 deletion lib/esbonio/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"aiosqlite",
"platformdirs",
"docutils",
"pygls>=1.1.0",
"pygls @ git+https://github.com/openlawlibrary/pygls@main",
"tomli ; python_version<'3.11'",
"websockets",
]
Expand Down
4 changes: 2 additions & 2 deletions lib/esbonio/tests/e2e/test_e2e_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def test_rst_directive_completions(
types.DidChangeTextDocumentParams(
text_document=types.VersionedTextDocumentIdentifier(uri=uri, version=2),
content_changes=[
types.TextDocumentContentChangeEvent_Type1(
types.TextDocumentContentChangePartial(
text=text,
range=types.Range(
start=types.Position(line=linum, character=0),
Expand Down Expand Up @@ -186,7 +186,7 @@ async def test_myst_directive_completions(
types.DidChangeTextDocumentParams(
text_document=types.VersionedTextDocumentIdentifier(uri=uri, version=2),
content_changes=[
types.TextDocumentContentChangeEvent_Type1(
types.TextDocumentContentChangePartial(
text=text,
range=types.Range(
start=types.Position(line=linum, character=0),
Expand Down
Loading

0 comments on commit e32fa7f

Please sign in to comment.