Skip to content

Commit

Permalink
improve splits constant / fixture
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Oct 24, 2023
1 parent a25875a commit 0959f2d
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions tests/dataset_builders/pie/test_tacred.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

PIE_DATASET_PATH = f"{PIE_BASE_PATH}/tacred"
HF_DATASET_PATH = Tacred.BASE_DATASET_PATH
SPLITS = ["train", "validation", "test"]
SPLIT_NAMES = {"train", "validation", "test"}
EXAMPLE_IDX = 0
NUM_SAMPLES = 3

Expand All @@ -38,8 +38,8 @@ def dataset_variant(request):
return request.param


@pytest.fixture(params=SPLITS, scope="module")
def split(request):
@pytest.fixture(params=SPLIT_NAMES, scope="module")
def split_name(request):
return request.param


Expand Down Expand Up @@ -68,12 +68,12 @@ def hf_dataset(dataset_variant):

@pytest.fixture(scope="module")
def hf_dataset_samples(hf_samples_fn):
data_files = {split: hf_samples_fn.format(split=split) for split in SPLITS}
data_files = {split: hf_samples_fn.format(split=split) for split in SPLIT_NAMES}
return load_dataset("json", data_files=data_files)


def test_hf_dataset_samples(hf_dataset_samples):
assert set(hf_dataset_samples) == {"train", "validation", "test"}
assert set(hf_dataset_samples) == SPLIT_NAMES
for ds in hf_dataset_samples.values():
assert len(ds) == NUM_SAMPLES

Expand All @@ -96,18 +96,18 @@ def test_dump_hf(hf_dataset, hf_samples_fn, hf_metadata_fn):


@pytest.fixture(params=range(NUM_SAMPLES), scope="module")
def hf_example(hf_dataset_samples, split, request):
return hf_dataset_samples[split][request.param]
def hf_example(hf_dataset_samples, split_name, request):
return hf_dataset_samples[split_name][request.param]


@pytest.fixture(scope="module")
def ner_names(hf_metadata_fn, split):
return _load_json(hf_metadata_fn.format(split=split, idx_or_feature="ner_names"))
def ner_names(hf_metadata_fn, split_name):
return _load_json(hf_metadata_fn.format(split=split_name, idx_or_feature="ner_names"))


@pytest.fixture(scope="module")
def relation_names(hf_metadata_fn, split):
return _load_json(hf_metadata_fn.format(split=split, idx_or_feature="relation_names"))
def relation_names(hf_metadata_fn, split_name):
return _load_json(hf_metadata_fn.format(split=split_name, idx_or_feature="relation_names"))


@pytest.fixture(scope="module")
Expand Down

0 comments on commit 0959f2d

Please sign in to comment.