Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add class instantiation monitor and support for textDocument/documentSymbol request to multilspy #11

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ with lsp.start_server():
result3 = lsp.request_references(
...
)
result4 = lsp.request_document_symbols(
...
)
...
```

Expand Down Expand Up @@ -177,6 +180,9 @@ A monitor under the Monitor-Guided Decoding framework, is instantiated using `mu
#### Switch-Enum Monitor
[src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py) provides the instantiation of `Monitor` for generating valid named enum constants in C#. Unit tests for the switch-enum monitor are present in [tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py](tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py), which also provide usage examples for the switch-enum monitor.

#### Class Instantiation Monitor
[src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py) provides the instantiation of `Monitor` for generating valid class instantiations following `'new '` in a Java code base. Unit tests for the class-instantiation monitor, which provide examples usages are present in [tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py](tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py).

## Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
This module provides the class-instantiation monitor, that is invoked when "new " is typed to instantiate new classes
"""

import os

from pathlib import PurePath
from typing import List
from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor, DecoderStates
from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer
from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper
from monitors4codegen.multilspy.multilspy_utils import TextUtils, FileUtils
from monitors4codegen.multilspy import multilspy_types

class ClassInstantiationMonitor(DereferencesMonitor):
"""
Class Instantiation Monitor that is invoked when "new " is typed to instantiate new classes
"""

def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None:
super().__init__(tokenizer, monitor_file_buffer, responsible_for_file_buffer_state)

async def pre(self) -> None:
cursor_idx = TextUtils.get_index_from_line_col(
self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path),
self.monitor_file_buffer.current_lc[0],
self.monitor_file_buffer.current_lc[1],
)
text_upto_cursor = self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path)[
:cursor_idx
]

# TODO: pre can be improved by checking for "new", and obtaining completions, and then prefixing a whitespace
if not text_upto_cursor.endswith("new "):
self.decoder_state = DecoderStates.S0
return

completions = await self.a_phi()
if len(completions) == 0:
self.decoder_state = DecoderStates.S0
else:
self.decoder_state = DecoderStates.Constrained
self.legal_completions = completions

async def a_phi(self) -> List[str]:
"""
Find the set of classes in the repository
Filter out the set of abstract classes in the repository
Remaining classes are instantiable. Return their names as legal completions
"""

legal_completions: List[str] = []
repository_root_path = self.monitor_file_buffer.lsp.repository_root_path
for path, _, files in os.walk(repository_root_path):
for file in files:
if file.endswith(".java"):
filecontents = FileUtils.read_file(self.monitor_file_buffer.lsp.logger, str(PurePath(path, file)))
relative_file_path = str(PurePath(os.path.relpath(str(PurePath(path, file)), repository_root_path)))
document_symbols, _ = await self.monitor_file_buffer.lsp.request_document_symbols(relative_file_path)
for symbol in document_symbols:
if symbol["kind"] != multilspy_types.SymbolKind.Class:
continue
decl_start_idx = TextUtils.get_index_from_line_col(filecontents, symbol["range"]["start"]["line"], symbol["range"]["start"]["character"])
decl_end_idx = TextUtils.get_index_from_line_col(filecontents, symbol["selectionRange"]["end"]["line"], symbol["selectionRange"]["end"]["character"])
decl_text = filecontents[decl_start_idx:decl_end_idx]
if "abstract" not in decl_text:
legal_completions.append(symbol["name"])

return legal_completions

async def update(self, generated_token: str):
"""
Updates the monitor state based on the generated token
"""
if self.responsible_for_file_buffer_state:
self.monitor_file_buffer.append_text(generated_token)
if self.decoder_state == DecoderStates.Constrained:
for break_char in self.all_break_chars:
if break_char in generated_token:
self.decoder_state = DecoderStates.S0
self.legal_completions = None
return

# No breaking characters found. Continue in constrained state
self.legal_completions = [
legal_completion[len(generated_token) :]
for legal_completion in self.legal_completions
if legal_completion.startswith(generated_token)
]
else:
# Nothing to be done in other states
return
61 changes: 60 additions & 1 deletion src/monitors4codegen/multilspy/language_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .multilspy_exceptions import MultilspyException
from .multilspy_utils import PathUtils, FileUtils, TextUtils
from pathlib import PurePath
from typing import AsyncIterator, Iterator, List, Dict, Union
from typing import AsyncIterator, Iterator, List, Dict, Union, Tuple
from .type_helpers import ensure_all_methods_implemented


Expand Down Expand Up @@ -550,6 +550,51 @@ async def request_completions(
for json_repr in set([json.dumps(item, sort_keys=True) for item in completions_list])
]

async def request_document_symbols(self, relative_file_path: str) -> Tuple[List[multilspy_types.UnifiedSymbolInformation], Union[List[multilspy_types.TreeRepr], None]]:
"""
Raise a [textDocument/documentSymbol](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_documentSymbol) request to the Language Server
to find symbols in the given file. Wait for the response and return the result.

:param relative_file_path: The relative path of the file that has the symbols

:return Tuple[List[multilspy_types.UnifiedSymbolInformation], Union[List[multilspy_types.TreeRepr], None]]: A list of symbols in the file, and the tree representation of the symbols
"""
with self.open_file(relative_file_path):
response = await self.server.send.document_symbol(
{
"textDocument": {
"uri": pathlib.Path(os.path.join(self.repository_root_path, relative_file_path)).as_uri()
}
}
)

ret: List[multilspy_types.UnifiedSymbolInformation] = []
l_tree = None
assert isinstance(response, list)
for item in response:
assert isinstance(item, dict)
assert LSPConstants.NAME in item
assert LSPConstants.KIND in item

if LSPConstants.CHILDREN in item:
# TODO: l_tree should be a list of TreeRepr. Define the following function to return TreeRepr as well

def visit_tree_nodes_and_build_tree_repr(tree: LSPTypes.DocumentSymbol) -> List[multilspy_types.UnifiedSymbolInformation]:
l: List[multilspy_types.UnifiedSymbolInformation] = []
children = tree['children'] if 'children' in tree else []
if 'children' in tree:
del tree['children']
l.append(multilspy_types.UnifiedSymbolInformation(**tree))
for child in children:
l.extend(visit_tree_nodes_and_build_tree_repr(child))
return l

ret.extend(visit_tree_nodes_and_build_tree_repr(item))
else:
ret.append(multilspy_types.UnifiedSymbolInformation(**item))

return ret, l_tree

@ensure_all_methods_implemented(LanguageServer)
class SyncLanguageServer:
"""
Expand Down Expand Up @@ -689,3 +734,17 @@ def request_completions(
self.loop,
).result()
return result

def request_document_symbols(self, relative_file_path: str) -> Tuple[List[multilspy_types.UnifiedSymbolInformation], Union[List[multilspy_types.TreeRepr], None]]:
"""
Raise a [textDocument/documentSymbol](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_documentSymbol) request to the Language Server
to find symbols in the given file. Wait for the response and return the result.

:param relative_file_path: The relative path of the file that has the symbols

:return Tuple[List[multilspy_types.UnifiedSymbolInformation], Union[List[multilspy_types.TreeRepr], None]]: A list of symbols in the file, and the tree representation of the symbols
"""
result = asyncio.run_coroutine_threadsafe(
self.language_server.request_document_symbols(relative_file_path), self.loop
).result()
return result
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,12 @@ class LSPConstants:

# key used to represent the changes made to a document
CONTENT_CHANGES = "contentChanges"

# key used to represent name of symbols
NAME = "name"

# key used to represent the kind of symbols
KIND = "kind"

# key used to represent children in document symbols
CHILDREN = "children"
89 changes: 87 additions & 2 deletions src/monitors4codegen/multilspy/multilspy_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
Defines wrapper objects around the types returned by LSP to ensure decoupling between LSP versions and multilspy
"""

from __future__ import annotations

from enum import IntEnum
from typing import TypedDict
from typing_extensions import NotRequired, TypedDict, List, Dict

URI = str
DocumentUri = str
Expand Down Expand Up @@ -123,4 +125,87 @@ class CompletionItem(TypedDict):

kind: CompletionItemKind
""" The kind of this completion item. Based of the kind
an icon is chosen by the editor. """
an icon is chosen by the editor. """

class SymbolKind(IntEnum):
"""A symbol kind."""

File = 1
Module = 2
Namespace = 3
Package = 4
Class = 5
Method = 6
Property = 7
Field = 8
Constructor = 9
Enum = 10
Interface = 11
Function = 12
Variable = 13
Constant = 14
String = 15
Number = 16
Boolean = 17
Array = 18
Object = 19
Key = 20
Null = 21
EnumMember = 22
Struct = 23
Event = 24
Operator = 25
TypeParameter = 26

class SymbolTag(IntEnum):
"""Symbol tags are extra annotations that tweak the rendering of a symbol.

@since 3.16"""

Deprecated = 1
""" Render a symbol as obsolete, usually using a strike-out. """

class UnifiedSymbolInformation(TypedDict):
"""Represents information about programming constructs like variables, classes,
interfaces etc."""

deprecated: NotRequired[bool]
""" Indicates if this symbol is deprecated.

@deprecated Use tags instead """
location: NotRequired[Location]
""" The location of this symbol. The location's range is used by a tool
to reveal the location in the editor. If the symbol is selected in the
tool the range's start information is used to position the cursor. So
the range usually spans more than the actual symbol's name and does
normally include things like visibility modifiers.

The range doesn't have to denote a node range in the sense of an abstract
syntax tree. It can therefore not be used to re-construct a hierarchy of
the symbols. """
name: str
""" The name of this symbol. """
kind: SymbolKind
""" The kind of this symbol. """
tags: NotRequired[List[SymbolTag]]
""" Tags for this symbol.

@since 3.16.0 """
containerName: NotRequired[str]
""" The name of the symbol containing this symbol. This information is for
user interface purposes (e.g. to render a qualifier in the user interface
if necessary). It can't be used to re-infer a hierarchy for the document
symbols. """

detail: NotRequired[str]
""" More detail for this symbol, e.g the signature of a function. """

range: NotRequired[Range]
""" The range enclosing this symbol not including leading/trailing whitespace but everything else
like comments. This information is typically used to determine if the clients cursor is
inside the symbol to reveal in the symbol in the UI. """
selectionRange: NotRequired[Range]
""" The range that should be selected and revealed when this symbol is being picked, e.g the name of a function.
Must be contained by the `range`. """

TreeRepr = Dict[int, List['TreeRepr']]
Loading