diff --git a/open_flamingo/eval/eval_datasets.py b/open_flamingo/eval/eval_datasets.py index 23d4ae1d..bbbca150 100644 --- a/open_flamingo/eval/eval_datasets.py +++ b/open_flamingo/eval/eval_datasets.py @@ -15,6 +15,7 @@ "vizwiz", "textvqa", "gqa", + "mantiseval", "hateful_memes", "imagenet", ] @@ -107,14 +108,26 @@ def get_img_path(self, question): return os.path.join(self.image_dir_path, question["image_id"]) elif self.dataset_name == "textvqa" or self.dataset_name == "gqa": return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg") + elif self.dataset_name == "mantiseval": + img_paths = [] + for img_id in question['image_id']: + img_paths.append(os.path.join(self.image_dir_path, f"{img_id}.jpg")) + return img_paths else: raise Exception(f"Unknown VQA dataset {self.dataset_name}") def __getitem__(self, idx): question = self.questions[idx] img_path = self.get_img_path(question) - image = Image.open(img_path) - image.load() + if self.dataset_name == "mantiseval": + image = [] + for path in img_path: + img = Image.open(path) + img.load() + image.append(img) + else: + image = Image.open(img_path) + image.load() results = { "image": image, "question": question["question"], diff --git a/open_flamingo/eval/eval_models/blip.py b/open_flamingo/eval/eval_models/blip.py index 725b0470..a5c1bf76 100644 --- a/open_flamingo/eval/eval_models/blip.py +++ b/open_flamingo/eval/eval_models/blip.py @@ -5,7 +5,7 @@ from transformers import Blip2Processor, Blip2ForConditionalGeneration from eval_models.eval_model import BaseEvalModel -from utils import unwrap_model +from utils import unwrap_model, combine_images from transformers.modeling_outputs import CausalLMOutputWithPast @@ -27,9 +27,14 @@ def required_args(self): def prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor: batch_images = None + for i in range(len(batch)): + if len(batch[i]) > 1: + batch[i] = combine_images(batch[i]) + """ assert all( len(example) == 1 for example in batch ), "BLIP-2 only supports one image per example" + """ for example in batch: if batch_images is None: batch_images = self.processor.image_processor( @@ -111,6 +116,9 @@ def get_textvqa_prompt(self, question, answer=None) -> str: def get_gqa_prompt(self, question, answer=None) -> str: return f"Question:{question} Short answer:{answer if answer is not None else ''}" + + def get_mantiseval_prompt(self, question, answer=None) -> str: + return f"Question:{question} Short answer:{answer if answer is not None else ''}" def get_coco_prompt(self, caption=None) -> str: return f"A photo of {caption if caption is not None else ''}" diff --git a/open_flamingo/eval/eval_models/open_flamingo.py b/open_flamingo/eval/eval_models/open_flamingo.py index d73417ab..98165529 100644 --- a/open_flamingo/eval/eval_models/open_flamingo.py +++ b/open_flamingo/eval/eval_models/open_flamingo.py @@ -291,6 +291,9 @@ def get_textvqa_prompt(self, question, answer=None) -> str: def get_gqa_prompt(self, question, answer=None) -> str: return f"Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" + def get_mantiseval_prompt(self, question, answer=None) -> str: + return f"Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" + def get_coco_prompt(self, caption=None) -> str: return f"Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}" diff --git a/open_flamingo/eval/evaluate.py b/open_flamingo/eval/evaluate.py index 4935d3e2..a1c73955 100644 --- a/open_flamingo/eval/evaluate.py +++ b/open_flamingo/eval/evaluate.py @@ -34,7 +34,7 @@ HatefulMemesDataset, ) from ok_vqa_utils import postprocess_ok_vqa_generation -from vqa_metric import compute_vqa_accuracy, postprocess_vqa_generation +from vqa_metric import compute_vqa_accuracy, postprocess_vqa_generation, compute_mantis_accuracy parser = argparse.ArgumentParser() parser.add_argument( @@ -152,29 +152,30 @@ default=None, ) -## VQAV2, OK-VQA, VizWiz, TextVQA, GQA Datasets -for task in ['vqav2', 'okvqa', 'vizwiz', 'textvqa', 'gqa']: +## VQAV2, OK-VQA, VizWiz, TextVQA, GQA, Mantis-Eval Datasets +for task in ['vqav2', 'okvqa', 'vizwiz', 'textvqa', 'gqa', 'mantiseval']: parser.add_argument( - f"--{task}_image_dir_path" if task=='gqa' or task=='textvqa' else f"--{task}_train_image_dir_path", + f"--{task}_image_dir_path" if task=='gqa' or task=='textvqa' or task=='mantiseval' else f"--{task}_train_image_dir_path", type=str, default=None, ) - if task!='gqa' and task!='textvqa': + if task != 'mantiseval': + if task!='gqa' and task!='textvqa': + parser.add_argument( + f"--{task}_test_image_dir_path", + type=str, + default=None, + ) parser.add_argument( - f"--{task}_test_image_dir_path", + f"--{task}_train_questions_json_path", + type=str, + default=None, + ) + parser.add_argument( + f"--{task}_train_annotations_json_path", type=str, default=None, ) - parser.add_argument( - f"--{task}_train_questions_json_path", - type=str, - default=None, - ) - parser.add_argument( - f"--{task}_train_annotations_json_path", - type=str, - default=None, - ) parser.add_argument( f"--{task}_test_questions_json_path", type=str, @@ -315,7 +316,7 @@ def main(): } ) - for vqa_task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa"]: + for vqa_task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa", "mantiseval"]: if var_args[f"eval_{vqa_task}"]: print(f"Evaluating on {vqa_task}...") @@ -601,16 +602,16 @@ def evaluate_vqa( float: accuracy score """ var_args = vars(args) - for task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa"]: + for task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa", "mantiseval"]: if dataset_name == task: task = task - train_image_dir_path = var_args[f"{task}_train_image_dir_path" if task!="textvqa" and task!="gqa" else f"{task}_image_dir_path"] - train_questions_json_path = var_args[f"{task}_train_questions_json_path"] - train_annotations_json_path = var_args[f"{task}_train_annotations_json_path"] - test_image_dir_path = var_args[f"{task}_test_image_dir_path" if task!="textvqa" and task!="gqa" else f"{task}_image_dir_path"] + train_image_dir_path = var_args[f"{task}_train_image_dir_path" if task!="textvqa" and task!="gqa" and task!="mantiseval" else f"{task}_image_dir_path"] + train_questions_json_path = var_args[f"{task}_train_questions_json_path"] if task!="mantiseval" else var_args[f"{task}_test_questions_json_path"] + train_annotations_json_path = var_args[f"{task}_train_annotations_json_path"] if task!="mantiseval" else var_args[f"{task}_test_annotations_json_path"] + test_image_dir_path = var_args[f"{task}_test_image_dir_path" if task!="textvqa" and task!="gqa" and task!="mantiseval" else f"{task}_image_dir_path"] test_questions_json_path = var_args[f"{task}_test_questions_json_path"] test_annotations_json_path = var_args[f"{task}_test_annotations_json_path"] - if dataset_name not in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa"]: + if dataset_name not in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa", "mantiseval"]: raise ValueError(f"Unsupported dataset: {dataset_name}") train_dataset = VQADataset( @@ -675,7 +676,10 @@ def evaluate_vqa( context_images = [x["image"] for x in batch_demo_samples[i]] else: context_images = [] - batch_images.append(context_images + [batch["image"][i]]) + if dataset_name == "mantiseval": + batch_images.append(context_images + batch["image"][i]) + else: + batch_images.append(context_images + [batch["image"][i]]) context_text = "".join( [ @@ -703,7 +707,7 @@ def evaluate_vqa( num_beams=num_beams, length_penalty=length_penalty, ) - + process_function = ( postprocess_ok_vqa_generation if dataset_name == "okvqa" @@ -732,11 +736,17 @@ def evaluate_vqa( f.write(json.dumps(all_predictions, indent=4)) if test_annotations_json_path is not None: - acc = compute_vqa_accuracy( - f"{dataset_name}results_{random_uuid}.json", - test_questions_json_path, - test_annotations_json_path, - ) + if dataset_name == "mantiseval": + acc = compute_mantis_accuracy( + f"{dataset_name}results_{random_uuid}.json", + test_annotations_json_path, + ) + else: + acc = compute_vqa_accuracy( + f"{dataset_name}results_{random_uuid}.json", + test_questions_json_path, + test_annotations_json_path, + ) # delete the temporary file os.remove(f"{dataset_name}results_{random_uuid}.json") diff --git a/open_flamingo/eval/utils.py b/open_flamingo/eval/utils.py index 6aa2052a..03473461 100644 --- a/open_flamingo/eval/utils.py +++ b/open_flamingo/eval/utils.py @@ -3,6 +3,7 @@ import random import torch.nn as nn from contextlib import suppress +from PIL import Image def random_seed(seed=42, rank=0): @@ -122,3 +123,25 @@ def get_autocast(precision): return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) else: return suppress + +def combine_images(images): + img_heights, _ = zip(*(img.size for img in images)) + avg_height = sum(img_heights) // len(img_heights) + for i, img in enumerate(images): + images[i] = img.resize((int(img.size[0] * avg_height / img.size[1]), avg_height)) + resized_heights, resized_widths = zip(*(img.size for img in images)) + total_width = sum(resized_widths) + max_height = max(resized_heights) + new_img = Image.new("RGB", (total_width + 10 * (len(images) - 1), max_height)) + x_offset = 0 + for i, img in enumerate(images): + if i > 0: + new_img.paste(Image.new("RGB", (1, max_height), (0, 0, 0)), (x_offset, 0)) + x_offset += 1 + new_img.paste(Image.new("RGB", (8, max_height), (255, 255, 255)), (x_offset, 0)) + x_offset += 8 + new_img.paste(Image.new("RGB", (1, max_height), (0, 0, 0)), (x_offset, 0)) + x_offset += 1 + new_img.paste(img, (x_offset, 0)) + x_offset += img.size[0] + return new_img \ No newline at end of file diff --git a/open_flamingo/eval/vqa_metric.py b/open_flamingo/eval/vqa_metric.py index 3659c556..7168d669 100644 --- a/open_flamingo/eval/vqa_metric.py +++ b/open_flamingo/eval/vqa_metric.py @@ -553,6 +553,26 @@ def compute_vqa_accuracy(result_json_path, question_json_path, annotation_json_p return vqaEval.accuracy["overall"] +def compute_mantis_accuracy(result_json_path, annotation_json_path): + dataset = json.load(open(annotation_json_path, "r")) + gt_ans = {} + for ann in dataset["annotations"]: + gt_ans[ann["question_id"]] = {"answer": ann["answers"][0]["answer"], "type": ann["question_type"]} + results = json.load(open(result_json_path, "r")) + assert type(results) == list, "results is not an array of objects" + correct = 0 + for res in results: + res_ans = res["answer"].lower().strip('()\n ') + if gt_ans[res["question_id"]]["type"] == "multi-choice": + if len(res_ans) > 1: + for c in res_ans: + if c.isalpha(): + res_ans = c + break + if res_ans == gt_ans[res["question_id"]]["answer"].lower().strip('()\n '): + correct+=1 + acc = correct / len(results) + return acc def postprocess_vqa_generation(predictions): answer = re.split("Question|Answer|Short", predictions, 1)[0]