Skip to content

Commit

Permalink
Add Switch-Enum monitor
Browse files Browse the repository at this point in the history
  • Loading branch information
LakshyAAAgrawal committed Nov 20, 2023
1 parent ad4055a commit c6e5a7f
Show file tree
Hide file tree
Showing 5 changed files with 390 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ A monitor under the Monitor-Guided Decoding framework, is instantiated using `mu
#### Dereferences Monitor
[src/monitors4codegen/monitor_guided_decoding/monitors/dereferences_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/dereferences_monitor.py) provides the instantiation of `Monitor` class for dereferences monitor. It can be used to guide LMs to generate valid identifier dereferences. Unit tests for the dereferences monitor are present in [tests/monitor_guided_decoding/test_dereferences_monitor_java.py](tests/monitor_guided_decoding/test_dereferences_monitor_java.py), which also provide usage examples for the dereferences monitor.

#### Switch-Enum Monitor
[src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py) provides the instantiation of `Monitor` for generating valid named enum constants in C#. Unit tests for the switch-enum monitor are present in [tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py](tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py), which also provide usage examples for the switch-enum monitor.

## Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"jedi-language-server==0.41.1",
"pydantic==1.10.5",
"code-tokenize==0.2.0",
"code-ast @ git+https://github.com/cedricrupb/code_ast@982940d04b1d721e5ac9a97d433f36d1fb47e8e0",
"openai==1.3.3",
"torch==1.12.0",
"transformers==4.30.0",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ pytest-asyncio==0.21.1
pygtrie==2.5.0
openai==1.3.3
code-tokenize==0.2.0
code-ast @ git+https://github.com/cedricrupb/code_ast@982940d04b1d721e5ac9a97d433f36d1fb47e8e0
--extra-index-url https://download.pytorch.org/whl/cu113
torch==1.12.0+cu113
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
This module provides the switch-enum monitor, that is invoked when "case " is typed in a switch statement
"""

from typing import List
from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor, DecoderStates
from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer
from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper
from monitors4codegen.multilspy.multilspy_utils import TextUtils
from monitors4codegen.multilspy import multilspy_types

class SwitchEnumMonitor(DereferencesMonitor):
"""
Provides the switch-enum monitor, that is invoked when "case " is typed in a switch statement to provide
enum values as completions
"""
def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None:
super().__init__(tokenizer, monitor_file_buffer, responsible_for_file_buffer_state)
self.all_break_chars.remove('.')

async def pre(self) -> None:
cursor_idx = TextUtils.get_index_from_line_col(
self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path),
self.monitor_file_buffer.current_lc[0],
self.monitor_file_buffer.current_lc[1],
)
text_upto_cursor = self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path)[
:cursor_idx
]

# TODO: pre can be improved by checking for r"switch.*case", and obtaining completions, and then prefixing a whitespace
if not text_upto_cursor.endswith("case "):
self.decoder_state = DecoderStates.S0
return

completions = await self.a_phi()
if len(completions) == 0:
self.decoder_state = DecoderStates.S0
else:
self.decoder_state = DecoderStates.Constrained
self.legal_completions = completions

async def a_phi(self) -> List[str]:
relative_file_path = self.monitor_file_buffer.file_path
line, column = self.monitor_file_buffer.current_lc

with self.monitor_file_buffer.lsp.open_file(relative_file_path):
legal_completions = await self.monitor_file_buffer.lsp.request_completions(
relative_file_path, line, column
)
legal_completions = [
completion["completionText"]
for completion in legal_completions
if completion["kind"] == multilspy_types.CompletionItemKind.EnumMember
]

return legal_completions

async def update(self, generated_token: str):
"""
Updates the monitor state based on the generated token
"""
if self.responsible_for_file_buffer_state:
self.monitor_file_buffer.append_text(generated_token)
if self.decoder_state == DecoderStates.Constrained:
for break_char in self.all_break_chars:
if break_char in generated_token:
self.decoder_state = DecoderStates.S0
self.legal_completions = None
return

# No breaking characters found. Continue in constrained state
self.legal_completions = [
legal_completion[len(generated_token) :]
for legal_completion in self.legal_completions
if legal_completion.startswith(generated_token)
]
else:
# Nothing to be done in other states
return
Loading

0 comments on commit c6e5a7f

Please sign in to comment.