Skip to content

Commit

Permalink
sciarg: fix partitioning (#159)
Browse files Browse the repository at this point in the history
* test label counts for labeled_partitions in sciarg

* fix and simplify tests

* fix allowing newlines between matching tags (important for abstract)
  • Loading branch information
ArneBinder authored Nov 1, 2024
1 parent 06e3af5 commit 2b3ba52
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
3 changes: 2 additions & 1 deletion dataset_builders/pie/sciarg/sciarg.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def _generate_document(self, example, **kwargs):
def document_converters(self) -> DocumentConvertersType:
regex_partitioner = RegexPartitioner(
partition_layer_name="labeled_partitions",
pattern="<([^>/]+)>.*</\\1>",
# find matching tags, allow newlines in between (s flag) and capture the tag name
pattern="<([^>/]+)>(?s:.)*?</\\1>",
label_group_id=1,
label_whitelist=["Title", "Abstract", "H1"],
skip_initial_partition=True,
Expand Down
34 changes: 21 additions & 13 deletions tests/dataset_builders/pie/sciarg/test_sciarg.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@
"spans": {"background_claim": 2752, "data": 4093, "own_claim": 5450},
},
}
CONVERTED_LAYER_MAPPING = {
"default": {
"spans": "labeled_spans",
"relations": "binary_relations",
},
"resolve_parts_of_same": {
"spans": "labeled_multi_spans",
"relations": "binary_relations",
},
}
FULL_LABEL_COUNTS_CONVERTED = {
variant: {CONVERTED_LAYER_MAPPING[variant][ln]: value for ln, value in counts.items()}
for variant, counts in FULL_LABEL_COUNTS.items()
}
LABELED_PARTITION_COUNTS = {"Abstract": 40, "H1": 340, "Title": 40}


def resolve_annotation(annotation: Annotation) -> Any:
Expand Down Expand Up @@ -257,32 +272,25 @@ def converted_dataset(dataset, target_document_type) -> Optional[DatasetDict]:
return dataset.to_document_type(target_document_type)


def test_converted_datasets(converted_dataset, dataset_variant):
def test_converted_datasets(converted_dataset, dataset_variant, target_document_type):
if converted_dataset is not None:
split_sizes = {name: len(ds) for name, ds in converted_dataset.items()}
assert split_sizes == SPLIT_SIZES
if dataset_variant == "default":
expected_document_type = TextDocumentWithLabeledSpansAndBinaryRelations
layer_name_mapping = {
"spans": "labeled_spans",
"relations": "binary_relations",
}
elif dataset_variant == "resolve_parts_of_same":
expected_document_type = TextDocumentWithLabeledMultiSpansAndBinaryRelations
layer_name_mapping = {
"spans": "labeled_multi_spans",
"relations": "binary_relations",
}
else:
raise ValueError(f"Unknown dataset variant: {dataset_variant}")

assert issubclass(converted_dataset.document_type, expected_document_type)
assert isinstance(converted_dataset["train"][0], expected_document_type)

if TEST_FULL_DATASET:
expected_label_counts = {
layer_name_mapping[ln]: value
for ln, value in FULL_LABEL_COUNTS[dataset_variant].items()
}
# copy to avoid modifying the original dict
expected_label_counts = {**FULL_LABEL_COUNTS_CONVERTED[dataset_variant]}
if issubclass(target_document_type, TextDocumentWithLabeledPartitions):
expected_label_counts["labeled_partitions"] = LABELED_PARTITION_COUNTS
assert_dataset_label_counts(converted_dataset, expected_label_counts)


Expand Down

0 comments on commit 2b3ba52

Please sign in to comment.