Skip to content

Commit

Permalink
fix and simplify tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 1, 2024
1 parent c4f14b8 commit 99505bd
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions tests/dataset_builders/pie/sciarg/test_sciarg.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,31 @@
"supports": 5789,
},
"spans": {"background_claim": 3291, "data": 4297, "own_claim": 6004},
"labeled_partitions": {"Abstract": 39, "H1": 340, "Title": 40},
},
"resolve_parts_of_same": {
"relations": {"contradicts": 696, "semantically_same": 44, "supports": 5788},
"spans": {"background_claim": 2752, "data": 4093, "own_claim": 5450},
"labeled_partitions": {"Abstract": 39, "H1": 340, "Title": 40},
},
}
CONVERTED_LAYER_MAPPING = {
"default": {
"spans": "labeled_spans",
"relations": "binary_relations",
},
"resolve_parts_of_same": {
"spans": "labeled_multi_spans",
"relations": "binary_relations",
},
}
FULL_LABEL_COUNTS_CONVERTED = {
variant: {CONVERTED_LAYER_MAPPING[variant][ln]: value for ln, value in counts.items()}
for variant, counts in FULL_LABEL_COUNTS.items()
}
LABELED_PARTITION_COUNTS = {
"Abstract": 39,
"H1": 340,
"Title": 40,
}


def resolve_annotation(annotation: Annotation) -> Any:
Expand Down Expand Up @@ -265,30 +282,19 @@ def test_converted_datasets(converted_dataset, dataset_variant, target_document_
assert split_sizes == SPLIT_SIZES
if dataset_variant == "default":
expected_document_type = TextDocumentWithLabeledSpansAndBinaryRelations
layer_name_mapping = {
"spans": "labeled_spans",
"relations": "binary_relations",
}
elif dataset_variant == "resolve_parts_of_same":
expected_document_type = TextDocumentWithLabeledMultiSpansAndBinaryRelations
layer_name_mapping = {
"spans": "labeled_multi_spans",
"relations": "binary_relations",
}
else:
raise ValueError(f"Unknown dataset variant: {dataset_variant}")

assert issubclass(converted_dataset.document_type, expected_document_type)
assert isinstance(converted_dataset["train"][0], expected_document_type)

if TEST_FULL_DATASET:
expected_label_counts = {
layer_name_mapping.get(ln, ln): value
for ln, value in FULL_LABEL_COUNTS[dataset_variant].items()
}
if not issubclass(target_document_type, TextDocumentWithLabeledPartitions):
expected_label_counts = {
k: v for k, v in expected_label_counts.items() if k != "labeled_partitions"
}
# copy to avoid modifying the original dict
expected_label_counts = {**FULL_LABEL_COUNTS_CONVERTED[dataset_variant]}
if issubclass(target_document_type, TextDocumentWithLabeledPartitions):
expected_label_counts["labeled_partitions"] = LABELED_PARTITION_COUNTS
assert_dataset_label_counts(converted_dataset, expected_label_counts)


Expand Down

0 comments on commit 99505bd

Please sign in to comment.