diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 8dbe96d05..13caba6c9 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -140,8 +140,6 @@ async def test_run_test_suite_with_inputs(data: st.DataObject): ) ) - input_data, _ = cu.unzip(user_test_suite_with_inputs) - mock_aiconfig = MockAIConfigRuntime() out = await run_test_suite_helper( @@ -166,13 +164,25 @@ async def test_run_test_suite_with_inputs(data: st.DataObject): "best_possible_value", "worst_possible_value", ] - inputs = df["input"].astype(str).tolist() # type: ignore[no-untyped-call] - assert set(inputs) == set(input_data) # type: ignore[no-untyped-call] + + input_pairs = { + (input_datum, metric.interpretation.id) + for input_datum, metric in user_test_suite_with_inputs + } + result_pairs = set( # type: ignore[no-untyped-call] + df[["input", "metric_id"]].itertuples(index=False, name=None) # type: ignore[no-untyped-call] + ) + + assert input_pairs == result_pairs df_brevity = df[df["metric_name"] == "brevity"] assert ( df_brevity["aiconfig_output"].apply(len) # type: ignore[no-untyped-call] == df_brevity["value"] # type: ignore[no-untyped-call] ).all() + + df_substring = df[df["metric_name"] == "substring_match"] + assert (df_substring["value"].apply(lambda x: x in {0.0, 0.1})).all() # type: ignore[no-untyped-call] + case Err(e): assert False, f"expected Ok, got Err({e})"