From 246681fc2dd760ded17285a1f7b5a91b0fa7b6db Mon Sep 17 00:00:00 2001 From: Lakshya A Agrawal Date: Mon, 6 Nov 2023 00:43:40 +0000 Subject: [PATCH] Add Monitor Guided Decoding - dereferences monitor --- README.md | 12 +- .../dereferences_monitor.py | 244 ++++++++++++++++++ monitor_guided_decoding/mgd_utils.py | 84 ++++++ monitor_guided_decoding/monitor.py | 140 ++++++++++ monitor_guided_decoding/tokenizer_wrapper.py | 136 ++++++++++ .../test_dereferences_monitor_java.py | 235 +++++++++++++++++ 6 files changed, 845 insertions(+), 6 deletions(-) create mode 100644 monitor_guided_decoding/dereferences_monitor.py create mode 100644 monitor_guided_decoding/mgd_utils.py create mode 100644 monitor_guided_decoding/monitor.py create mode 100644 monitor_guided_decoding/tokenizer_wrapper.py create mode 100644 tests/monitor_guided_decoding/test_dereferences_monitor_java.py diff --git a/README.md b/README.md index b9fe116..bcdf7e2 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,11 @@ Some of the analyses results that `multilspy` can provide are: - Finding the callers of a function or the instantiations of a class ([textDocument/references](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_references)) - Providing type-based dereference completions ([textDocument/completion](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_completion)) +The file [multilspy/language_server.py](multilspy/language_server.py) provides the `multilspy` API. Several tests for `multilspy` present under [tests/multilspy/](tests/multilspy/) provide detailed usage examples for `multilspy`. The tests can be executed by running: +```bash +pytest tests/multilspy +``` + Example usage: ```python from multilspy import SyncLanguageServer @@ -140,14 +145,9 @@ async with lsp.start_server(): ... ``` -Several tests for `multilspy` present under [tests/multilspy/](tests/multilspy/) provide detailed usage examples for `multilspy`. The tests can be executed by running: -```bash -pytest tests/multilspy -``` - ## 5. Monitor-Guided Decoding -Coming Soon... +A monitor under the Monitor-Guided Decoding framework, is instantiated using `multilspy` as the LSP client, and as a logits-processor to guide the LM decoding. [monitor_guided_decoding/monitor.py](monitor_guided_decoding/monitor.py) provides the class `MGDLogitsProcessor` which can be used with any HuggingFace Language Model, as a `LogitsProcessor` to guide the LM using MGD. [monitor_guided_decoding/dereferences_monitor.py](monitor_guided_decoding/dereferences_monitor.py) provides the instantiation for dereferences monitor. Unit tests for the dereferences monitor are present in [tests/monitor_guided_decoding/test_dereferences_monitor_java.py](tests/monitor_guided_decoding/test_dereferences_monitor_java.py), which also provide usage examples for the dereferences monitor. ## Contributing diff --git a/monitor_guided_decoding/dereferences_monitor.py b/monitor_guided_decoding/dereferences_monitor.py new file mode 100644 index 0000000..dee212d --- /dev/null +++ b/monitor_guided_decoding/dereferences_monitor.py @@ -0,0 +1,244 @@ +""" +This module provides the instantiation of dereferences monitor +""" + +import code_tokenize as ctok + +from typing import List, Union, Set +from enum import Enum + +from multilspy import multilspy_types +from multilspy.multilspy_config import Language +from multilspy.multilspy_utils import TextUtils +from multilspy.multilspy_types import Position + +from monitor_guided_decoding.monitor import Monitor, MonitorFileBuffer +from monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper +from monitor_guided_decoding.mgd_utils import PLUtils + + +class DecoderStates(Enum): + """ + Enum for the state of the decoder + """ + + UnInitialized = 0 + S0 = 1 + Constrained = 2 + + +class DereferencesMonitor(Monitor): + """ + Provides the dereferences monitor + """ + + def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None: + super().__init__(tokenizer, monitor_file_buffer, responsible_for_file_buffer_state) + self.decoder_state = DecoderStates.UnInitialized + self.all_break_chars = DereferencesMonitor.find_all_break_chars( + self.tokenizer.tokenizer_char_set, monitor_file_buffer.language + ) + self.prompt_len: Union[None, int] = None + self.legal_completions: Union[List[str], None] = None + + @staticmethod + def find_all_break_chars(charset: Set[str], language: Language) -> Set[str]: + """ + Finds the set of characters, which when appended to the end of an identifier, cause the identifier to be terminated + For example "," is a breaking character, since "abc," is 2 tokens, "abc" and "," + On the other hand, "a" is not a breaking character, since "abca" is a single token + """ + all_break_chars: Set[str] = set() + for vocab_char in charset: + toks: List[ctok.tokens.ASTToken] = PLUtils.tokenizer_pl("abc" + vocab_char, language) + toks = [t for t in toks if t.text != ""] + if len(toks) == 0 or toks[0].text == "abc": + all_break_chars.add(vocab_char) + return all_break_chars + + async def initialize(self, input_ids: List[int], input_text: str) -> None: + """ + Initializes the monitor when it is invoked for the first time with inputs + """ + self.prompt_len = len(input_ids) + await self.pre() + + async def pre(self) -> None: + """ + Checks if the static analysis should be performed at this point. + In case of dereferences monitor, the last character shuold be a dot + """ + cursor_idx = TextUtils.get_index_from_line_col( + self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path), + self.monitor_file_buffer.current_lc[0], + self.monitor_file_buffer.current_lc[1], + ) + text_upto_cursor = self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path)[ + :cursor_idx + ] + if text_upto_cursor[-1] != ".": + self.decoder_state = DecoderStates.S0 + return + + completions = await self.a_phi() + + if len(completions) == 0: + self.decoder_state = DecoderStates.S0 + else: + self.decoder_state = DecoderStates.Constrained + self.legal_completions = completions + + async def maskgen(self, input_ids: List[int]) -> List[int]: + """ + Takes the list of input tokens, and returns the list of tokens to be blacklisted in the next step + + maskgen is invoked for every new token to be generated + The first time it is invoked, maskgen performs the initialization + Subsequent invocations are handled based on the current state of the decoder + """ + + input_text = self.tokenizer.decode( + input_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=True, + ) + + if self.decoder_state == DecoderStates.UnInitialized: + # Handle initialization. This is the first time monitor is being invoked + await self.initialize(input_ids, input_text) + else: + # A new token has been generated. Handle the new token by calling update + gen_so_far = self.tokenizer.decode( + input_ids[self.prompt_len :], clean_up_tokenization_spaces=False, skip_special_tokens=True + ) + assert gen_so_far.startswith(self.monitor_file_buffer.gen_text), (gen_so_far, self.monitor_file_buffer.gen_text) + assert input_text.endswith(gen_so_far) + new_gen_text = gen_so_far[len(self.monitor_file_buffer.gen_text) :] + + await self.update(new_gen_text) + + if self.decoder_state == DecoderStates.S0: + # If the decoder is in state S0, then need to check pre() + # pre() will determine if the decoder should transition to state S1 + # If so, it invokes a_phi() and transitions the decoder state + await self.pre() + + if self.decoder_state == DecoderStates.Constrained: + # If the decoder is in state S1, then generate the set of blacklisted tokens + # based on the current state of the monitor, legal_completions + possible_token_ids: Set[int] = set() + for legal_suffix in self.legal_completions: + # If a token contains part of the end of an identifier followed by some breaking characters, then we allow it + # allow decoding of tokens like 'abc<', basically tokens that span identifier boundaries + if self.tokenizer.vocab_trie.has_node(legal_suffix) != 0: + for suffix_token, suffix_token_id in self.tokenizer.vocab_trie.iteritems(prefix=legal_suffix): + if suffix_token[len(legal_suffix) : len(legal_suffix) + 1] in self.all_break_chars: + possible_token_ids.add(suffix_token_id) + + # If a token is a prefix of the remaining suffix, then we allow it + for suffix_token, suffix_token_id in self.tokenizer.vocab_trie.prefixes(legal_suffix): + possible_token_ids.add(suffix_token_id) + + blacklisted_ids = [i for i in self.tokenizer.all_token_ids if i not in possible_token_ids] + else: + blacklisted_ids = [] + + return blacklisted_ids + + async def a_phi(self) -> List[str]: + """ + Uses multilspy to perform static analysis and returns the list of type-compliant dereferences + at the current cursor position (which ends with a dot) + """ + relative_file_path = self.monitor_file_buffer.file_path + line, column = self.monitor_file_buffer.current_lc + + with self.monitor_file_buffer.lsp.open_file(relative_file_path): + legal_completions1 = await self.monitor_file_buffer.lsp.request_completions( + relative_file_path, line, column, allow_incomplete=True + ) + legal_completions1 = [ + completion["completionText"] + for completion in legal_completions1 + if completion["kind"] != multilspy_types.CompletionItemKind.Keyword + ] + lsp_text = self.monitor_file_buffer.lsp.get_open_file_text(relative_file_path) + request_idx = TextUtils.get_index_from_line_col(lsp_text, line, column) + opening_bracket_stream = PLUtils.get_opening_bracket_stream( + lsp_text[:request_idx], self.monitor_file_buffer.language + ) + if len(opening_bracket_stream) == 0: + return legal_completions1 + + closing_bracket_stream = PLUtils.get_closing_bracket_stream( + lsp_text[request_idx:], self.monitor_file_buffer.language + ) + if len(opening_bracket_stream) <= len(closing_bracket_stream): + return legal_completions1 + + err = False + for j in range(len(closing_bracket_stream)): + if closing_bracket_stream[-j - 1] == "}" and opening_bracket_stream[j] != "{": + err = True + break + elif closing_bracket_stream[-j - 1] == ")" and opening_bracket_stream[j] != "(": + err = True + break + elif closing_bracket_stream[-j - 1] == "]" and opening_bracket_stream[j] != "[": + err = True + break + + if err: + return legal_completions1 + + opening_bracket_stream = opening_bracket_stream[len(closing_bracket_stream) :] + remaining_close_brackets = list( + map(lambda x: "}" if x == "{" else (")" if x == "(" else "]"), opening_bracket_stream) + )[::-1] + + insert_text = "".join(remaining_close_brackets) + updated_cursor = self.monitor_file_buffer.lsp.insert_text_at_position( + relative_file_path, line, column, insert_text + ) + assert updated_cursor["line"] == line + assert updated_cursor["character"] == column + len(insert_text) + legal_completions2 = await self.monitor_file_buffer.lsp.request_completions( + relative_file_path, line, column, allow_incomplete=True + ) + legal_completions2 = [ + completion["completionText"] + for completion in legal_completions2 + if completion["kind"] != multilspy_types.CompletionItemKind.Keyword + ] + + deleted_text = self.monitor_file_buffer.lsp.delete_text_between_positions( + relative_file_path, + Position(line=line, character=column), + Position(line=line, character=column + len(insert_text)), + ) + assert deleted_text == insert_text + + return list(set(legal_completions1 + legal_completions2)) + + async def update(self, generated_token: str): + """ + Updates the monitor state based on the generated token + """ + if self.responsible_for_file_buffer_state: + self.monitor_file_buffer.append_text(generated_token) + if self.decoder_state == DecoderStates.Constrained: + for break_char in self.all_break_chars: + if break_char in generated_token: + self.decoder_state = DecoderStates.S0 + self.legal_completions = None + return + + # No breaking characters found. Continue in constrained state + self.legal_completions = [ + legal_completion[len(generated_token) :] + for legal_completion in self.legal_completions + if legal_completion.startswith(generated_token) + ] + else: + # Nothing to be done in other states + return diff --git a/monitor_guided_decoding/mgd_utils.py b/monitor_guided_decoding/mgd_utils.py new file mode 100644 index 0000000..dceb141 --- /dev/null +++ b/monitor_guided_decoding/mgd_utils.py @@ -0,0 +1,84 @@ +""" +This module provides the utility functions for handling programming language text +""" + +import code_tokenize as ctok + +from typing import List +from multilspy.multilspy_config import Language + + +class PLUtils: + """ + This class provides various utility functions for handling programming language text + """ + + @staticmethod + def tokenizer_pl(inp_text: str, lang: Language) -> List[ctok.tokens.ASTToken]: + """ + Tokenizes the given text using code_tokenize + """ + lang_s = str(lang) if lang != Language.CSHARP else "c-sharp" + if inp_text.strip() == "": + return [] + lsp_text_lang_tokenized: List[ctok.tokens.ASTToken] = ctok.tokenize( + inp_text, lang=lang_s, syntax_error="ignore" + ) + lsp_text_lang_tokenized: List[ctok.tokens.ASTToken] = [tok for tok in lsp_text_lang_tokenized if tok.text != ""] + return lsp_text_lang_tokenized + + @staticmethod + def get_opening_bracket_stream(inp_text: str, lang: Language) -> List[str]: + """ + Returns the list of opened brackets in the given text + """ + bracket_stream: List[str] = [] + err = False + lsp_text_lang_tokenized = PLUtils.tokenizer_pl(inp_text, lang) + for tok in lsp_text_lang_tokenized: + if tok.type in ["{", "(", "["]: + bracket_stream.append(tok.type) + elif tok.type in ["}", ")", "]"]: + if len(bracket_stream) == 0: + err = True + break + if ( + (tok.type == "}" and bracket_stream[-1] == "{") + or (tok.type == ")" and bracket_stream[-1] == "(") + or (tok.type == "]" and bracket_stream[-1] == "[") + ): + bracket_stream.pop() + else: + err = True + break + if err: + raise Exception("Invalid bracket stream") + return bracket_stream + + @staticmethod + def get_closing_bracket_stream(inp_text: str, lang: Language) -> List[str]: + """ + Returns the list of closing brackets in the given text + """ + bracket_stream: List[str] = [] + err = False + lsp_text_lang_tokenized = PLUtils.tokenizer_pl(inp_text, lang) + for tok in lsp_text_lang_tokenized[::-1]: + if tok.type in ["}", ")", "]"]: + bracket_stream.append(tok.type) + elif tok.type in ["{", "(", "["]: + if len(bracket_stream) == 0: + err = True + break + if ( + (tok.type == "{" and bracket_stream[-1] == "}") + or (tok.type == "(" and bracket_stream[-1] == ")") + or (tok.type == "[" and bracket_stream[-1] == "]") + ): + bracket_stream.pop() + else: + err = True + break + if err: + raise Exception("Invalid bracket stream") + return bracket_stream[::-1] diff --git a/monitor_guided_decoding/monitor.py b/monitor_guided_decoding/monitor.py new file mode 100644 index 0000000..3c77bbe --- /dev/null +++ b/monitor_guided_decoding/monitor.py @@ -0,0 +1,140 @@ +""" +Provides the definition of a monitor as per the Monitor-Guided Decoding framework +""" + +import asyncio +import torch + +from asyncio.events import AbstractEventLoop +from typing import List, Tuple, Union +from transformers import LogitsProcessor +from monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper +from multilspy import LanguageServer +from multilspy.multilspy_config import Language +from dataclasses import dataclass +from multilspy.multilspy_utils import TextUtils + + +@dataclass +class MonitorFileBuffer: + """ + Dataclass for storing the state of the monitor for the prompt file in which the generation is happening + """ + + lsp: LanguageServer + file_path: str + prompt_lc: Tuple[int, int] + current_lc: Tuple[int, int] + language: Language + gen_text: str = "" + + def append_text(self, text: str): + """ + Appends the given text to the prompt file and returns the new line and character + """ + current_lc_index = TextUtils.get_index_from_line_col( + self.lsp.get_open_file_text(self.file_path), self.current_lc[0], self.current_lc[1] + ) + new_lc = self.lsp.insert_text_at_position(self.file_path, self.current_lc[0], self.current_lc[1], text) + self.current_lc = (new_lc["line"], new_lc["character"]) + self.gen_text += text + assert current_lc_index + len(text) == TextUtils.get_index_from_line_col( + self.lsp.get_open_file_text(self.file_path), self.current_lc[0], self.current_lc[1] + ) + + +class Monitor: + """ + Provides the definition of a monitor as per the Monitor-Guided Decoding framework + """ + + def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None: + self.tokenizer = tokenizer + self.monitor_file_buffer = monitor_file_buffer + self.responsible_for_file_buffer_state = responsible_for_file_buffer_state + + async def pre(self) -> None: + """ + If the current state is uninitialized, or s0, this function checks + if the static analysis should be performed at this point. + If yes, it invokes the static analysis, and updates the state. + """ + raise NotImplementedError() + + async def maskgen(self, input_ids: List[int]) -> List[int]: + """ + Given input_ids, which is the list of token ids generated so far (or input for the first time), + this function returns the list of token ids that should be masked for the next token generation. + + This is the function that is invoked by the end user at every token decodes. + """ + raise NotImplementedError() + + def a_phi(self): + """ + This function defines the implementation of the static analysis, + and returns the result of the static analysis. + It is invoked primarily by pre() + """ + raise NotImplementedError() + + def update(self, generated_token: str): + """ + This function updates the state of the monitor, given the generated token. + """ + raise NotImplementedError() + + +class MGDLogitsProcessor(LogitsProcessor): + """ + Provides the logits processor for monitor guided decoding + """ + + loop: AbstractEventLoop + + def __init__(self, monitors: List[Monitor], loop: Union[None, AbstractEventLoop] = None) -> None: + super().__init__() + + if loop is None: + self.loop = asyncio.get_event_loop() + else: + self.loop = loop + + self.monitors: List[Monitor] = monitors + + async def process_scores_for_single_input_id( + self, segment_idx: int, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + """ + Asynchronously processes the scores for a single input id using the MGD framework + """ + blacklisted_ids: List[int] = await self.monitors[segment_idx].maskgen(input_ids.tolist()) + output_scores: torch.FloatTensor = torch.where( + torch.tensor([True if i in blacklisted_ids else False for i in range(scores.shape[0])]).to(scores.device), + float("-inf") * torch.ones(scores.shape[0]).to(scores.device), + scores, + ).to(scores) + return output_scores + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + """ + This method is called by the HuggingFace decoder, for every token generation with + the input_ids (seen so far including prompt) and scores (for the next token). + This method processes the scores using the MGD framework. + """ + assert len(input_ids.shape) == 2 + assert input_ids.shape[0] == len(self.monitors) + assert len(scores.shape) == 2 + + async def f(input_ids_arg: torch.LongTensor, scores_arg: torch.FloatTensor): + new_score_coroutines = [ + self.process_scores_for_single_input_id(i, input_ids_arg[i], scores_arg[i]) + for i in range(input_ids_arg.shape[0]) + ] + new_scores = await asyncio.gather(*new_score_coroutines) + return tuple(new_scores) + + future = asyncio.run_coroutine_threadsafe(f(input_ids, scores), self.loop) + results = future.result() + new_scores = torch.stack(results, dim=0).to(scores) + return new_scores diff --git a/monitor_guided_decoding/tokenizer_wrapper.py b/monitor_guided_decoding/tokenizer_wrapper.py new file mode 100644 index 0000000..ac0d517 --- /dev/null +++ b/monitor_guided_decoding/tokenizer_wrapper.py @@ -0,0 +1,136 @@ +""" +This file provides the tokenizer wrapper that is used to provide a common interface over +HF tokenizers and TikToken tokenizers +""" + +import torch +import tiktoken + +from typing import List, Union +from pygtrie import CharTrie +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + + +class TokenizerWrapper: + """ + This class provides a common interface over HF tokenizers and TikToken tokenizers + """ + + def __init__(self, tokenizer: Union[PreTrainedTokenizerBase, tiktoken.core.Encoding]): + """ + Initializes the tokenizer wrapper + """ + self.tokenizer = tokenizer + self.vocab_trie = CharTrie() + self.tokenizer_char_set: Set[str] = set() + self.all_token_ids: Set[int] = set() + + def decode(self, *args, **kwargs) -> str: + """ + Decodes the given token ids to a string + + Params: + token_ids, clean_up_tokenization_spaces, skip_special_tokens + """ + raise NotImplementedError() + + def convert_ids_to_tokens(self, x) -> List[str]: + """ + Converts the given token ids to a list of tokens + """ + raise NotImplementedError() + + def convert_tokens_to_string(self, x) -> str: + """ + Converts the given list of tokens to a string + """ + raise NotImplementedError() + + +class HFTokenizerWrapper(TokenizerWrapper): + """ + This class provides an instance of TokenizerWrapper for HF tokenizers + """ + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.__dict__.update(tokenizer.__dict__) + for k, v in tokenizer.vocab.items(): + decoded_token = tokenizer.decode(v, clean_up_tokenization_spaces=False, skip_special_tokens=True) + if decoded_token != "": + self.tokenizer_char_set.update(decoded_token) + self.vocab_trie[decoded_token] = v + self.all_token_ids = set(tokenizer.vocab.values()) + + def decode(self, *args, **kwargs) -> str: + """ + Decodes the given token ids to a string + """ + return self.tokenizer.decode(*args, **kwargs) + + def convert_ids_to_tokens(self, x) -> List[str]: + """ + Converts the given token ids to a list of tokens + """ + return self.tokenizer.convert_ids_to_tokens(x) + + def convert_tokens_to_string(self, x) -> str: + """ + Converts the given list of tokens to a string + """ + return self.tokenizer.convert_tokens_to_string(x) + + +class TikTokenWrapper(TokenizerWrapper): + """ + This class provides an instance of TokenizerWrapper for TikToken tokenizers + """ + def __init__(self, tokenizer: tiktoken.core.Encoding): + super().__init__(tokenizer) + + assert len(tokenizer.special_tokens_set) == 1 + self.all_special_ids = {tokenizer.encode_single_token(token) for token in tokenizer.special_tokens_set} + for k_ in tokenizer.token_byte_values(): + v = tokenizer.encode_single_token(k_) + decoded_token = tokenizer.decode([tokenizer.encode_single_token(k_)]) + if decoded_token != "": + self.tokenizer_char_set.update(decoded_token) + self.vocab_trie[decoded_token] = v + self.all_token_ids.add(v) + + def decode(self, token_ids: torch.Tensor, *args, **kwargs) -> str: + """ + Decodes the given token ids to a string + """ + token_ids, clean_up_tokenization_spaces, skip_special_tokens = None, None, None + if len(args) == 0: + pass + elif len(args) == 1: + skip_special_tokens: bool = args[0] + elif len(args) == 2: + skip_special_tokens, clean_up_tokenization_spaces = args[0], args[1] + + if clean_up_tokenization_spaces is None: + clean_up_tokenization_spaces = kwargs.get("clean_up_tokenization_spaces", True) + if skip_special_tokens is None: + skip_special_tokens = kwargs.get("skip_special_tokens", False) + + assert not clean_up_tokenization_spaces + assert skip_special_tokens + assert isinstance(token_ids, torch.Tensor) + token_ids = token_ids.tolist() + + token_ids = [i for i in token_ids if i not in self.all_special_ids] + + return self.tokenizer.decode(token_ids) + + def convert_ids_to_tokens(self, x) -> List[str]: + """ + Converts the given token ids to a list of tokens + """ + return [self.tokenizer.decode([i]) for i in x] + + def convert_tokens_to_string(self, x) -> str: + """ + Converts the given list of tokens to a string + """ + return "".join(x) diff --git a/tests/monitor_guided_decoding/test_dereferences_monitor_java.py b/tests/monitor_guided_decoding/test_dereferences_monitor_java.py new file mode 100644 index 0000000..22ee4f0 --- /dev/null +++ b/tests/monitor_guided_decoding/test_dereferences_monitor_java.py @@ -0,0 +1,235 @@ +""" +This file contains tests for Monitor-Guided Decoding for dereferences in Java +""" +import transformers +import pytest + +from pathlib import PurePath +from multilspy.language_server import SyncLanguageServer +from multilspy.multilspy_config import Language +from tests.test_utils import create_test_context +from transformers import AutoTokenizer, AutoModelForCausalLM +from multilspy.multilspy_utils import TextUtils +from monitor_guided_decoding.dereferences_monitor import DereferencesMonitor +from monitor_guided_decoding.monitor import MonitorFileBuffer, MGDLogitsProcessor +from transformers.generation.utils import LogitsProcessorList +from multilspy.multilspy_types import Position +from monitor_guided_decoding.tokenizer_wrapper import HFTokenizerWrapper + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_multilspy_java_clickhouse_highlevel_sinker() -> None: + """ + Test the working of dereferences monitor with Java repository - clickhouse-highlevel-sinker + """ + code_language = Language.JAVA + params = { + "code_language": code_language, + "repo_url": "https://github.com/Index103000/clickhouse-highlevel-sinker/", + "repo_commit": "ee31d278918fe5e64669a6840c4d8fb53889e573", + } + + model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained( + "bigcode/santacoder", trust_remote_code=True + ).cuda() + tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder") + + with create_test_context(params) as context: + lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) + filepath = "src/main/java/com/xlvchao/clickhouse/datasource/ClickHouseDataSource.java" + # All the communication with the language server must be performed inside the context manager + # The server process is started when the context manager is entered and is terminated when the context manager is exited. + with lsp.start_server(): + with lsp.open_file(filepath): + filebuffer = MonitorFileBuffer(lsp.language_server, filepath, (74, 17), (74, 17), code_language, "") + deleted_text = filebuffer.lsp.delete_text_between_positions( + filepath, Position(line=74, character=17), Position(line=78, character=4) + ) + assert ( + deleted_text + == """newServerNode() + .withIp(arr[0]) + .withPort(Integer.parseInt(arr[1])) + .build(); + """ + ) + monitor = DereferencesMonitor(HFTokenizerWrapper(tokenizer), filebuffer) + mgd_logits_processor = MGDLogitsProcessor([monitor], lsp.language_server.server.loop) + + with open(str(PurePath(context.source_directory, filepath)), "r") as f: + filecontent = f.read() + + pos_idx = TextUtils.get_index_from_line_col(filecontent, 74, 17) + assert filecontent[pos_idx] == "n" + prompt = filecontent[:pos_idx] + assert prompt[-1] == "." + prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :] + + generated_code_without_mgd = model.generate( + prompt_tokenized, do_sample=False, max_new_tokens=30, early_stopping=True + ) + generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -30:]) + assert ( + generated_code_without_mgd + == "builder()\n .hostAddress(arr[0])\n .port(Integer.parseInt(arr[1]))\n .build();\n }\n\n " + ) + + # Generate code using santacoder model with the MGD logits processor and greedy decoding + logits_processor = LogitsProcessorList([mgd_logits_processor]) + generated_code = model.generate( + prompt_tokenized, + do_sample=False, + max_new_tokens=30, + logits_processor=logits_processor, + early_stopping=True, + ) + + generated_code = tokenizer.decode(generated_code[0, -30:]) + assert ( + generated_code + == "newServerNode()\n .withIp(arr[0])\n .withPort(Integer.parseInt(arr[1]))\n .build();" + ) + + +@pytest.mark.asyncio +async def test_multilspy_java_clickhouse_highlevel_sinker_modified(): + """ + Test the working of dereferences monitor with Java repository - clickhouse-highlevel-sinker modified + """ + code_language = Language.JAVA + params = { + "code_language": code_language, + "repo_url": "https://github.com/LakshyAAAgrawal/clickhouse-highlevel-sinker/", + "repo_commit": "5775fd7a67e7b60998e1614cf44a8a1fc3190ab0" + } + + model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained( + "bigcode/santacoder", trust_remote_code=True + ).cuda() + tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder") + + with create_test_context(params) as context: + lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) + # All the communication with the language server must be performed inside the context manager + # The server process is started when the context manager is entered and is terminated when the context manager is exited. + # The context manager is an asynchronous context manager, so it must be used with async with. + with lsp.start_server(): + completions_filepath = "src/main/java/com/xlvchao/clickhouse/datasource/ClickHouseDataSource.java" + with lsp.open_file(completions_filepath): + deleted_text = lsp.delete_text_between_positions( + completions_filepath, + Position(line=75, character=17), + Position(line=77, character=4) + ) + assert deleted_text == """withIpPort(arr[0], Integer.parseInt(arr[1])) + .build(); + """ + + prompt_pos = (75, 17) + + with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f: + filecontent = f.read() + + pos_idx = TextUtils.get_index_from_line_col(filecontent, prompt_pos[0], prompt_pos[1]) + assert filecontent[pos_idx] == "w" + prompt = filecontent[:pos_idx] + assert prompt[-1] == "." + prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :] + + gen = model.generate( + prompt_tokenized, do_sample=False, max_new_tokens=30, early_stopping=True + ) + generated_code_without_mgd = tokenizer.decode(gen[0, -30:]) + + assert ( + generated_code_without_mgd == + "hostAddress(arr[0])\n .port(Integer.parseInt(arr[1]))\n .build();\n }\n\n private List" + ) + + filebuffer = MonitorFileBuffer( + lsp.language_server, + completions_filepath, + prompt_pos, + prompt_pos, + code_language, + ) + monitor = DereferencesMonitor(HFTokenizerWrapper(tokenizer), filebuffer) + mgd_logits_processor = MGDLogitsProcessor([monitor], lsp.language_server.server.loop) + + # Generate code using santacoder model with the MGD logits processor and greedy decoding + logits_processor = LogitsProcessorList([mgd_logits_processor]) + gen = model.generate( + prompt_tokenized, + do_sample=False, + max_new_tokens=30, + logits_processor=logits_processor, + early_stopping=True, + ) + + generated_code = tokenizer.decode(gen[0, -30:]) + + assert ( + generated_code == + "arr[0])\n .withIpPort(arr[1])\n .build();\n }\n\n private List convertToList(" + )