Skip to content

Commit

Permalink
[Optimization] Advance parser concurrently with model forward pass (#…
Browse files Browse the repository at this point in the history
…1065)

Refactors `Engine.__call__` and `TokenParser._parse` coroutine/generator
to yield a _future_ wrapping `LLInterpreter.mid_process` (running in a
`ThreadPoolExecutor`) such that `mid_process` can run concurrently with
the forward pass. This function comes directly from the rust extension,
so it releases the GIL and threading *should* be sufficient to ensure
true concurrency.

---------

Co-authored-by: Loc Huynh <[email protected]>
  • Loading branch information
hudson-ai and lochuynh1412 authored Nov 10, 2024
1 parent ed9f078 commit b6bcee7
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 86 deletions.
134 changes: 85 additions & 49 deletions guidance/_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
from typing import Any, Generator, Optional, Tuple, Union
from typing import Any, Generator, Optional, Union
from concurrent.futures import ThreadPoolExecutor, Future

import llguidance # type: ignore[import-untyped]
import numpy as np
Expand All @@ -11,7 +12,6 @@
from .models._byte_tokenizer import ByteTokenizer
from .models._tokenizer import Tokenizer


class TokenParserException(Exception):
pass

Expand Down Expand Up @@ -52,8 +52,10 @@ def __init__(
serialized_grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
self._threadpool = ThreadPoolExecutor(max_workers=1)
self._generator = self._parse(prompt, ensure_bos_token)
self._done = False
self._has_pending_stop = False

def is_accepting(self) -> bool:
return self.ll_interpreter.is_accepting()
Expand All @@ -63,12 +65,13 @@ def done(self) -> bool:

def advance(
self, token: Optional[int]
) -> Tuple[Optional[GenData], EngineCallResponse]:
try:
return self._generator.send(token)
except StopIteration as e:
self._done = True
return None, e.value
) -> tuple[list[int], Future[tuple[Optional[bytes], LLInterpreterResponse]]]:
if self.done():
raise TokenParserException("Cannot advance on a done parser")
return self._generator.send(token)

def has_pending_stop(self) -> bool:
return self._has_pending_stop

def _process_prompt(self, prompt: bytes, ensure_bos_token: bool) -> list[int]:
prompt_tokens = self.ll_interpreter.process_prompt(
Expand All @@ -84,55 +87,78 @@ def _process_prompt(self, prompt: bytes, ensure_bos_token: bool) -> list[int]:

return self.tokenizer.recode(prompt_tokens)

def mid_process(self) -> tuple[Optional[bytes], LLInterpreterResponse]:
mask, ll_response_string = self.ll_interpreter.mid_process()
ll_response = LLInterpreterResponse.model_validate_json(ll_response_string)
return mask, ll_response

def _parse(
self,
prompt: bytes,
ensure_bos_token: bool,
) -> Generator[Tuple[Optional[GenData], EngineCallResponse], Optional[int], EngineCallResponse]:
) -> Generator[
tuple[
list[int],
Future[tuple[Optional[bytes], LLInterpreterResponse]],
],
Optional[int],
None
]:
tokens = self._process_prompt(prompt=prompt, ensure_bos_token=ensure_bos_token)

while True:
mask, resp = self.ll_interpreter.mid_process()
r = LLInterpreterResponse.model_validate_json(resp)
response = r.progress.to_engine_call_response()
if r.stop:
# Note: need to call/set has_pending_stop before spinning up the mid_process future
# as the two methods cannot be called concurrently
self._has_pending_stop = self.ll_interpreter.has_pending_stop()
mid_process_future = self._threadpool.submit(self.mid_process)
token = yield (tokens, mid_process_future)

# Upstairs should have already waited on this future
mask, ll_response = mid_process_future.result()

if ll_response.stop:
# This is the only case in which the mask is None
assert mask is None
# If we're done, our caller should NOT send us a token
if token is not None:
raise TokenParserException(f"Expected None, got token {token}")
self._done = True
break

if mask is not None:
assert r.temperature is not None
gen_data = GenData(
tokens=tokens,
mask=mask,
temperature=r.temperature,
assert mask is not None
if token is None:
raise TokenParserException("Expected token, got None")
if not mask[token]:
# Note: we could punt this probem to ll_interpreter.post_process,
# but it's a bit clearer to handle it here
raise InvalidTokenException(
token=token,
valid_tokens=[i for i in range(len(mask)) if mask[i]],
prompt_tokens=tokens
)
# Send caller the mask and response; wait for token
token = yield (gen_data, response)
if token is None:
raise TokenParserException("Expected token, got None")
if not mask[token]:
# Note: we could punt this probem to ll_interpreter.post_process,
# but it's a bit clearer to handle it here
raise InvalidTokenException(token, gen_data.valid_next_tokens, tokens)
else:
gen_data = None
token = yield (gen_data, response)
if token is not None:
raise TokenParserException(f"Expected None, got token {token}")

backtrack, ff_tokens = self.ll_interpreter.post_process(token)
if backtrack:
tokens = tokens[:-backtrack]
tokens = tokens + ff_tokens

def cleanup(self):
# Rather than having our caller send us None at the end, we'll handle that internally
# so we can (1) verify that the generator actually stops and (2) check the stop reason
# and raise if needed
if not self.done():
try:
self._generator.send(None)
except StopIteration:
pass
if not self.done():
raise TokenParserException("Tried to cleanup but parser is not done")
stop_reason = self.ll_interpreter.stop_reason()
if stop_reason not in {"NoExtension", "EndOfSentence"}:
# TODO: extend exception handling
# Will raise if there is some "bad" stop reason (like hit token limit) OR we're NOT stopped.
# TODO: raise specific exceptions for reasons such as MaxTokensTotal
raise TokenParserException(f"Unexpected stop reason: {stop_reason}")

return response


class ByteParserException(Exception):
def __init__(self, *args, **kwargs):
self.current_byte = kwargs.pop("current_byte", None)
Expand All @@ -155,6 +181,8 @@ def __init__(
self.pos = 0
self._variables: dict[str, Any] = {}
self._variables_log_probs: dict[str, Any] = {}
# Prime the parser
self._advance(None)
self.consume_bytes(prompt)

def matched(self) -> bool:
Expand All @@ -179,14 +207,26 @@ def next_byte_mask(self) -> NDArray[np.uint8]:
mask[t[0]] = 1
return mask

def consume_bytes(self, bts: bytes) -> None:
# Run underlying ll_parser and fast-forward all of our bytes
# until we have a "choice" (generation step) to make
while self.gen_data is None and not self.token_parser.done():
self.gen_data, response = self.token_parser.advance(None)
self._update_capture(response)
self.bytes += response.new_bytes
def _advance(self, token: Optional[int]) -> None:
tokens, mid_process_fut = self.token_parser.advance(token)
mask, ll_response = mid_process_fut.result()
if ll_response.stop:
assert mask is None
self.token_parser.cleanup()
self.gen_data = None
else:
assert mask is not None
assert ll_response.temperature is not None
self.gen_data = GenData(
tokens=tokens,
mask=mask,
temperature=ll_response.temperature,
)
response = ll_response.progress.to_engine_call_response()
self._update_capture(response)
self.bytes += response.new_bytes

def consume_bytes(self, bts: bytes) -> None:
if not bts:
return

Expand Down Expand Up @@ -228,9 +268,7 @@ def consume_bytes(self, bts: bytes) -> None:
consumed_bytes=self.bytes[: self.pos],
)
# Byte was good, have ll_parser consume it so we can advance further
self.gen_data, response = self.token_parser.advance(b)
self._update_capture(response)
self.bytes += response.new_bytes
self._advance(b)

# Run consume_bytes to advance ll_parser and consume the next byte
self.consume_bytes(bts)
Expand All @@ -241,9 +279,7 @@ def force_done(self):
if self.token_parser.done():
return

self.gen_data, response = self.token_parser.advance(self.tokenizer.eos_token_id)
self._update_capture(response)
self.bytes += response.new_bytes
self._advance(self.tokenizer.eos_token_id)
if not self.token_parser.done() or not self.matched():
raise ByteParserException("Hit end of input before reaching a valid state")

Expand Down
4 changes: 2 additions & 2 deletions guidance/models/_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def __init__(self, tokenizer, byte_patterns, compute_log_probs, force):
# seed the random number generator
self._rand_generator = np.random.default_rng(seed=42)

def get_next_token(self, token_ids: list[int], mask: Optional[bytes], temperature: float) -> int:
def sample_with_temperature(self, logits, mask, temperature):
self.called_temperatures.append(temperature)
return super().get_next_token(token_ids, mask, temperature)
return super().sample_with_temperature(logits, mask, temperature)

def get_logits(self, token_ids: list[int]) -> np.ndarray:
"""Pretends to compute the logits for the given token state."""
Expand Down
102 changes: 73 additions & 29 deletions guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,42 +133,86 @@ def __call__(self, prompt, grammar, ensure_bos_token=True) -> Iterator[EngineCal
"""
parser = self.start(prompt, grammar, ensure_bos_token)

has_get_logits = True
token = None
while not parser.done():
gen_data, response = parser.advance(token)

if gen_data is not None:
if parser.is_accepting() and self.tokenizer.eos_token_id is not None:
# Whenever we are in an accepting state, we will allow the model to generate whatever it wants
# but we will treat any "illegal" tokens as EOS, allowing the model to finish gracefully.
assert gen_data.mask[self.tokenizer.eos_token_id]
token = self.get_next_token(
token_ids=gen_data.tokens,
mask=None,
temperature=gen_data.temperature
)
if not gen_data.mask[token]:
token = self.tokenizer.eos_token_id
else:
token = self.get_next_token(
token_ids=gen_data.tokens,
mask=gen_data.mask,
temperature=gen_data.temperature
)
while True:
tokens, mid_process_fut = parser.advance(token)

# Note that has_pending_stop implies that the response is a stop response,
# but the converse is not true. We can therefore avoid some (but not all)
# unnecessary calls to get_logits on the final iteration.
has_pending_stop = parser.has_pending_stop()

if has_get_logits and not has_pending_stop:
try:
logits = self.get_logits(token_ids=tokens)
except NotImplementedError:
# Permanently fall-back to get_next_token if get_logits is not implemented
has_get_logits = False
logits = None
else:
token = None
logits = None

# Important: don't wait on this future until after getting the logits;
# this allows the mask to be built concurrently with model inference
mask, ll_response = mid_process_fut.result()

engine_response = ll_response.progress.to_engine_call_response()
yield engine_response

if ll_response.stop:
assert mask is None
# May raise an exception if the parser is in an bad state!
parser.cleanup()
# Ensure we break AFTER yielding the final response
break

# If there was a pending stop, we should have broken out of the loop
assert not has_pending_stop

# Help the type checker: assert that everything we need to get the next token is not None
assert mask is not None
assert ll_response.temperature is not None

can_finish_early = parser.is_accepting() and self.tokenizer.eos_token_id is not None

if can_finish_early:
# Type checker needs some help
assert self.tokenizer.eos_token_id is not None
# Should be equivalent to parser.is_accepting()
assert mask[self.tokenizer.eos_token_id]
# Whenever we are in an accepting state, we will allow the model to generate whatever it wants
# but we will treat any "illegal" tokens as EOS, allowing the model to finish gracefully.
# Hence, mask must be None
mask_for_sampling = None
else:
mask_for_sampling = mask

if logits is not None:
token = self.sample_with_temperature(
logits=logits,
mask=mask_for_sampling,
temperature=ll_response.temperature,
)
else:
token = self.get_next_token(
tokens,
mask_for_sampling,
ll_response.temperature
)

if can_finish_early and not mask[token]:
# Type checker needs some help
assert self.tokenizer.eos_token_id is not None
token = self.tokenizer.eos_token_id

yield response

def get_next_token(self, token_ids: list[int], mask: Optional[bytes], temperature: float) -> int:
"""Base implementation for getting the next token from the model which calls get_logits and sample_with_temperature.
Subclasses may override this method, e.g. if they use external APIs that do not support getting logits directly.
"""
logits = self.get_logits(token_ids)
token = self.sample_with_temperature(logits, mask, temperature)
return token
# Prefer to implement get_logits over get_next_token as it allows for concurrent mask computation
raise NotImplementedError

def get_logits(self, token_ids: list[int]) -> np.ndarray:
# Prefer to implement get_logits over get_next_token as it allows for concurrent mask computation
raise NotImplementedError

def sample_with_temperature(self, logits: np.ndarray, mask: Optional[bytes], temperature: float) -> int:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"referencing",
"requests",
"tiktoken>=0.3",
"llguidance>=0.1.7",
"llguidance>=0.3.0",
]

# Our basic list of 'extras'
Expand Down
10 changes: 5 additions & 5 deletions tests/model_integration/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,18 @@ def test_associativity(selected_model: models.Model):
REMOTE_MODELS = [models.AzureGuidance]
for rm in REMOTE_MODELS:
if isinstance(selected_model, rm):
pytest.skip("Method get_next_token not available for remote models")
pytest.skip("Method get_logits not available for remote models")
prompt = "pi = "
grammar = gen("number", regex=r"\d")
engine = selected_model.engine

with patch.object(engine, "get_next_token", side_effect=engine.get_next_token) as get_next_token_1:
with patch.object(engine, "get_logits", side_effect=engine.get_logits) as get_logits_1:
_ = selected_model + (prompt + grammar)
prompt_tokens_1 = get_next_token_1.call_args.kwargs["token_ids"]
prompt_tokens_1 = get_logits_1.call_args_list[0].kwargs["token_ids"]

with patch.object(engine, "get_next_token", side_effect=engine.get_next_token) as get_next_token_2:
with patch.object(engine, "get_logits", side_effect=engine.get_logits) as get_logits_2:
_ = (selected_model + prompt) + grammar
prompt_tokens_2 = get_next_token_2.call_args.kwargs["token_ids"]
prompt_tokens_2 = get_logits_2.call_args_list[0].kwargs["token_ids"]

# Main assertion: the prompt tokens should be the same
assert prompt_tokens_1 == prompt_tokens_2
Expand Down

0 comments on commit b6bcee7

Please sign in to comment.