Skip to content

Commit

Permalink
Merge pull request #2 from cjlovering/cjlovering/request_args
Browse files Browse the repository at this point in the history
Updated the requests so that its easier to understand.
  • Loading branch information
jon-tow authored Apr 27, 2022
2 parents a3a9a7c + 21d897d commit 5b0d95a
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions lm_eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,25 +345,27 @@ def _collate(x):

reord = utils.Reorderer(requests, _collate)

for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str):
until = [until]
max_length = None
elif isinstance(until, list) and len(until) == 2:
until, max_length = [until[0]], until[1]
elif isinstance(until, list):
max_length = None
for context, request_args in tqdm(reord.get_reordered()):
stopping_criteria = request_args["stopping_criteria"]
max_generation_length = request_args["max_generation_length"]

assert isinstance(stopping_criteria, str) or stopping_criteria is None
assert (
isinstance(max_generation_length, int) or max_generation_length is None
)

until = [stopping_criteria]
primary_until = self.tok_encode(until[0])
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)

if max_length is not None:
max_length = min(max_length, context_enc.shape[1] + self.max_gen_toks)
else:
if max_generation_length is None:
max_length = context_enc.shape[1] + self.max_gen_toks

else:
max_length = min(
max_generation_length, context_enc.shape[1] + self.max_gen_toks
)
cont = self._model_generate(
context_enc,
max_length,
Expand Down Expand Up @@ -720,9 +722,11 @@ def construct_requests(self, doc, ctx):
else:
# If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
cont_request = rf.greedy_until(
ctx, [self.stopping_criteria(), self.max_generation_length()]
)
request_args = {
"stopping_criteria": self.stopping_criteria(),
"max_generation_length": self.max_generation_length(),
}
cont_request = rf.greedy_until(ctx, request_args)
_requests.append(cont_request)

return _requests
Expand Down

0 comments on commit 5b0d95a

Please sign in to comment.