Skip to content

Commit

Permalink
adjust random seed in testing
Browse files Browse the repository at this point in the history
  • Loading branch information
abarbosa94 committed Mar 18, 2024
1 parent 179da1e commit 9c56c50
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
12 changes: 7 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
MINI_BATCH_SIZE = 8
RANDOM_SEED = 42

set_seed(42)
@pytest.fixture(scope='function', autouse=True)
def reset_prngs():
set_seed(42)


@pytest.fixture(scope="session", autouse=True)
Expand Down Expand Up @@ -236,7 +238,7 @@ def load_mnist_model_softmax():
return model


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(scope="session", autouse=False)
def load_hf_distilbert_sequence_classifier():
"""
TODO
Expand All @@ -248,8 +250,8 @@ def load_hf_distilbert_sequence_classifier():
return model


@pytest.fixture(scope="session", autouse=True)
def mock_hf_text():
@pytest.fixture(scope="session", autouse=False)
def dummy_hf_tokenizer():
"""
TODO
"""
Expand All @@ -262,4 +264,4 @@ def mock_hf_text():
@pytest.fixture(scope="session", autouse=True)
def set_env():
"""Set ENV var, so test outputs are not polluted by progress bars and warnings."""
os.environ["PYTEST"] = "1"
os.environ["PYTEST"] = "1"
14 changes: 7 additions & 7 deletions tests/functions/test_pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,32 +250,32 @@ def test_add_mean_shift_to_first_layer(load_mnist_model):
[
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
lazy_fixture("mock_hf_text"),
lazy_fixture("dummy_hf_tokenizer"),
False,
{},
nullcontext(np.array([[0.01157812, 0.03933399]])),
nullcontext(np.array([[0.00424026, -0.03878461]])),
),
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
lazy_fixture("mock_hf_text"),
lazy_fixture("dummy_hf_tokenizer"),
False,
{"labels": torch.tensor([1]), "output_hidden_states": True},
nullcontext(np.array([[0.01157812, 0.03933399]])),
nullcontext(np.array([[0.00424026, -0.03878461]])),
),
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
{'input_ids': torch.tensor([[ 101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899,
102]]), 'attention_mask': torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])},
False,
{"labels": torch.tensor([1]), "output_hidden_states": True},
nullcontext(np.array([[0.01157812, 0.03933399]])),
nullcontext(np.array([[0.00424026, -0.03878461]])),
),
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
lazy_fixture("mock_hf_text"),
lazy_fixture("dummy_hf_tokenizer"),
True,
{},
nullcontext(np.array([[0.49306148, 0.5069385]])),
nullcontext(np.array([[0.51075452, 0.4892454]])),
),
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
Expand Down

0 comments on commit 9c56c50

Please sign in to comment.