Skip to content

Commit

Permalink
Merge pull request #1 from ArneBinder/tacred_dataset
Browse files Browse the repository at this point in the history
add TACRED dataset
  • Loading branch information
ArneBinder authored Oct 24, 2023
2 parents f4d6697 + 8986e3b commit e1c4beb
Show file tree
Hide file tree
Showing 30 changed files with 6,290 additions and 0 deletions.
207 changes: 207 additions & 0 deletions dataset_builders/pie/tacred/tacred.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple

import datasets
import pytorch_ie.data.builder
from pytorch_ie import token_based_document_to_text_based
from pytorch_ie.annotations import BinaryRelation, LabeledSpan, _post_init_single_label
from pytorch_ie.core import Annotation, AnnotationList, Document, annotation_field
from pytorch_ie.documents import (
TextDocumentWithLabeledSpansAndBinaryRelations,
TokenBasedDocument,
)


@dataclass(eq=True, frozen=True)
class TokenRelation(Annotation):
head_idx: int
tail_idx: int
label: str
score: float = 1.0

def __post_init__(self) -> None:
_post_init_single_label(self)


@dataclass(eq=True, frozen=True)
class TokenAttribute(Annotation):
idx: int
label: str


@dataclass
class TacredDocument(Document):
tokens: Tuple[str, ...]
id: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
stanford_ner: AnnotationList[TokenAttribute] = annotation_field(target="tokens")
stanford_pos: AnnotationList[TokenAttribute] = annotation_field(target="tokens")
entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
dependency_relations: AnnotationList[TokenRelation] = annotation_field(target="tokens")


@dataclass
class SimpleTacredDocument(TokenBasedDocument):
labeled_spans: AnnotationList[LabeledSpan] = annotation_field(target="tokens")
binary_relations: AnnotationList[BinaryRelation] = annotation_field(target="labeled_spans")


def example_to_document(
example: Dict[str, Any],
relation_int2str: Callable[[int], str],
ner_int2str: Callable[[int], str],
) -> TacredDocument:
document = TacredDocument(
tokens=tuple(example["token"]), id=example["id"], metadata=dict(doc_id=example["docid"])
)

for idx, (ner, pos) in enumerate(zip(example["stanford_ner"], example["stanford_pos"])):
document.stanford_ner.append(TokenAttribute(idx=idx, label=ner))
document.stanford_pos.append(TokenAttribute(idx=idx, label=pos))

for tail_idx, (deprel_label, head_idx) in enumerate(
zip(example["stanford_deprel"], example["stanford_head"])
):
if head_idx >= 0:
document.dependency_relations.append(
TokenRelation(
head_idx=head_idx,
tail_idx=tail_idx,
label=deprel_label,
)
)

head = LabeledSpan(
start=example["subj_start"],
end=example["subj_end"],
label=ner_int2str(example["subj_type"]),
)
tail = LabeledSpan(
start=example["obj_start"],
end=example["obj_end"],
label=ner_int2str(example["obj_type"]),
)
document.entities.append(head)
document.entities.append(tail)

relation_str = relation_int2str(example["relation"])
relation = BinaryRelation(head=head, tail=tail, label=relation_str)
document.relations.append(relation)

return document


def _entity_to_dict(
entity: LabeledSpan, key_prefix: str = "", label_mapping: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
return {
f"{key_prefix}start": entity.start,
f"{key_prefix}end": entity.end,
f"{key_prefix}type": label_mapping[entity.label]
if label_mapping is not None
else entity.label,
}


def document_to_example(
document: TacredDocument,
ner_names: Optional[List[str]] = None,
relation_names: Optional[List[str]] = None,
) -> Dict[str, Any]:
ner2idx = {name: idx for idx, name in enumerate(ner_names)} if ner_names is not None else None
rel2idx = (
{name: idx for idx, name in enumerate(relation_names)}
if relation_names is not None
else None
)

token = list(document.tokens)
stanford_ner_dict = {ner.idx: ner.label for ner in document.stanford_ner}
stanford_pos_dict = {pos.idx: pos.label for pos in document.stanford_pos}
stanford_ner = [stanford_ner_dict[idx] for idx in range(len(token))]
stanford_pos = [stanford_pos_dict[idx] for idx in range(len(token))]

stanford_deprel = ["ROOT"] * len(document.tokens)
stanford_head = [-1] * len(document.tokens)
for dep_rel in document.dependency_relations:
stanford_deprel[dep_rel.tail_idx] = dep_rel.label
stanford_head[dep_rel.tail_idx] = dep_rel.head_idx

rel = document.relations[0]
obj: LabeledSpan = rel.tail
subj: LabeledSpan = rel.head
return {
"id": document.id,
"docid": document.metadata["doc_id"],
"relation": rel.label if rel2idx is None else rel2idx[rel.label],
"token": token,
"stanford_ner": stanford_ner,
"stanford_pos": stanford_pos,
"stanford_deprel": stanford_deprel,
"stanford_head": stanford_head,
**_entity_to_dict(obj, key_prefix="obj_", label_mapping=ner2idx),
**_entity_to_dict(subj, key_prefix="subj_", label_mapping=ner2idx),
}


def convert_to_text_document_with_labeled_spans_and_binary_relations(
document: TacredDocument,
) -> TextDocumentWithLabeledSpansAndBinaryRelations:
doc_simplified = document.as_type(
SimpleTacredDocument,
field_mapping={"entities": "labeled_spans", "relations": "binary_relations"},
)
result = token_based_document_to_text_based(
doc_simplified,
result_document_type=TextDocumentWithLabeledSpansAndBinaryRelations,
join_tokens_with=" ",
)
return result


class TacredConfig(datasets.BuilderConfig):
"""BuilderConfig for Tacred."""

def __init__(self, **kwargs):
"""BuilderConfig for Tacred.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super().__init__(**kwargs)


class Tacred(pytorch_ie.data.builder.GeneratorBasedBuilder):
DOCUMENT_TYPE = TacredDocument

DOCUMENT_CONVERTERS = {
TextDocumentWithLabeledSpansAndBinaryRelations: convert_to_text_document_with_labeled_spans_and_binary_relations,
}

BASE_DATASET_PATH = "DFKI-SLT/tacred"

BUILDER_CONFIGS = [
TacredConfig(
name="original", version=datasets.Version("1.1.0"), description="The original TACRED."
),
TacredConfig(
name="revisited",
version=datasets.Version("1.1.0"),
description="The revised TACRED (corrected labels in dev and test split).",
),
TacredConfig(
name="re-tacred",
version=datasets.Version("1.1.0"),
description="Relabeled TACRED (corrected labels for all splits and pruned)",
),
]

def _generate_document_kwargs(self, dataset):
return {
"ner_int2str": dataset.features["subj_type"].int2str,
"relation_int2str": dataset.features["relation"].int2str,
}

def _generate_document(self, example, **kwargs):
return example_to_document(example, **kwargs)
69 changes: 69 additions & 0 deletions tests/dataset_builders/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import json
import logging
import os
import re
from pathlib import Path
from typing import List, Optional

from tests import FIXTURES_ROOT

DATASET_BUILDER_BASE_PATH = Path("dataset_builders")
HF_BASE_PATH = DATASET_BUILDER_BASE_PATH / "hf"
PIE_BASE_PATH = DATASET_BUILDER_BASE_PATH / "pie"
HF_DS_FIXTURE_DATA_PATH = FIXTURES_ROOT / "dataset_builders" / "hf"

logger = logging.getLogger(__name__)


def _deep_compare(
obj,
obj_expected,
path: Optional[str] = None,
excluded_paths: Optional[List[str]] = None,
enforce_equal_dict_keys: bool = True,
):
if path is not None and excluded_paths is not None:
for excluded_path in excluded_paths:
if re.match(excluded_path, path):
return

if type(obj) != type(obj_expected):
raise AssertionError(f"{path}: {obj} != {obj_expected}")
if isinstance(obj, (list, tuple)):
if len(obj) != len(obj_expected):
raise AssertionError(f"{path}: {obj} != {obj_expected}")
for i in range(len(obj)):
_deep_compare(
obj[i],
obj_expected[i],
path=f"{path}.{i}" if path is not None else str(i),
excluded_paths=excluded_paths,
enforce_equal_dict_keys=enforce_equal_dict_keys,
)
elif isinstance(obj, dict):
if enforce_equal_dict_keys and obj.keys() != obj_expected.keys():
raise AssertionError(f"{path}: {obj} != {obj_expected}")
for k in set(obj) | set(obj_expected):
_deep_compare(
obj.get(k, None),
obj_expected.get(k, None),
path=f"{path}.{k}" if path is not None else str(k),
excluded_paths=excluded_paths,
enforce_equal_dict_keys=enforce_equal_dict_keys,
)
else:
if obj != obj_expected:
raise AssertionError(f"{path}: {obj} != {obj_expected}")


def _dump_json(obj, fn):
logger.warning(f"dump fixture data: {fn}")
os.makedirs(os.path.dirname(fn), exist_ok=True)
with open(fn, "w") as f:
json.dump(obj, f, indent=2, sort_keys=True)


def _load_json(fn: str):
with open(fn) as f:
ex = json.load(f)
return ex
Loading

0 comments on commit e1c4beb

Please sign in to comment.