Skip to content

Commit

Permalink
Add Monitor Guided Decoding - dereferences monitor
Browse files Browse the repository at this point in the history
  • Loading branch information
LakshyAAAgrawal committed Nov 6, 2023
1 parent 3b20749 commit 246681f
Show file tree
Hide file tree
Showing 6 changed files with 845 additions and 6 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ Some of the analyses results that `multilspy` can provide are:
- Finding the callers of a function or the instantiations of a class ([textDocument/references](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_references))
- Providing type-based dereference completions ([textDocument/completion](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_completion))

The file [multilspy/language_server.py](multilspy/language_server.py) provides the `multilspy` API. Several tests for `multilspy` present under [tests/multilspy/](tests/multilspy/) provide detailed usage examples for `multilspy`. The tests can be executed by running:
```bash
pytest tests/multilspy
```

Example usage:
```python
from multilspy import SyncLanguageServer
Expand Down Expand Up @@ -140,14 +145,9 @@ async with lsp.start_server():
...
```

Several tests for `multilspy` present under [tests/multilspy/](tests/multilspy/) provide detailed usage examples for `multilspy`. The tests can be executed by running:
```bash
pytest tests/multilspy
```

## 5. Monitor-Guided Decoding

Coming Soon...
A monitor under the Monitor-Guided Decoding framework, is instantiated using `multilspy` as the LSP client, and as a logits-processor to guide the LM decoding. [monitor_guided_decoding/monitor.py](monitor_guided_decoding/monitor.py) provides the class `MGDLogitsProcessor` which can be used with any HuggingFace Language Model, as a `LogitsProcessor` to guide the LM using MGD. [monitor_guided_decoding/dereferences_monitor.py](monitor_guided_decoding/dereferences_monitor.py) provides the instantiation for dereferences monitor. Unit tests for the dereferences monitor are present in [tests/monitor_guided_decoding/test_dereferences_monitor_java.py](tests/monitor_guided_decoding/test_dereferences_monitor_java.py), which also provide usage examples for the dereferences monitor.

## Contributing

Expand Down
244 changes: 244 additions & 0 deletions monitor_guided_decoding/dereferences_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
"""
This module provides the instantiation of dereferences monitor
"""

import code_tokenize as ctok

from typing import List, Union, Set
from enum import Enum

from multilspy import multilspy_types
from multilspy.multilspy_config import Language
from multilspy.multilspy_utils import TextUtils
from multilspy.multilspy_types import Position

from monitor_guided_decoding.monitor import Monitor, MonitorFileBuffer
from monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper
from monitor_guided_decoding.mgd_utils import PLUtils


class DecoderStates(Enum):
"""
Enum for the state of the decoder
"""

UnInitialized = 0
S0 = 1
Constrained = 2


class DereferencesMonitor(Monitor):
"""
Provides the dereferences monitor
"""

def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None:
super().__init__(tokenizer, monitor_file_buffer, responsible_for_file_buffer_state)
self.decoder_state = DecoderStates.UnInitialized
self.all_break_chars = DereferencesMonitor.find_all_break_chars(
self.tokenizer.tokenizer_char_set, monitor_file_buffer.language
)
self.prompt_len: Union[None, int] = None
self.legal_completions: Union[List[str], None] = None

@staticmethod
def find_all_break_chars(charset: Set[str], language: Language) -> Set[str]:
"""
Finds the set of characters, which when appended to the end of an identifier, cause the identifier to be terminated
For example "," is a breaking character, since "abc," is 2 tokens, "abc" and ","
On the other hand, "a" is not a breaking character, since "abca" is a single token
"""
all_break_chars: Set[str] = set()
for vocab_char in charset:
toks: List[ctok.tokens.ASTToken] = PLUtils.tokenizer_pl("abc" + vocab_char, language)
toks = [t for t in toks if t.text != ""]
if len(toks) == 0 or toks[0].text == "abc":
all_break_chars.add(vocab_char)
return all_break_chars

async def initialize(self, input_ids: List[int], input_text: str) -> None:
"""
Initializes the monitor when it is invoked for the first time with inputs
"""
self.prompt_len = len(input_ids)
await self.pre()

async def pre(self) -> None:
"""
Checks if the static analysis should be performed at this point.
In case of dereferences monitor, the last character shuold be a dot
"""
cursor_idx = TextUtils.get_index_from_line_col(
self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path),
self.monitor_file_buffer.current_lc[0],
self.monitor_file_buffer.current_lc[1],
)
text_upto_cursor = self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path)[
:cursor_idx
]
if text_upto_cursor[-1] != ".":
self.decoder_state = DecoderStates.S0
return

completions = await self.a_phi()

if len(completions) == 0:
self.decoder_state = DecoderStates.S0
else:
self.decoder_state = DecoderStates.Constrained
self.legal_completions = completions

async def maskgen(self, input_ids: List[int]) -> List[int]:
"""
Takes the list of input tokens, and returns the list of tokens to be blacklisted in the next step
maskgen is invoked for every new token to be generated
The first time it is invoked, maskgen performs the initialization
Subsequent invocations are handled based on the current state of the decoder
"""

input_text = self.tokenizer.decode(
input_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=True,
)

if self.decoder_state == DecoderStates.UnInitialized:
# Handle initialization. This is the first time monitor is being invoked
await self.initialize(input_ids, input_text)
else:
# A new token has been generated. Handle the new token by calling update
gen_so_far = self.tokenizer.decode(
input_ids[self.prompt_len :], clean_up_tokenization_spaces=False, skip_special_tokens=True
)
assert gen_so_far.startswith(self.monitor_file_buffer.gen_text), (gen_so_far, self.monitor_file_buffer.gen_text)
assert input_text.endswith(gen_so_far)
new_gen_text = gen_so_far[len(self.monitor_file_buffer.gen_text) :]

await self.update(new_gen_text)

if self.decoder_state == DecoderStates.S0:
# If the decoder is in state S0, then need to check pre()
# pre() will determine if the decoder should transition to state S1
# If so, it invokes a_phi() and transitions the decoder state
await self.pre()

if self.decoder_state == DecoderStates.Constrained:
# If the decoder is in state S1, then generate the set of blacklisted tokens
# based on the current state of the monitor, legal_completions
possible_token_ids: Set[int] = set()
for legal_suffix in self.legal_completions:
# If a token contains part of the end of an identifier followed by some breaking characters, then we allow it
# allow decoding of tokens like 'abc<', basically tokens that span identifier boundaries
if self.tokenizer.vocab_trie.has_node(legal_suffix) != 0:
for suffix_token, suffix_token_id in self.tokenizer.vocab_trie.iteritems(prefix=legal_suffix):
if suffix_token[len(legal_suffix) : len(legal_suffix) + 1] in self.all_break_chars:
possible_token_ids.add(suffix_token_id)

# If a token is a prefix of the remaining suffix, then we allow it
for suffix_token, suffix_token_id in self.tokenizer.vocab_trie.prefixes(legal_suffix):
possible_token_ids.add(suffix_token_id)

blacklisted_ids = [i for i in self.tokenizer.all_token_ids if i not in possible_token_ids]
else:
blacklisted_ids = []

return blacklisted_ids

async def a_phi(self) -> List[str]:
"""
Uses multilspy to perform static analysis and returns the list of type-compliant dereferences
at the current cursor position (which ends with a dot)
"""
relative_file_path = self.monitor_file_buffer.file_path
line, column = self.monitor_file_buffer.current_lc

with self.monitor_file_buffer.lsp.open_file(relative_file_path):
legal_completions1 = await self.monitor_file_buffer.lsp.request_completions(
relative_file_path, line, column, allow_incomplete=True
)
legal_completions1 = [
completion["completionText"]
for completion in legal_completions1
if completion["kind"] != multilspy_types.CompletionItemKind.Keyword
]
lsp_text = self.monitor_file_buffer.lsp.get_open_file_text(relative_file_path)
request_idx = TextUtils.get_index_from_line_col(lsp_text, line, column)
opening_bracket_stream = PLUtils.get_opening_bracket_stream(
lsp_text[:request_idx], self.monitor_file_buffer.language
)
if len(opening_bracket_stream) == 0:
return legal_completions1

closing_bracket_stream = PLUtils.get_closing_bracket_stream(
lsp_text[request_idx:], self.monitor_file_buffer.language
)
if len(opening_bracket_stream) <= len(closing_bracket_stream):
return legal_completions1

err = False
for j in range(len(closing_bracket_stream)):
if closing_bracket_stream[-j - 1] == "}" and opening_bracket_stream[j] != "{":
err = True
break
elif closing_bracket_stream[-j - 1] == ")" and opening_bracket_stream[j] != "(":
err = True
break
elif closing_bracket_stream[-j - 1] == "]" and opening_bracket_stream[j] != "[":
err = True
break

if err:
return legal_completions1

opening_bracket_stream = opening_bracket_stream[len(closing_bracket_stream) :]
remaining_close_brackets = list(
map(lambda x: "}" if x == "{" else (")" if x == "(" else "]"), opening_bracket_stream)
)[::-1]

insert_text = "".join(remaining_close_brackets)
updated_cursor = self.monitor_file_buffer.lsp.insert_text_at_position(
relative_file_path, line, column, insert_text
)
assert updated_cursor["line"] == line
assert updated_cursor["character"] == column + len(insert_text)
legal_completions2 = await self.monitor_file_buffer.lsp.request_completions(
relative_file_path, line, column, allow_incomplete=True
)
legal_completions2 = [
completion["completionText"]
for completion in legal_completions2
if completion["kind"] != multilspy_types.CompletionItemKind.Keyword
]

deleted_text = self.monitor_file_buffer.lsp.delete_text_between_positions(
relative_file_path,
Position(line=line, character=column),
Position(line=line, character=column + len(insert_text)),
)
assert deleted_text == insert_text

return list(set(legal_completions1 + legal_completions2))

async def update(self, generated_token: str):
"""
Updates the monitor state based on the generated token
"""
if self.responsible_for_file_buffer_state:
self.monitor_file_buffer.append_text(generated_token)
if self.decoder_state == DecoderStates.Constrained:
for break_char in self.all_break_chars:
if break_char in generated_token:
self.decoder_state = DecoderStates.S0
self.legal_completions = None
return

# No breaking characters found. Continue in constrained state
self.legal_completions = [
legal_completion[len(generated_token) :]
for legal_completion in self.legal_completions
if legal_completion.startswith(generated_token)
]
else:
# Nothing to be done in other states
return
84 changes: 84 additions & 0 deletions monitor_guided_decoding/mgd_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
This module provides the utility functions for handling programming language text
"""

import code_tokenize as ctok

from typing import List
from multilspy.multilspy_config import Language


class PLUtils:
"""
This class provides various utility functions for handling programming language text
"""

@staticmethod
def tokenizer_pl(inp_text: str, lang: Language) -> List[ctok.tokens.ASTToken]:
"""
Tokenizes the given text using code_tokenize
"""
lang_s = str(lang) if lang != Language.CSHARP else "c-sharp"
if inp_text.strip() == "":
return []
lsp_text_lang_tokenized: List[ctok.tokens.ASTToken] = ctok.tokenize(
inp_text, lang=lang_s, syntax_error="ignore"
)
lsp_text_lang_tokenized: List[ctok.tokens.ASTToken] = [tok for tok in lsp_text_lang_tokenized if tok.text != ""]
return lsp_text_lang_tokenized

@staticmethod
def get_opening_bracket_stream(inp_text: str, lang: Language) -> List[str]:
"""
Returns the list of opened brackets in the given text
"""
bracket_stream: List[str] = []
err = False
lsp_text_lang_tokenized = PLUtils.tokenizer_pl(inp_text, lang)
for tok in lsp_text_lang_tokenized:
if tok.type in ["{", "(", "["]:
bracket_stream.append(tok.type)
elif tok.type in ["}", ")", "]"]:
if len(bracket_stream) == 0:
err = True
break
if (
(tok.type == "}" and bracket_stream[-1] == "{")
or (tok.type == ")" and bracket_stream[-1] == "(")
or (tok.type == "]" and bracket_stream[-1] == "[")
):
bracket_stream.pop()
else:
err = True
break
if err:
raise Exception("Invalid bracket stream")
return bracket_stream

@staticmethod
def get_closing_bracket_stream(inp_text: str, lang: Language) -> List[str]:
"""
Returns the list of closing brackets in the given text
"""
bracket_stream: List[str] = []
err = False
lsp_text_lang_tokenized = PLUtils.tokenizer_pl(inp_text, lang)
for tok in lsp_text_lang_tokenized[::-1]:
if tok.type in ["}", ")", "]"]:
bracket_stream.append(tok.type)
elif tok.type in ["{", "(", "["]:
if len(bracket_stream) == 0:
err = True
break
if (
(tok.type == "{" and bracket_stream[-1] == "}")
or (tok.type == "(" and bracket_stream[-1] == ")")
or (tok.type == "[" and bracket_stream[-1] == "]")
):
bracket_stream.pop()
else:
err = True
break
if err:
raise Exception("Invalid bracket stream")
return bracket_stream[::-1]
Loading

0 comments on commit 246681f

Please sign in to comment.