-
Notifications
You must be signed in to change notification settings - Fork 168
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #43 from zou-group/vision-math
Added two multimodal experiments: MathVista and ScienceQA
- Loading branch information
Showing
4 changed files
with
823 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Results for Solution Optimization on Multimodal Tasks | ||
|
||
We are pleased to release the results of our solution optimization experiments for multimodal tasks, including [MathVista](https://mathvista.github.io/) and [ScienceQA](https://scienceqa.github.io/) tasks. This release aims to ensure transparency and provide detailed insights into our optimization processes. | ||
|
||
The files in this repository contain the solution trajectory, final solution, and the answers for each question in the datasets. Each `.json` file (can be downloaded from google drive) is structured as a dictionary where the key is the question index, and the value is another dictionary with the following possible keys: | ||
|
||
- **`question`**: The question text. | ||
- **`answer`**: The ground truth answer. | ||
- **`predictions`**: The trajectory of solutions throughout the optimization process. | ||
- **`loss_history`**: The history of loss function values over different iterations of optimization. | ||
- **`performance_history`**: The prediction score, where 1 indicates a correct answer and 0 indicates an incorrect answer. | ||
- **`result_data`**: The intermediate results of predictions. | ||
- **`ques_data`**: The metadata of the question, which is useful for further fine-grained analysis. | ||
|
||
## Notes | ||
Solution optimization in general depends heavily on the test-time objective. Depending on the test-time objective, e.g. the model can be driven to explore more, which is the approach we took. For this reason, we used majority voting to get the final prediction. There are lots of interesting questions around identifying good test-time training strategies! | ||
|
||
## Experiment on MathVista | ||
|
||
[MathVista](https://mathvista.github.io/) is a comprehensive benchmark for mathematical reasoning within visual contexts. It emphasizes the diversity and complexity of visual perception and mathematical reasoning challenges. | ||
|
||
To conduct an experiment on MathVista, use the following example command: | ||
|
||
```sh | ||
cd evaluation | ||
python solution_optimization_mm.py --task mathvista \ | ||
--engine=gpt-4o \ | ||
--eval_engine=gpt-4o \ | ||
--max_iterations 4 \ | ||
--num_threads 10 \ | ||
--majority_voting | ||
``` | ||
|
||
The resulting `mathvista_predictions.json` file can be downloaded from [here](https://drive.google.com/drive/u/0/folders/1hJ6kQozkTNiUPxtEhZCgCqnK2EFW5MHY). | ||
|
||
## Experiment on ScienceQA | ||
|
||
[ScienceQA](https://scienceqa.github.io/) (Science Question Answering) is a multimodal benchmark consisting of multiple-choice questions covering a diverse set of science topics. It challenges participants to understand scientific images, retrieve relevant knowledge, and provide accurate reasoning for high-school-level scientific questions. | ||
|
||
Running an experiment on ScienceQA, use the following example command: | ||
|
||
```sh | ||
cd evaluation | ||
|
||
python solution_optimization_mm.py --task scienceqa \ | ||
--engine=gpt-4o \ | ||
--eval_engine=gpt-4o \ | ||
--max_iterations 8 \ | ||
--num_threads 20 | ||
``` | ||
|
||
The resulting `scienceqa_predictions.json` file can be downloaded from [here](https://drive.google.com/file/d/1BkMD3CcaxAUpB-9L0aI8bLEqoo8jHqlZ/view?usp=sharing). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,285 @@ | ||
import re | ||
import os | ||
import json | ||
import argparse | ||
import concurrent | ||
from tqdm import tqdm | ||
import numpy as np | ||
from dotenv import load_dotenv | ||
load_dotenv(override=True) | ||
from collections import Counter | ||
|
||
import textgrad as tg | ||
from textgrad.tasks.multimodal import load_multimodal_instance_task | ||
|
||
def load_checkpoint_results(checkpoint_file): | ||
all_solutions = {} | ||
if checkpoint_file.endswith(".jsonl"): | ||
with open(checkpoint_file, "r") as f: | ||
for line in f: | ||
data = json.loads(line) | ||
all_solutions[data["ques_data"]["pid"]] = data | ||
elif checkpoint_file.endswith(".json"): | ||
with open(checkpoint_file, "r") as f: | ||
all_solutions = json.load(f) | ||
print("Loaded existing results from", checkpoint_file) | ||
print(f"There are {len(all_solutions)} samples in the file!") | ||
return all_solutions | ||
|
||
class MajorityVoting: | ||
def __init__(self, solutions, max_iterations): | ||
self.solutions = solutions | ||
self.max_iterations = max_iterations | ||
|
||
def majority_element(self, predictions): | ||
count = Counter(predictions) | ||
return max(count.keys(), key=count.get) | ||
|
||
def compute_accuracy(self): | ||
final_score = 0 | ||
for n in range(1, self.max_iterations + 1): | ||
correct = 0 | ||
for _, solution in self.solutions.items(): | ||
answer = solution["answer"] | ||
predictions = [res["normalized_answer"] for res in solution["result_data"]] | ||
predictions = predictions[1:n+1] | ||
majority = self.majority_element(predictions) | ||
if majority == answer: | ||
correct += 1 | ||
acc = round(correct / len(self.solutions), 3) | ||
if acc > final_score: | ||
final_score = acc | ||
return final_score | ||
|
||
def config(): | ||
parser = argparse.ArgumentParser(description="Optimize a prompt for a task.") | ||
parser.add_argument("--task", type=str, default="mathvista", choices = ["mathvista", "scienceqa"], help="The task to evaluate the model on.") | ||
parser.add_argument("--multimodal", action='store_true', help="Determine if we want to load multimodal dataset.") | ||
parser.add_argument("--task_instruction", type=str, default=None, help="The instruction for the task.") | ||
parser.add_argument("--evaluation_instruction", type=str, default=None, help="The instruction for the evaluation.") | ||
parser.add_argument("--instance_role", type=str, default=None, help="The role description of the instance.") | ||
parser.add_argument("--image_role", type=str, default=None, help="The role description of the image.") | ||
parser.add_argument("--question_role", type=str, default=None, help="The role description of the question.") | ||
parser.add_argument("--cache_root", type=str, default=None, help="The cache directory to save data.") | ||
parser.add_argument("--results_dir", type=str, default="./results/solution_optimization_mm", help="The directory to save the results.") | ||
parser.add_argument("--output_file", type=str, default="", help="The output file to save or load the results.") | ||
parser.add_argument("--save_every", type=int, default=50, help="The frequency to save the results.") | ||
parser.add_argument("--engine", type=str, default="gpt-4o", help="The API to use for inference.") | ||
parser.add_argument("--eval_engine", type=str, default="gpt-4o", help="The API to use for evaluation.") | ||
parser.add_argument("--max_iterations", type=int, default=3, help="The maximum number of iterations of test-time updates.") | ||
parser.add_argument("--num_threads", type=int, default=16, help="The number of threads to use for evaluation.") | ||
parser.add_argument("--test_num", type=int, default=-1, help="The number of test samples to evaluate., -1 means all.") | ||
parser.add_argument("--majority_voting", action='store_true', default=False, help="Determine if we want to use majority voting.") | ||
return parser.parse_args() | ||
|
||
def get_zeroshot_answer(question, mm_data=None): | ||
"""Getting the zero-shot answer from an LLM without optimizing the response at test time.""" | ||
# The system prompt is from: https://github.com/openai/simple-evals/blob/main/sampler/chat_completion_sampler.py | ||
STARTING_SYSTEM_PROMPT = ( | ||
"You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture." | ||
+ "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-07-01" | ||
) | ||
if args.question_role is not None: | ||
question_role = args.question_role | ||
else: | ||
question_role = "question to the language model" | ||
if args.image_role is not None: | ||
image_role = args.image_role | ||
else: | ||
image_role = "image associate with question to the language model" | ||
system_prompt = tg.Variable(STARTING_SYSTEM_PROMPT, requires_grad=False, role_description="system prompt to the language model") | ||
if mm_data is None: # vanilla text. | ||
model = tg.BlackboxLLM(llm_engine, system_prompt) | ||
response = model(tg.Variable(question, requires_grad=False, role_description=question_role)) | ||
else: # multi-modal (such as image). | ||
from textgrad.autograd import MultimodalLLMCall | ||
variable_mm_data = tg.Variable(mm_data, requires_grad=False, role_description=image_role) | ||
variable_question = tg.Variable(question, requires_grad=False, role_description=question_role) | ||
response = MultimodalLLMCall(args.engine)([variable_mm_data, variable_question]) | ||
return response | ||
|
||
def run_test_time_training(sample): | ||
performance_history = [] | ||
results_data = [] | ||
mm_data, question, answer, ques_data, test_time_objective, instance_eval_fn = sample # NOTE: check the sample format | ||
zero_shot_response = get_zeroshot_answer(question, mm_data) | ||
|
||
# NOTE: Set the instance role description | ||
if args.instance_role is not None: | ||
instance_role_description = args.instance_role | ||
else: | ||
if args.task == "mathvista": | ||
instance_role_description = "detailed reasoning and prediction for the image-related math question." | ||
elif args.task == "scienceqa": | ||
instance_role_description = "detailed reasoning and prediction for the image-related science question." | ||
else: | ||
instance_role_description = "precise and short answer to the image-related question." | ||
|
||
instance_var = tg.Variable(zero_shot_response.value, requires_grad=True, role_description=instance_role_description) | ||
|
||
# Evaluate the zero-shot response | ||
# NOTE: The instance_eval_fn should return a tuple of (score, result_data) if result_data is needed. | ||
eval_score = instance_eval_fn(instance_var) | ||
if isinstance(eval_score, tuple): | ||
score, result_data = eval_score | ||
else: | ||
score = eval_score | ||
result_data = None | ||
performance_history.append(int(score)) | ||
results_data.append(result_data) | ||
|
||
optimizer = tg.TextualGradientDescent(engine=llm_engine, | ||
parameters=[instance_var], | ||
constraints=["Your response should answer the question given an image."]) | ||
|
||
predictions = [] | ||
predictions.append(tg.Variable( | ||
instance_var.value, | ||
role_description=instance_var.role_description | ||
)) | ||
|
||
loss_history = [] | ||
|
||
# Start test time optimization | ||
for _ in range(args.max_iterations): | ||
optimizer.zero_grad() | ||
# Compute the test time loss | ||
test_time_loss = test_time_objective(instance_var) | ||
loss_history.append(str(test_time_loss)) | ||
test_time_loss.backward() | ||
optimizer.step() | ||
|
||
eval_score = instance_eval_fn(instance_var) | ||
# NOTE: The instance_eval_fn should return a tuple of (score, result_data) if result_data is needed. | ||
if isinstance(eval_score, tuple): | ||
score, result_data = eval_score | ||
else: | ||
score = eval_score | ||
result_data = None | ||
|
||
predictions.append(tg.Variable( | ||
instance_var.value, | ||
role_description=instance_var.role_description | ||
)) | ||
performance_history.append(int(score)) | ||
results_data.append(result_data) | ||
|
||
return question, answer, predictions, performance_history, loss_history, results_data, ques_data # NOTE: check the return format | ||
|
||
args = config() | ||
llm_engine = tg.get_engine(engine_name=args.engine) | ||
tg.set_backward_engine(llm_engine, override=True) | ||
|
||
eval_engine = tg.get_engine(engine_name=args.eval_engine) | ||
test_set = load_multimodal_instance_task(args.task, | ||
evaluation_api=eval_engine, | ||
task_instruction=args.task_instruction, # NOTE: check the instruction | ||
evaluation_instruction=args.evaluation_instruction, # NOTE: check the instruction | ||
root=args.cache_root) | ||
|
||
# Create the results directory if it does not exist | ||
os.makedirs(args.results_dir, exist_ok=True) | ||
output_file = f"{args.task}_predictions.json" if args.output_file == "" else args.output_file | ||
output_file = os.path.join(args.results_dir, output_file) | ||
output_jsonl_file = output_file.replace(".json", ".jsonl") | ||
|
||
# Read the checkpoint result file if it exists | ||
if os.path.exists(output_jsonl_file): | ||
all_solutions = load_checkpoint_results(output_jsonl_file) | ||
elif os.path.exists(output_file): | ||
all_solutions = load_checkpoint_results(output_file) | ||
else: | ||
all_solutions = {} | ||
all_history = [solution["performance_history"] for solution in all_solutions.values()] | ||
|
||
if args.num_threads > 0: | ||
# Use ThreadPoolExecutor | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_threads) as executor: | ||
futures = [] | ||
print("\nPreparing to run test-time training...") | ||
for i, sample in enumerate(test_set): | ||
image_bytes, query, answer, ques_data, test_time_objective, instance_eval_fn = sample # NOTE: check the sample format | ||
pid = ques_data["pid"] # NOTE: check the question id | ||
if pid in all_solutions: | ||
# print(f"Skipping sample {pid}") | ||
continue | ||
if args.test_num > 0 and i >= args.test_num: | ||
break | ||
future = executor.submit(run_test_time_training, sample) | ||
futures.append(future) | ||
|
||
print("\nRunning test-time training...") | ||
for _, future in enumerate(tqdm(concurrent.futures.as_completed(futures), total=len(futures), position=0)): | ||
question, answer, predictions, performance_history, loss_history, result_data, ques_data = future.result() # NOTE: check the return format | ||
pid = ques_data["pid"] # NOTE: check the question id | ||
all_solutions[pid] = { | ||
"question": question, | ||
"answer": answer, | ||
"predictions": [p.value for p in predictions], | ||
"loss_history": loss_history, | ||
"performance_history": performance_history, | ||
"result_data": result_data, | ||
"ques_data": ques_data | ||
} | ||
all_history.append(performance_history) | ||
# save result to output file in append mode | ||
with open(output_jsonl_file, "a") as f: | ||
json_line = json.dumps(all_solutions[pid]) | ||
f.write(json_line + "\n") | ||
else: | ||
# Use regular for loop | ||
if args.test_num > 0: | ||
total = min(args.test_num, len(test_set)) | ||
else: | ||
total = len(test_set) | ||
for i, sample in enumerate(tqdm(test_set, total=total, position=0)): | ||
image_bytes, query, answer, ques_data, test_time_objective, instance_eval_fn = sample # NOTE: check the sample format | ||
pid = ques_data["pid"] # NOTE: check the question id | ||
if pid in all_solutions: | ||
continue | ||
if args.test_num > 0 and i >= args.test_num: | ||
break | ||
future = run_test_time_training(sample) | ||
|
||
question, answer, predictions, performance_history, loss_history, result_data, ques_data = future # NOTE: check the return format | ||
result = { | ||
"question": question, | ||
"answer": answer, | ||
"predictions": [p.value for p in predictions], | ||
"performance_history": performance_history, | ||
"loss_history": loss_history, | ||
"result_data": result_data, | ||
"ques_data": ques_data | ||
} | ||
all_solutions[pid] = result | ||
all_history.append(performance_history) | ||
|
||
# save result to output file in append mode | ||
with open(output_jsonl_file, "a") as f: | ||
json_line = json.dumps(result) | ||
f.write(json_line + "\n") | ||
if i % args.save_every == 0: | ||
with open(output_file, "w") as f: | ||
print("Saving results to", output_file) | ||
json.dump(all_solutions, f, indent=4) | ||
|
||
# Quick summary of the results | ||
all_solutions = {k: all_solutions[k] for k in sorted(all_solutions, key=lambda x: int(x))} # sorted by int(pid) | ||
with open(output_file, "w") as f: | ||
json.dump(all_solutions, f, indent=4) | ||
|
||
print("\nSummary of the results:") | ||
score_history = np.array(all_history).mean(axis=0) | ||
score_history = [round(float(s), 3) for s in score_history] | ||
print(score_history) | ||
|
||
if args.majority_voting: | ||
print("\nMajority Voting:") | ||
majority_voting = MajorityVoting(all_solutions, args.max_iterations) | ||
final_score = majority_voting.compute_accuracy() | ||
else: | ||
final_score = round(max(np.array(all_history).mean(axis=0)), 3) | ||
|
||
print(f"Final Accuracy: {final_score}") | ||
print(f"Total samples: {len(all_solutions)}") | ||
print("Done!") |
Oops, something went wrong.