-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* re-add statistics: TokenCountCollector, FieldLengthCollector, SubFieldLengthCollector, DummyCollector, LabelCountCollector * revert: change document_dataset fixture scope * fix document_dataset fixture * cleanup tests
- Loading branch information
1 parent
5924d7b
commit e358c55
Showing
2 changed files
with
222 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import logging | ||
from collections import defaultdict | ||
from typing import Any, Callable, Dict, List, Optional, Type, Union | ||
|
||
from transformers import AutoTokenizer, PreTrainedTokenizer | ||
|
||
from pytorch_ie.core import Document, DocumentStatistic | ||
from pytorch_ie.documents import TextBasedDocument | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TokenCountCollector(DocumentStatistic): | ||
"""Collects the token count of a field when tokenizing its content with a Huggingface | ||
tokenizer. | ||
The content of the field should be a string. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
tokenizer: Union[str, PreTrainedTokenizer], | ||
text_field: str = "text", | ||
tokenizer_kwargs: Optional[Dict[str, Any]] = None, | ||
document_type: Optional[Type[Document]] = None, | ||
**kwargs, | ||
): | ||
if document_type is None and text_field == "text": | ||
document_type = TextBasedDocument | ||
super().__init__(document_type=document_type, **kwargs) | ||
self.tokenizer = ( | ||
AutoTokenizer.from_pretrained(tokenizer) if isinstance(tokenizer, str) else tokenizer | ||
) | ||
self.tokenizer_kwargs = tokenizer_kwargs or {} | ||
self.text_field = text_field | ||
|
||
def _collect(self, doc: Document) -> int: | ||
text = getattr(doc, self.text_field) | ||
encodings = self.tokenizer(text, **self.tokenizer_kwargs) | ||
tokens = encodings.tokens() | ||
return len(tokens) | ||
|
||
|
||
class FieldLengthCollector(DocumentStatistic): | ||
"""Collects the length of a field, e.g. to collect the number the characters in the input text. | ||
The field should be a list of sized elements. | ||
""" | ||
|
||
def __init__(self, field: str, **kwargs): | ||
super().__init__(**kwargs) | ||
self.field = field | ||
|
||
def _collect(self, doc: Document) -> int: | ||
field_obj = getattr(doc, self.field) | ||
return len(field_obj) | ||
|
||
|
||
class SubFieldLengthCollector(DocumentStatistic): | ||
"""Collects the length of a subfield in a field, e.g. to collect the number of arguments of | ||
N-ary relations.""" | ||
|
||
def __init__(self, field: str, subfield: str, **kwargs): | ||
super().__init__(**kwargs) | ||
self.field = field | ||
self.subfield = subfield | ||
|
||
def _collect(self, doc: Document) -> List[int]: | ||
field_obj = getattr(doc, self.field) | ||
lengths = [] | ||
for entry in field_obj: | ||
subfield_obj = getattr(entry, self.subfield) | ||
lengths.append(len(subfield_obj)) | ||
return lengths | ||
|
||
|
||
class DummyCollector(DocumentStatistic): | ||
"""A dummy collector that always returns 1, e.g. to count the number of documents. | ||
Can be used to count the number of documents. | ||
""" | ||
|
||
DEFAULT_AGGREGATION_FUNCTIONS = ["sum"] | ||
|
||
def _collect(self, doc: Document) -> int: | ||
return 1 | ||
|
||
|
||
class LabelCountCollector(DocumentStatistic): | ||
"""Collects the number of field entries per label, e.g. to collect the number of entities per | ||
type. | ||
The field should be a list of elements with a label attribute. | ||
Important: To make correct use of the result data, missing values need to be filled with 0, e.g.: | ||
{("ORG",): [2, 3], ("LOC",): [2]} -> {("ORG",): [2, 3], ("LOC",): [2, 0]} | ||
""" | ||
|
||
DEFAULT_AGGREGATION_FUNCTIONS = ["mean", "std", "min", "max", "len", "sum"] | ||
|
||
def __init__( | ||
self, field: str, labels: Union[List[str], str], label_attribute: str = "label", **kwargs | ||
): | ||
super().__init__(**kwargs) | ||
self.field = field | ||
self.label_attribute = label_attribute | ||
if not (isinstance(labels, list) or labels == "INFERRED"): | ||
raise ValueError("labels must be a list of strings or 'INFERRED'") | ||
if labels == "INFERRED": | ||
logger.warning( | ||
f"Inferring labels with {self.__class__.__name__} from data produces wrong results " | ||
f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values " | ||
f"are not included in the calculation. We remove these aggregation functions from " | ||
f"this collector, but be aware that the results may be wrong for your own aggregation " | ||
f"functions that rely on zero values." | ||
) | ||
self.aggregation_functions: Dict[str, Callable[[List], Any]] = { | ||
name: func | ||
for name, func in self.aggregation_functions.items() | ||
if name not in ["mean", "std", "min"] | ||
} | ||
|
||
self.labels = labels | ||
|
||
def _collect(self, doc: Document) -> Dict[str, int]: | ||
field_obj = getattr(doc, self.field) | ||
counts: Dict[str, int] | ||
if self.labels == "INFERRED": | ||
counts = defaultdict(int) | ||
else: | ||
counts = {label: 0 for label in self.labels} | ||
for elem in field_obj: | ||
label = getattr(elem, self.label_attribute) | ||
counts[label] += 1 | ||
return dict(counts) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from pytorch_ie.metrics.statistics import ( | ||
DummyCollector, | ||
FieldLengthCollector, | ||
LabelCountCollector, | ||
SubFieldLengthCollector, | ||
TokenCountCollector, | ||
) | ||
|
||
|
||
def test_statistics(document_dataset): | ||
statistic = DummyCollector() | ||
values = statistic(document_dataset) | ||
assert values == {"test": {"sum": 2}, "train": {"sum": 8}, "val": {"sum": 2}} | ||
|
||
# note that we check for labels=["LOC", "PER", "ORG"], but the actual labels in the data are just ["PER", "ORG"] | ||
statistic = LabelCountCollector(field="entities", labels=["LOC", "PER", "ORG"]) | ||
values = statistic(document_dataset) | ||
assert values == { | ||
"test": { | ||
"LOC": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 2, "sum": 0}, | ||
"PER": {"mean": 0.5, "std": 0.5, "min": 0, "max": 1, "len": 2, "sum": 1}, | ||
"ORG": {"mean": 1.0, "std": 1.0, "min": 0, "max": 2, "len": 2, "sum": 2}, | ||
}, | ||
"val": { | ||
"LOC": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 2, "sum": 0}, | ||
"PER": {"mean": 0.5, "std": 0.5, "min": 0, "max": 1, "len": 2, "sum": 1}, | ||
"ORG": {"mean": 1.0, "std": 1.0, "min": 0, "max": 2, "len": 2, "sum": 2}, | ||
}, | ||
"train": { | ||
"LOC": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 8, "sum": 0}, | ||
"PER": { | ||
"mean": 0.875, | ||
"std": 0.5994789404140899, | ||
"min": 0, | ||
"max": 2, | ||
"len": 8, | ||
"sum": 7, | ||
}, | ||
"ORG": { | ||
"mean": 1.125, | ||
"std": 0.7806247497997998, | ||
"min": 0, | ||
"max": 2, | ||
"len": 8, | ||
"sum": 9, | ||
}, | ||
}, | ||
} | ||
|
||
statistic = LabelCountCollector(field="entities", labels="INFERRED") | ||
values = statistic(document_dataset) | ||
assert values == { | ||
"test": {"PER": {"max": 1, "len": 1, "sum": 1}, "ORG": {"max": 2, "len": 1, "sum": 2}}, | ||
"val": {"PER": {"max": 1, "len": 1, "sum": 1}, "ORG": {"max": 2, "len": 1, "sum": 2}}, | ||
"train": {"PER": {"max": 2, "len": 6, "sum": 7}, "ORG": {"max": 2, "len": 6, "sum": 9}}, | ||
} | ||
|
||
statistic = FieldLengthCollector(field="text") | ||
values = statistic(document_dataset) | ||
assert values == { | ||
"test": {"max": 51, "mean": 34.5, "min": 18, "std": 16.5}, | ||
"train": {"max": 54, "mean": 28.25, "min": 15, "std": 14.694812009685595}, | ||
"val": {"max": 51, "mean": 34.5, "min": 18, "std": 16.5}, | ||
} | ||
|
||
# this is not super useful, we just collect the lengths of the labels, but it is enough to test the code | ||
statistic = SubFieldLengthCollector(field="entities", subfield="label") | ||
values = statistic(document_dataset) | ||
assert values == { | ||
"test": {"max": 3, "mean": 3.0, "min": 3, "std": 0.0}, | ||
"train": {"max": 3, "mean": 3.0, "min": 3, "std": 0.0}, | ||
"val": {"max": 3, "mean": 3.0, "min": 3, "std": 0.0}, | ||
} | ||
|
||
|
||
def test_statistics_with_tokenize(document_dataset): | ||
statistic = TokenCountCollector( | ||
text_field="text", | ||
tokenizer="bert-base-uncased", | ||
tokenizer_kwargs=dict(add_special_tokens=False), | ||
) | ||
values = statistic(document_dataset) | ||
assert values == { | ||
"test": {"max": 13, "mean": 8.5, "min": 4, "std": 4.5}, | ||
"train": {"max": 14, "mean": 7.75, "min": 4, "std": 3.6314597615834874}, | ||
"val": {"max": 13, "mean": 8.5, "min": 4, "std": 4.5}, | ||
} |