From cd722d7815bb4b1492f0a0a384e78bcbafbf3b83 Mon Sep 17 00:00:00 2001 From: MichaelFu512 Date: Tue, 14 May 2024 10:41:03 -0700 Subject: [PATCH] new function for test partial dependence --- .../test_partial_dependence.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/evalml/tests/model_understanding_tests/test_partial_dependence.py b/evalml/tests/model_understanding_tests/test_partial_dependence.py index 09fb5ce29e..bf12f072ad 100644 --- a/evalml/tests/model_understanding_tests/test_partial_dependence.py +++ b/evalml/tests/model_understanding_tests/test_partial_dependence.py @@ -1,3 +1,4 @@ +import collections import re from unittest.mock import patch @@ -713,13 +714,24 @@ def test_partial_dependence_more_categories_than_grid_resolution( fraud_local, logistic_regression_binary_pipeline, ): - def round_dict_keys(dictionary, places=3): + def round_dict_keys(dictionary, places=6): """Function to round all keys of a dictionary that has floats as keys.""" dictionary_rounded = {} for key in dictionary: dictionary_rounded[round(key, places)] = dictionary[key] return dictionary_rounded + def check_dicts_approx_equal(part_dep_ans, part_dep_dict, rel=1e-3): + keys_part_dep_ans = part_dep_ans.keys() + keys_part_dep_dict = part_dep_dict.keys() + keys_part_dep_ans.sort() + keys_part_dep_dict.sort() + + assert keys_part_dep_ans == pytest.approx(keys_part_dep_dict, rel=rel) + assert collections.Counter(list(part_dep_ans.values())) == collections.Counter( + list(part_dep_dict.values()), + ) + X, y = fraud_local X = X[:100] y = y[:100] @@ -753,7 +765,7 @@ def round_dict_keys(dictionary, places=3): grid_resolution=round(num_cat_features / 2), ) part_dep_dict = dict(part_dep["partial_dependence"].value_counts()) - assert part_dep_ans_rounded == round_dict_keys(part_dep_dict) + check_dicts_approx_equal(part_dep_ans_rounded, round_dict_keys(part_dep_dict)) fast_part_dep = partial_dependence( pipeline, @@ -774,7 +786,7 @@ def round_dict_keys(dictionary, places=3): grid_resolution=round(num_cat_features), ) part_dep_dict = dict(part_dep["partial_dependence"].value_counts()) - assert part_dep_ans_rounded == round_dict_keys(part_dep_dict) + check_dicts_approx_equal(part_dep_ans_rounded, round_dict_keys(part_dep_dict)) fast_part_dep = partial_dependence( pipeline, @@ -795,7 +807,7 @@ def round_dict_keys(dictionary, places=3): grid_resolution=round(num_cat_features * 2), ) part_dep_dict = dict(part_dep["partial_dependence"].value_counts()) - assert part_dep_ans_rounded == round_dict_keys(part_dep_dict) + check_dicts_approx_equal(part_dep_ans_rounded, round_dict_keys(part_dep_dict)) fast_part_dep = partial_dependence( pipeline,