Skip to content

Commit

Permalink
added gqa as eval dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Oscar Lo committed May 2, 2024
1 parent a5378a8 commit 358cecc
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 1 deletion.
3 changes: 2 additions & 1 deletion open_flamingo/eval/eval_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"okvqa",
"vizwiz",
"textvqa",
"gqa",
"hateful_memes",
"imagenet",
]
Expand Down Expand Up @@ -104,7 +105,7 @@ def get_img_path(self, question):
)
elif self.dataset_name == "vizwiz":
return os.path.join(self.image_dir_path, question["image_id"])
elif self.dataset_name == "textvqa":
elif self.dataset_name == "textvqa" or self.dataset_name == "gqa":
return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg")
else:
raise Exception(f"Unknown VQA dataset {self.dataset_name}")
Expand Down
3 changes: 3 additions & 0 deletions open_flamingo/eval/eval_models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def get_vizwiz_prompt(self, question, answer=None) -> str:

def get_textvqa_prompt(self, question, answer=None) -> str:
return f"Question:{question} Short answer:{answer if answer is not None else ''}"

def get_gqa_prompt(self, question, answer=None) -> str:
return f"Question:{question} Short answer:{answer if answer is not None else ''}"

def get_coco_prompt(self, caption=None) -> str:
return f"A photo of {caption if caption is not None else ''}"
Expand Down
3 changes: 3 additions & 0 deletions open_flamingo/eval/eval_models/open_flamingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ def get_vizwiz_prompt(self, question, answer=None) -> str:

def get_textvqa_prompt(self, question, answer=None) -> str:
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"

def get_gqa_prompt(self, question, answer=None) -> str:
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"

def get_coco_prompt(self, caption=None) -> str:
return f"<image>Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"
Expand Down
91 changes: 91 additions & 0 deletions open_flamingo/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@
default=False,
help="Whether to evaluate on TextVQA.",
)

parser.add_argument(
"--eval_gqa",
action="store_true",
default=False,
help="Whether to evaluate on GQA.",
)

parser.add_argument(
"--eval_imagenet",
action="store_true",
Expand Down Expand Up @@ -346,6 +354,44 @@
default=None,
)

# GQA Dataset
parser.add_argument(
"--gqa_train_image_dir_path",
type=str,
help="Path to the gqa train images directory.",
default=None,
)
parser.add_argument(
"--gqa_train_questions_json_path",
type=str,
help="Path to the gqa questions json file.",
default=None,
)
parser.add_argument(
"--gqa_train_annotations_json_path",
type=str,
help="Path to the gqa annotations json file",
default=None,
)
parser.add_argument(
"--gqa_test_image_dir_path",
type=str,
help="Path to the gqa test images directory.",
default=None,
)
parser.add_argument(
"--gqa_test_questions_json_path",
type=str,
help="Path to the gqa questions json file",
default=None,
)
parser.add_argument(
"--gqa_test_annotations_json_path",
type=str,
help="Path to the gqa annotations json file",
default=None,
)

## Imagenet dataset
parser.add_argument("--imagenet_root", type=str, default="/tmp")

Expand Down Expand Up @@ -650,6 +696,44 @@ def main():
"stddev": np.nanstd(scores),
}
)

if args.eval_gqa:
print("Evaluating on GQA...")

#load cached demonstration features on GQA
if args.cached_demonstration_features is not None:
cached_features = torch.load(
f"{args.cached_demonstration_features}/imagenet.pkl", map_location="cpu"

This comment has been minimized.

Copy link
@anas-awadalla

anas-awadalla May 2, 2024

Collaborator

This is still imagenet .pkl lets change to vqa

)
else:
cached_features = None

for shot in args.shots:
scores = []
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
gqa_score = evaluate_vqa(
args=args,
eval_model=eval_model,
num_shots=shot,
seed=seed,
dataset_name="gqa",
max_new_tokens=10,
cached_features=cached_features,
)
if args.rank == 0:
print(f"Shots {shot} Trial {trial} GQA score: {gqa_score}")
scores.append(gqa_score)

if args.rank == 0:
print(f"Shots {shot} Mean GQA score: {np.nanmean(scores)}")
results["gqa"].append(
{
"shots": shot,
"trials": scores,
"mean": np.nanmean(scores),
"stddev": np.nanstd(scores),
}
)

if args.eval_imagenet:
print("Evaluating on ImageNet...")
Expand Down Expand Up @@ -968,6 +1052,13 @@ def evaluate_vqa(
test_image_dir_path = args.textvqa_image_dir_path
test_questions_json_path = args.textvqa_test_questions_json_path
test_annotations_json_path = args.textvqa_test_annotations_json_path
elif dataset_name == "gqa":
train_image_dir_path = args.gqa_train_image_dir_path
train_questions_json_path = args.gqa_train_questions_json_path
train_annotations_json_path = args.gqa_train_annotations_json_path
test_image_dir_path = args.gqa_test_image_dir_path
test_questions_json_path = args.gqa_test_questions_json_path
test_annotations_json_path = args.gqa_test_annotations_json_path
else:
raise ValueError(f"Unsupported dataset: {dataset_name}")

Expand Down

0 comments on commit 358cecc

Please sign in to comment.