Skip to content

Commit

Permalink
define and use TextDocumentWithPartitions in test_regex_partitioner.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 3, 2023
1 parent aaea168 commit 8641cba
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/pie_datasets/document/processing/regex_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(
self,
pattern: str,
collect_statistics: bool = False,
partition_layer_name: str = "labeled_partitions",
partition_layer_name: str = "partitions",
text_field_name: str = "text",
**partitioner_kwargs,
):
Expand Down
39 changes: 23 additions & 16 deletions tests/unit/document/test_regex_partitioner.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import dataclasses
import json
import logging
from typing import Tuple

import pytest
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocumentWithLabeledPartitions
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.documents import TextBasedDocument

from pie_datasets.document.processing import RegexPartitioner
from pie_datasets.document.processing.regex_partitioner import (
_get_partitions_with_matcher,
)


@dataclasses.dataclass
class TextDocumentWithPartitions(TextBasedDocument):
partitions: AnnotationList[LabeledSpan] = annotation_field(target="text")


def have_overlap(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> bool:
other_start_overlaps = start_end[0] <= other_start_end[0] < start_end[1]
other_end_overlaps = start_end[0] < other_start_end[1] <= start_end[1]
Expand All @@ -31,10 +38,10 @@ def test_regex_partitioner():
)
# The document contains a text separated by some markers like <start>, <middle> and <end>. RegexPartitioner
# partitions the text based on the given pattern. After partitioning, there are be four partitions with same label.
document = TextDocumentWithLabeledPartitions(text=TEXT1)
document = TextDocumentWithPartitions(text=TEXT1)
new_document = regex_partitioner(document)

partitions = new_document.labeled_partitions
partitions = new_document.partitions
labels = [partition.label for partition in partitions]
assert len(partitions) == 4
assert labels == ["partition"] * len(partitions)
Expand Down Expand Up @@ -64,13 +71,13 @@ def test_regex_partitioner_with_statistics(caplog):

# The document contains a text separated by some markers like <start>, <middle> and <end>. After partitioning, there
# are three partitions excluding initial part. Therefore, document length is not be equal to sum of partitions.
document = TextDocumentWithLabeledPartitions(text=TEXT1)
document = TextDocumentWithPartitions(text=TEXT1)
caplog.set_level(logging.INFO)
caplog.clear()
regex_partitioner.enter_dataset(None)
new_document = regex_partitioner(document)
regex_partitioner.exit_dataset(None)
partitions = new_document.labeled_partitions
partitions = new_document.partitions
assert len(partitions) == 3

assert len(caplog.records) == 1
Expand All @@ -91,13 +98,13 @@ def test_regex_partitioner_with_statistics(caplog):
# from each document, therefore statistics contains information from previous document as well. After partitioning,
# there are two partitions excluding initial part. Therefore, the sum of document lengths is not be equal to sum of
# partitions.
document = TextDocumentWithLabeledPartitions(text=TEXT2)
document = TextDocumentWithPartitions(text=TEXT2)
caplog.set_level(logging.INFO)
caplog.clear()
regex_partitioner.enter_dataset(None)
new_document = regex_partitioner(document)
regex_partitioner.exit_dataset(None)
partitions = new_document.labeled_partitions
partitions = new_document.partitions
assert len(partitions) == 2

assert len(caplog.records) == 1
Expand Down Expand Up @@ -133,9 +140,9 @@ def test_regex_partitioner_without_label_group_id(label_whitelist, skip_initial_
)
# The document contains a text separated by some markers like <start>, <middle> and <end>. Since label_group_id is
# None, the partitions (if any) will have same label.
document = TextDocumentWithLabeledPartitions(text=TEXT1)
document = TextDocumentWithPartitions(text=TEXT1)
new_document = regex_partitioner(document)
partitions = new_document.labeled_partitions
partitions = new_document.partitions
assert [partition.label for partition in partitions] == ["partition"] * len(partitions)
if skip_initial_partition:
if label_whitelist == ["<start>", "<middle>", "<end>"] or label_whitelist == []:
Expand Down Expand Up @@ -194,9 +201,9 @@ def test_regex_partitioner_with_label_group_id(label_whitelist, skip_initial_par
)
# The document contains a text separated by some markers like <start>, <middle> and <end>. Possible partitions can
# be four including the initial partition.
document = TextDocumentWithLabeledPartitions(text=TEXT1)
document = TextDocumentWithPartitions(text=TEXT1)
new_document = regex_partitioner(document)
partitions = new_document.labeled_partitions
partitions = new_document.partitions
labels = [partition.label for partition in partitions]
if skip_initial_partition:
if label_whitelist == ["<start>", "<end>"] or label_whitelist == [
Expand Down Expand Up @@ -283,10 +290,10 @@ def test_regex_partitioner_with_no_match_found(skip_initial_partition, label_whi
)
# The document contains a text separated by some markers like <start> and <end>. Only possible partition in the
# document based on the given pattern is the initial partition.
document = TextDocumentWithLabeledPartitions(text=TEXT2)
document = TextDocumentWithPartitions(text=TEXT2)
new_document = regex_partitioner(document)

partitions = new_document.labeled_partitions
partitions = new_document.partitions
if skip_initial_partition:
# No matter what the value of label_whitelist is, there will be no partition created, since the given pattern
# is not in the document and skip_initial_partition is True.
Expand Down Expand Up @@ -324,7 +331,7 @@ def test_get_partitions_with_matcher():
# The document contains a text separated by some markers like <start>, <middle> and <end>. finditer method is used
# which returns non overlapping match from the text. Therefore, none of the partition created should have overlapped
# span and all of them should be instances of LabeledSpan.
document = TextDocumentWithLabeledPartitions(text=TEXT1)
document = TextDocumentWithPartitions(text=TEXT1)
partitions = []
for partition in _get_partitions_with_matcher(
text=document.text,
Expand Down Expand Up @@ -358,10 +365,10 @@ def test_regex_partitioner_with_strip_whitespace(strip_whitespace, verbose, capl
strip_whitespace=strip_whitespace,
verbose=verbose,
)
document = TextDocumentWithLabeledPartitions(text=TEXT1)
document = TextDocumentWithPartitions(text=TEXT1)
new_document = regex_partitioner(document)

partitions = new_document.labeled_partitions
partitions = new_document.partitions
labels = [partition.label for partition in partitions]
if strip_whitespace:
assert len(partitions) == 3
Expand Down

0 comments on commit 8641cba

Please sign in to comment.