Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add conll2003 dataset #14

Merged
merged 6 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions dataset_builders/pie/conll2003/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# PIE Dataset Card for "conll2003"

This is a [PyTorch-IE](https://github.com/ChristophAlt/pytorch-ie) wrapper for the
[CoNLL 2003 Huggingface dataset loading script](https://huggingface.co/datasets/conll2003).

## Data Schema

The document type for this dataset is `CoNLL2003Document` which defines the following data fields:

- `text` (str)
- `id` (str, optional)
- `metadata` (dictionary, optional)

and the following annotation layers:

- `entities` (annotation type: `LabeledSpan`, target: `text`)

See [here](https://github.com/ChristophAlt/pytorch-ie/blob/main/src/pytorch_ie/annotations.py) for the definitions of
the annotation types.
49 changes: 49 additions & 0 deletions dataset_builders/pie/conll2003/conll2003.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from dataclasses import dataclass

import datasets
import pytorch_ie.data.builder
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.documents import TextDocument, TextDocumentWithLabeledSpans
from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans


@dataclass
class CoNLL2003Document(TextDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")


class Conll2003(pytorch_ie.data.builder.GeneratorBasedBuilder):
DOCUMENT_TYPE = CoNLL2003Document

BASE_DATASET_PATH = "conll2003"

BUILDER_CONFIGS = [
datasets.BuilderConfig(
name="conll2003", version=datasets.Version("1.0.0"), description="CoNLL2003 dataset"
),
]

DOCUMENT_CONVERTERS = {
TextDocumentWithLabeledSpans: {
# just rename the layer
"entities": "labeled_spans",
}
}

def _generate_document_kwargs(self, dataset):
return {"int_to_str": dataset.features["ner_tags"].feature.int2str}

def _generate_document(self, example, int_to_str):
doc_id = example["id"]
tokens = example["tokens"]
ner_tags = [int_to_str(tag) for tag in example["ner_tags"]]

text, ner_spans = tokens_and_tags_to_text_and_labeled_spans(tokens=tokens, tags=ner_tags)

document = CoNLL2003Document(text=text, id=doc_id)

for span in sorted(ner_spans, key=lambda span: span.start):
document.entities.append(span)

return document
118 changes: 118 additions & 0 deletions tests/dataset_builders/pie/test_conll2003.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import datasets
import pytest
from pytorch_ie import DatasetDict
from pytorch_ie.core import Document
from pytorch_ie.documents import TextDocumentWithLabeledSpans

from dataset_builders.pie.conll2003.conll2003 import Conll2003
from tests.dataset_builders.common import PIE_BASE_PATH

DATASET_NAME = "conll2003"
PIE_DATASET_PATH = PIE_BASE_PATH / DATASET_NAME
HF_DATASET_PATH = Conll2003.BASE_DATASET_PATH
SPLIT_NAMES = {"train", "validation", "test"}
SPLIT_SIZES = {"train": 14041, "validation": 3250, "test": 3453}


@pytest.fixture(params=[config.name for config in Conll2003.BUILDER_CONFIGS], scope="module")
def dataset_name(request):
return request.param


@pytest.fixture(scope="module")
def hf_dataset(dataset_name):
return datasets.load_dataset(str(HF_DATASET_PATH), name=dataset_name)


def test_hf_dataset(hf_dataset):
assert set(hf_dataset) == SPLIT_NAMES
split_sizes = {split_name: len(ds) for split_name, ds in hf_dataset.items()}
assert split_sizes == SPLIT_SIZES


@pytest.fixture(scope="module")
def hf_example(hf_dataset):
return hf_dataset["train"][0]


def test_hf_example(hf_example, dataset_name):
if dataset_name == "conll2003":
assert hf_example == {
"chunk_tags": [11, 21, 11, 12, 21, 22, 11, 12, 0],
"id": "0",
"ner_tags": [3, 0, 7, 0, 0, 0, 7, 0, 0],
"pos_tags": [22, 42, 16, 21, 35, 37, 16, 21, 7],
"tokens": ["EU", "rejects", "German", "call", "to", "boycott", "British", "lamb", "."],
}
else:
raise ValueError(f"Unknown dataset name: {dataset_name}")


@pytest.fixture(scope="module")
def document(hf_example, hf_dataset):
conll2003 = Conll2003()
generate_document_kwargs = conll2003._generate_document_kwargs(hf_dataset["train"])
document = conll2003._generate_document(example=hf_example, **generate_document_kwargs)
return document


def test_document(document, dataset_name):
assert isinstance(document, Document)
if dataset_name == "conll2003":
assert document.text == "EU rejects German call to boycott British lamb ."
entities = list(document.entities)
assert len(entities) == 3
assert str(entities[0]) == "EU"
assert str(entities[1]) == "German"
assert str(entities[2]) == "British"
else:
raise ValueError(f"Unknown dataset name: {dataset_name}")


@pytest.fixture(scope="module")
def pie_dataset(dataset_name):
return DatasetDict.load_dataset(str(PIE_DATASET_PATH), name=dataset_name)


def test_pie_dataset(pie_dataset):
assert set(pie_dataset) == SPLIT_NAMES
split_sizes = {split_name: len(ds) for split_name, ds in pie_dataset.items()}
assert split_sizes == SPLIT_SIZES


@pytest.fixture(scope="module", params=list(Conll2003.DOCUMENT_CONVERTERS))
def converter_document_type(request):
return request.param


@pytest.fixture(scope="module")
def converted_pie_dataset(pie_dataset, converter_document_type):
pie_dataset_converted = pie_dataset.to_document_type(document_type=converter_document_type)
return pie_dataset_converted


def test_converted_pie_dataset(converted_pie_dataset, converter_document_type):
assert set(converted_pie_dataset) == SPLIT_NAMES
split_sizes = {split_name: len(ds) for split_name, ds in converted_pie_dataset.items()}
assert split_sizes == SPLIT_SIZES
for ds in converted_pie_dataset.values():
for document in ds:
assert isinstance(document, converter_document_type)


@pytest.fixture(scope="module")
def converted_document(converted_pie_dataset):
return converted_pie_dataset["train"][0]


def test_converted_document(converted_document, converter_document_type):
assert isinstance(converted_document, converter_document_type)
if converter_document_type == TextDocumentWithLabeledSpans:
assert converted_document.text == "EU rejects German call to boycott British lamb ."
entities = list(converted_document.labeled_spans)
assert len(entities) == 3
assert str(entities[0]) == "EU"
assert str(entities[1]) == "German"
assert str(entities[2]) == "British"
else:
raise ValueError(f"Unknown converter document type: {converter_document_type}")
Loading