-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add class instantiation monitor and unit tests
- Loading branch information
1 parent
4c0bf39
commit 19fcef7
Showing
3 changed files
with
219 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 92 additions & 0 deletions
92
src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
""" | ||
This module provides the class-instantiation monitor, that is invoked when "new " is typed to instantiate new classes | ||
""" | ||
|
||
import os | ||
|
||
from pathlib import PurePath | ||
from typing import List | ||
from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor, DecoderStates | ||
from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer | ||
from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper | ||
from monitors4codegen.multilspy.multilspy_utils import TextUtils, FileUtils | ||
from monitors4codegen.multilspy import multilspy_types | ||
|
||
class ClassInstantiationMonitor(DereferencesMonitor): | ||
""" | ||
Class Instantiation Monitor that is invoked when "new " is typed to instantiate new classes | ||
""" | ||
|
||
def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None: | ||
super().__init__(tokenizer, monitor_file_buffer, responsible_for_file_buffer_state) | ||
|
||
async def pre(self) -> None: | ||
cursor_idx = TextUtils.get_index_from_line_col( | ||
self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path), | ||
self.monitor_file_buffer.current_lc[0], | ||
self.monitor_file_buffer.current_lc[1], | ||
) | ||
text_upto_cursor = self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path)[ | ||
:cursor_idx | ||
] | ||
|
||
# TODO: pre can be improved by checking for "new", and obtaining completions, and then prefixing a whitespace | ||
if not text_upto_cursor.endswith("new "): | ||
self.decoder_state = DecoderStates.S0 | ||
return | ||
|
||
completions = await self.a_phi() | ||
if len(completions) == 0: | ||
self.decoder_state = DecoderStates.S0 | ||
else: | ||
self.decoder_state = DecoderStates.Constrained | ||
self.legal_completions = completions | ||
|
||
async def a_phi(self) -> List[str]: | ||
""" | ||
Find the set of classes in the repository | ||
Filter out the set of abstract classes in the repository | ||
Remaining classes are instantiable. Return their names as legal completions | ||
""" | ||
|
||
legal_completions: List[str] = [] | ||
repository_root_path = self.monitor_file_buffer.lsp.repository_root_path | ||
for path, _, files in os.walk(repository_root_path): | ||
for file in files: | ||
if file.endswith(".java"): | ||
filecontents = FileUtils.read_file(self.monitor_file_buffer.lsp.logger, str(PurePath(path, file))) | ||
relative_file_path = str(PurePath(os.path.relpath(str(PurePath(path, file)), repository_root_path))) | ||
document_symbols, _ = await self.monitor_file_buffer.lsp.request_document_symbols(relative_file_path) | ||
for symbol in document_symbols: | ||
if symbol["kind"] != multilspy_types.SymbolKind.Class: | ||
continue | ||
decl_start_idx = TextUtils.get_index_from_line_col(filecontents, symbol["range"]["start"]["line"], symbol["range"]["start"]["character"]) | ||
decl_end_idx = TextUtils.get_index_from_line_col(filecontents, symbol["selectionRange"]["end"]["line"], symbol["selectionRange"]["end"]["character"]) | ||
decl_text = filecontents[decl_start_idx:decl_end_idx] | ||
if "abstract" not in decl_text: | ||
legal_completions.append(symbol["name"]) | ||
|
||
return legal_completions | ||
|
||
async def update(self, generated_token: str): | ||
""" | ||
Updates the monitor state based on the generated token | ||
""" | ||
if self.responsible_for_file_buffer_state: | ||
self.monitor_file_buffer.append_text(generated_token) | ||
if self.decoder_state == DecoderStates.Constrained: | ||
for break_char in self.all_break_chars: | ||
if break_char in generated_token: | ||
self.decoder_state = DecoderStates.S0 | ||
self.legal_completions = None | ||
return | ||
|
||
# No breaking characters found. Continue in constrained state | ||
self.legal_completions = [ | ||
legal_completion[len(generated_token) :] | ||
for legal_completion in self.legal_completions | ||
if legal_completion.startswith(generated_token) | ||
] | ||
else: | ||
# Nothing to be done in other states | ||
return |
121 changes: 121 additions & 0 deletions
121
tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
""" | ||
This file contains tests for Monitor-Guided Decoding for valid class instantiations in Java | ||
""" | ||
|
||
import torch | ||
import transformers | ||
import pytest | ||
|
||
from pathlib import PurePath | ||
from monitors4codegen.multilspy.language_server import SyncLanguageServer | ||
from monitors4codegen.multilspy.multilspy_config import Language | ||
from tests.test_utils import create_test_context, is_cuda_available | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
from monitors4codegen.multilspy.multilspy_utils import TextUtils | ||
from monitors4codegen.monitor_guided_decoding.monitors.class_instantiation_monitor import ClassInstantiationMonitor | ||
from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer | ||
from monitors4codegen.monitor_guided_decoding.hf_gen import MGDLogitsProcessor | ||
from transformers.generation.utils import LogitsProcessorList | ||
from monitors4codegen.multilspy.multilspy_types import Position | ||
from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import HFTokenizerWrapper | ||
|
||
pytest_plugins = ("pytest_asyncio",) | ||
|
||
@pytest.mark.asyncio | ||
async def test_multilspy_java_example_repo_class_instantiation() -> None: | ||
""" | ||
Test the working of ClassInstantiationMonitor with Java repository - ExampleRepo | ||
""" | ||
code_language = Language.JAVA | ||
params = { | ||
"code_language": code_language, | ||
"repo_url": "https://github.com/LakshyAAAgrawal/ExampleRepo/", | ||
"repo_commit": "f3762fd55a457ff9c6b0bf3b266de2b203a766ab", | ||
} | ||
|
||
device = torch.device('cuda' if is_cuda_available() else 'cpu') | ||
|
||
model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained( | ||
"bigcode/santacoder", trust_remote_code=True | ||
).to(device) | ||
tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder") | ||
|
||
with create_test_context(params) as context: | ||
lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) | ||
with lsp.start_server(): | ||
completions_filepath = "Main.java" | ||
with lsp.open_file(completions_filepath): | ||
deleted_text = lsp.delete_text_between_positions( | ||
completions_filepath, | ||
Position(line=16, character=24), | ||
Position(line=36, character=5) | ||
) | ||
assert deleted_text == """Student("Alice", 10); | ||
Person p2 = new Teacher("Bob", "Science"); | ||
// Create some course objects | ||
Course c1 = new Course("Math 101", t1, mathStudents); | ||
Course c2 = new Course("English 101", t2, englishStudents); | ||
// Print some information about the objects | ||
System.out.println("Person p1's name is " + p1.getName()); | ||
System.out.println("Student s1's name is " + s1.getName()); | ||
System.out.println("Student s1's id is " + s1.getId()); | ||
System.out.println("Teacher t1's name is " + t1.getName()); | ||
System.out.println("Teacher t1's subject is " + t1.getSubject()); | ||
System.out.println("Course c1's name is " + c1.getName()); | ||
System.out.println("Course c1's teacher is " + c1.getTeacher().getName()); | ||
""" | ||
prompt_pos = (16, 24) | ||
|
||
with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f: | ||
filecontent = f.read() | ||
|
||
pos_idx = TextUtils.get_index_from_line_col(filecontent, prompt_pos[0], prompt_pos[1]) | ||
assert filecontent[:pos_idx].endswith('new ') | ||
|
||
prompt = filecontent[:pos_idx] | ||
assert filecontent[pos_idx-1] == " " | ||
prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :] | ||
|
||
generated_code_without_mgd = model.generate( | ||
prompt_tokenized, do_sample=False, max_new_tokens=30, early_stopping=True | ||
) | ||
generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -30:]) | ||
|
||
assert ( | ||
generated_code_without_mgd | ||
== " Person(\"John\", \"Doe\", \"123-4567\", \"[email protected]\", \"1234" | ||
) | ||
|
||
filebuffer = MonitorFileBuffer( | ||
lsp.language_server, | ||
completions_filepath, | ||
prompt_pos, | ||
prompt_pos, | ||
code_language, | ||
) | ||
monitor = ClassInstantiationMonitor(HFTokenizerWrapper(tokenizer), filebuffer) | ||
mgd_logits_processor = MGDLogitsProcessor([monitor], lsp.language_server.server.loop) | ||
|
||
# Generate code using santacoder model with the MGD logits processor and greedy decoding | ||
logits_processor = LogitsProcessorList([mgd_logits_processor]) | ||
generated_code = model.generate( | ||
prompt_tokenized, | ||
do_sample=False, | ||
max_new_tokens=30, | ||
logits_processor=logits_processor, | ||
early_stopping=True, | ||
) | ||
|
||
generated_code = tokenizer.decode(generated_code[0, -30:]) | ||
|
||
assert ( | ||
generated_code | ||
== "Student(\"John\", 1001);\n Person p2 = new Student(\"Mary\", 1002);\n Person p" | ||
) |