diff --git a/README.md b/README.md index 7a5f77c..f39a768 100644 --- a/README.md +++ b/README.md @@ -174,6 +174,9 @@ A monitor under the Monitor-Guided Decoding framework, is instantiated using `mu #### Dereferences Monitor [src/monitors4codegen/monitor_guided_decoding/monitors/dereferences_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/dereferences_monitor.py) provides the instantiation of `Monitor` class for dereferences monitor. It can be used to guide LMs to generate valid identifier dereferences. Unit tests for the dereferences monitor are present in [tests/monitor_guided_decoding/test_dereferences_monitor_java.py](tests/monitor_guided_decoding/test_dereferences_monitor_java.py), which also provide usage examples for the dereferences monitor. +#### Switch-Enum Monitor +[src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py) provides the instantiation of `Monitor` for generating valid named enum constants in C#. Unit tests for the switch-enum monitor are present in [tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py](tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py), which also provide usage examples for the switch-enum monitor. + ## Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a diff --git a/pyproject.toml b/pyproject.toml index 54ddd87..2f53f57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "jedi-language-server==0.41.1", "pydantic==1.10.5", "code-tokenize==0.2.0", + "code-ast @ git+https://github.com/cedricrupb/code_ast@982940d04b1d721e5ac9a97d433f36d1fb47e8e0", "openai==1.3.3", "torch==1.12.0", "transformers==4.30.0", diff --git a/requirements.txt b/requirements.txt index af6ca8f..c5cc3c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,5 +13,6 @@ pytest-asyncio==0.21.1 pygtrie==2.5.0 openai==1.3.3 code-tokenize==0.2.0 +code-ast @ git+https://github.com/cedricrupb/code_ast@982940d04b1d721e5ac9a97d433f36d1fb47e8e0 --extra-index-url https://download.pytorch.org/whl/cu113 torch==1.12.0+cu113 \ No newline at end of file diff --git a/src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py b/src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py new file mode 100644 index 0000000..5805bf2 --- /dev/null +++ b/src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py @@ -0,0 +1,80 @@ +""" +This module provides the switch-enum monitor, that is invoked when "case " is typed in a switch statement +""" + +from typing import List +from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor, DecoderStates +from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer +from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper +from monitors4codegen.multilspy.multilspy_utils import TextUtils +from monitors4codegen.multilspy import multilspy_types + +class SwitchEnumMonitor(DereferencesMonitor): + """ + Provides the switch-enum monitor, that is invoked when "case " is typed in a switch statement to provide + enum values as completions + """ + def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None: + super().__init__(tokenizer, monitor_file_buffer, responsible_for_file_buffer_state) + self.all_break_chars.remove('.') + + async def pre(self) -> None: + cursor_idx = TextUtils.get_index_from_line_col( + self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path), + self.monitor_file_buffer.current_lc[0], + self.monitor_file_buffer.current_lc[1], + ) + text_upto_cursor = self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path)[ + :cursor_idx + ] + + # TODO: pre can be improved by checking for r"switch.*case", and obtaining completions, and then prefixing a whitespace + if not text_upto_cursor.endswith("case "): + self.decoder_state = DecoderStates.S0 + return + + completions = await self.a_phi() + if len(completions) == 0: + self.decoder_state = DecoderStates.S0 + else: + self.decoder_state = DecoderStates.Constrained + self.legal_completions = completions + + async def a_phi(self) -> List[str]: + relative_file_path = self.monitor_file_buffer.file_path + line, column = self.monitor_file_buffer.current_lc + + with self.monitor_file_buffer.lsp.open_file(relative_file_path): + legal_completions = await self.monitor_file_buffer.lsp.request_completions( + relative_file_path, line, column + ) + legal_completions = [ + completion["completionText"] + for completion in legal_completions + if completion["kind"] == multilspy_types.CompletionItemKind.EnumMember + ] + + return legal_completions + + async def update(self, generated_token: str): + """ + Updates the monitor state based on the generated token + """ + if self.responsible_for_file_buffer_state: + self.monitor_file_buffer.append_text(generated_token) + if self.decoder_state == DecoderStates.Constrained: + for break_char in self.all_break_chars: + if break_char in generated_token: + self.decoder_state = DecoderStates.S0 + self.legal_completions = None + return + + # No breaking characters found. Continue in constrained state + self.legal_completions = [ + legal_completion[len(generated_token) :] + for legal_completion in self.legal_completions + if legal_completion.startswith(generated_token) + ] + else: + # Nothing to be done in other states + return \ No newline at end of file diff --git a/tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py b/tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py new file mode 100644 index 0000000..508143f --- /dev/null +++ b/tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py @@ -0,0 +1,305 @@ +""" +This file contains tests for Monitor-Guided Decoding for switch-enum in C# +""" + +import torch +import transformers +import pytest + +from pathlib import PurePath +from monitors4codegen.multilspy.language_server import SyncLanguageServer +from monitors4codegen.multilspy.multilspy_config import Language +from tests.test_utils import create_test_context, is_cuda_available +from transformers import AutoTokenizer, AutoModelForCausalLM +from monitors4codegen.multilspy.multilspy_utils import TextUtils +from monitors4codegen.monitor_guided_decoding.monitors.switch_enum_monitor import SwitchEnumMonitor +from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor +from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer +from monitors4codegen.monitor_guided_decoding.hf_gen import MGDLogitsProcessor +from transformers.generation.utils import LogitsProcessorList +from monitors4codegen.multilspy.multilspy_types import Position +from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import HFTokenizerWrapper + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_multilspy_csharp_ryujinx_switch_enum() -> None: + """ + Test the working of SwitchEnumMonitor with C# repository - Ryujinx + """ + code_language = Language.CSHARP + params = { + "code_language": code_language, + "repo_url": "https://github.com/Ryujinx/Ryujinx/", + "repo_commit": "e768a54f17b390c3ac10904c7909e3bef020edbd" + } + + device = torch.device('cuda' if is_cuda_available() else 'cpu') + + model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained( + "bigcode/santacoder", trust_remote_code=True + ).to(device) + tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder") + + with create_test_context(params) as context: + lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) + with lsp.start_server(): + completions_filepath = "src/ARMeilleure/CodeGen/Arm64/CodeGenerator.cs" + with lsp.open_file(completions_filepath): + deleted_text = lsp.delete_text_between_positions( + completions_filepath, + Position(line=1369, character=21), + Position(line=1385, character=8) + ) + assert deleted_text == """AccessSize.Byte: + context.Assembler.Stlxrb(desired, address, result); + break; + case AccessSize.Hword: + context.Assembler.Stlxrh(desired, address, result); + break; + default: + context.Assembler.Stlxr(desired, address, result); + break; + } + + context.Assembler.Cbnz(result, startOffset - context.StreamOffset); // Retry if store failed. + + context.JumpHere(); + + context.Assembler.Clrex(); + """ + filebuffer = MonitorFileBuffer( + lsp.language_server, + completions_filepath, + (1369, 21), + (1369, 21), + code_language, + ) + monitor = SwitchEnumMonitor(HFTokenizerWrapper(tokenizer), filebuffer) + mgd_logits_processor = MGDLogitsProcessor([monitor], lsp.language_server.server.loop) + + with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f: + filecontent = f.read() + + pos_idx = TextUtils.get_index_from_line_col(filecontent, 1369, 21) + assert filecontent[:pos_idx].endswith('case ') + prompt = filecontent[:pos_idx] + assert prompt[-1] == " " + prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :] + + generated_code_without_mgd = model.generate( + prompt_tokenized, do_sample=False, max_new_tokens=100, early_stopping=True + ) + generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -100:]) + + assert ( + generated_code_without_mgd + == "1:\n context.Assembler.Stb(result, Register(ZrRegister, result.Type));\n break;\n case 2:\n context.Assembler.Stw(result, Register(ZrRegister, result.Type));\n break;\n case 4:\n context.Assembler.Std(result, Register(ZrRegister, result.Type));\n break;\n case 8:\n context.Assembler.Stq(result, Register(Zr" + ) + + # Generate code using santacoder model with the MGD logits processor and greedy decoding + logits_processor = LogitsProcessorList([mgd_logits_processor]) + generated_code = model.generate( + prompt_tokenized, + do_sample=False, + max_new_tokens=100, + logits_processor=logits_processor, + early_stopping=True, + ) + + generated_code = tokenizer.decode(generated_code[0, -100:]) + + assert ( + generated_code + == "AccessSize.Byte:\n context.Assembler.Staxrb(actual, address);\n break;\n case AccessSize.Hword:\n context.Assembler.Staxrh(actual, address);\n break;\n default:\n context.Assembler.Staxr(actual, address);\n break;\n }\n\n context.Assembler.Cmp(actual, desired);\n\n context.JumpToNear(ArmCondition.Eq);\n\n context.Assembler.Staxr(result," + ) + + + + completions_filepath = "src/ARMeilleure/CodeGen/X86/CodeGenerator.cs" + with lsp.open_file(completions_filepath): + deleted_text = lsp.delete_text_between_positions( + completions_filepath, + Position(line=224, character=37), + Position(line=243, character=28) + ) + assert deleted_text == """Intrinsic.X86Comisdlt: + context.Assembler.Comisd(src1, src2); + context.Assembler.Setcc(dest, X86Condition.Below); + break; + + case Intrinsic.X86Comisseq: + context.Assembler.Comiss(src1, src2); + context.Assembler.Setcc(dest, X86Condition.Equal); + break; + + case Intrinsic.X86Comissge: + context.Assembler.Comiss(src1, src2); + context.Assembler.Setcc(dest, X86Condition.AboveOrEqual); + break; + + case Intrinsic.X86Comisslt: + context.Assembler.Comiss(src1, src2); + context.Assembler.Setcc(dest, X86Condition.Below); + break; + """ + filebuffer = MonitorFileBuffer( + lsp.language_server, + completions_filepath, + (224, 37), + (224, 37), + code_language, + ) + monitor = SwitchEnumMonitor(HFTokenizerWrapper(tokenizer), filebuffer) + mgd_logits_processor = MGDLogitsProcessor([monitor], lsp.language_server.server.loop) + + with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f: + filecontent = f.read() + + pos_idx = TextUtils.get_index_from_line_col(filecontent, 224, 37) + assert filecontent[:pos_idx].endswith('case ') + prompt = filecontent[:pos_idx] + assert prompt[-1] == " " + prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :] + + generated_code_without_mgd = model.generate( + prompt_tokenized, do_sample=False, max_new_tokens=50, early_stopping=True + ) + generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -50:]) + + assert ( + generated_code_without_mgd + == " Intrinsic.X86Comisdgt:\n context.Assembler.Comisd(src1, src2);\n context.Assembler.Setcc(dest, X86Condition.GreaterThan);\n break;\n\n case In" + ) + + # Generate code using santacoder model with the MGD logits processor and greedy decoding + logits_processor = LogitsProcessorList([mgd_logits_processor]) + generated_code = model.generate( + prompt_tokenized, + do_sample=False, + max_new_tokens=50, + logits_processor=logits_processor, + early_stopping=True, + ) + + generated_code = tokenizer.decode(generated_code[0, -50:]) + + assert ( + generated_code + == "Intrinsic.X86Comisdlt:\n context.Assembler.Comisd(src1, src2);\n context.Assembler.Setcc(dest, X86Condition.LessThan);\n break;\n\n case Intrinsic" + ) + +@pytest.mark.asyncio +@pytest.mark.skip(reason="TODO: This runs too slow. Reimplement joint monitoring") +async def test_multilspy_csharp_ryujinx_joint_switch_enum_dereferences() -> None: + """ + Test the working of Joint monitoring with SwitchEnumMonitor and DereferencesMonitor with C# repository - Ryujinx + """ + + code_language = Language.CSHARP + params = { + "code_language": code_language, + "repo_url": "https://github.com/Ryujinx/Ryujinx/", + "repo_commit": "e768a54f17b390c3ac10904c7909e3bef020edbd" + } + + device = torch.device('cuda' if is_cuda_available() else 'cpu') + + model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained( + "bigcode/santacoder", trust_remote_code=True + ).to(device) + tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder") + + with create_test_context(params) as context: + lsp1 = SyncLanguageServer.create(context.config, context.logger, context.source_directory) + lsp2 = SyncLanguageServer.create(context.config, context.logger, context.source_directory) + with lsp1.start_server(), lsp2.start_server(): + completions_filepath = "src/ARMeilleure/CodeGen/X86/CodeGenerator.cs" + with lsp1.open_file(completions_filepath), lsp2.open_file(completions_filepath): + deleted_text1 = lsp1.delete_text_between_positions( + completions_filepath, + Position(line=224, character=37), + Position(line=243, character=28) + ) + deleted_text2 = lsp2.delete_text_between_positions( + completions_filepath, + Position(line=224, character=37), + Position(line=243, character=28) + ) + assert deleted_text1 == deleted_text2 + assert deleted_text1 == """Intrinsic.X86Comisdlt: + context.Assembler.Comisd(src1, src2); + context.Assembler.Setcc(dest, X86Condition.Below); + break; + + case Intrinsic.X86Comisseq: + context.Assembler.Comiss(src1, src2); + context.Assembler.Setcc(dest, X86Condition.Equal); + break; + + case Intrinsic.X86Comissge: + context.Assembler.Comiss(src1, src2); + context.Assembler.Setcc(dest, X86Condition.AboveOrEqual); + break; + + case Intrinsic.X86Comisslt: + context.Assembler.Comiss(src1, src2); + context.Assembler.Setcc(dest, X86Condition.Below); + break; + """ + filebuffer_enum = MonitorFileBuffer( + lsp1.language_server, + completions_filepath, + (224, 37), + (224, 37), + code_language, + ) + monitor_switch_enum = SwitchEnumMonitor(HFTokenizerWrapper(tokenizer), filebuffer_enum) + mgd_logits_processor_switch_enum = MGDLogitsProcessor([monitor_switch_enum], lsp1.language_server.server.loop) + + filebuffer_dereferences = MonitorFileBuffer( + lsp2.language_server, + completions_filepath, + (224, 37), + (224, 37), + code_language, + ) + monitor_dereferences = DereferencesMonitor(HFTokenizerWrapper(tokenizer), filebuffer_dereferences) + mgd_logits_processor_dereferences = MGDLogitsProcessor([monitor_dereferences], lsp2.language_server.server.loop) + + with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f: + filecontent = f.read() + + pos_idx = TextUtils.get_index_from_line_col(filecontent, 224, 37) + assert filecontent[:pos_idx].endswith('case ') + prompt = filecontent[:pos_idx] + assert prompt[-1] == " " + prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :] + + generated_code_without_mgd = model.generate( + prompt_tokenized, do_sample=False, max_new_tokens=50, early_stopping=True + ) + generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -50:]) + + assert ( + generated_code_without_mgd + == " Intrinsic.X86Comisdgt:\n context.Assembler.Comisd(src1, src2);\n context.Assembler.Setcc(dest, X86Condition.GreaterThan);\n break;\n\n case In" + ) + + # Generate code using santacoder model with the MGD logits processor and greedy decoding + logits_processor = LogitsProcessorList([mgd_logits_processor_switch_enum, mgd_logits_processor_dereferences]) + generated_code = model.generate( + prompt_tokenized, + do_sample=False, + max_new_tokens=50, + logits_processor=logits_processor, + early_stopping=True, + ) + + generated_code = tokenizer.decode(generated_code[0, -50:]) + + assert ( + generated_code + == "Intrinsic.X86Comisdlt:\n context.Assembler.Comisd(src1, src2);\n context.Assembler.Setcc(dest, X86Condition.Below);\n break;\n\n case Intrinsic" + )