Skip to content

Commit

Permalink
Merge pull request #21 from tomlimi/tomlimi/fix_fs_write_out
Browse files Browse the repository at this point in the history
Fixed issue with write_out for datasets without a training split
  • Loading branch information
cjlovering authored Apr 28, 2022
2 parents 256de63 + 49f0699 commit a963ab8
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions lm_eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,12 +915,12 @@ def fewshot_context(

if num_fewshot == 0:
labeled_examples = ""
fewshotex, fewshotidx, fewshotsource = [], [], None
fewshotex, fewshotidx, self.fewshotsource = [], [], None
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex, fewshotidx = self.fewshot_examples(k=num_fewshot, rnd=rnd)
fewshotsource = "train"
self.fewshotsource = "train"
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
Expand All @@ -929,18 +929,18 @@ def fewshot_context(
else self.test_docs()
)
if self.has_validation_docs():
fewshotsource = "val"
self.fewshotsource = "val"
elif self.test_docs():
fewshotsource = "test"
self.fewshotsource = "test"

fewshotex, fewshotidx = self._get_fewshot_examples(
self._fewshot_docs, k=num_fewshot + 1, rnd=rnd
)
fewshotex, fewshotidx = [
fewshotex, fewshotidx = zip(*[
(shot, idx)
for shot, idx in zip(fewshotex, fewshotidx)
if shot != doc
]
])
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex, fewshotidx = (
fewshotex[:num_fewshot],
Expand All @@ -966,7 +966,7 @@ def fewshot_context(
ctx,
{
"fewshot_idx": fewshotidx,
"fewshot_source": fewshotsource,
"fewshot_source": self.fewshotsource,
"fewshot_num": num_fewshot,
"ctx": ctx,
},
Expand Down

0 comments on commit a963ab8

Please sign in to comment.