From 4f74a057110667d460e210509b42c3eded5ce1f0 Mon Sep 17 00:00:00 2001 From: Ruangrin L <88072261+idalr@users.noreply.github.com> Date: Thu, 9 Nov 2023 16:38:45 +0100 Subject: [PATCH] edit types.py and test_cdcp.py --- src/pie_datasets/document/types.py | 31 +++++++++++++++++++++---- tests/dataset_builders/pie/test_cdcp.py | 6 ++--- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/pie_datasets/document/types.py b/src/pie_datasets/document/types.py index f8a2a353..983d80ff 100644 --- a/src/pie_datasets/document/types.py +++ b/src/pie_datasets/document/types.py @@ -1,12 +1,33 @@ import dataclasses -import logging -from typing import Any, Dict, Optional +from typing import Optional -from pytorch_ie.annotations import BinaryRelation, LabeledSpan -from pytorch_ie.core import Annotation, AnnotationList, Document, annotation_field +from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan +from pytorch_ie.core import Annotation, AnnotationList, annotation_field from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument -logger = logging.getLogger(__name__) + +@dataclasses.dataclass(eq=True, frozen=True) +class Attribute(Annotation): + annotation: Annotation + label: str + value: Optional[str] = None + score: Optional[float] = dataclasses.field(default=None, compare=False) + + +@dataclasses.dataclass +class BratDocument(TextBasedDocument): + spans: AnnotationList[LabeledMultiSpan] = annotation_field(target="text") + relations: AnnotationList[BinaryRelation] = annotation_field(target="spans") + span_attributes: AnnotationList[Attribute] = annotation_field(target="spans") + relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations") + + +@dataclasses.dataclass +class BratDocumentWithMergedSpans(TextBasedDocument): + spans: AnnotationList[LabeledSpan] = annotation_field(target="text") + relations: AnnotationList[BinaryRelation] = annotation_field(target="spans") + span_attributes: AnnotationList[Attribute] = annotation_field(target="spans") + relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations") @dataclasses.dataclass diff --git a/tests/dataset_builders/pie/test_cdcp.py b/tests/dataset_builders/pie/test_cdcp.py index 758d3d7b..e2cb01f1 100644 --- a/tests/dataset_builders/pie/test_cdcp.py +++ b/tests/dataset_builders/pie/test_cdcp.py @@ -121,13 +121,13 @@ def test_generated_document(generated_document, split): @pytest.fixture(scope="module") -def reversed_generated_document(generated_document, generate_document_kwargs): +def hf_example_back(generated_document, generate_document_kwargs): return document_to_example(generated_document, **generate_document_kwargs) -def test_example_to_document_and_back(hf_example, reversed_generated_document): +def test_example_to_document_and_back(hf_example, hf_example_back): _deep_compare( - obj=reversed_generated_document, + obj=hf_example_back, obj_expected=hf_example, )