diff --git a/concordia/language_model/together_ai.py b/concordia/language_model/together_ai.py index f99ac5d..4ee7103 100644 --- a/concordia/language_model/together_ai.py +++ b/concordia/language_model/together_ai.py @@ -49,6 +49,28 @@ _NUM_INITIAL_TOKENS = 500 +def _find_response_start_index(tokens): + r"""Finds the start of the response in the prompt. + + Args: + tokens: A list of strings. + + Returns: + The index of the last occurrence of '' followed by 'model' + and '\n', or 1 if the sequence is not found. This corresponds to the start + of the response. + """ + assert len(tokens) >= 3, "Response doesn't match expectation." + for i in range(len(tokens) - 3, -1, -1): + if ( + tokens[i] == '' + and tokens[i + 1] == 'model' + and tokens[i + 2] == '\n' + ): + return i + 3 # Return the index after the sequence + raise ValueError("Response doesn't match expectation.") + + def _ensure_prompt_not_too_long( prompt: str, num_response_tokens: int, @@ -273,8 +295,10 @@ def _sample_choice( ) continue else: - # removing the first token since it is always scored with None. - score = sum(result.prompt[0].logprobs.token_logprobs[1:]) + logprobs = result.prompt[0].logprobs + response_idx = _find_response_start_index(logprobs.tokens) + response_log_probs = logprobs.token_logprobs[response_idx:] + score = sum(response_log_probs) return score raise language_model.InvalidResponseError(