Skip to content

Commit

Permalink
edit types.py and test_cdcp.py
Browse files Browse the repository at this point in the history
  • Loading branch information
idalr authored and ArneBinder committed Nov 9, 2023
1 parent 8729583 commit 4f74a05
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
31 changes: 26 additions & 5 deletions src/pie_datasets/document/types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
import dataclasses
import logging
from typing import Any, Dict, Optional
from typing import Optional

from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.core import Annotation, AnnotationList, Document, annotation_field
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
from pytorch_ie.core import Annotation, AnnotationList, annotation_field
from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument

logger = logging.getLogger(__name__)

@dataclasses.dataclass(eq=True, frozen=True)
class Attribute(Annotation):
annotation: Annotation
label: str
value: Optional[str] = None
score: Optional[float] = dataclasses.field(default=None, compare=False)


@dataclasses.dataclass
class BratDocument(TextBasedDocument):
spans: AnnotationList[LabeledMultiSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="spans")
span_attributes: AnnotationList[Attribute] = annotation_field(target="spans")
relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations")


@dataclasses.dataclass
class BratDocumentWithMergedSpans(TextBasedDocument):
spans: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="spans")
span_attributes: AnnotationList[Attribute] = annotation_field(target="spans")
relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations")


@dataclasses.dataclass
Expand Down
6 changes: 3 additions & 3 deletions tests/dataset_builders/pie/test_cdcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ def test_generated_document(generated_document, split):


@pytest.fixture(scope="module")
def reversed_generated_document(generated_document, generate_document_kwargs):
def hf_example_back(generated_document, generate_document_kwargs):
return document_to_example(generated_document, **generate_document_kwargs)


def test_example_to_document_and_back(hf_example, reversed_generated_document):
def test_example_to_document_and_back(hf_example, hf_example_back):
_deep_compare(
obj=reversed_generated_document,
obj=hf_example_back,
obj_expected=hf_example,
)

Expand Down

0 comments on commit 4f74a05

Please sign in to comment.