Skip to content

Commit

Permalink
kwarg change
Browse files Browse the repository at this point in the history
  • Loading branch information
i-gao committed Oct 9, 2023
1 parent 3f1d66b commit b51caf9
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions open_flamingo/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def main():
eval_model=eval_model,
results=results,
eval_fn=evaluate_vqa,
max_generation_length=10,
max_new_tokens=10,
)

if args.eval_imagenet:
Expand Down Expand Up @@ -631,8 +631,8 @@ def evaluate_captioning(
args: argparse.Namespace,
eval_model: BaseEvalModel,
seed: int = 42,
min_generation_length: int = 0,
max_generation_length: int = 20,
min_new_tokens: int = 0,
max_new_tokens: int = 20,
num_beams: int = 3,
length_penalty: float = 0.0,
num_shots: int = 8,
Expand All @@ -645,7 +645,7 @@ def evaluate_captioning(
args (argparse.Namespace): arguments
eval_model (BaseEvalModel): model to evaluate
seed (int, optional): seed for random number generator. Defaults to 42.
max_generation_length (int, optional): maximum length of the generated caption. Defaults to 20.
max_new_tokens (int, optional): maximum length of the generated caption. Defaults to 20.
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
num_shots (int, optional): number of in-context samples to use. Defaults to 8.
Expand Down Expand Up @@ -747,8 +747,8 @@ def evaluate_captioning(
outputs = eval_model.get_outputs(
batch_images=batch_images,
batch_text=batch_text,
min_generation_length=min_generation_length,
max_generation_length=max_generation_length,
min_new_tokens=min_new_tokens,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
length_penalty=length_penalty,
)
Expand Down Expand Up @@ -808,8 +808,8 @@ def evaluate_vqa(
args: argparse.Namespace,
eval_model: BaseEvalModel,
seed: int = 42,
min_generation_length: int = 0,
max_generation_length: int = 5,
min_new_tokens: int = 0,
max_new_tokens: int = 5,
num_beams: int = 3,
length_penalty: float = 0.0,
num_shots: int = 8,
Expand All @@ -823,7 +823,7 @@ def evaluate_vqa(
args (argparse.Namespace): arguments
eval_model (BaseEvalModel): model to evaluate
seed (int, optional): random seed. Defaults to 42.
max_generation_length (int, optional): max generation length. Defaults to 5.
max_new_tokens (int, optional): max generation length. Defaults to 5.
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
num_shots (int, optional): number of shots to use. Defaults to 8.
Expand Down Expand Up @@ -948,8 +948,8 @@ def evaluate_vqa(
outputs = eval_model.get_outputs(
batch_images=batch_images,
batch_text=batch_text,
min_generation_length=min_generation_length,
max_generation_length=max_generation_length,
min_new_tokens=min_new_tokens,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
length_penalty=length_penalty,
)
Expand Down

0 comments on commit b51caf9

Please sign in to comment.