diff --git a/src/pytorch_ie/metrics/f1.py b/src/pytorch_ie/metrics/f1.py index 5610b310..721661e2 100644 --- a/src/pytorch_ie/metrics/f1.py +++ b/src/pytorch_ie/metrics/f1.py @@ -1,11 +1,12 @@ import logging from collections import defaultdict from functools import partial -from typing import Callable, Collection, Dict, Optional, Tuple, Union +from typing import Callable, Collection, Dict, Hashable, Optional, Tuple, Union import pandas as pd from pytorch_ie.core import Annotation, Document, DocumentMetric +from pytorch_ie.utils.hydra import resolve_target logger = logging.getLogger(__name__) @@ -35,11 +36,15 @@ def __init__( labels: Optional[Union[Collection[str], str]] = None, label_field: str = "label", show_as_markdown: bool = False, + annotation_processor: Optional[Union[Callable[[Annotation], Hashable], str]] = None, ): super().__init__() self.layer = layer self.label_field = label_field self.show_as_markdown = show_as_markdown + if isinstance(annotation_processor, str): + annotation_processor = resolve_target(annotation_processor) + self.annotation_processor = annotation_processor self.per_label = labels is not None self.infer_labels = False @@ -71,12 +76,18 @@ def calculate_counts( self, document: Document, annotation_filter: Optional[Callable[[Annotation], bool]] = None, + annotation_processor: Optional[Callable[[Annotation], Hashable]] = None, ) -> Tuple[int, int, int]: + annotation_processor = annotation_processor or (lambda ann: ann) annotation_filter = annotation_filter or (lambda ann: True) predicted_annotations = { - ann for ann in document[self.layer].predictions if annotation_filter(ann) + annotation_processor(ann) + for ann in document[self.layer].predictions + if annotation_filter(ann) + } + gold_annotations = { + annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann) } - gold_annotations = {ann for ann in document[self.layer] if annotation_filter(ann)} tp = len([ann for ann in predicted_annotations & gold_annotations]) fn = len([ann for ann in gold_annotations - predicted_annotations]) fp = len([ann for ann in predicted_annotations - gold_annotations]) @@ -97,6 +108,7 @@ def _update(self, document: Document): ) if self.per_label and not self.infer_labels else None, + annotation_processor=self.annotation_processor, ) self.add_counts(new_counts, label="MICRO") if self.infer_labels: @@ -111,6 +123,7 @@ def _update(self, document: Document): annotation_filter=partial( has_this_label, label_field=self.label_field, label=label ), + annotation_processor=self.annotation_processor, ) self.add_counts(new_counts, label=label)