diff --git a/langchain_benchmarks/tool_usage/evaluators.py b/langchain_benchmarks/tool_usage/evaluators.py index 9d6e89e..d29c3ff 100644 --- a/langchain_benchmarks/tool_usage/evaluators.py +++ b/langchain_benchmarks/tool_usage/evaluators.py @@ -134,11 +134,17 @@ def compare_outputs( if "output" in run_outputs and qa_evaluator: output = run_outputs["output"] with collect_runs() as cb: - qa_results = qa_evaluator.evaluate_strings( - prediction=output, - reference=example_outputs["reference"], - input=run_inputs["question"], - ) + if isinstance(qa_evaluator, QAMathEvaluator): + qa_results = qa_evaluator.evaluate_strings( + prediction=output, + reference=example_outputs["reference"], + ) + else: + qa_results = qa_evaluator.evaluate_strings( + prediction=output, + reference=example_outputs["reference"], + input=run_inputs["question"], + ) results.append( EvaluationResult( key="correctness",