diff --git a/lm_eval/base.py b/lm_eval/base.py index 6ab40a3a6c..16707d771b 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -345,25 +345,27 @@ def _collate(x): reord = utils.Reorderer(requests, _collate) - for context, until in tqdm(reord.get_reordered()): - if isinstance(until, str): - until = [until] - max_length = None - elif isinstance(until, list) and len(until) == 2: - until, max_length = [until[0]], until[1] - elif isinstance(until, list): - max_length = None + for context, request_args in tqdm(reord.get_reordered()): + stopping_criteria = request_args["stopping_criteria"] + max_generation_length = request_args["max_generation_length"] + assert isinstance(stopping_criteria, str) or stopping_criteria is None + assert ( + isinstance(max_generation_length, int) or max_generation_length is None + ) + + until = [stopping_criteria] primary_until = self.tok_encode(until[0]) context_enc = torch.tensor( [self.tok_encode(context)[self.max_gen_toks - self.max_length :]] ).to(self.device) - if max_length is not None: - max_length = min(max_length, context_enc.shape[1] + self.max_gen_toks) - else: + if max_generation_length is None: max_length = context_enc.shape[1] + self.max_gen_toks - + else: + max_length = min( + max_generation_length, context_enc.shape[1] + self.max_gen_toks + ) cont = self._model_generate( context_enc, max_length, @@ -720,9 +722,11 @@ def construct_requests(self, doc, ctx): else: # If not, then this is a generation prompt. # NOTE: In the future, target will be a list of strings. - cont_request = rf.greedy_until( - ctx, [self.stopping_criteria(), self.max_generation_length()] - ) + request_args = { + "stopping_criteria": self.stopping_criteria(), + "max_generation_length": self.max_generation_length(), + } + cont_request = rf.greedy_until(ctx, request_args) _requests.append(cont_request) return _requests