From 8986e3ba2743dc3f1ed17073d0727a5757e2e491 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 24 Oct 2023 20:56:51 +0200 Subject: [PATCH] improve splits constant / fixture --- tests/dataset_builders/pie/test_tacred.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/dataset_builders/pie/test_tacred.py b/tests/dataset_builders/pie/test_tacred.py index 8a9c7217..70463120 100644 --- a/tests/dataset_builders/pie/test_tacred.py +++ b/tests/dataset_builders/pie/test_tacred.py @@ -24,7 +24,7 @@ PIE_DATASET_PATH = f"{PIE_BASE_PATH}/tacred" HF_DATASET_PATH = Tacred.BASE_DATASET_PATH -SPLITS = ["train", "validation", "test"] +SPLIT_NAMES = {"train", "validation", "test"} EXAMPLE_IDX = 0 NUM_SAMPLES = 3 @@ -38,8 +38,8 @@ def dataset_variant(request): return request.param -@pytest.fixture(params=SPLITS, scope="module") -def split(request): +@pytest.fixture(params=SPLIT_NAMES, scope="module") +def split_name(request): return request.param @@ -68,12 +68,12 @@ def hf_dataset(dataset_variant): @pytest.fixture(scope="module") def hf_dataset_samples(hf_samples_fn): - data_files = {split: hf_samples_fn.format(split=split) for split in SPLITS} + data_files = {split: hf_samples_fn.format(split=split) for split in SPLIT_NAMES} return load_dataset("json", data_files=data_files) def test_hf_dataset_samples(hf_dataset_samples): - assert set(hf_dataset_samples) == {"train", "validation", "test"} + assert set(hf_dataset_samples) == SPLIT_NAMES for ds in hf_dataset_samples.values(): assert len(ds) == NUM_SAMPLES @@ -96,18 +96,18 @@ def test_dump_hf(hf_dataset, hf_samples_fn, hf_metadata_fn): @pytest.fixture(params=range(NUM_SAMPLES), scope="module") -def hf_example(hf_dataset_samples, split, request): - return hf_dataset_samples[split][request.param] +def hf_example(hf_dataset_samples, split_name, request): + return hf_dataset_samples[split_name][request.param] @pytest.fixture(scope="module") -def ner_names(hf_metadata_fn, split): - return _load_json(hf_metadata_fn.format(split=split, idx_or_feature="ner_names")) +def ner_names(hf_metadata_fn, split_name): + return _load_json(hf_metadata_fn.format(split=split_name, idx_or_feature="ner_names")) @pytest.fixture(scope="module") -def relation_names(hf_metadata_fn, split): - return _load_json(hf_metadata_fn.format(split=split, idx_or_feature="relation_names")) +def relation_names(hf_metadata_fn, split_name): + return _load_json(hf_metadata_fn.format(split=split_name, idx_or_feature="relation_names")) @pytest.fixture(scope="module")