Skip to content

Commit

Permalink
bugfix formatting on titanic dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
annahedstroem committed Mar 25, 2024
1 parent 3d06f90 commit 76d9fdc
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,26 @@
import pytest
import torch
from keras.datasets import cifar10
from quantus.helpers.model.models import (CifarCNNModel, ConvNet1D,
ConvNet1DTF, LeNet, LeNetTF,
TitanicSimpleTFModel,
TitanicSimpleTorchModel)
from quantus.helpers.model.models import (
CifarCNNModel,
ConvNet1D,
ConvNet1DTF,
LeNet,
LeNetTF,
TitanicSimpleTFModel,
TitanicSimpleTorchModel,
)
from sklearn.model_selection import train_test_split
from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
set_seed)
from transformers import AutoModelForSequenceClassification, AutoTokenizer, set_seed

CIFAR_IMAGE_SIZE = 32
MNIST_IMAGE_SIZE = 28
BATCH_SIZE = 124
MINI_BATCH_SIZE = 8
RANDOM_SEED = 42

@pytest.fixture(scope='function', autouse=True)

@pytest.fixture(scope="function", autouse=True)
def reset_prngs():
set_seed(42)

Expand Down Expand Up @@ -208,7 +213,10 @@ def titanic_dataset():
X = df_enc.drop(["survived"], axis=1).values.astype(float)
Y = df_enc["survived"].values.astype(int)
_, test_features, _, test_labels = train_test_split(X, Y, test_size=0.3)
return {"x_batch": test_features, "y_batch": test_labels}
return {
"x_batch": test_features[:MINI_BATCH_SIZE],
"y_batch": test_labels[:BATCH_SIZE].reshape(-1).astype(int)[:MINI_BATCH_SIZE],
}


@pytest.fixture(scope="session", autouse=True)
Expand Down Expand Up @@ -264,4 +272,4 @@ def dummy_hf_tokenizer():
@pytest.fixture(scope="session", autouse=True)
def set_env():
"""Set ENV var, so test outputs are not polluted by progress bars and warnings."""
os.environ["PYTEST"] = "1"
os.environ["PYTEST"] = "1"

0 comments on commit 76d9fdc

Please sign in to comment.