diff --git a/dataset_builders/pie/sciarg/sciarg.py b/dataset_builders/pie/sciarg/sciarg.py index dc0e2a7d..002a8fff 100644 --- a/dataset_builders/pie/sciarg/sciarg.py +++ b/dataset_builders/pie/sciarg/sciarg.py @@ -123,7 +123,8 @@ def _generate_document(self, example, **kwargs): def document_converters(self) -> DocumentConvertersType: regex_partitioner = RegexPartitioner( partition_layer_name="labeled_partitions", - pattern="<([^>/]+)>.*", + # find matching tags, allow newlines in between (s flag) and capture the tag name + pattern="<([^>/]+)>(?s:.)*?", label_group_id=1, label_whitelist=["Title", "Abstract", "H1"], skip_initial_partition=True, diff --git a/tests/dataset_builders/pie/sciarg/test_sciarg.py b/tests/dataset_builders/pie/sciarg/test_sciarg.py index 48104f08..3a3ad8e0 100644 --- a/tests/dataset_builders/pie/sciarg/test_sciarg.py +++ b/tests/dataset_builders/pie/sciarg/test_sciarg.py @@ -56,6 +56,21 @@ "spans": {"background_claim": 2752, "data": 4093, "own_claim": 5450}, }, } +CONVERTED_LAYER_MAPPING = { + "default": { + "spans": "labeled_spans", + "relations": "binary_relations", + }, + "resolve_parts_of_same": { + "spans": "labeled_multi_spans", + "relations": "binary_relations", + }, +} +FULL_LABEL_COUNTS_CONVERTED = { + variant: {CONVERTED_LAYER_MAPPING[variant][ln]: value for ln, value in counts.items()} + for variant, counts in FULL_LABEL_COUNTS.items() +} +LABELED_PARTITION_COUNTS = {"Abstract": 40, "H1": 340, "Title": 40} def resolve_annotation(annotation: Annotation) -> Any: @@ -257,32 +272,25 @@ def converted_dataset(dataset, target_document_type) -> Optional[DatasetDict]: return dataset.to_document_type(target_document_type) -def test_converted_datasets(converted_dataset, dataset_variant): +def test_converted_datasets(converted_dataset, dataset_variant, target_document_type): if converted_dataset is not None: split_sizes = {name: len(ds) for name, ds in converted_dataset.items()} assert split_sizes == SPLIT_SIZES if dataset_variant == "default": expected_document_type = TextDocumentWithLabeledSpansAndBinaryRelations - layer_name_mapping = { - "spans": "labeled_spans", - "relations": "binary_relations", - } elif dataset_variant == "resolve_parts_of_same": expected_document_type = TextDocumentWithLabeledMultiSpansAndBinaryRelations - layer_name_mapping = { - "spans": "labeled_multi_spans", - "relations": "binary_relations", - } else: raise ValueError(f"Unknown dataset variant: {dataset_variant}") + assert issubclass(converted_dataset.document_type, expected_document_type) assert isinstance(converted_dataset["train"][0], expected_document_type) if TEST_FULL_DATASET: - expected_label_counts = { - layer_name_mapping[ln]: value - for ln, value in FULL_LABEL_COUNTS[dataset_variant].items() - } + # copy to avoid modifying the original dict + expected_label_counts = {**FULL_LABEL_COUNTS_CONVERTED[dataset_variant]} + if issubclass(target_document_type, TextDocumentWithLabeledPartitions): + expected_label_counts["labeled_partitions"] = LABELED_PARTITION_COUNTS assert_dataset_label_counts(converted_dataset, expected_label_counts)