diff --git a/tests/dataset_builders/pie/sciarg/test_sciarg.py b/tests/dataset_builders/pie/sciarg/test_sciarg.py index d994023e..1010368b 100644 --- a/tests/dataset_builders/pie/sciarg/test_sciarg.py +++ b/tests/dataset_builders/pie/sciarg/test_sciarg.py @@ -50,14 +50,31 @@ "supports": 5789, }, "spans": {"background_claim": 3291, "data": 4297, "own_claim": 6004}, - "labeled_partitions": {"Abstract": 39, "H1": 340, "Title": 40}, }, "resolve_parts_of_same": { "relations": {"contradicts": 696, "semantically_same": 44, "supports": 5788}, "spans": {"background_claim": 2752, "data": 4093, "own_claim": 5450}, - "labeled_partitions": {"Abstract": 39, "H1": 340, "Title": 40}, }, } +CONVERTED_LAYER_MAPPING = { + "default": { + "spans": "labeled_spans", + "relations": "binary_relations", + }, + "resolve_parts_of_same": { + "spans": "labeled_multi_spans", + "relations": "binary_relations", + }, +} +FULL_LABEL_COUNTS_CONVERTED = { + variant: {CONVERTED_LAYER_MAPPING[variant][ln]: value for ln, value in counts.items()} + for variant, counts in FULL_LABEL_COUNTS.items() +} +LABELED_PARTITION_COUNTS = { + "Abstract": 39, + "H1": 340, + "Title": 40, +} def resolve_annotation(annotation: Annotation) -> Any: @@ -265,30 +282,19 @@ def test_converted_datasets(converted_dataset, dataset_variant, target_document_ assert split_sizes == SPLIT_SIZES if dataset_variant == "default": expected_document_type = TextDocumentWithLabeledSpansAndBinaryRelations - layer_name_mapping = { - "spans": "labeled_spans", - "relations": "binary_relations", - } elif dataset_variant == "resolve_parts_of_same": expected_document_type = TextDocumentWithLabeledMultiSpansAndBinaryRelations - layer_name_mapping = { - "spans": "labeled_multi_spans", - "relations": "binary_relations", - } else: raise ValueError(f"Unknown dataset variant: {dataset_variant}") + assert issubclass(converted_dataset.document_type, expected_document_type) assert isinstance(converted_dataset["train"][0], expected_document_type) if TEST_FULL_DATASET: - expected_label_counts = { - layer_name_mapping.get(ln, ln): value - for ln, value in FULL_LABEL_COUNTS[dataset_variant].items() - } - if not issubclass(target_document_type, TextDocumentWithLabeledPartitions): - expected_label_counts = { - k: v for k, v in expected_label_counts.items() if k != "labeled_partitions" - } + # copy to avoid modifying the original dict + expected_label_counts = {**FULL_LABEL_COUNTS_CONVERTED[dataset_variant]} + if issubclass(target_document_type, TextDocumentWithLabeledPartitions): + expected_label_counts["labeled_partitions"] = LABELED_PARTITION_COUNTS assert_dataset_label_counts(converted_dataset, expected_label_counts)