diff --git a/lm_eval/base.py b/lm_eval/base.py index 6261761b3c..887da9ed9a 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -915,12 +915,12 @@ def fewshot_context( if num_fewshot == 0: labeled_examples = "" - fewshotex, fewshotidx, fewshotsource = [], [], None + fewshotex, fewshotidx, self.fewshotsource = [], [], None else: # for sets with no training docs, draw from other set *but ensure no overlap with current doc* if self.has_training_docs(): fewshotex, fewshotidx = self.fewshot_examples(k=num_fewshot, rnd=rnd) - fewshotsource = "train" + self.fewshotsource = "train" else: if self._fewshot_docs is None: self._fewshot_docs = list( @@ -929,18 +929,18 @@ def fewshot_context( else self.test_docs() ) if self.has_validation_docs(): - fewshotsource = "val" + self.fewshotsource = "val" elif self.test_docs(): - fewshotsource = "test" + self.fewshotsource = "test" fewshotex, fewshotidx = self._get_fewshot_examples( self._fewshot_docs, k=num_fewshot + 1, rnd=rnd ) - fewshotex, fewshotidx = [ + fewshotex, fewshotidx = zip(*[ (shot, idx) for shot, idx in zip(fewshotex, fewshotidx) if shot != doc - ] + ]) # get rid of the doc that's the one we're evaluating, if it's in the fewshot fewshotex, fewshotidx = ( fewshotex[:num_fewshot], @@ -966,7 +966,7 @@ def fewshot_context( ctx, { "fewshot_idx": fewshotidx, - "fewshot_source": fewshotsource, + "fewshot_source": self.fewshotsource, "fewshot_num": num_fewshot, "ctx": ctx, },