Skip to content

Commit

Permalink
🚧
Browse files Browse the repository at this point in the history
  • Loading branch information
aaarrti committed Apr 27, 2024
1 parent 6c57e7f commit b3179ba
Showing 1 changed file with 56 additions and 56 deletions.
112 changes: 56 additions & 56 deletions tests/functions/test_pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,62 +247,62 @@ def test_add_mean_shift_to_first_layer(load_mnist_model):
assert torch.all(torch.isclose(a1, a2, atol=1e-04))


@pytest.mark.pytorch_model
@pytest.mark.parametrize(
"hf_model,data,softmax,model_kwargs,expected",
[
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
lazy_fixture("dummy_hf_tokenizer"),
False,
{},
nullcontext(np.array([[0.00424026, -0.03878461]])),
),
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
lazy_fixture("dummy_hf_tokenizer"),
False,
{"labels": torch.tensor([1]), "output_hidden_states": True},
nullcontext(np.array([[0.00424026, -0.03878461]])),
),
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
{
"input_ids": torch.tensor(
[[101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102]]
),
"attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
},
False,
{"labels": torch.tensor([1]), "output_hidden_states": True},
nullcontext(np.array([[0.00424026, -0.03878461]])),
),
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
lazy_fixture("dummy_hf_tokenizer"),
True,
{},
nullcontext(np.array([[0.51075452, 0.4892454]])),
),
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
np.array([1, 2, 3]),
False,
{},
pytest.raises(ValueError),
),
],
)
def test_huggingface_classifier_predict(
hf_model, data, softmax, model_kwargs, expected
):
model = PyTorchModel(
model=hf_model, softmax=softmax, model_predict_kwargs=model_kwargs
)
with expected:
out = model.predict(x=data)
assert np.allclose(out, expected.enter_result), "Test failed."

#@pytest.mark.pytorch_model
#@pytest.mark.parametrize(
# "hf_model,data,softmax,model_kwargs,expected",
# [
# (
# lazy_fixture("load_hf_distilbert_sequence_classifier"),
# lazy_fixture("dummy_hf_tokenizer"),
# False,
# {},
# nullcontext(np.array([[0.00424026, -0.03878461]])),
# ),
# (
# lazy_fixture("load_hf_distilbert_sequence_classifier"),
# lazy_fixture("dummy_hf_tokenizer"),
# False,
# {"labels": torch.tensor([1]), "output_hidden_states": True},
# nullcontext(np.array([[0.00424026, -0.03878461]])),
# ),
# (
# lazy_fixture("load_hf_distilbert_sequence_classifier"),
# {
# "input_ids": torch.tensor(
# [[101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102]]
# ),
# "attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
# },
# False,
# {"labels": torch.tensor([1]), "output_hidden_states": True},
# nullcontext(np.array([[0.00424026, -0.03878461]])),
# ),
# (
# lazy_fixture("load_hf_distilbert_sequence_classifier"),
# lazy_fixture("dummy_hf_tokenizer"),
# True,
# {},
# nullcontext(np.array([[0.51075452, 0.4892454]])),
# ),
# (
# lazy_fixture("load_hf_distilbert_sequence_classifier"),
# np.array([1, 2, 3]),
# False,
# {},
# pytest.raises(ValueError),
# ),
# ],
#)
#def test_huggingface_classifier_predict(
# hf_model, data, softmax, model_kwargs, expected
#):
# model = PyTorchModel(
# model=hf_model, softmax=softmax, model_predict_kwargs=model_kwargs
# )
# with expected:
# out = model.predict(x=data)
# assert np.allclose(out, expected.enter_result), "Test failed."
#

@pytest.mark.pytorch_model
@pytest.mark.parametrize(
Expand Down

0 comments on commit b3179ba

Please sign in to comment.