diff --git a/lm_eval/base.py b/lm_eval/base.py index 7ca3c677af..5f84a55010 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -180,6 +180,20 @@ def loglikelihood(self, requests): continuation_enc = self.tok_encode(continuation) # continuation_enc = self.tok_encode(continuation, is_continuation=True) + context_continuation_enc = self.tok_encode(context + continuation) + + if context_enc + continuation_enc != context_continuation_enc: + if context_continuation_enc[: len(context_enc)] == context_enc: + # continuation_enc is incorrect and context_enc is correct + continuation_enc = context_continuation_enc[len(context_enc) :] + elif context_continuation_enc[-len(continuation_enc) :] == continuation_enc: + # continuation_enc is correct and context_enc is incorrect + context_enc = context_continuation_enc[: -len(continuation_enc)] + else: + # Both are incorrect + print( + f"WARNING: Unnatural tokenization of concatenated context ...{repr(context[-20:])} and continuation {repr(continuation)}" + ) new_reqs.append(((context, continuation), context_enc, continuation_enc))