Skip to content

Commit

Permalink
Add class instantiation monitor and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
LakshyAAAgrawal committed Nov 21, 2023
1 parent 4c0bf39 commit 19fcef7
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 0 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ with lsp.start_server():
result3 = lsp.request_references(
...
)
result4 = lsp.request_document_symbols(
...
)
...
```

Expand Down Expand Up @@ -177,6 +180,9 @@ A monitor under the Monitor-Guided Decoding framework, is instantiated using `mu
#### Switch-Enum Monitor
[src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py) provides the instantiation of `Monitor` for generating valid named enum constants in C#. Unit tests for the switch-enum monitor are present in [tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py](tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py), which also provide usage examples for the switch-enum monitor.

#### Class Instantiation Monitor
[src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py) provides the instantiation of `Monitor` for generating valid class instantiations following `'new '` in a Java code base. Unit tests for the class-instantiation monitor, which provide examples usages are present in [tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py](tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py).

## Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
This module provides the class-instantiation monitor, that is invoked when "new " is typed to instantiate new classes
"""

import os

from pathlib import PurePath
from typing import List
from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor, DecoderStates
from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer
from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper
from monitors4codegen.multilspy.multilspy_utils import TextUtils, FileUtils
from monitors4codegen.multilspy import multilspy_types

class ClassInstantiationMonitor(DereferencesMonitor):
"""
Class Instantiation Monitor that is invoked when "new " is typed to instantiate new classes
"""

def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None:
super().__init__(tokenizer, monitor_file_buffer, responsible_for_file_buffer_state)

async def pre(self) -> None:
cursor_idx = TextUtils.get_index_from_line_col(
self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path),
self.monitor_file_buffer.current_lc[0],
self.monitor_file_buffer.current_lc[1],
)
text_upto_cursor = self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path)[
:cursor_idx
]

# TODO: pre can be improved by checking for "new", and obtaining completions, and then prefixing a whitespace
if not text_upto_cursor.endswith("new "):
self.decoder_state = DecoderStates.S0
return

completions = await self.a_phi()
if len(completions) == 0:
self.decoder_state = DecoderStates.S0
else:
self.decoder_state = DecoderStates.Constrained
self.legal_completions = completions

async def a_phi(self) -> List[str]:
"""
Find the set of classes in the repository
Filter out the set of abstract classes in the repository
Remaining classes are instantiable. Return their names as legal completions
"""

legal_completions: List[str] = []
repository_root_path = self.monitor_file_buffer.lsp.repository_root_path
for path, _, files in os.walk(repository_root_path):
for file in files:
if file.endswith(".java"):
filecontents = FileUtils.read_file(self.monitor_file_buffer.lsp.logger, str(PurePath(path, file)))
relative_file_path = str(PurePath(os.path.relpath(str(PurePath(path, file)), repository_root_path)))
document_symbols, _ = await self.monitor_file_buffer.lsp.request_document_symbols(relative_file_path)
for symbol in document_symbols:
if symbol["kind"] != multilspy_types.SymbolKind.Class:
continue
decl_start_idx = TextUtils.get_index_from_line_col(filecontents, symbol["range"]["start"]["line"], symbol["range"]["start"]["character"])
decl_end_idx = TextUtils.get_index_from_line_col(filecontents, symbol["selectionRange"]["end"]["line"], symbol["selectionRange"]["end"]["character"])
decl_text = filecontents[decl_start_idx:decl_end_idx]
if "abstract" not in decl_text:
legal_completions.append(symbol["name"])

return legal_completions

async def update(self, generated_token: str):
"""
Updates the monitor state based on the generated token
"""
if self.responsible_for_file_buffer_state:
self.monitor_file_buffer.append_text(generated_token)
if self.decoder_state == DecoderStates.Constrained:
for break_char in self.all_break_chars:
if break_char in generated_token:
self.decoder_state = DecoderStates.S0
self.legal_completions = None
return

# No breaking characters found. Continue in constrained state
self.legal_completions = [
legal_completion[len(generated_token) :]
for legal_completion in self.legal_completions
if legal_completion.startswith(generated_token)
]
else:
# Nothing to be done in other states
return
121 changes: 121 additions & 0 deletions tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
This file contains tests for Monitor-Guided Decoding for valid class instantiations in Java
"""

import torch
import transformers
import pytest

from pathlib import PurePath
from monitors4codegen.multilspy.language_server import SyncLanguageServer
from monitors4codegen.multilspy.multilspy_config import Language
from tests.test_utils import create_test_context, is_cuda_available
from transformers import AutoTokenizer, AutoModelForCausalLM
from monitors4codegen.multilspy.multilspy_utils import TextUtils
from monitors4codegen.monitor_guided_decoding.monitors.class_instantiation_monitor import ClassInstantiationMonitor
from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer
from monitors4codegen.monitor_guided_decoding.hf_gen import MGDLogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from monitors4codegen.multilspy.multilspy_types import Position
from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import HFTokenizerWrapper

pytest_plugins = ("pytest_asyncio",)

@pytest.mark.asyncio
async def test_multilspy_java_example_repo_class_instantiation() -> None:
"""
Test the working of ClassInstantiationMonitor with Java repository - ExampleRepo
"""
code_language = Language.JAVA
params = {
"code_language": code_language,
"repo_url": "https://github.com/LakshyAAAgrawal/ExampleRepo/",
"repo_commit": "f3762fd55a457ff9c6b0bf3b266de2b203a766ab",
}

device = torch.device('cuda' if is_cuda_available() else 'cpu')

model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained(
"bigcode/santacoder", trust_remote_code=True
).to(device)
tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder")

with create_test_context(params) as context:
lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory)
with lsp.start_server():
completions_filepath = "Main.java"
with lsp.open_file(completions_filepath):
deleted_text = lsp.delete_text_between_positions(
completions_filepath,
Position(line=16, character=24),
Position(line=36, character=5)
)
assert deleted_text == """Student("Alice", 10);
Person p2 = new Teacher("Bob", "Science");
// Create some course objects
Course c1 = new Course("Math 101", t1, mathStudents);
Course c2 = new Course("English 101", t2, englishStudents);
// Print some information about the objects
System.out.println("Person p1's name is " + p1.getName());
System.out.println("Student s1's name is " + s1.getName());
System.out.println("Student s1's id is " + s1.getId());
System.out.println("Teacher t1's name is " + t1.getName());
System.out.println("Teacher t1's subject is " + t1.getSubject());
System.out.println("Course c1's name is " + c1.getName());
System.out.println("Course c1's teacher is " + c1.getTeacher().getName());
"""
prompt_pos = (16, 24)

with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f:
filecontent = f.read()

pos_idx = TextUtils.get_index_from_line_col(filecontent, prompt_pos[0], prompt_pos[1])
assert filecontent[:pos_idx].endswith('new ')

prompt = filecontent[:pos_idx]
assert filecontent[pos_idx-1] == " "
prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :]

generated_code_without_mgd = model.generate(
prompt_tokenized, do_sample=False, max_new_tokens=30, early_stopping=True
)
generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -30:])

assert (
generated_code_without_mgd
== " Person(\"John\", \"Doe\", \"123-4567\", \"[email protected]\", \"1234"
)

filebuffer = MonitorFileBuffer(
lsp.language_server,
completions_filepath,
prompt_pos,
prompt_pos,
code_language,
)
monitor = ClassInstantiationMonitor(HFTokenizerWrapper(tokenizer), filebuffer)
mgd_logits_processor = MGDLogitsProcessor([monitor], lsp.language_server.server.loop)

# Generate code using santacoder model with the MGD logits processor and greedy decoding
logits_processor = LogitsProcessorList([mgd_logits_processor])
generated_code = model.generate(
prompt_tokenized,
do_sample=False,
max_new_tokens=30,
logits_processor=logits_processor,
early_stopping=True,
)

generated_code = tokenizer.decode(generated_code[0, -30:])

assert (
generated_code
== "Student(\"John\", 1001);\n Person p2 = new Student(\"Mary\", 1002);\n Person p"
)

0 comments on commit 19fcef7

Please sign in to comment.