Skip to content

Commit

Permalink
resolving the single missing mypy test
Browse files Browse the repository at this point in the history
  • Loading branch information
davor10105 committed Oct 8, 2024
1 parent 6591aab commit 1609535
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions quantus/metrics/randomisation/random_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,16 +267,11 @@ def evaluate_instance(
"""
# Randomly select off-class labels.
np.random.seed(self.seed)
y_off = np.array(
[
np.random.choice(
[y_ for y_ in list(np.arange(0, self.num_classes)) if y_ != y]
)
]
)
y_off = np.array([np.random.choice([y_ for y_ in list(np.arange(0, self.num_classes)) if y_ != y])])
# Explain against a random class.
a_perturbed = self.explain_batch(model, np.expand_dims(x, axis=0), y_off)
return self.similarity_func(a.flatten(), a_perturbed.flatten())
similarity = float(self.similarity_func(a.flatten(), a_perturbed.flatten()))
return similarity

def custom_preprocess(
self,
Expand Down Expand Up @@ -328,7 +323,4 @@ def evaluate_batch(
scores_batch:
Evaluation results.
"""
return [
self.evaluate_instance(model, x, y, a)
for x, y, a in zip(x_batch, y_batch, a_batch)
]
return [self.evaluate_instance(model, x, y, a) for x, y, a in zip(x_batch, y_batch, a_batch)]

0 comments on commit 1609535

Please sign in to comment.