From ab0d724a135cc527f35f93107b5cc7ee17c4cbd3 Mon Sep 17 00:00:00 2001 From: Kevin Lin Date: Wed, 8 Jan 2025 14:54:55 -0800 Subject: [PATCH] add majority vote to eval script --- evaluate_gsm8k.py | 44 +++++++++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/evaluate_gsm8k.py b/evaluate_gsm8k.py index 8ad9a27eae..0bf5611693 100644 --- a/evaluate_gsm8k.py +++ b/evaluate_gsm8k.py @@ -11,26 +11,25 @@ import re import statistics +def standardize_answer(answer: str): + answer = answer.replace("$", "") + answer = answer.replace(",", "") + answer = answer.strip(".") + return answer + def evaluate(input_file: str): correct = 0 + any_correct_count = 0 total = 0 example_num_rethinks = [] with jsonlines.open(input_file) as reader: for obj in reader: ignore_regex = '(?s).*#### ' answer = re.sub(ignore_regex, '', obj['answer']) - - answer = answer.replace("$", "") - answer = answer.replace(",", "") - answer = answer.strip(".") - + answer = standardize_answer(answer) final_answer = "" - - # do majority voting over obj['responses'] - - if 'responses' not in obj: # then just do 'response' for message in obj['response']['messages']: @@ -54,9 +53,7 @@ def evaluate(input_file: str): if matches == []: continue final_answer = "".join(matches[-1]) - final_answer = final_answer.replace("$", "") - final_answer = final_answer.replace(",", "") - final_answer = final_answer.strip(".") + final_answer = standardize_answer(final_answer) if final_answer in votes: votes[final_answer] += 1 @@ -73,6 +70,26 @@ def evaluate(input_file: str): print("\n\n") total += 1 + # see if any of the answers match the final answer + any_correct = False + for response in obj['responses']: + for message in response['messages']: + if message['message_type'] == "function_call": + if message['function_call']['name'] == "send_message": + arguments = json.loads(message['function_call']['arguments']) + response_answer = arguments['message'] + # do the same sanitization as above + regex_str = "(-?[$0-9.,]{2,})|(-?[0-9]+)" + matches = re.findall(regex_str, response_answer) + if matches == []: + continue + response_answer = "".join(matches[-1]) + response_answer = standardize_answer(response_answer) + if response_answer == answer: + any_correct = True + if any_correct: + any_correct_count += 1 + num_rethinks = 0 if 'offline_responses' in obj and len(obj['offline_responses']) > 0: for message in obj['offline_responses'][0]['messages']: @@ -81,7 +98,8 @@ def evaluate(input_file: str): num_rethinks += 1 example_num_rethinks.append(num_rethinks) - print("Accuracy: ", correct / total) + print("Accuracy (any correct): ", any_correct_count / total) + print("Accuracy (final): ", correct / total) with jsonlines.open(input_file) as reader: lines = list(reader) print("avg. num rethinks: ", statistics.mean(example_num_rethinks))