diff --git a/tests/unit/core/test_dataset.py b/tests/unit/core/test_dataset.py index 00f1df45..b946158a 100644 --- a/tests/unit/core/test_dataset.py +++ b/tests/unit/core/test_dataset.py @@ -434,22 +434,37 @@ def test_dataset_with_taskmodule( assert not document["entities"].predictions -def test_pie_dataset_from_documents(documents): - assert len(documents) == 8 - assert all(isinstance(doc, TextBasedDocument) for doc in documents) +@pytest.mark.parametrize("as_iterable_dataset", [False, True]) +def test_pie_dataset_from_documents(documents, as_iterable_dataset): + if as_iterable_dataset: + dataset_class = IterableDataset + else: + dataset_class = Dataset - dataset_from_documents = Dataset.from_documents(documents) + dataset_from_documents = dataset_class.from_documents(documents) - assert isinstance(dataset_from_documents, Dataset) + assert isinstance(dataset_from_documents, dataset_class) - assert len(dataset_from_documents) == 8 assert all(isinstance(doc, TextBasedDocument) for doc in dataset_from_documents) assert all(doc1.id == doc2.id for doc1, doc2 in zip(documents, dataset_from_documents)) - assert hasattr(dataset_from_documents, "document_type") + # Test dataset creation with document converter + dataset_from_documents_with_converter = dataset_class.from_documents( + documents, document_converters={TestDocumentWithLabel: convert_to_document_with_label} + ) + + assert isinstance(dataset_from_documents_with_converter, dataset_class) + + assert len(dataset_from_documents_with_converter.document_converters) == 1 + assert TestDocumentWithLabel in dataset_from_documents_with_converter.document_converters + assert ( + dataset_from_documents_with_converter.document_converters[TestDocumentWithLabel] + == convert_to_document_with_label + ) + # Test dataset creation with empty list empty_doc_list = list[Document]() with pytest.raises(ValueError) as excinfo: - Dataset.from_documents(empty_doc_list) + dataset_class.from_documents(empty_doc_list) assert str(excinfo.value) == "No documents to create dataset from"