From 358cecc281bd744684b405a23b0b704d99f2eefd Mon Sep 17 00:00:00 2001 From: Oscar Lo Date: Thu, 2 May 2024 02:46:40 -0700 Subject: [PATCH] added gqa as eval dataset --- open_flamingo/eval/eval_datasets.py | 3 +- open_flamingo/eval/eval_models/blip.py | 3 + .../eval/eval_models/open_flamingo.py | 3 + open_flamingo/eval/evaluate.py | 91 +++++++++++++++++++ 4 files changed, 99 insertions(+), 1 deletion(-) diff --git a/open_flamingo/eval/eval_datasets.py b/open_flamingo/eval/eval_datasets.py index df50af6a..23d4ae1d 100644 --- a/open_flamingo/eval/eval_datasets.py +++ b/open_flamingo/eval/eval_datasets.py @@ -14,6 +14,7 @@ "okvqa", "vizwiz", "textvqa", + "gqa", "hateful_memes", "imagenet", ] @@ -104,7 +105,7 @@ def get_img_path(self, question): ) elif self.dataset_name == "vizwiz": return os.path.join(self.image_dir_path, question["image_id"]) - elif self.dataset_name == "textvqa": + elif self.dataset_name == "textvqa" or self.dataset_name == "gqa": return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg") else: raise Exception(f"Unknown VQA dataset {self.dataset_name}") diff --git a/open_flamingo/eval/eval_models/blip.py b/open_flamingo/eval/eval_models/blip.py index 87f08036..725b0470 100644 --- a/open_flamingo/eval/eval_models/blip.py +++ b/open_flamingo/eval/eval_models/blip.py @@ -108,6 +108,9 @@ def get_vizwiz_prompt(self, question, answer=None) -> str: def get_textvqa_prompt(self, question, answer=None) -> str: return f"Question:{question} Short answer:{answer if answer is not None else ''}" + + def get_gqa_prompt(self, question, answer=None) -> str: + return f"Question:{question} Short answer:{answer if answer is not None else ''}" def get_coco_prompt(self, caption=None) -> str: return f"A photo of {caption if caption is not None else ''}" diff --git a/open_flamingo/eval/eval_models/open_flamingo.py b/open_flamingo/eval/eval_models/open_flamingo.py index 0a25198c..d73417ab 100644 --- a/open_flamingo/eval/eval_models/open_flamingo.py +++ b/open_flamingo/eval/eval_models/open_flamingo.py @@ -287,6 +287,9 @@ def get_vizwiz_prompt(self, question, answer=None) -> str: def get_textvqa_prompt(self, question, answer=None) -> str: return f"Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" + + def get_gqa_prompt(self, question, answer=None) -> str: + return f"Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" def get_coco_prompt(self, caption=None) -> str: return f"Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}" diff --git a/open_flamingo/eval/evaluate.py b/open_flamingo/eval/evaluate.py index 4a25fca0..9a2c4e47 100644 --- a/open_flamingo/eval/evaluate.py +++ b/open_flamingo/eval/evaluate.py @@ -139,6 +139,14 @@ default=False, help="Whether to evaluate on TextVQA.", ) + +parser.add_argument( + "--eval_gqa", + action="store_true", + default=False, + help="Whether to evaluate on GQA.", +) + parser.add_argument( "--eval_imagenet", action="store_true", @@ -346,6 +354,44 @@ default=None, ) +# GQA Dataset +parser.add_argument( + "--gqa_train_image_dir_path", + type=str, + help="Path to the gqa train images directory.", + default=None, +) +parser.add_argument( + "--gqa_train_questions_json_path", + type=str, + help="Path to the gqa questions json file.", + default=None, +) +parser.add_argument( + "--gqa_train_annotations_json_path", + type=str, + help="Path to the gqa annotations json file", + default=None, +) +parser.add_argument( + "--gqa_test_image_dir_path", + type=str, + help="Path to the gqa test images directory.", + default=None, +) +parser.add_argument( + "--gqa_test_questions_json_path", + type=str, + help="Path to the gqa questions json file", + default=None, +) +parser.add_argument( + "--gqa_test_annotations_json_path", + type=str, + help="Path to the gqa annotations json file", + default=None, +) + ## Imagenet dataset parser.add_argument("--imagenet_root", type=str, default="/tmp") @@ -650,6 +696,44 @@ def main(): "stddev": np.nanstd(scores), } ) + + if args.eval_gqa: + print("Evaluating on GQA...") + + #load cached demonstration features on GQA + if args.cached_demonstration_features is not None: + cached_features = torch.load( + f"{args.cached_demonstration_features}/imagenet.pkl", map_location="cpu" + ) + else: + cached_features = None + + for shot in args.shots: + scores = [] + for seed, trial in zip(args.trial_seeds, range(args.num_trials)): + gqa_score = evaluate_vqa( + args=args, + eval_model=eval_model, + num_shots=shot, + seed=seed, + dataset_name="gqa", + max_new_tokens=10, + cached_features=cached_features, + ) + if args.rank == 0: + print(f"Shots {shot} Trial {trial} GQA score: {gqa_score}") + scores.append(gqa_score) + + if args.rank == 0: + print(f"Shots {shot} Mean GQA score: {np.nanmean(scores)}") + results["gqa"].append( + { + "shots": shot, + "trials": scores, + "mean": np.nanmean(scores), + "stddev": np.nanstd(scores), + } + ) if args.eval_imagenet: print("Evaluating on ImageNet...") @@ -968,6 +1052,13 @@ def evaluate_vqa( test_image_dir_path = args.textvqa_image_dir_path test_questions_json_path = args.textvqa_test_questions_json_path test_annotations_json_path = args.textvqa_test_annotations_json_path + elif dataset_name == "gqa": + train_image_dir_path = args.gqa_train_image_dir_path + train_questions_json_path = args.gqa_train_questions_json_path + train_annotations_json_path = args.gqa_train_annotations_json_path + test_image_dir_path = args.gqa_test_image_dir_path + test_questions_json_path = args.gqa_test_questions_json_path + test_annotations_json_path = args.gqa_test_annotations_json_path else: raise ValueError(f"Unsupported dataset: {dataset_name}")