From b51caf94d24f0ab1492e4b9608c733d51c578121 Mon Sep 17 00:00:00 2001 From: i-gao Date: Mon, 9 Oct 2023 09:37:21 -0700 Subject: [PATCH] kwarg change --- open_flamingo/eval/evaluate.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/open_flamingo/eval/evaluate.py b/open_flamingo/eval/evaluate.py index 4b0b9ce0..8f3632ec 100644 --- a/open_flamingo/eval/evaluate.py +++ b/open_flamingo/eval/evaluate.py @@ -574,7 +574,7 @@ def main(): eval_model=eval_model, results=results, eval_fn=evaluate_vqa, - max_generation_length=10, + max_new_tokens=10, ) if args.eval_imagenet: @@ -631,8 +631,8 @@ def evaluate_captioning( args: argparse.Namespace, eval_model: BaseEvalModel, seed: int = 42, - min_generation_length: int = 0, - max_generation_length: int = 20, + min_new_tokens: int = 0, + max_new_tokens: int = 20, num_beams: int = 3, length_penalty: float = 0.0, num_shots: int = 8, @@ -645,7 +645,7 @@ def evaluate_captioning( args (argparse.Namespace): arguments eval_model (BaseEvalModel): model to evaluate seed (int, optional): seed for random number generator. Defaults to 42. - max_generation_length (int, optional): maximum length of the generated caption. Defaults to 20. + max_new_tokens (int, optional): maximum length of the generated caption. Defaults to 20. num_beams (int, optional): number of beams to use for beam search. Defaults to 3. length_penalty (float, optional): length penalty for beam search. Defaults to -2.0. num_shots (int, optional): number of in-context samples to use. Defaults to 8. @@ -747,8 +747,8 @@ def evaluate_captioning( outputs = eval_model.get_outputs( batch_images=batch_images, batch_text=batch_text, - min_generation_length=min_generation_length, - max_generation_length=max_generation_length, + min_new_tokens=min_new_tokens, + max_new_tokens=max_new_tokens, num_beams=num_beams, length_penalty=length_penalty, ) @@ -808,8 +808,8 @@ def evaluate_vqa( args: argparse.Namespace, eval_model: BaseEvalModel, seed: int = 42, - min_generation_length: int = 0, - max_generation_length: int = 5, + min_new_tokens: int = 0, + max_new_tokens: int = 5, num_beams: int = 3, length_penalty: float = 0.0, num_shots: int = 8, @@ -823,7 +823,7 @@ def evaluate_vqa( args (argparse.Namespace): arguments eval_model (BaseEvalModel): model to evaluate seed (int, optional): random seed. Defaults to 42. - max_generation_length (int, optional): max generation length. Defaults to 5. + max_new_tokens (int, optional): max generation length. Defaults to 5. num_beams (int, optional): number of beams to use for beam search. Defaults to 3. length_penalty (float, optional): length penalty for beam search. Defaults to -2.0. num_shots (int, optional): number of shots to use. Defaults to 8. @@ -948,8 +948,8 @@ def evaluate_vqa( outputs = eval_model.get_outputs( batch_images=batch_images, batch_text=batch_text, - min_generation_length=min_generation_length, - max_generation_length=max_generation_length, + min_new_tokens=min_new_tokens, + max_new_tokens=max_new_tokens, num_beams=num_beams, length_penalty=length_penalty, )