Skip to content

Commit

Permalink
Ensure same type for CF output as the input features (#412)
Browse files Browse the repository at this point in the history
* same type for the input features

* added tests and fixed bug

Signed-off-by: Amit Sharma <[email protected]>

* fixed lint error

* fixed lint

* fixed private data interface

* common functions in base_data

* fixed lint error

* fixed isort

* avoiding flaky test fail whenever final_cfs_df is none

---------

Signed-off-by: Amit Sharma <[email protected]>
  • Loading branch information
amit-sharma authored Oct 26, 2023
1 parent 8cd02e8 commit b12abb4
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 20 deletions.
26 changes: 26 additions & 0 deletions dice_ml/data_interfaces/base_data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from abc import ABC, abstractmethod

import pandas as pd
from raiutils.exceptions import UserConfigValidationException

from dice_ml.utils.exception import SystemException
Expand Down Expand Up @@ -71,6 +72,31 @@ def _validate_and_set_permitted_range(self, params, features_dict=None):
)
self.permitted_range, _ = self.get_features_range(input_permitted_range, features_dict)

def ensure_consistent_type(self, output_df, query_instance):
qdf = self.query_instance_to_df(query_instance)
output_df = output_df.astype(qdf.dtypes.to_dict())
return output_df

def query_instance_to_df(self, query_instance):
if isinstance(query_instance, list):
if isinstance(query_instance[0], dict): # prepare a list of query instances
test = pd.DataFrame(query_instance, columns=self.feature_names)

else: # prepare a single query instance in list
query_instance = {'row1': query_instance}
test = pd.DataFrame.from_dict(
query_instance, orient='index', columns=self.feature_names)

elif isinstance(query_instance, dict):
test = pd.DataFrame({k: [v] for k, v in query_instance.items()}, columns=self.feature_names)

elif isinstance(query_instance, pd.DataFrame):
test = query_instance.copy()

else:
raise ValueError("Query instance should be a dict, a pandas dataframe, a list, or a list of dicts")
return test

@abstractmethod
def __init__(self, params):
"""The init method needs to be implemented by the inherting classes."""
Expand Down
7 changes: 5 additions & 2 deletions dice_ml/data_interfaces/private_data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,7 @@ def prepare_df_for_ohe_encoding(self):

return df

def prepare_query_instance(self, query_instance):
"""Prepares user defined test input(s) for DiCE."""
def query_instance_to_df(self, query_instance):
if isinstance(query_instance, list):
if isinstance(query_instance[0], dict): # prepare a list of query instances
test = pd.DataFrame(query_instance, columns=self.feature_names)
Expand All @@ -370,7 +369,11 @@ def prepare_query_instance(self, query_instance):

else:
raise ValueError("Query instance should be a dict, a pandas dataframe, a list, or a list of dicts")
return test

def prepare_query_instance(self, query_instance):
"""Prepares user defined test input(s) for DiCE."""
test = self.query_instance_to_df(query_instance)
test = test.reset_index(drop=True)
return test

Expand Down
19 changes: 1 addition & 18 deletions dice_ml/data_interfaces/public_data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,24 +451,7 @@ def prepare_df_for_ohe_encoding(self):

def prepare_query_instance(self, query_instance):
"""Prepares user defined test input(s) for DiCE."""
if isinstance(query_instance, list):
if isinstance(query_instance[0], dict): # prepare a list of query instances
test = pd.DataFrame(query_instance, columns=self.feature_names)

else: # prepare a single query instance in list
query_instance = {'row1': query_instance}
test = pd.DataFrame.from_dict(
query_instance, orient='index', columns=self.feature_names)

elif isinstance(query_instance, dict):
test = pd.DataFrame({k: [v] for k, v in query_instance.items()}, columns=self.feature_names)

elif isinstance(query_instance, pd.DataFrame):
test = query_instance.copy()

else:
raise ValueError("Query instance should be a dict, a pandas dataframe, a list, or a list of dicts")

test = self.query_instance_to_df(query_instance)
test = test.reset_index(drop=True)
# encode categorical and numerical columns
test = self._set_feature_dtypes(test,
Expand Down
8 changes: 8 additions & 0 deletions dice_ml/explainer_interfaces/explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ def generate_counterfactuals(self, query_instances, total_CFs,
posthoc_sparsity_algorithm=posthoc_sparsity_algorithm,
verbose=verbose,
**kwargs)
res.test_instance_df = self.data_interface.ensure_consistent_type(
res.test_instance_df, query_instance)
if res.final_cfs_df is not None and len(res.final_cfs_df) > 0:
res.final_cfs_df = self.data_interface.ensure_consistent_type(
res.final_cfs_df, query_instance)
if res.final_cfs_df_sparse is not None and len(res.final_cfs_df_sparse) > 0:
res.final_cfs_df_sparse = self.data_interface.ensure_consistent_type(
res.final_cfs_df_sparse, query_instance)
cf_examples_arr.append(res)
self._check_any_counterfactuals_computed(cf_examples_arr=cf_examples_arr)

Expand Down
43 changes: 43 additions & 0 deletions tests/test_dice_interface/test_explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,28 @@ def test_zero_cfs_internal(
total_CFs=total_CFs, desired_class=desired_class,
desired_range=desired_range, permitted_range=permitted_range)

@pytest.mark.parametrize("desired_class", [1])
def test_cfs_type_consistency(
self, desired_class, method,
sample_custom_query_1, sample_counterfactual_example_dummy,
custom_public_data_interface,
sklearn_binary_classification_model_interface):
exp = dice_ml.Dice(
custom_public_data_interface,
sklearn_binary_classification_model_interface,
method=method)
sample_custom_query = pd.concat([sample_custom_query_1, sample_custom_query_1])
cf_explanations = exp.generate_counterfactuals(
query_instances=sample_custom_query,
total_CFs=2,
desired_class=desired_class)
for col in sample_custom_query.columns:
assert cf_explanations.cf_examples_list[0].test_instance_df[col].dtype == sample_custom_query[col].dtype
if cf_explanations.cf_examples_list[0].final_cfs_df is not None:
assert cf_explanations.cf_examples_list[0].final_cfs_df[col].dtype == sample_custom_query[col].dtype
if cf_explanations.cf_examples_list[0].final_cfs_df_sparse is not None:
assert cf_explanations.cf_examples_list[0].final_cfs_df_sparse[col].dtype == sample_custom_query[col].dtype


@pytest.mark.parametrize("method", ['random', 'genetic', 'kdtree'])
class TestExplainerBaseMultiClassClassification:
Expand Down Expand Up @@ -428,6 +450,27 @@ def test_zero_cfs_internal(
total_CFs=total_CFs, desired_class=desired_class,
desired_range=desired_range, permitted_range=permitted_range)

@pytest.mark.parametrize("desired_class", [1])
def test_cfs_type_consistency(
self, desired_class, method, sample_custom_query_1,
custom_public_data_interface,
sklearn_multiclass_classification_model_interface):
exp = dice_ml.Dice(
custom_public_data_interface,
sklearn_multiclass_classification_model_interface,
method=method)
cf_explanations = exp.generate_counterfactuals(
query_instances=[sample_custom_query_1],
total_CFs=2,
desired_class=desired_class)

for col in sample_custom_query_1.columns:
assert cf_explanations.cf_examples_list[0].test_instance_df[col].dtype == sample_custom_query_1[col].dtype
if cf_explanations.cf_examples_list[0].final_cfs_df is not None:
assert cf_explanations.cf_examples_list[0].final_cfs_df[col].dtype == sample_custom_query_1[col].dtype
if cf_explanations.cf_examples_list[0].final_cfs_df_sparse is not None:
assert cf_explanations.cf_examples_list[0].final_cfs_df_sparse[col].dtype == sample_custom_query_1[col].dtype


class TestExplainerBaseRegression:

Expand Down

0 comments on commit b12abb4

Please sign in to comment.