diff --git a/src/pytorch_ie/annotations.py b/src/pytorch_ie/annotations.py index b275de43..362cac6e 100644 --- a/src/pytorch_ie/annotations.py +++ b/src/pytorch_ie/annotations.py @@ -92,17 +92,6 @@ def __post_init__(self) -> None: _post_init_multi_label(self) -@dataclass(eq=True, frozen=True) -class MultiLabeledMultiSpan(Annotation): - slices: Tuple[Tuple[int, int], ...] - label: Tuple[str, ...] - score: Optional[Tuple[float, ...]] = field(default=None, compare=False) - - def __post_init__(self) -> None: - _post_init_multi_span(self) - _post_init_multi_label(self) - - @dataclass(eq=True, frozen=True) class BinaryRelation(Annotation): head: Annotation diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 5b3c0627..85af40e2 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -8,7 +8,6 @@ LabeledSpan, MultiLabel, MultiLabeledBinaryRelation, - MultiLabeledMultiSpan, MultiLabeledSpan, Span, ) @@ -125,38 +124,6 @@ def test_multilabeled_span(): MultiLabeledSpan(start=5, end=6, label=("label5", "label6"), score=(0.1, 0.2, 0.3)) -def test_multilabeled_multi_span(): - multilabeled_multi_span1 = MultiLabeledMultiSpan( - slices=((1, 2), (3, 4)), label=("label1", "label2") - ) - assert multilabeled_multi_span1.slices == ((1, 2), (3, 4)) - assert multilabeled_multi_span1.label == ("label1", "label2") - assert multilabeled_multi_span1.score == pytest.approx((1.0, 1.0)) - - multilabeled_multi_span2 = MultiLabeledMultiSpan( - slices=((5, 6), (7, 8)), label=("label3", "label4"), score=(0.4, 0.5) - ) - assert multilabeled_multi_span2.slices == ((5, 6), (7, 8)) - assert multilabeled_multi_span2.label == ("label3", "label4") - assert multilabeled_multi_span2.score == pytest.approx((0.4, 0.5)) - - assert multilabeled_multi_span2.asdict() == { - "_id": multilabeled_multi_span2._id, - "slices": ((5, 6), (7, 8)), - "label": ("label3", "label4"), - "score": (0.4, 0.5), - } - - _test_annotation_reconstruction(multilabeled_multi_span2) - - with pytest.raises( - ValueError, match=re.escape("Number of labels (2) and scores (3) must be equal.") - ): - MultiLabeledMultiSpan( - slices=((9, 10), (11, 12)), label=("label5", "label6"), score=(0.1, 0.2, 0.3) - ) - - def test_binary_relation(): head = Span(start=1, end=2) tail = Span(start=3, end=4)