diff --git a/evaluate_gsm8k.py b/evaluate_gsm8k.py index 0bf5611693..17f15c4ae6 100644 --- a/evaluate_gsm8k.py +++ b/evaluate_gsm8k.py @@ -99,7 +99,7 @@ def evaluate(input_file: str): example_num_rethinks.append(num_rethinks) print("Accuracy (any correct): ", any_correct_count / total) - print("Accuracy (final): ", correct / total) + print("Accuracy (majority vote): ", correct / total) with jsonlines.open(input_file) as reader: lines = list(reader) print("avg. num rethinks: ", statistics.mean(example_num_rethinks))