Skip to content

Commit

Permalink
do not try to infer the document_type from the function in (Iterable)…
Browse files Browse the repository at this point in the history
…Dataset.map() (#337)
  • Loading branch information
ArneBinder authored Sep 13, 2023
1 parent 1ede365 commit b78f7fc
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 22 deletions.
10 changes: 2 additions & 8 deletions src/pytorch_ie/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,7 @@ def map(
)

if result_document_type is None:
if function is not None and as_documents:
result_document_type = _infer_document_type_from_function_return(function=function)
if result_document_type is None:
result_document_type = self.document_type
result_document_type = self.document_type

return Dataset.from_hf_dataset(
dataset,
Expand Down Expand Up @@ -518,10 +515,7 @@ def map( # type: ignore
)

if result_document_type is None:
if function is not None and as_documents:
result_document_type = _infer_document_type_from_function_return(function=function)
if result_document_type is None:
result_document_type = self.document_type
result_document_type = self.document_type

return IterableDataset.from_hf_dataset(
dataset_mapped,
Expand Down
16 changes: 2 additions & 14 deletions tests/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def clear_relations_batched(documents):
assert sum(len(doc.relations) for doc in train_dataset) == 7


@pytest.mark.parametrize("infer_type", [False, True])
def test_dataset_map_with_result_document_type(maybe_iterable_dataset, infer_type):
def test_dataset_map_with_result_document_type(maybe_iterable_dataset):
@dataclass
class TestDocument(TextDocument):
sentences: AnnotationList[Span] = annotation_field(target="text")
Expand Down Expand Up @@ -142,7 +141,7 @@ def clear_relations_and_add_one_token(

mapped_dataset1 = train_dataset.map(
clear_relations_and_add_one_token,
result_document_type=TestDocumentWithTokensButNoRelations if not infer_type else None,
result_document_type=TestDocumentWithTokensButNoRelations,
)

assert sum(len(doc.relations) for doc in train_dataset) == 7
Expand All @@ -160,17 +159,6 @@ def clear_relations_and_add_one_token(
f.name for f in TestDocumentWithTokensButNoRelations.fields()
}

if infer_type:

def func_wrong_return_type(document: TestDocument) -> Dict:
return document # type: ignore

with pytest.raises(
TypeError,
match="the return type annotation of the function used with map is not a subclass of Document",
):
train_dataset.map(func_wrong_return_type)


@pytest.mark.parametrize("encode_target", [False, True])
@pytest.mark.parametrize("inplace", [False, True])
Expand Down

0 comments on commit b78f7fc

Please sign in to comment.