diff --git a/README.md b/README.md index f39a768..eaf2f94 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,9 @@ with lsp.start_server(): result3 = lsp.request_references( ... ) + result4 = lsp.request_document_symbols( + ... + ) ... ``` @@ -177,6 +180,9 @@ A monitor under the Monitor-Guided Decoding framework, is instantiated using `mu #### Switch-Enum Monitor [src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py) provides the instantiation of `Monitor` for generating valid named enum constants in C#. Unit tests for the switch-enum monitor are present in [tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py](tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py), which also provide usage examples for the switch-enum monitor. +#### Class Instantiation Monitor +[src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py) provides the instantiation of `Monitor` for generating valid class instantiations following `'new '` in a Java code base. Unit tests for the class-instantiation monitor, which provide examples usages are present in [tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py](tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py). + ## Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a diff --git a/src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py b/src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py new file mode 100644 index 0000000..fb2096c --- /dev/null +++ b/src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py @@ -0,0 +1,92 @@ +""" +This module provides the class-instantiation monitor, that is invoked when "new " is typed to instantiate new classes +""" + +import os + +from pathlib import PurePath +from typing import List +from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor, DecoderStates +from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer +from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper +from monitors4codegen.multilspy.multilspy_utils import TextUtils, FileUtils +from monitors4codegen.multilspy import multilspy_types + +class ClassInstantiationMonitor(DereferencesMonitor): + """ + Class Instantiation Monitor that is invoked when "new " is typed to instantiate new classes + """ + + def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None: + super().__init__(tokenizer, monitor_file_buffer, responsible_for_file_buffer_state) + + async def pre(self) -> None: + cursor_idx = TextUtils.get_index_from_line_col( + self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path), + self.monitor_file_buffer.current_lc[0], + self.monitor_file_buffer.current_lc[1], + ) + text_upto_cursor = self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path)[ + :cursor_idx + ] + + # TODO: pre can be improved by checking for "new", and obtaining completions, and then prefixing a whitespace + if not text_upto_cursor.endswith("new "): + self.decoder_state = DecoderStates.S0 + return + + completions = await self.a_phi() + if len(completions) == 0: + self.decoder_state = DecoderStates.S0 + else: + self.decoder_state = DecoderStates.Constrained + self.legal_completions = completions + + async def a_phi(self) -> List[str]: + """ + Find the set of classes in the repository + Filter out the set of abstract classes in the repository + Remaining classes are instantiable. Return their names as legal completions + """ + + legal_completions: List[str] = [] + repository_root_path = self.monitor_file_buffer.lsp.repository_root_path + for path, _, files in os.walk(repository_root_path): + for file in files: + if file.endswith(".java"): + filecontents = FileUtils.read_file(self.monitor_file_buffer.lsp.logger, str(PurePath(path, file))) + relative_file_path = str(PurePath(os.path.relpath(str(PurePath(path, file)), repository_root_path))) + document_symbols, _ = await self.monitor_file_buffer.lsp.request_document_symbols(relative_file_path) + for symbol in document_symbols: + if symbol["kind"] != multilspy_types.SymbolKind.Class: + continue + decl_start_idx = TextUtils.get_index_from_line_col(filecontents, symbol["range"]["start"]["line"], symbol["range"]["start"]["character"]) + decl_end_idx = TextUtils.get_index_from_line_col(filecontents, symbol["selectionRange"]["end"]["line"], symbol["selectionRange"]["end"]["character"]) + decl_text = filecontents[decl_start_idx:decl_end_idx] + if "abstract" not in decl_text: + legal_completions.append(symbol["name"]) + + return legal_completions + + async def update(self, generated_token: str): + """ + Updates the monitor state based on the generated token + """ + if self.responsible_for_file_buffer_state: + self.monitor_file_buffer.append_text(generated_token) + if self.decoder_state == DecoderStates.Constrained: + for break_char in self.all_break_chars: + if break_char in generated_token: + self.decoder_state = DecoderStates.S0 + self.legal_completions = None + return + + # No breaking characters found. Continue in constrained state + self.legal_completions = [ + legal_completion[len(generated_token) :] + for legal_completion in self.legal_completions + if legal_completion.startswith(generated_token) + ] + else: + # Nothing to be done in other states + return \ No newline at end of file diff --git a/src/monitors4codegen/multilspy/language_server.py b/src/monitors4codegen/multilspy/language_server.py index 1a6f734..995d381 100644 --- a/src/monitors4codegen/multilspy/language_server.py +++ b/src/monitors4codegen/multilspy/language_server.py @@ -26,7 +26,7 @@ from .multilspy_exceptions import MultilspyException from .multilspy_utils import PathUtils, FileUtils, TextUtils from pathlib import PurePath -from typing import AsyncIterator, Iterator, List, Dict, Union +from typing import AsyncIterator, Iterator, List, Dict, Union, Tuple from .type_helpers import ensure_all_methods_implemented @@ -550,6 +550,51 @@ async def request_completions( for json_repr in set([json.dumps(item, sort_keys=True) for item in completions_list]) ] + async def request_document_symbols(self, relative_file_path: str) -> Tuple[List[multilspy_types.UnifiedSymbolInformation], Union[List[multilspy_types.TreeRepr], None]]: + """ + Raise a [textDocument/documentSymbol](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_documentSymbol) request to the Language Server + to find symbols in the given file. Wait for the response and return the result. + + :param relative_file_path: The relative path of the file that has the symbols + + :return Tuple[List[multilspy_types.UnifiedSymbolInformation], Union[List[multilspy_types.TreeRepr], None]]: A list of symbols in the file, and the tree representation of the symbols + """ + with self.open_file(relative_file_path): + response = await self.server.send.document_symbol( + { + "textDocument": { + "uri": pathlib.Path(os.path.join(self.repository_root_path, relative_file_path)).as_uri() + } + } + ) + + ret: List[multilspy_types.UnifiedSymbolInformation] = [] + l_tree = None + assert isinstance(response, list) + for item in response: + assert isinstance(item, dict) + assert LSPConstants.NAME in item + assert LSPConstants.KIND in item + + if LSPConstants.CHILDREN in item: + # TODO: l_tree should be a list of TreeRepr. Define the following function to return TreeRepr as well + + def visit_tree_nodes_and_build_tree_repr(tree: LSPTypes.DocumentSymbol) -> List[multilspy_types.UnifiedSymbolInformation]: + l: List[multilspy_types.UnifiedSymbolInformation] = [] + children = tree['children'] if 'children' in tree else [] + if 'children' in tree: + del tree['children'] + l.append(multilspy_types.UnifiedSymbolInformation(**tree)) + for child in children: + l.extend(visit_tree_nodes_and_build_tree_repr(child)) + return l + + ret.extend(visit_tree_nodes_and_build_tree_repr(item)) + else: + ret.append(multilspy_types.UnifiedSymbolInformation(**item)) + + return ret, l_tree + @ensure_all_methods_implemented(LanguageServer) class SyncLanguageServer: """ @@ -689,3 +734,17 @@ def request_completions( self.loop, ).result() return result + + def request_document_symbols(self, relative_file_path: str) -> Tuple[List[multilspy_types.UnifiedSymbolInformation], Union[List[multilspy_types.TreeRepr], None]]: + """ + Raise a [textDocument/documentSymbol](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_documentSymbol) request to the Language Server + to find symbols in the given file. Wait for the response and return the result. + + :param relative_file_path: The relative path of the file that has the symbols + + :return Tuple[List[multilspy_types.UnifiedSymbolInformation], Union[List[multilspy_types.TreeRepr], None]]: A list of symbols in the file, and the tree representation of the symbols + """ + result = asyncio.run_coroutine_threadsafe( + self.language_server.request_document_symbols(relative_file_path), self.loop + ).result() + return result \ No newline at end of file diff --git a/src/monitors4codegen/multilspy/lsp_protocol_handler/lsp_constants.py b/src/monitors4codegen/multilspy/lsp_protocol_handler/lsp_constants.py index ceede5f..6026802 100644 --- a/src/monitors4codegen/multilspy/lsp_protocol_handler/lsp_constants.py +++ b/src/monitors4codegen/multilspy/lsp_protocol_handler/lsp_constants.py @@ -48,3 +48,12 @@ class LSPConstants: # key used to represent the changes made to a document CONTENT_CHANGES = "contentChanges" + + # key used to represent name of symbols + NAME = "name" + + # key used to represent the kind of symbols + KIND = "kind" + + # key used to represent children in document symbols + CHILDREN = "children" diff --git a/src/monitors4codegen/multilspy/multilspy_types.py b/src/monitors4codegen/multilspy/multilspy_types.py index e9f3563..659a288 100644 --- a/src/monitors4codegen/multilspy/multilspy_types.py +++ b/src/monitors4codegen/multilspy/multilspy_types.py @@ -2,8 +2,10 @@ Defines wrapper objects around the types returned by LSP to ensure decoupling between LSP versions and multilspy """ +from __future__ import annotations + from enum import IntEnum -from typing import TypedDict +from typing_extensions import NotRequired, TypedDict, List, Dict URI = str DocumentUri = str @@ -123,4 +125,87 @@ class CompletionItem(TypedDict): kind: CompletionItemKind """ The kind of this completion item. Based of the kind - an icon is chosen by the editor. """ \ No newline at end of file + an icon is chosen by the editor. """ + +class SymbolKind(IntEnum): + """A symbol kind.""" + + File = 1 + Module = 2 + Namespace = 3 + Package = 4 + Class = 5 + Method = 6 + Property = 7 + Field = 8 + Constructor = 9 + Enum = 10 + Interface = 11 + Function = 12 + Variable = 13 + Constant = 14 + String = 15 + Number = 16 + Boolean = 17 + Array = 18 + Object = 19 + Key = 20 + Null = 21 + EnumMember = 22 + Struct = 23 + Event = 24 + Operator = 25 + TypeParameter = 26 + +class SymbolTag(IntEnum): + """Symbol tags are extra annotations that tweak the rendering of a symbol. + + @since 3.16""" + + Deprecated = 1 + """ Render a symbol as obsolete, usually using a strike-out. """ + +class UnifiedSymbolInformation(TypedDict): + """Represents information about programming constructs like variables, classes, + interfaces etc.""" + + deprecated: NotRequired[bool] + """ Indicates if this symbol is deprecated. + + @deprecated Use tags instead """ + location: NotRequired[Location] + """ The location of this symbol. The location's range is used by a tool + to reveal the location in the editor. If the symbol is selected in the + tool the range's start information is used to position the cursor. So + the range usually spans more than the actual symbol's name and does + normally include things like visibility modifiers. + + The range doesn't have to denote a node range in the sense of an abstract + syntax tree. It can therefore not be used to re-construct a hierarchy of + the symbols. """ + name: str + """ The name of this symbol. """ + kind: SymbolKind + """ The kind of this symbol. """ + tags: NotRequired[List[SymbolTag]] + """ Tags for this symbol. + + @since 3.16.0 """ + containerName: NotRequired[str] + """ The name of the symbol containing this symbol. This information is for + user interface purposes (e.g. to render a qualifier in the user interface + if necessary). It can't be used to re-infer a hierarchy for the document + symbols. """ + + detail: NotRequired[str] + """ More detail for this symbol, e.g the signature of a function. """ + + range: NotRequired[Range] + """ The range enclosing this symbol not including leading/trailing whitespace but everything else + like comments. This information is typically used to determine if the clients cursor is + inside the symbol to reveal in the symbol in the UI. """ + selectionRange: NotRequired[Range] + """ The range that should be selected and revealed when this symbol is being picked, e.g the name of a function. + Must be contained by the `range`. """ + +TreeRepr = Dict[int, List['TreeRepr']] diff --git a/tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py b/tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py new file mode 100644 index 0000000..eb87feb --- /dev/null +++ b/tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py @@ -0,0 +1,121 @@ +""" +This file contains tests for Monitor-Guided Decoding for valid class instantiations in Java +""" + +import torch +import transformers +import pytest + +from pathlib import PurePath +from monitors4codegen.multilspy.language_server import SyncLanguageServer +from monitors4codegen.multilspy.multilspy_config import Language +from tests.test_utils import create_test_context, is_cuda_available +from transformers import AutoTokenizer, AutoModelForCausalLM +from monitors4codegen.multilspy.multilspy_utils import TextUtils +from monitors4codegen.monitor_guided_decoding.monitors.class_instantiation_monitor import ClassInstantiationMonitor +from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer +from monitors4codegen.monitor_guided_decoding.hf_gen import MGDLogitsProcessor +from transformers.generation.utils import LogitsProcessorList +from monitors4codegen.multilspy.multilspy_types import Position +from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import HFTokenizerWrapper + +pytest_plugins = ("pytest_asyncio",) + +@pytest.mark.asyncio +async def test_multilspy_java_example_repo_class_instantiation() -> None: + """ + Test the working of ClassInstantiationMonitor with Java repository - ExampleRepo + """ + code_language = Language.JAVA + params = { + "code_language": code_language, + "repo_url": "https://github.com/LakshyAAAgrawal/ExampleRepo/", + "repo_commit": "f3762fd55a457ff9c6b0bf3b266de2b203a766ab", + } + + device = torch.device('cuda' if is_cuda_available() else 'cpu') + + model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained( + "bigcode/santacoder", trust_remote_code=True + ).to(device) + tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder") + + with create_test_context(params) as context: + lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) + with lsp.start_server(): + completions_filepath = "Main.java" + with lsp.open_file(completions_filepath): + deleted_text = lsp.delete_text_between_positions( + completions_filepath, + Position(line=16, character=24), + Position(line=36, character=5) + ) + assert deleted_text == """Student("Alice", 10); + Person p2 = new Teacher("Bob", "Science"); + + // Create some course objects + Course c1 = new Course("Math 101", t1, mathStudents); + Course c2 = new Course("English 101", t2, englishStudents); + + // Print some information about the objects + + System.out.println("Person p1's name is " + p1.getName()); + + System.out.println("Student s1's name is " + s1.getName()); + System.out.println("Student s1's id is " + s1.getId()); + + System.out.println("Teacher t1's name is " + t1.getName()); + System.out.println("Teacher t1's subject is " + t1.getSubject()); + + System.out.println("Course c1's name is " + c1.getName()); + System.out.println("Course c1's teacher is " + c1.getTeacher().getName()); + + """ + prompt_pos = (16, 24) + + with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f: + filecontent = f.read() + + pos_idx = TextUtils.get_index_from_line_col(filecontent, prompt_pos[0], prompt_pos[1]) + assert filecontent[:pos_idx].endswith('new ') + + prompt = filecontent[:pos_idx] + assert filecontent[pos_idx-1] == " " + prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :] + + generated_code_without_mgd = model.generate( + prompt_tokenized, do_sample=False, max_new_tokens=30, early_stopping=True + ) + generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -30:]) + + assert ( + generated_code_without_mgd + == " Person(\"John\", \"Doe\", \"123-4567\", \"kenaa@example.com\", \"1234" + ) + + filebuffer = MonitorFileBuffer( + lsp.language_server, + completions_filepath, + prompt_pos, + prompt_pos, + code_language, + ) + monitor = ClassInstantiationMonitor(HFTokenizerWrapper(tokenizer), filebuffer) + mgd_logits_processor = MGDLogitsProcessor([monitor], lsp.language_server.server.loop) + + # Generate code using santacoder model with the MGD logits processor and greedy decoding + logits_processor = LogitsProcessorList([mgd_logits_processor]) + generated_code = model.generate( + prompt_tokenized, + do_sample=False, + max_new_tokens=30, + logits_processor=logits_processor, + early_stopping=True, + ) + + generated_code = tokenizer.decode(generated_code[0, -30:]) + + assert ( + generated_code + == "Student(\"John\", 1001);\n Person p2 = new Student(\"Mary\", 1002);\n Person p" + ) diff --git a/tests/multilspy/test_multilspy_java.py b/tests/multilspy/test_multilspy_java.py index ac15360..03b72f9 100644 --- a/tests/multilspy/test_multilspy_java.py +++ b/tests/multilspy/test_multilspy_java.py @@ -206,3 +206,142 @@ async def test_multilspy_java_clickhouse_highlevel_sinker_modified(): completions = await lsp.request_completions(completions_filepath, 136, 23) completions = [completion["completionText"] for completion in completions if completion["kind"] == CompletionItemKind.Constructor] assert completions == ['ClickHouseSinkBuffer'] + +@pytest.mark.asyncio +async def test_multilspy_java_example_repo_document_symbols() -> None: + """ + Test the working of multilspy with Java repository - clickhouse-highlevel-sinker + """ + code_language = Language.JAVA + params = { + "code_language": code_language, + "repo_url": "https://github.com/LakshyAAAgrawal/ExampleRepo/", + "repo_commit": "f3762fd55a457ff9c6b0bf3b266de2b203a766ab", + } + with create_test_context(params) as context: + lsp = LanguageServer.create(context.config, context.logger, context.source_directory) + + # All the communication with the language server must be performed inside the context manager + # The server process is started when the context manager is entered and is terminated when the context manager is exited. + async with lsp.start_server(): + filepath = str(PurePath("Person.java")) + result = await lsp.request_document_symbols(filepath) + + assert result == ( + [ + { + "name": "Person", + "kind": 5, + "range": { + "start": {"line": 0, "character": 0}, + "end": {"line": 14, "character": 1}, + }, + "selectionRange": { + "start": {"line": 1, "character": 22}, + "end": {"line": 1, "character": 28}, + }, + "detail": "", + }, + { + "name": "name", + "kind": 8, + "range": { + "start": {"line": 2, "character": 4}, + "end": {"line": 3, "character": 24}, + }, + "selectionRange": { + "start": {"line": 3, "character": 19}, + "end": {"line": 3, "character": 23}, + }, + "detail": "", + }, + { + "name": "Person(String)", + "kind": 9, + "range": { + "start": {"line": 5, "character": 4}, + "end": {"line": 8, "character": 5}, + }, + "selectionRange": { + "start": {"line": 6, "character": 11}, + "end": {"line": 6, "character": 17}, + }, + "detail": "", + }, + { + "name": "getName()", + "kind": 6, + "range": { + "start": {"line": 10, "character": 4}, + "end": {"line": 13, "character": 5}, + }, + "selectionRange": { + "start": {"line": 11, "character": 18}, + "end": {"line": 11, "character": 25}, + }, + "detail": " : String", + }, + ], + None, + ) + + filepath = str(PurePath("Student.java")) + result = await lsp.request_document_symbols(filepath) + + assert result == ( + [ + { + "name": "Student", + "kind": 5, + "range": { + "start": {"line": 0, "character": 0}, + "end": {"line": 16, "character": 1}, + }, + "selectionRange": { + "start": {"line": 1, "character": 13}, + "end": {"line": 1, "character": 20}, + }, + "detail": "", + }, + { + "name": "id", + "kind": 8, + "range": { + "start": {"line": 2, "character": 4}, + "end": {"line": 3, "character": 19}, + }, + "selectionRange": { + "start": {"line": 3, "character": 16}, + "end": {"line": 3, "character": 18}, + }, + "detail": "", + }, + { + "name": "Student(String, int)", + "kind": 9, + "range": { + "start": {"line": 5, "character": 4}, + "end": {"line": 10, "character": 5}, + }, + "selectionRange": { + "start": {"line": 6, "character": 11}, + "end": {"line": 6, "character": 18}, + }, + "detail": "", + }, + { + "name": "getId()", + "kind": 6, + "range": { + "start": {"line": 12, "character": 4}, + "end": {"line": 15, "character": 5}, + }, + "selectionRange": { + "start": {"line": 13, "character": 15}, + "end": {"line": 13, "character": 20}, + }, + "detail": " : int", + }, + ], + None, + )