Skip to content

Commit

Permalink
Enable test case test_columns_out_of_order (#418)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaugup authored Nov 3, 2023
2 parents 3101ff9 + fc8aedd commit 0ff6831
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
11 changes: 0 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,6 @@ def binary_classification_exp_object(method="random"):
return exp


@pytest.fixture(scope="session")
def binary_classification_exp_object_out_of_order(method="random"):
backend = 'sklearn'
dataset = helpers.load_outcome_not_last_column_dataset()
d = dice_ml.Data(dataframe=dataset, continuous_features=['Numerical'], outcome_name='Outcome')
model = _load_custom_testing_binary_model()
m = dice_ml.Model(model=model, backend=backend)
exp = dice_ml.Dice(d, m, method=method)
return exp


@pytest.fixture(scope="session")
def multi_classification_exp_object(method="random"):
backend = 'sklearn'
Expand Down
35 changes: 23 additions & 12 deletions tests/test_dice_interface/test_explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import dice_ml
from dice_ml.diverse_counterfactuals import CounterfactualExamples
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
from dice_ml.utils import helpers

from ..conftest import _load_custom_testing_binary_model


@pytest.mark.parametrize("method", ['random', 'genetic', 'kdtree'])
Expand Down Expand Up @@ -116,18 +119,26 @@ def test_global_feature_importance(

self._verify_feature_importance(global_importance.summary_importance)

# @pytest.mark.parametrize("desired_class, binary_classification_exp_object_out_of_order",
# [(1, 'random'), (1, 'genetic'), (1, 'kdtree')],
# indirect=['binary_classification_exp_object_out_of_order'])
# def test_columns_out_of_order(self, desired_class, binary_classification_exp_object_out_of_order, sample_custom_query_1):
# exp = binary_classification_exp_object_out_of_order # explainer object
# exp._generate_counterfactuals(
# query_instance=sample_custom_query_1,
# total_CFs=0,
# desired_class=desired_class,
# desired_range=None,
# permitted_range=None,
# features_to_vary='all')
@pytest.mark.parametrize("desired_class", [1])
def test_columns_out_of_order(self, desired_class, method, sample_custom_query_1):
if method == 'genetic':
pytest.skip('DiceGenetic explainer fails this test case')

dataset = helpers.load_outcome_not_last_column_dataset()
d = dice_ml.Data(
dataframe=dataset, continuous_features=['Numerical'],
outcome_name='Outcome')
model = _load_custom_testing_binary_model()
m = dice_ml.Model(model=model, backend='sklearn')
exp = dice_ml.Dice(d, m, method=method)

exp._generate_counterfactuals(
query_instance=sample_custom_query_1,
total_CFs=0,
desired_class=desired_class,
desired_range=None,
permitted_range=None,
features_to_vary='all')

@pytest.mark.parametrize("desired_class", [1])
def test_incorrect_features_to_vary_list(
Expand Down

0 comments on commit 0ff6831

Please sign in to comment.