Skip to content

Commit

Permalink
improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 2, 2023
1 parent 70a70cb commit 3873566
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions tests/dataset_builders/pie/test_conll2003.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
from pytorch_ie import DatasetDict
from pytorch_ie.core import Document
from pytorch_ie.documents import TextDocumentWithLabeledSpans

from dataset_builders.pie.conll2003.conll2003 import Conll2003
from tests.dataset_builders.common import PIE_BASE_PATH
Expand Down Expand Up @@ -47,10 +48,15 @@ def test_hf_example(hf_example, dataset_name):
raise ValueError(f"Unknown dataset name: {dataset_name}")


def test_generate_document(hf_example, hf_dataset, dataset_name):
@pytest.fixture(scope="module")
def document(hf_example, hf_dataset):
conll2003 = Conll2003()
generate_document_kwargs = conll2003._generate_document_kwargs(hf_dataset["train"])
document = conll2003._generate_document(example=hf_example, **generate_document_kwargs)
return document


def test_document(document, dataset_name):
assert isinstance(document, Document)
if dataset_name == "conll2003":
assert document.text == "EU rejects German call to boycott British lamb ."
Expand All @@ -68,7 +74,7 @@ def pie_dataset(dataset_name):
return DatasetDict.load_dataset(str(PIE_DATASET_PATH), name=dataset_name)


def test_dataset(pie_dataset):
def test_pie_dataset(pie_dataset):
assert set(pie_dataset) == SPLIT_NAMES
split_sizes = {split_name: len(ds) for split_name, ds in pie_dataset.items()}
assert split_sizes == SPLIT_SIZES
Expand All @@ -92,3 +98,21 @@ def test_converted_pie_dataset(converted_pie_dataset, converter_document_type):
for ds in converted_pie_dataset.values():
for document in ds:
assert isinstance(document, converter_document_type)


@pytest.fixture(scope="module")
def converted_document(converted_pie_dataset):
return converted_pie_dataset["train"][0]


def test_converted_document(converted_document, converter_document_type):
assert isinstance(document, converter_document_type)
if converter_document_type == TextDocumentWithLabeledSpans:
assert document.text == "EU rejects German call to boycott British lamb ."
entities = list(document.labeled_spans)
assert len(entities) == 3
assert str(entities[0]) == "EU"
assert str(entities[1]) == "German"
assert str(entities[2]) == "British"
else:
raise ValueError(f"Unknown converter document type: {converter_document_type}")

0 comments on commit 3873566

Please sign in to comment.