Skip to content

Commit

Permalink
Making sure only response gets evaluated for log_probs in together.ai…
Browse files Browse the repository at this point in the history
… api for choice sampling.

PiperOrigin-RevId: 688477264
Change-Id: Ifd14023b4b8128e6cd11d0586b99b34719ac3f72
  • Loading branch information
vezhnick authored and copybara-github committed Oct 22, 2024
1 parent 281847f commit 7d9306f
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions concordia/language_model/together_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,28 @@
_NUM_INITIAL_TOKENS = 500


def _find_response_start_index(tokens):
r"""Finds the start of the response in the prompt.
Args:
tokens: A list of strings.
Returns:
The index of the last occurrence of '<start_of_turn>' followed by 'model'
and '\n', or 1 if the sequence is not found. This corresponds to the start
of the response.
"""
assert len(tokens) >= 3, "Response doesn't match expectation."
for i in range(len(tokens) - 3, -1, -1):
if (
tokens[i] == '<start_of_turn>'
and tokens[i + 1] == 'model'
and tokens[i + 2] == '\n'
):
return i + 3 # Return the index after the sequence
raise ValueError("Response doesn't match expectation.")


def _ensure_prompt_not_too_long(
prompt: str,
num_response_tokens: int,
Expand Down Expand Up @@ -273,8 +295,10 @@ def _sample_choice(
)
continue
else:
# removing the first token since it is always scored with None.
score = sum(result.prompt[0].logprobs.token_logprobs[1:])
logprobs = result.prompt[0].logprobs
response_idx = _find_response_start_index(logprobs.tokens)
response_log_probs = logprobs.token_logprobs[response_idx:]
score = sum(response_log_probs)
return score

raise language_model.InvalidResponseError(
Expand Down

0 comments on commit 7d9306f

Please sign in to comment.