From 3873566e13e9236ee36c41d7e0f95b20eab46819 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 2 Nov 2023 18:47:07 +0100 Subject: [PATCH] improve tests --- tests/dataset_builders/pie/test_conll2003.py | 28 ++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/dataset_builders/pie/test_conll2003.py b/tests/dataset_builders/pie/test_conll2003.py index 0e779db7..9bb5e8f1 100644 --- a/tests/dataset_builders/pie/test_conll2003.py +++ b/tests/dataset_builders/pie/test_conll2003.py @@ -2,6 +2,7 @@ import pytest from pytorch_ie import DatasetDict from pytorch_ie.core import Document +from pytorch_ie.documents import TextDocumentWithLabeledSpans from dataset_builders.pie.conll2003.conll2003 import Conll2003 from tests.dataset_builders.common import PIE_BASE_PATH @@ -47,10 +48,15 @@ def test_hf_example(hf_example, dataset_name): raise ValueError(f"Unknown dataset name: {dataset_name}") -def test_generate_document(hf_example, hf_dataset, dataset_name): +@pytest.fixture(scope="module") +def document(hf_example, hf_dataset): conll2003 = Conll2003() generate_document_kwargs = conll2003._generate_document_kwargs(hf_dataset["train"]) document = conll2003._generate_document(example=hf_example, **generate_document_kwargs) + return document + + +def test_document(document, dataset_name): assert isinstance(document, Document) if dataset_name == "conll2003": assert document.text == "EU rejects German call to boycott British lamb ." @@ -68,7 +74,7 @@ def pie_dataset(dataset_name): return DatasetDict.load_dataset(str(PIE_DATASET_PATH), name=dataset_name) -def test_dataset(pie_dataset): +def test_pie_dataset(pie_dataset): assert set(pie_dataset) == SPLIT_NAMES split_sizes = {split_name: len(ds) for split_name, ds in pie_dataset.items()} assert split_sizes == SPLIT_SIZES @@ -92,3 +98,21 @@ def test_converted_pie_dataset(converted_pie_dataset, converter_document_type): for ds in converted_pie_dataset.values(): for document in ds: assert isinstance(document, converter_document_type) + + +@pytest.fixture(scope="module") +def converted_document(converted_pie_dataset): + return converted_pie_dataset["train"][0] + + +def test_converted_document(converted_document, converter_document_type): + assert isinstance(document, converter_document_type) + if converter_document_type == TextDocumentWithLabeledSpans: + assert document.text == "EU rejects German call to boycott British lamb ." + entities = list(document.labeled_spans) + assert len(entities) == 3 + assert str(entities[0]) == "EU" + assert str(entities[1]) == "German" + assert str(entities[2]) == "British" + else: + raise ValueError(f"Unknown converter document type: {converter_document_type}")