From a69dc1eae6541ac6d3a3335623dfd2e7b0041b30 Mon Sep 17 00:00:00 2001 From: KlaudiaTH Date: Wed, 8 Nov 2023 18:31:22 +0100 Subject: [PATCH] Fix unnatural tokenizations if possible --- lm_eval/base.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/lm_eval/base.py b/lm_eval/base.py index 7ca3c677af..5f84a55010 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -180,6 +180,20 @@ def loglikelihood(self, requests): continuation_enc = self.tok_encode(continuation) # continuation_enc = self.tok_encode(continuation, is_continuation=True) + context_continuation_enc = self.tok_encode(context + continuation) + + if context_enc + continuation_enc != context_continuation_enc: + if context_continuation_enc[: len(context_enc)] == context_enc: + # continuation_enc is incorrect and context_enc is correct + continuation_enc = context_continuation_enc[len(context_enc) :] + elif context_continuation_enc[-len(continuation_enc) :] == continuation_enc: + # continuation_enc is correct and context_enc is incorrect + context_enc = context_continuation_enc[: -len(continuation_enc)] + else: + # Both are incorrect + print( + f"WARNING: Unnatural tokenization of concatenated context ...{repr(context[-20:])} and continuation {repr(continuation)}" + ) new_reqs.append(((context, continuation), context_enc, continuation_enc))