diff --git a/README.md b/README.md index f39a768..eaf2f94 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,9 @@ with lsp.start_server(): result3 = lsp.request_references( ... ) + result4 = lsp.request_document_symbols( + ... + ) ... ``` @@ -177,6 +180,9 @@ A monitor under the Monitor-Guided Decoding framework, is instantiated using `mu #### Switch-Enum Monitor [src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py) provides the instantiation of `Monitor` for generating valid named enum constants in C#. Unit tests for the switch-enum monitor are present in [tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py](tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py), which also provide usage examples for the switch-enum monitor. +#### Class Instantiation Monitor +[src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py) provides the instantiation of `Monitor` for generating valid class instantiations following `'new '` in a Java code base. Unit tests for the class-instantiation monitor, which provide examples usages are present in [tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py](tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py). + ## Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a diff --git a/src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py b/src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py new file mode 100644 index 0000000..fb2096c --- /dev/null +++ b/src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py @@ -0,0 +1,92 @@ +""" +This module provides the class-instantiation monitor, that is invoked when "new " is typed to instantiate new classes +""" + +import os + +from pathlib import PurePath +from typing import List +from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor, DecoderStates +from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer +from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper +from monitors4codegen.multilspy.multilspy_utils import TextUtils, FileUtils +from monitors4codegen.multilspy import multilspy_types + +class ClassInstantiationMonitor(DereferencesMonitor): + """ + Class Instantiation Monitor that is invoked when "new " is typed to instantiate new classes + """ + + def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None: + super().__init__(tokenizer, monitor_file_buffer, responsible_for_file_buffer_state) + + async def pre(self) -> None: + cursor_idx = TextUtils.get_index_from_line_col( + self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path), + self.monitor_file_buffer.current_lc[0], + self.monitor_file_buffer.current_lc[1], + ) + text_upto_cursor = self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path)[ + :cursor_idx + ] + + # TODO: pre can be improved by checking for "new", and obtaining completions, and then prefixing a whitespace + if not text_upto_cursor.endswith("new "): + self.decoder_state = DecoderStates.S0 + return + + completions = await self.a_phi() + if len(completions) == 0: + self.decoder_state = DecoderStates.S0 + else: + self.decoder_state = DecoderStates.Constrained + self.legal_completions = completions + + async def a_phi(self) -> List[str]: + """ + Find the set of classes in the repository + Filter out the set of abstract classes in the repository + Remaining classes are instantiable. Return their names as legal completions + """ + + legal_completions: List[str] = [] + repository_root_path = self.monitor_file_buffer.lsp.repository_root_path + for path, _, files in os.walk(repository_root_path): + for file in files: + if file.endswith(".java"): + filecontents = FileUtils.read_file(self.monitor_file_buffer.lsp.logger, str(PurePath(path, file))) + relative_file_path = str(PurePath(os.path.relpath(str(PurePath(path, file)), repository_root_path))) + document_symbols, _ = await self.monitor_file_buffer.lsp.request_document_symbols(relative_file_path) + for symbol in document_symbols: + if symbol["kind"] != multilspy_types.SymbolKind.Class: + continue + decl_start_idx = TextUtils.get_index_from_line_col(filecontents, symbol["range"]["start"]["line"], symbol["range"]["start"]["character"]) + decl_end_idx = TextUtils.get_index_from_line_col(filecontents, symbol["selectionRange"]["end"]["line"], symbol["selectionRange"]["end"]["character"]) + decl_text = filecontents[decl_start_idx:decl_end_idx] + if "abstract" not in decl_text: + legal_completions.append(symbol["name"]) + + return legal_completions + + async def update(self, generated_token: str): + """ + Updates the monitor state based on the generated token + """ + if self.responsible_for_file_buffer_state: + self.monitor_file_buffer.append_text(generated_token) + if self.decoder_state == DecoderStates.Constrained: + for break_char in self.all_break_chars: + if break_char in generated_token: + self.decoder_state = DecoderStates.S0 + self.legal_completions = None + return + + # No breaking characters found. Continue in constrained state + self.legal_completions = [ + legal_completion[len(generated_token) :] + for legal_completion in self.legal_completions + if legal_completion.startswith(generated_token) + ] + else: + # Nothing to be done in other states + return \ No newline at end of file diff --git a/tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py b/tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py new file mode 100644 index 0000000..eb87feb --- /dev/null +++ b/tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py @@ -0,0 +1,121 @@ +""" +This file contains tests for Monitor-Guided Decoding for valid class instantiations in Java +""" + +import torch +import transformers +import pytest + +from pathlib import PurePath +from monitors4codegen.multilspy.language_server import SyncLanguageServer +from monitors4codegen.multilspy.multilspy_config import Language +from tests.test_utils import create_test_context, is_cuda_available +from transformers import AutoTokenizer, AutoModelForCausalLM +from monitors4codegen.multilspy.multilspy_utils import TextUtils +from monitors4codegen.monitor_guided_decoding.monitors.class_instantiation_monitor import ClassInstantiationMonitor +from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer +from monitors4codegen.monitor_guided_decoding.hf_gen import MGDLogitsProcessor +from transformers.generation.utils import LogitsProcessorList +from monitors4codegen.multilspy.multilspy_types import Position +from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import HFTokenizerWrapper + +pytest_plugins = ("pytest_asyncio",) + +@pytest.mark.asyncio +async def test_multilspy_java_example_repo_class_instantiation() -> None: + """ + Test the working of ClassInstantiationMonitor with Java repository - ExampleRepo + """ + code_language = Language.JAVA + params = { + "code_language": code_language, + "repo_url": "https://github.com/LakshyAAAgrawal/ExampleRepo/", + "repo_commit": "f3762fd55a457ff9c6b0bf3b266de2b203a766ab", + } + + device = torch.device('cuda' if is_cuda_available() else 'cpu') + + model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained( + "bigcode/santacoder", trust_remote_code=True + ).to(device) + tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder") + + with create_test_context(params) as context: + lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) + with lsp.start_server(): + completions_filepath = "Main.java" + with lsp.open_file(completions_filepath): + deleted_text = lsp.delete_text_between_positions( + completions_filepath, + Position(line=16, character=24), + Position(line=36, character=5) + ) + assert deleted_text == """Student("Alice", 10); + Person p2 = new Teacher("Bob", "Science"); + + // Create some course objects + Course c1 = new Course("Math 101", t1, mathStudents); + Course c2 = new Course("English 101", t2, englishStudents); + + // Print some information about the objects + + System.out.println("Person p1's name is " + p1.getName()); + + System.out.println("Student s1's name is " + s1.getName()); + System.out.println("Student s1's id is " + s1.getId()); + + System.out.println("Teacher t1's name is " + t1.getName()); + System.out.println("Teacher t1's subject is " + t1.getSubject()); + + System.out.println("Course c1's name is " + c1.getName()); + System.out.println("Course c1's teacher is " + c1.getTeacher().getName()); + + """ + prompt_pos = (16, 24) + + with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f: + filecontent = f.read() + + pos_idx = TextUtils.get_index_from_line_col(filecontent, prompt_pos[0], prompt_pos[1]) + assert filecontent[:pos_idx].endswith('new ') + + prompt = filecontent[:pos_idx] + assert filecontent[pos_idx-1] == " " + prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :] + + generated_code_without_mgd = model.generate( + prompt_tokenized, do_sample=False, max_new_tokens=30, early_stopping=True + ) + generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -30:]) + + assert ( + generated_code_without_mgd + == " Person(\"John\", \"Doe\", \"123-4567\", \"kenaa@example.com\", \"1234" + ) + + filebuffer = MonitorFileBuffer( + lsp.language_server, + completions_filepath, + prompt_pos, + prompt_pos, + code_language, + ) + monitor = ClassInstantiationMonitor(HFTokenizerWrapper(tokenizer), filebuffer) + mgd_logits_processor = MGDLogitsProcessor([monitor], lsp.language_server.server.loop) + + # Generate code using santacoder model with the MGD logits processor and greedy decoding + logits_processor = LogitsProcessorList([mgd_logits_processor]) + generated_code = model.generate( + prompt_tokenized, + do_sample=False, + max_new_tokens=30, + logits_processor=logits_processor, + early_stopping=True, + ) + + generated_code = tokenizer.decode(generated_code[0, -30:]) + + assert ( + generated_code + == "Student(\"John\", 1001);\n Person p2 = new Student(\"Mary\", 1002);\n Person p" + )