Skip to content

Commit

Permalink
Merge pull request #328 from understandable-machine-intelligence-lab/…
Browse files Browse the repository at this point in the history
…fixes-explain-func

Bugfixes tabular data and added unit tests on explanation_func.py
  • Loading branch information
annahedstroem authored Mar 1, 2024
2 parents d67efa6 + 9d02d0b commit 279fe00
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
25 changes: 18 additions & 7 deletions quantus/functions/explanation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import warnings
from importlib import util
from typing import Optional, Union
from typing import Optional, Union, Callable

import numpy as np
import quantus
Expand Down Expand Up @@ -391,7 +391,6 @@ def generate_tf_explanation(
)

elif method == "SmoothGrad":

num_samples = kwargs.get("num_samples", 5)
noise = kwargs.get("noise", 0.1)
explainer = tf_explain.core.smoothgrad.SmoothGrad()
Expand Down Expand Up @@ -513,6 +512,7 @@ def generate_captum_explanation(

if not isinstance(inputs, torch.Tensor):
inputs = torch.Tensor(inputs).to(device)
inputs.requires_grad_()

if not isinstance(targets, torch.Tensor):
targets = torch.as_tensor(targets).to(device)
Expand Down Expand Up @@ -667,14 +667,22 @@ def f_reduce_axes(a):
elif method == "Control Var. Sobel Filter":
explanation = torch.zeros(size=inputs.shape)

if inputs.is_cuda:
inputs = inputs.cpu()

inputs_numpy = inputs.detach().numpy()

for i in range(len(explanation)):
explanation[i] = torch.Tensor(
np.clip(scipy.ndimage.sobel(inputs[i].cpu().numpy()), 0, 1)
np.clip(scipy.ndimage.sobel(inputs_numpy[i]), 0, 1)
)
explanation = explanation.mean(**reduce_axes)
if len(explanation.shape) > 2:
explanation = explanation.mean(**reduce_axes)

elif method == "Control Var. Random Uniform":
explanation = torch.rand(size=(inputs.shape[0], *inputs.shape[2:]))
explanation = torch.rand(size=(inputs.shape))
if len(explanation.shape) > 2:
explanation = explanation.mean(**reduce_axes)

elif method == "Control Var. Constant":
assert (
Expand All @@ -686,13 +694,16 @@ def f_reduce_axes(a):
# Update the tensor with values per input x.
for i in range(explanation.shape[0]):
constant_value = get_baseline_value(
value=kwargs["constant_value"], arr=inputs[i], return_shape=(1,)
value=kwargs["constant_value"],
arr=inputs[i],
return_shape=kwargs.get("return_shape", (1,)),
)[0]
explanation[i] = torch.Tensor().new_full(
size=explanation[0].shape, fill_value=constant_value
)

explanation = explanation.mean(**reduce_axes)
if len(explanation.shape) > 2:
explanation = explanation.mean(**reduce_axes)

else:
raise KeyError(
Expand Down
26 changes: 25 additions & 1 deletion tests/functions/test_explanation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,14 @@
},
{"shape": (8, 1, 28, 28)},
),
(
lazy_fixture("titanic_model_torch"),
lazy_fixture("titanic_dataset"),
{
"method": "Control Var. Sobel Filter",
},
{"min": -10000.0, "max": 10000.0},
),
(
lazy_fixture("load_1d_3ch_conv_model"),
lazy_fixture("almost_uniform_1d_no_abatch"),
Expand Down Expand Up @@ -334,6 +342,15 @@
},
{"value": 0.0},
),
(
lazy_fixture("titanic_model_torch"),
lazy_fixture("titanic_dataset"),
{
"method": "Control Var. Constant",
"constant_value": 0.0,
},
{"value": 0.0},
),
(
lazy_fixture("load_mnist_model"),
lazy_fixture("load_mnist_images"),
Expand All @@ -342,6 +359,14 @@
},
{"min": 0.0, "max": 1.0},
),
(
lazy_fixture("titanic_model_torch"),
lazy_fixture("titanic_dataset"),
{
"method": "Control Var. Random Uniform",
},
{"min": 0.0, "max": 1.0},
),
(
lazy_fixture("load_1d_3ch_conv_model"),
lazy_fixture("almost_uniform_1d_no_abatch"),
Expand Down Expand Up @@ -750,7 +775,6 @@ def test_explain_func(
params: dict,
expected: Union[float, dict, bool],
):

x_batch, y_batch = (data["x_batch"], data["y_batch"])
if "exception" in expected:
with pytest.raises(expected["exception"]):
Expand Down

0 comments on commit 279fe00

Please sign in to comment.