Skip to content

Commit

Permalink
add majority vote to eval script
Browse files Browse the repository at this point in the history
  • Loading branch information
kl2806 committed Jan 8, 2025
1 parent 7b490f5 commit ab0d724
Showing 1 changed file with 31 additions and 13 deletions.
44 changes: 31 additions & 13 deletions evaluate_gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,25 @@
import re
import statistics

def standardize_answer(answer: str):
answer = answer.replace("$", "")
answer = answer.replace(",", "")
answer = answer.strip(".")
return answer


def evaluate(input_file: str):
correct = 0
any_correct_count = 0
total = 0
example_num_rethinks = []
with jsonlines.open(input_file) as reader:
for obj in reader:
ignore_regex = '(?s).*#### '
answer = re.sub(ignore_regex, '', obj['answer'])

answer = answer.replace("$", "")
answer = answer.replace(",", "")
answer = answer.strip(".")

answer = standardize_answer(answer)

final_answer = ""

# do majority voting over obj['responses']


if 'responses' not in obj:
# then just do 'response'
for message in obj['response']['messages']:
Expand All @@ -54,9 +53,7 @@ def evaluate(input_file: str):
if matches == []:
continue
final_answer = "".join(matches[-1])
final_answer = final_answer.replace("$", "")
final_answer = final_answer.replace(",", "")
final_answer = final_answer.strip(".")
final_answer = standardize_answer(final_answer)

if final_answer in votes:
votes[final_answer] += 1
Expand All @@ -73,6 +70,26 @@ def evaluate(input_file: str):
print("\n\n")
total += 1

# see if any of the answers match the final answer
any_correct = False
for response in obj['responses']:
for message in response['messages']:
if message['message_type'] == "function_call":
if message['function_call']['name'] == "send_message":
arguments = json.loads(message['function_call']['arguments'])
response_answer = arguments['message']
# do the same sanitization as above
regex_str = "(-?[$0-9.,]{2,})|(-?[0-9]+)"
matches = re.findall(regex_str, response_answer)
if matches == []:
continue
response_answer = "".join(matches[-1])
response_answer = standardize_answer(response_answer)
if response_answer == answer:
any_correct = True
if any_correct:
any_correct_count += 1

num_rethinks = 0
if 'offline_responses' in obj and len(obj['offline_responses']) > 0:
for message in obj['offline_responses'][0]['messages']:
Expand All @@ -81,7 +98,8 @@ def evaluate(input_file: str):
num_rethinks += 1
example_num_rethinks.append(num_rethinks)

print("Accuracy: ", correct / total)
print("Accuracy (any correct): ", any_correct_count / total)
print("Accuracy (final): ", correct / total)
with jsonlines.open(input_file) as reader:
lines = list(reader)
print("avg. num rethinks: ", statistics.mean(example_num_rethinks))
Expand Down

0 comments on commit ab0d724

Please sign in to comment.