From b78f7fc1c59ba718942b2f2662cbd089f250151f Mon Sep 17 00:00:00 2001 From: ArneBinder Date: Wed, 13 Sep 2023 20:24:31 +0200 Subject: [PATCH] do not try to infer the document_type from the function in (Iterable)Dataset.map() (#337) --- src/pytorch_ie/data/dataset.py | 10 ++-------- tests/data/test_dataset.py | 16 ++-------------- 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/src/pytorch_ie/data/dataset.py b/src/pytorch_ie/data/dataset.py index 32c754ab..c64281f4 100644 --- a/src/pytorch_ie/data/dataset.py +++ b/src/pytorch_ie/data/dataset.py @@ -369,10 +369,7 @@ def map( ) if result_document_type is None: - if function is not None and as_documents: - result_document_type = _infer_document_type_from_function_return(function=function) - if result_document_type is None: - result_document_type = self.document_type + result_document_type = self.document_type return Dataset.from_hf_dataset( dataset, @@ -518,10 +515,7 @@ def map( # type: ignore ) if result_document_type is None: - if function is not None and as_documents: - result_document_type = _infer_document_type_from_function_return(function=function) - if result_document_type is None: - result_document_type = self.document_type + result_document_type = self.document_type return IterableDataset.from_hf_dataset( dataset_mapped, diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index ab65e11e..7db0f3b8 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -111,8 +111,7 @@ def clear_relations_batched(documents): assert sum(len(doc.relations) for doc in train_dataset) == 7 -@pytest.mark.parametrize("infer_type", [False, True]) -def test_dataset_map_with_result_document_type(maybe_iterable_dataset, infer_type): +def test_dataset_map_with_result_document_type(maybe_iterable_dataset): @dataclass class TestDocument(TextDocument): sentences: AnnotationList[Span] = annotation_field(target="text") @@ -142,7 +141,7 @@ def clear_relations_and_add_one_token( mapped_dataset1 = train_dataset.map( clear_relations_and_add_one_token, - result_document_type=TestDocumentWithTokensButNoRelations if not infer_type else None, + result_document_type=TestDocumentWithTokensButNoRelations, ) assert sum(len(doc.relations) for doc in train_dataset) == 7 @@ -160,17 +159,6 @@ def clear_relations_and_add_one_token( f.name for f in TestDocumentWithTokensButNoRelations.fields() } - if infer_type: - - def func_wrong_return_type(document: TestDocument) -> Dict: - return document # type: ignore - - with pytest.raises( - TypeError, - match="the return type annotation of the function used with map is not a subclass of Document", - ): - train_dataset.map(func_wrong_return_type) - @pytest.mark.parametrize("encode_target", [False, True]) @pytest.mark.parametrize("inplace", [False, True])