diff --git a/lm_eval/models/utils.py b/lm_eval/models/utils.py index 89281708c8..38c74359ff 100644 --- a/lm_eval/models/utils.py +++ b/lm_eval/models/utils.py @@ -750,6 +750,9 @@ def segmented_tok_encode(string: SegmentedString, tokenizer: PreTrainedTokenizer is a list of segment labels. """ + if type(string)!=SegmentedString: + raise ValueError(f"Input must be of type SegmentedString (found type {str(type(string))}).\n" + f"Do not use smart truncation strategy for language modeling tasks.") assert type(string) == SegmentedString, "string must be a SegmentedString" encoding = tokenizer( string,