From 3c3cf11e1cebb8fc228c570fbcf6705e8044dbe9 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 2 Nov 2023 18:00:41 +0100 Subject: [PATCH 1/6] add PIE dataset loading script for conll2003 --- dataset_builders/pie/conll2003/conll2003.py | 60 +++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 dataset_builders/pie/conll2003/conll2003.py diff --git a/dataset_builders/pie/conll2003/conll2003.py b/dataset_builders/pie/conll2003/conll2003.py new file mode 100644 index 00000000..ff7b1f27 --- /dev/null +++ b/dataset_builders/pie/conll2003/conll2003.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass + +import datasets +import pytorch_ie.data.builder +from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.documents import TextDocument, TextDocumentWithLabeledSpans +from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans + + +class CoNLL2003Config(datasets.BuilderConfig): + """BuilderConfig for CoNLL2003""" + + def __init__(self, **kwargs): + """BuilderConfig for CoNLL2003. + Args: + **kwargs: keyword arguments forwarded to super. + """ + super().__init__(**kwargs) + + +@dataclass +class CoNLL2003Document(TextDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + + +class Conll2003(pytorch_ie.data.builder.GeneratorBasedBuilder): + DOCUMENT_TYPE = CoNLL2003Document + + BASE_DATASET_PATH = "conll2003" + + BUILDER_CONFIGS = [ + CoNLL2003Config( + name="conll2003", version=datasets.Version("1.0.0"), description="CoNLL2003 dataset" + ), + ] + + DOCUMENT_CONVERTERS = { + TextDocumentWithLabeledSpans: { + # just rename the layer + "entities": "labeled_spans", + } + } + + def _generate_document_kwargs(self, dataset): + return {"int_to_str": dataset.features["ner_tags"].feature.int2str} + + def _generate_document(self, example, int_to_str): + doc_id = example["id"] + tokens = example["tokens"] + ner_tags = [int_to_str(tag) for tag in example["ner_tags"]] + + text, ner_spans = tokens_and_tags_to_text_and_labeled_spans(tokens=tokens, tags=ner_tags) + + document = CoNLL2003Document(text=text, id=doc_id) + + for span in sorted(ner_spans, key=lambda span: span.start): + document.entities.append(span) + + return document From 50e52218eb9e79ec1278ed554acdc05bfc56656e Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 2 Nov 2023 18:36:04 +0100 Subject: [PATCH 2/6] simplify --- dataset_builders/pie/conll2003/conll2003.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/dataset_builders/pie/conll2003/conll2003.py b/dataset_builders/pie/conll2003/conll2003.py index ff7b1f27..ebfe7191 100644 --- a/dataset_builders/pie/conll2003/conll2003.py +++ b/dataset_builders/pie/conll2003/conll2003.py @@ -8,17 +8,6 @@ from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans -class CoNLL2003Config(datasets.BuilderConfig): - """BuilderConfig for CoNLL2003""" - - def __init__(self, **kwargs): - """BuilderConfig for CoNLL2003. - Args: - **kwargs: keyword arguments forwarded to super. - """ - super().__init__(**kwargs) - - @dataclass class CoNLL2003Document(TextDocument): entities: AnnotationList[LabeledSpan] = annotation_field(target="text") @@ -30,7 +19,7 @@ class Conll2003(pytorch_ie.data.builder.GeneratorBasedBuilder): BASE_DATASET_PATH = "conll2003" BUILDER_CONFIGS = [ - CoNLL2003Config( + datasets.BuilderConfig( name="conll2003", version=datasets.Version("1.0.0"), description="CoNLL2003 dataset" ), ] From ea73e0ebf083c08ca3930f65b9f8992c76761094 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 2 Nov 2023 18:36:26 +0100 Subject: [PATCH 3/6] add tests --- tests/dataset_builders/pie/test_conll2003.py | 94 ++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 tests/dataset_builders/pie/test_conll2003.py diff --git a/tests/dataset_builders/pie/test_conll2003.py b/tests/dataset_builders/pie/test_conll2003.py new file mode 100644 index 00000000..0e779db7 --- /dev/null +++ b/tests/dataset_builders/pie/test_conll2003.py @@ -0,0 +1,94 @@ +import datasets +import pytest +from pytorch_ie import DatasetDict +from pytorch_ie.core import Document + +from dataset_builders.pie.conll2003.conll2003 import Conll2003 +from tests.dataset_builders.common import PIE_BASE_PATH + +DATASET_NAME = "conll2003" +PIE_DATASET_PATH = PIE_BASE_PATH / DATASET_NAME +HF_DATASET_PATH = Conll2003.BASE_DATASET_PATH +SPLIT_NAMES = {"train", "validation", "test"} +SPLIT_SIZES = {"train": 14041, "validation": 3250, "test": 3453} + + +@pytest.fixture(params=[config.name for config in Conll2003.BUILDER_CONFIGS], scope="module") +def dataset_name(request): + return request.param + + +@pytest.fixture(scope="module") +def hf_dataset(dataset_name): + return datasets.load_dataset(str(HF_DATASET_PATH), name=dataset_name) + + +def test_hf_dataset(hf_dataset): + assert set(hf_dataset) == SPLIT_NAMES + split_sizes = {split_name: len(ds) for split_name, ds in hf_dataset.items()} + assert split_sizes == SPLIT_SIZES + + +@pytest.fixture(scope="module") +def hf_example(hf_dataset): + return hf_dataset["train"][0] + + +def test_hf_example(hf_example, dataset_name): + if dataset_name == "conll2003": + assert hf_example == { + "chunk_tags": [11, 21, 11, 12, 21, 22, 11, 12, 0], + "id": "0", + "ner_tags": [3, 0, 7, 0, 0, 0, 7, 0, 0], + "pos_tags": [22, 42, 16, 21, 35, 37, 16, 21, 7], + "tokens": ["EU", "rejects", "German", "call", "to", "boycott", "British", "lamb", "."], + } + else: + raise ValueError(f"Unknown dataset name: {dataset_name}") + + +def test_generate_document(hf_example, hf_dataset, dataset_name): + conll2003 = Conll2003() + generate_document_kwargs = conll2003._generate_document_kwargs(hf_dataset["train"]) + document = conll2003._generate_document(example=hf_example, **generate_document_kwargs) + assert isinstance(document, Document) + if dataset_name == "conll2003": + assert document.text == "EU rejects German call to boycott British lamb ." + entities = list(document.entities) + assert len(entities) == 3 + assert str(entities[0]) == "EU" + assert str(entities[1]) == "German" + assert str(entities[2]) == "British" + else: + raise ValueError(f"Unknown dataset name: {dataset_name}") + + +@pytest.fixture(scope="module") +def pie_dataset(dataset_name): + return DatasetDict.load_dataset(str(PIE_DATASET_PATH), name=dataset_name) + + +def test_dataset(pie_dataset): + assert set(pie_dataset) == SPLIT_NAMES + split_sizes = {split_name: len(ds) for split_name, ds in pie_dataset.items()} + assert split_sizes == SPLIT_SIZES + + +@pytest.fixture(scope="module", params=list(Conll2003.DOCUMENT_CONVERTERS)) +def converter_document_type(request): + return request.param + + +@pytest.fixture(scope="module") +def converted_pie_dataset(pie_dataset, converter_document_type): + pie_dataset_converted = pie_dataset.to_document_type(document_type=converter_document_type) + return pie_dataset_converted + + +def test_converted_pie_dataset(converted_pie_dataset, converter_document_type): + assert set(converted_pie_dataset) == SPLIT_NAMES + split_sizes = {split_name: len(ds) for split_name, ds in converted_pie_dataset.items()} + assert split_sizes == SPLIT_SIZES + for ds in converted_pie_dataset.values(): + for document in ds: + assert isinstance(document, converter_document_type) From 5d64047aa005e3e9d259ea590b8b728759ec32e7 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 2 Nov 2023 18:47:07 +0100 Subject: [PATCH 4/6] improve tests --- tests/dataset_builders/pie/test_conll2003.py | 28 ++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/dataset_builders/pie/test_conll2003.py b/tests/dataset_builders/pie/test_conll2003.py index 0e779db7..9bb5e8f1 100644 --- a/tests/dataset_builders/pie/test_conll2003.py +++ b/tests/dataset_builders/pie/test_conll2003.py @@ -2,6 +2,7 @@ import pytest from pytorch_ie import DatasetDict from pytorch_ie.core import Document +from pytorch_ie.documents import TextDocumentWithLabeledSpans from dataset_builders.pie.conll2003.conll2003 import Conll2003 from tests.dataset_builders.common import PIE_BASE_PATH @@ -47,10 +48,15 @@ def test_hf_example(hf_example, dataset_name): raise ValueError(f"Unknown dataset name: {dataset_name}") -def test_generate_document(hf_example, hf_dataset, dataset_name): +@pytest.fixture(scope="module") +def document(hf_example, hf_dataset): conll2003 = Conll2003() generate_document_kwargs = conll2003._generate_document_kwargs(hf_dataset["train"]) document = conll2003._generate_document(example=hf_example, **generate_document_kwargs) + return document + + +def test_document(document, dataset_name): assert isinstance(document, Document) if dataset_name == "conll2003": assert document.text == "EU rejects German call to boycott British lamb ." @@ -68,7 +74,7 @@ def pie_dataset(dataset_name): return DatasetDict.load_dataset(str(PIE_DATASET_PATH), name=dataset_name) -def test_dataset(pie_dataset): +def test_pie_dataset(pie_dataset): assert set(pie_dataset) == SPLIT_NAMES split_sizes = {split_name: len(ds) for split_name, ds in pie_dataset.items()} assert split_sizes == SPLIT_SIZES @@ -92,3 +98,21 @@ def test_converted_pie_dataset(converted_pie_dataset, converter_document_type): for ds in converted_pie_dataset.values(): for document in ds: assert isinstance(document, converter_document_type) + + +@pytest.fixture(scope="module") +def converted_document(converted_pie_dataset): + return converted_pie_dataset["train"][0] + + +def test_converted_document(converted_document, converter_document_type): + assert isinstance(document, converter_document_type) + if converter_document_type == TextDocumentWithLabeledSpans: + assert document.text == "EU rejects German call to boycott British lamb ." + entities = list(document.labeled_spans) + assert len(entities) == 3 + assert str(entities[0]) == "EU" + assert str(entities[1]) == "German" + assert str(entities[2]) == "British" + else: + raise ValueError(f"Unknown converter document type: {converter_document_type}") From 28d164beaca050ff1ed3bc41eac09964c790fbad Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 2 Nov 2023 18:59:37 +0100 Subject: [PATCH 5/6] fix test --- tests/dataset_builders/pie/test_conll2003.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/dataset_builders/pie/test_conll2003.py b/tests/dataset_builders/pie/test_conll2003.py index 9bb5e8f1..017eb8b8 100644 --- a/tests/dataset_builders/pie/test_conll2003.py +++ b/tests/dataset_builders/pie/test_conll2003.py @@ -106,10 +106,10 @@ def converted_document(converted_pie_dataset): def test_converted_document(converted_document, converter_document_type): - assert isinstance(document, converter_document_type) + assert isinstance(converted_document, converter_document_type) if converter_document_type == TextDocumentWithLabeledSpans: - assert document.text == "EU rejects German call to boycott British lamb ." - entities = list(document.labeled_spans) + assert converted_document.text == "EU rejects German call to boycott British lamb ." + entities = list(converted_document.labeled_spans) assert len(entities) == 3 assert str(entities[0]) == "EU" assert str(entities[1]) == "German" From 61a962b48f76df1e53fa7b456b93f2bf5ca0fe26 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 2 Nov 2023 20:00:21 +0100 Subject: [PATCH 6/6] add README.md --- dataset_builders/pie/conll2003/README.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 dataset_builders/pie/conll2003/README.md diff --git a/dataset_builders/pie/conll2003/README.md b/dataset_builders/pie/conll2003/README.md new file mode 100644 index 00000000..c8b5c4c1 --- /dev/null +++ b/dataset_builders/pie/conll2003/README.md @@ -0,0 +1,19 @@ +# PIE Dataset Card for "conll2003" + +This is a [PyTorch-IE](https://github.com/ChristophAlt/pytorch-ie) wrapper for the +[CoNLL 2003 Huggingface dataset loading script](https://huggingface.co/datasets/conll2003). + +## Data Schema + +The document type for this dataset is `CoNLL2003Document` which defines the following data fields: + +- `text` (str) +- `id` (str, optional) +- `metadata` (dictionary, optional) + +and the following annotation layers: + +- `entities` (annotation type: `LabeledSpan`, target: `text`) + +See [here](https://github.com/ChristophAlt/pytorch-ie/blob/main/src/pytorch_ie/annotations.py) for the definitions of +the annotation types.