Skip to content

Commit

Permalink
Allowed token ids (#598)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiworldwzj authored Oct 30, 2024
1 parent 8c58650 commit e7184fc
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 1 deletion.
10 changes: 10 additions & 0 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
stop_sequences: List[List[int]] = [],
input_penalty: bool = False,
regular_constraint: Optional[str] = None,
allowed_token_ids: Optional[List[int]] = None,
) -> None:
self.best_of = best_of
self.do_sample = do_sample
Expand All @@ -60,8 +61,17 @@ def __init__(
self.regular_constraint = regular_constraint
self.regex_guide = None
self.fsm_current_state: int = 0
self.allowed_token_ids = allowed_token_ids
# this check is not very good to placed here. to do...
if self.allowed_token_ids is not None:
if not all(e < vocab_size for e in self.allowed_token_ids):
logger.error("allowed_token_ids contain tokenid >= vobsize, we remove these token ids")
self.allowed_token_ids = [e for e in self.allowed_token_ids if e < vocab_size]
return

def has_constraint_setting(self) -> bool:
return self.regular_constraint is not None or self.allowed_token_ids is not None


class InferReq:
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def decode_batch(self, batch_id):

logits = self.model.forward(**kwargs)

all_has_no_constraint = all([e.sampling_param.regular_constraint is None for e in run_reqs])
all_has_no_constraint = all([not e.sampling_param.has_constraint_setting() for e in run_reqs])
if not all_has_no_constraint:
mask = torch.ones_like(logits, dtype=torch.bool)
for i, run_obj in enumerate(run_reqs):
Expand Down Expand Up @@ -146,6 +146,8 @@ def _mask_req_out_token(self, i, run_obj: InferReq, mask):
regex_guide: RegexGuide = sample_params.regex_guide
ok_token_id_list = regex_guide.get_next_instruction(sample_params.fsm_current_state).tokens
mask[i, ok_token_id_list] = False
elif sample_params.allowed_token_ids is not None:
mask[i, sample_params.allowed_token_ids] = False
else:
mask[i, :] = False
return
18 changes: 18 additions & 0 deletions lightllm/server/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def __init__(
# Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty
input_penalty: bool = DEFAULT_INPUT_PENALTY,
regular_constraint: Optional[str] = None, # Regular expressions constrain the output.
# If provided, the engine will construct a logits,
# processor which only retains scores for the given token ids. Defaults to None.
# allowed_token_ids only can be used in "--simple_constraint_mode" started server.
allowed_token_ids: Optional[List[int]] = None,
) -> None:
self.best_of = best_of
self.n = n
Expand All @@ -51,6 +55,7 @@ def __init__(
self.add_spaces_between_special_tokens = add_spaces_between_special_tokens
self.print_eos_token = print_eos_token
self.regular_constraint = regular_constraint
self.allowed_token_ids = allowed_token_ids
if self.do_sample is False:
self.temperature = 1.0
self.top_p = 1.0
Expand Down Expand Up @@ -131,6 +136,18 @@ def verify(self):

self._verify_stop_sentences()

self._verify_allowed_token_ids()

return

def _verify_allowed_token_ids(self):
if self.allowed_token_ids is not None:
if (not isinstance(self.allowed_token_ids, list)) or (
not all(isinstance(token_id, int) for token_id in self.allowed_token_ids)
):
raise ValueError(f"allowed_token_ids need format List[int], but get {self.allowed_token_ids}")
if self.regular_constraint is not None:
raise ValueError("allowed_token_ids and regular_constraint can not be used in same time")
return

def _verify_stop_sentences(self):
Expand Down Expand Up @@ -187,4 +204,5 @@ def to_dict(self):
ret["best_of"] = self.best_of
ret["input_penalty"] = self.input_penalty
ret["regular_constraint"] = self.regular_constraint
ret["allowed_token_ids"] = self.allowed_token_ids
return ret
10 changes: 10 additions & 0 deletions test/test_constraint_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,13 @@ def run(self):
}
thread = RequestThread(url, headers, data)
thread.start()

time.sleep(10)

for i in range(20):
data = {
"inputs": "Are dog a man? ",
"parameters": {"do_sample": False, "ignore_eos": True, "max_new_tokens": 200, "allowed_token_ids": [2, 3]},
}
thread = RequestThread(url, headers, data)
thread.start()

0 comments on commit e7184fc

Please sign in to comment.