Skip to content

Commit

Permalink
🚧
Browse files Browse the repository at this point in the history
  • Loading branch information
aaarrti committed Apr 27, 2024
1 parent 6c57e7f commit 56b82d4
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 154 deletions.
128 changes: 0 additions & 128 deletions pr.patch

This file was deleted.

52 changes: 26 additions & 26 deletions tests/functions/test_pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,29 +304,29 @@ def test_huggingface_classifier_predict(
assert np.allclose(out, expected.enter_result), "Test failed."


@pytest.mark.pytorch_model
@pytest.mark.parametrize(
"transformers_installed,base_class,expected",
[
(True, PreTrainedModel, nullcontext(np.array([[0.1, 0.9]], dtype=np.float32))),
(False, None, pytest.raises(ValueError)),
],
)
def test_predict_transformers_installed(
mocker, transformers_installed, base_class, expected
):
mocker.patch("importlib.util.find_spec", return_value=transformers_installed)
from quantus.helpers.model import pytorch_model

reload(pytorch_model)
# Mock the model's behavior
model_instance = PyTorchModel(model=mocker.MagicMock(spec=base_class))
model_instance.model.training = False
model_instance.model.return_value.logits = torch.tensor([[0.1, 0.9]])
model_instance.softmax = False

# Prepare input and call the predict method
x = {"input_ids": np.array([1, 2, 3]), "attention_mask": np.array([1, 1, 1])}
with expected:
predictions = model_instance.predict(x)
assert np.array_equal(predictions, expected.enter_result), "Test failed."
#@pytest.mark.pytorch_model
#@pytest.mark.parametrize(
# "transformers_installed,base_class,expected",
# [
# (True, PreTrainedModel, nullcontext(np.array([[0.1, 0.9]], dtype=np.float32))),
# (False, None, pytest.raises(ValueError)),
# ],
#)
#def test_predict_transformers_installed(
# mocker, transformers_installed, base_class, expected
#):
# mocker.patch("importlib.util.find_spec", return_value=transformers_installed)
# from quantus.helpers.model import pytorch_model
#
# reload(pytorch_model)
# # Mock the model's behavior
# model_instance = PyTorchModel(model=mocker.MagicMock(spec=base_class))
# model_instance.model.training = False
# model_instance.model.return_value.logits = torch.tensor([[0.1, 0.9]])
# model_instance.softmax = False
#
# # Prepare input and call the predict method
# x = {"input_ids": np.array([1, 2, 3]), "attention_mask": np.array([1, 1, 1])}
# with expected:
# predictions = model_instance.predict(x)
# assert np.array_equal(predictions, expected.enter_result), "Test failed."

0 comments on commit 56b82d4

Please sign in to comment.