From bc897a56cd740ee31c5ac1cd5761b588cd3348bc Mon Sep 17 00:00:00 2001 From: ArneBinder Date: Sun, 17 Sep 2023 21:14:43 +0200 Subject: [PATCH] (de-)serialize `document_type` when calling `DatasetDict` (#346) * implement utils.hydra.serialize_document_type() * DatasetDict.to_json(): save a metadata.json that contains the serialized document type * DatasetDict.from_json(): make the parameter document_type optional, but try to load a metadata.json file and use the document_type from there (if it is not explicitly specified) --- src/pytorch_ie/data/dataset_dict.py | 43 +++++++++++++++++++++++++++-- src/pytorch_ie/utils/hydra.py | 4 +++ tests/data/test_dataset_dict.py | 22 +++++++++++++++ tests/utils/test_hydra.py | 9 ++++++ 4 files changed, 75 insertions(+), 3 deletions(-) diff --git a/src/pytorch_ie/data/dataset_dict.py b/src/pytorch_ie/data/dataset_dict.py index 5e8b9c70..27f058c6 100644 --- a/src/pytorch_ie/data/dataset_dict.py +++ b/src/pytorch_ie/data/dataset_dict.py @@ -8,7 +8,7 @@ from pytorch_ie.core import Document from pytorch_ie.data.dataset import Dataset, IterableDataset, get_pie_dataset_type -from pytorch_ie.utils.hydra import resolve_target +from pytorch_ie.utils.hydra import resolve_target, serialize_document_type from .common import ( EnterDatasetDictMixin, @@ -19,6 +19,8 @@ logger = logging.getLogger(__name__) +METADATA_FILE_NAME = "metadata.json" + D = TypeVar("D", bound=Document) @@ -73,17 +75,40 @@ def from_hf( @classmethod def from_json( # type: ignore cls, - document_type: Union[Type[Document], str], + document_type: Optional[Union[Type[Document], str]] = None, + metadata_path: Optional[Union[str, Path]] = None, + data_dir: Optional[str] = None, **kwargs, ) -> "DatasetDict": """Creates a PIE DatasetDict from JSONLINE files. Uses `datasets.load_dataset("json")` under the hood. + Requires a document type to be provided. If the document type is not provided, we try to load it from the + metadata file. Args: document_type: document type of the dataset + data_dir: Defining the `data_dir` of the dataset configuration. See datasets.load_dataset() for more + information. + metadata_path: path to the metadata file. Should point to a directory containing the metadata file + `metadata.json`. Defaults to the value of the `data_dir` parameter. **kwargs: additional keyword arguments for `datasets.load_dataset()` """ - hf_dataset = datasets.load_dataset("json", **kwargs) + # try to load metadata + if metadata_path is None: + metadata_path = data_dir + if metadata_path is not None: + metadata_file_name = Path(metadata_path) / METADATA_FILE_NAME + if os.path.exists(metadata_file_name): + with open(metadata_file_name) as f: + metadata = json.load(f) + document_type = document_type or metadata.get("document_type", None) + + if document_type is None: + raise ValueError( + f"document_type must be provided if it cannot be loaded from the metadata file" + ) + + hf_dataset = datasets.load_dataset("json", data_dir=data_dir, **kwargs) if isinstance( hf_dataset, ( @@ -110,6 +135,18 @@ def to_json(self, path: Union[str, Path], **kwargs) -> None: """ path = Path(path) + + # save the metadata + metadata = {"document_type": serialize_document_type(self.document_type)} + os.makedirs(path, exist_ok=True) + if os.path.exists(path / METADATA_FILE_NAME): + logger.warning( + f"metadata file '{path / METADATA_FILE_NAME}' already exists, overwriting it" + ) + with open(path / METADATA_FILE_NAME, "w") as f: + json.dump(metadata, f, indent=2) + + # save the splits for split, dataset in self.items(): split_path = path / split logger.info(f'serialize documents to "{split_path}" ...') diff --git a/src/pytorch_ie/utils/hydra.py b/src/pytorch_ie/utils/hydra.py index 71ef6860..8d35f9ab 100644 --- a/src/pytorch_ie/utils/hydra.py +++ b/src/pytorch_ie/utils/hydra.py @@ -104,3 +104,7 @@ def resolve_optional_document_type( f"(resolved) document_type must be a subclass of Document, but it is: {dt}" ) return dt + + +def serialize_document_type(document_type: Type[Document]) -> str: + return f"{document_type.__module__}.{document_type.__name__}" diff --git a/tests/data/test_dataset_dict.py b/tests/data/test_dataset_dict.py index ef0a3e13..fbd9b268 100644 --- a/tests/data/test_dataset_dict.py +++ b/tests/data/test_dataset_dict.py @@ -52,6 +52,15 @@ def test_from_json(dataset_dict): assert len(dataset_dict["validation"]) == 3 +def test_from_json_no_serialized_document_type(dataset_dict): + with pytest.raises(ValueError) as excinfo: + DatasetDict.from_json(data_dir=DATA_PATH) + assert ( + str(excinfo.value) + == "document_type must be provided if it cannot be loaded from the metadata file" + ) + + def test_load_dataset(): dataset_dict = DatasetDict.load_dataset( "pie/brat", base_dataset_kwargs=dict(data_dir=FIXTURES_ROOT / "datasets" / "brat") @@ -90,6 +99,19 @@ def test_to_json_and_back(dataset_dict, tmp_path): assert doc1 == doc2 +def test_to_json_and_back_serialize_document_type(dataset_dict, tmp_path): + path = Path(tmp_path) / "dataset_dict" + dataset_dict.to_json(path) + dataset_dict_from_json = DatasetDict.from_json( + data_dir=path, + ) + assert set(dataset_dict_from_json) == set(dataset_dict) + for split in dataset_dict: + assert len(dataset_dict_from_json[split]) == len(dataset_dict[split]) + for doc1, doc2 in zip(dataset_dict_from_json[split], dataset_dict[split]): + assert doc1 == doc2 + + def test_document_type_empty_no_splits(): with pytest.raises(ValueError) as excinfo: DatasetDict().document_type diff --git a/tests/utils/test_hydra.py b/tests/utils/test_hydra.py index d549f791..3f9efbc8 100644 --- a/tests/utils/test_hydra.py +++ b/tests/utils/test_hydra.py @@ -8,7 +8,9 @@ InstantiationException, resolve_optional_document_type, resolve_target, + serialize_document_type, ) +from tests.conftest import TestDocument def test_resolve_target_string(): @@ -102,3 +104,10 @@ def test_resolve_optional_document_type_no_document(): str(excinfo.value) == "(resolved) document_type must be a subclass of Document, but it is: " ) + + +def test_serialize_document_type(): + serialized_dt = serialize_document_type(TestDocument) + assert serialized_dt == "tests.conftest.TestDocument" + resolved_dt = resolve_optional_document_type(serialized_dt) + assert resolved_dt == TestDocument