Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

numpy to cupy replacements #329

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ e.g., multiple XAI methods, multifaceted evaluation through several metrics, or
You can then simply run a large-scale evaluation as follows (this aggregates the result by `np.mean` averaging):

```python
import numpy as np
import cupy as np
results = quantus.evaluate(
metrics=metrics,
xai_methods=xai_methods,
Expand Down
4 changes: 2 additions & 2 deletions docs/source/getting_started/getting_started_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ xai_methods = {
You can then simply run a large-scale evaluation as follows (this aggregates the result by `np.mean` averaging):

```python
import numpy as np
import cupy as np
results = quantus.evaluate(
metrics=metrics,
xai_methods=xai_methods,
Expand Down Expand Up @@ -275,7 +275,7 @@ For example, if you want to replace `similarity_func` in your evaluation, you ca

```python
import scipy
import numpy as np
import cupy as np

def my_similarity_func(a: np.array, b: np.array, **kwargs) -> float:
"""Calculate the similarity of a and b by subtraction."""
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
"cupy>=13.0.0; python_version >= '3.9'",
"cupy<12.0.0; python_version == '3.7'",
"numpy>=1.19.5",
"pandas<=1.3.3; python_version == '3.7'",
"pandas>=1.5.3; python_version > '3.7'",
Expand Down
2 changes: 1 addition & 1 deletion quantus/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from typing import Union, Callable, Dict, Optional, Any

import numpy as np
import cupy as np
import pandas as pd

from quantus.helpers import asserts
Expand Down
2 changes: 1 addition & 1 deletion quantus/functions/complexity_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

import scipy
import numpy as np
import cupy as np


def entropy(a: np.array, x: np.array, **kwargs) -> float:
Expand Down
2 changes: 1 addition & 1 deletion quantus/functions/discretise_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see <https://www.gnu.org/licenses/>.
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

import numpy as np
import cupy as np


def floating_points(a: np.array, **kwargs) -> float:
Expand Down
2 changes: 1 addition & 1 deletion quantus/functions/explanation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from importlib import util
from typing import Optional, Union

import numpy as np
import cupy as np
import quantus
import scipy

Expand Down
2 changes: 1 addition & 1 deletion quantus/functions/loss_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see <https://www.gnu.org/licenses/>.
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

import numpy as np
import cupy as np


def mse(a: np.array, b: np.array, **kwargs) -> float:
Expand Down
2 changes: 1 addition & 1 deletion quantus/functions/mosaic_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import random
from typing import List, Tuple, Optional, Union, Any

import numpy as np
import cupy as np


def build_single_mosaic(mosaic_images_list: List[np.ndarray]) -> np.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion quantus/functions/n_bins_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

import scipy
import numpy as np
import cupy as np


def freedman_diaconis_rule(a_batch: np.ndarray) -> int:
Expand Down
2 changes: 1 addition & 1 deletion quantus/functions/norm_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see <https://www.gnu.org/licenses/>.
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

import numpy as np
import cupy as np


def fro_norm(a: np.array) -> float:
Expand Down
2 changes: 1 addition & 1 deletion quantus/functions/normalise_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from typing import Optional, Sequence

import numpy as np
import cupy as np


def normalise_by_max(
Expand Down
2 changes: 1 addition & 1 deletion quantus/functions/perturb_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import warnings
from typing import Any, Callable, Sequence, Tuple, Union, Optional
import cv2
import numpy as np
import cupy as np
from scipy.sparse import lil_matrix, csc_matrix
from scipy.sparse.linalg import spsolve

Expand Down
2 changes: 1 addition & 1 deletion quantus/functions/similarity_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from typing import Union

import numpy as np
import cupy as np
import scipy
import skimage

Expand Down
2 changes: 1 addition & 1 deletion quantus/helpers/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


from typing import Callable, Tuple, Sequence, Union
import numpy as np
import cupy as np


def assert_features_in_step(
Expand Down
2 changes: 1 addition & 1 deletion quantus/helpers/model/model_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple, List, Union, Generator, TypeVar, Generic

import numpy as np
import cupy as np

if util.find_spec("tensorflow"):
import tensorflow as tf
Expand Down
2 changes: 1 addition & 1 deletion quantus/helpers/model/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import warnings
import logging

import numpy as np
import cupy as np
import torch
from torch import nn
from functools import lru_cache
Expand Down
2 changes: 1 addition & 1 deletion quantus/helpers/model/tf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from keras import activations
from keras import Model
from keras.models import clone_model
import numpy as np
import cupy as np
import tensorflow as tf
from warnings import warn
from cachetools import cachedmethod, LRUCache
Expand Down
2 changes: 1 addition & 1 deletion quantus/helpers/perturbation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
from typing import List, TYPE_CHECKING, Callable, Mapping
import numpy as np
import cupy as np
import functools

if sys.version_info >= (3, 8):
Expand Down
2 changes: 1 addition & 1 deletion quantus/helpers/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import List, Union, Dict, Any

import matplotlib.pyplot as plt
import numpy as np
import cupy as np

from quantus.helpers import warn

Expand Down
2 changes: 1 addition & 1 deletion quantus/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from importlib import util
from typing import Any, Dict, Optional, Sequence, Tuple, Union, List, TypeVar

import numpy as np
import cupy as np
from skimage.segmentation import slic, felzenszwalb

from quantus.helpers import asserts
Expand Down
2 changes: 1 addition & 1 deletion quantus/helpers/warn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import warnings

import numpy as np
import cupy as np

from quantus.helpers.utils import get_name

Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/axiomatic/completeness.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np

from quantus.functions.perturb_func import baseline_replacement_by_indices
from quantus.helpers import warn
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/axiomatic/input_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import cupy as np

from quantus.functions.perturb_func import baseline_replacement_by_shift, perturb_batch
from quantus.helpers import asserts, warn
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/axiomatic/non_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np

from quantus.functions.perturb_func import baseline_replacement_by_indices
from quantus.helpers import asserts, warn
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)

import matplotlib.pyplot as plt
import numpy as np
import cupy as np
from sklearn.utils import gen_batches
from tqdm.auto import tqdm

Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/complexity/complexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np
import scipy

from quantus.helpers import warn
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/complexity/effective_complexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np

from quantus.helpers import warn
from quantus.helpers.enums import (
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/complexity/sparseness.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np

from quantus.helpers import warn
from quantus.helpers.enums import (
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/faithfulness/faithfulness_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np

from quantus.functions.perturb_func import baseline_replacement_by_indices
from quantus.functions.similarity_func import correlation_pearson
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/faithfulness/faithfulness_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np

from quantus.functions.perturb_func import baseline_replacement_by_indices
from quantus.functions.similarity_func import correlation_pearson
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/faithfulness/infidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import cupy as np

from quantus.functions.loss_func import mse
from quantus.functions.perturb_func import baseline_replacement_by_indices
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/faithfulness/irof.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np

from quantus.functions.perturb_func import baseline_replacement_by_indices
from quantus.helpers import asserts, utils, warn
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/faithfulness/monotonicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np

from quantus.functions.perturb_func import baseline_replacement_by_indices
from quantus.helpers import asserts, utils, warn
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/faithfulness/monotonicity_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np

from quantus.functions.perturb_func import baseline_replacement_by_indices
from quantus.functions.similarity_func import correlation_spearman
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/faithfulness/pixel_flipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import cupy as np

from quantus.functions.perturb_func import baseline_replacement_by_indices
from quantus.helpers import asserts, plotting, utils, warn
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/faithfulness/region_perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np

from quantus.functions.perturb_func import baseline_replacement_by_indices
from quantus.helpers import asserts, plotting, utils, warn
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/faithfulness/road.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np

from quantus.functions.perturb_func import noisy_linear_imputation
from quantus.helpers import warn
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/faithfulness/selectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np

from quantus.functions.perturb_func import baseline_replacement_by_indices
from quantus.helpers import plotting, utils, warn
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/faithfulness/sensitivity_n.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np

from quantus.functions.normalise_func import normalise_by_max
from quantus.functions.perturb_func import baseline_replacement_by_indices
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/faithfulness/sufficiency.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional, no_type_check

import numpy as np
import cupy as np
from scipy.spatial.distance import cdist

from quantus.helpers.model.model_interface import ModelInterface
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/localisation/attribution_localisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from typing import Callable, Dict, List, Optional

import numpy as np
import cupy as np

from quantus.helpers import asserts, warn
from quantus.helpers.enums import (
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/localisation/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import cupy as np
from sklearn.metrics import auc, roc_curve

from quantus.helpers import asserts, warn
Expand Down
Loading
Loading