From 24e13947cde61e74c94dbd8494bae3784cc14bc4 Mon Sep 17 00:00:00 2001 From: Jonathan Lessinger Date: Fri, 8 Dec 2023 17:38:38 -0500 Subject: [PATCH] [AIC-py][eval] make e2e test a little more robust --- python/tests/test_eval.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 8dbe96d05..13caba6c9 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -140,8 +140,6 @@ async def test_run_test_suite_with_inputs(data: st.DataObject): ) ) - input_data, _ = cu.unzip(user_test_suite_with_inputs) - mock_aiconfig = MockAIConfigRuntime() out = await run_test_suite_helper( @@ -166,13 +164,25 @@ async def test_run_test_suite_with_inputs(data: st.DataObject): "best_possible_value", "worst_possible_value", ] - inputs = df["input"].astype(str).tolist() # type: ignore[no-untyped-call] - assert set(inputs) == set(input_data) # type: ignore[no-untyped-call] + + input_pairs = { + (input_datum, metric.interpretation.id) + for input_datum, metric in user_test_suite_with_inputs + } + result_pairs = set( # type: ignore[no-untyped-call] + df[["input", "metric_id"]].itertuples(index=False, name=None) # type: ignore[no-untyped-call] + ) + + assert input_pairs == result_pairs df_brevity = df[df["metric_name"] == "brevity"] assert ( df_brevity["aiconfig_output"].apply(len) # type: ignore[no-untyped-call] == df_brevity["value"] # type: ignore[no-untyped-call] ).all() + + df_substring = df[df["metric_name"] == "substring_match"] + assert (df_substring["value"].apply(lambda x: x in {0.0, 0.1})).all() # type: ignore[no-untyped-call] + case Err(e): assert False, f"expected Ok, got Err({e})"