Skip to content

Commit

Permalink
(de-)serialize document_type when calling DatasetDict (#346)
Browse files Browse the repository at this point in the history
* implement utils.hydra.serialize_document_type()

* DatasetDict.to_json(): save a metadata.json that contains the serialized document type

* DatasetDict.from_json(): make the parameter document_type optional, but try to load a metadata.json file and use the document_type from there (if it is not explicitly specified)
  • Loading branch information
ArneBinder authored Sep 17, 2023
1 parent 758e436 commit bc897a5
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 3 deletions.
43 changes: 40 additions & 3 deletions src/pytorch_ie/data/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pytorch_ie.core import Document
from pytorch_ie.data.dataset import Dataset, IterableDataset, get_pie_dataset_type
from pytorch_ie.utils.hydra import resolve_target
from pytorch_ie.utils.hydra import resolve_target, serialize_document_type

from .common import (
EnterDatasetDictMixin,
Expand All @@ -19,6 +19,8 @@

logger = logging.getLogger(__name__)

METADATA_FILE_NAME = "metadata.json"


D = TypeVar("D", bound=Document)

Expand Down Expand Up @@ -73,17 +75,40 @@ def from_hf(
@classmethod
def from_json( # type: ignore
cls,
document_type: Union[Type[Document], str],
document_type: Optional[Union[Type[Document], str]] = None,
metadata_path: Optional[Union[str, Path]] = None,
data_dir: Optional[str] = None,
**kwargs,
) -> "DatasetDict":
"""Creates a PIE DatasetDict from JSONLINE files. Uses `datasets.load_dataset("json")` under the hood.
Requires a document type to be provided. If the document type is not provided, we try to load it from the
metadata file.
Args:
document_type: document type of the dataset
data_dir: Defining the `data_dir` of the dataset configuration. See datasets.load_dataset() for more
information.
metadata_path: path to the metadata file. Should point to a directory containing the metadata file
`metadata.json`. Defaults to the value of the `data_dir` parameter.
**kwargs: additional keyword arguments for `datasets.load_dataset()`
"""

hf_dataset = datasets.load_dataset("json", **kwargs)
# try to load metadata
if metadata_path is None:
metadata_path = data_dir
if metadata_path is not None:
metadata_file_name = Path(metadata_path) / METADATA_FILE_NAME
if os.path.exists(metadata_file_name):
with open(metadata_file_name) as f:
metadata = json.load(f)
document_type = document_type or metadata.get("document_type", None)

if document_type is None:
raise ValueError(
f"document_type must be provided if it cannot be loaded from the metadata file"
)

hf_dataset = datasets.load_dataset("json", data_dir=data_dir, **kwargs)
if isinstance(
hf_dataset,
(
Expand All @@ -110,6 +135,18 @@ def to_json(self, path: Union[str, Path], **kwargs) -> None:
"""

path = Path(path)

# save the metadata
metadata = {"document_type": serialize_document_type(self.document_type)}
os.makedirs(path, exist_ok=True)
if os.path.exists(path / METADATA_FILE_NAME):
logger.warning(
f"metadata file '{path / METADATA_FILE_NAME}' already exists, overwriting it"
)
with open(path / METADATA_FILE_NAME, "w") as f:
json.dump(metadata, f, indent=2)

# save the splits
for split, dataset in self.items():
split_path = path / split
logger.info(f'serialize documents to "{split_path}" ...')
Expand Down
4 changes: 4 additions & 0 deletions src/pytorch_ie/utils/hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,7 @@ def resolve_optional_document_type(
f"(resolved) document_type must be a subclass of Document, but it is: {dt}"
)
return dt


def serialize_document_type(document_type: Type[Document]) -> str:
return f"{document_type.__module__}.{document_type.__name__}"
22 changes: 22 additions & 0 deletions tests/data/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ def test_from_json(dataset_dict):
assert len(dataset_dict["validation"]) == 3


def test_from_json_no_serialized_document_type(dataset_dict):
with pytest.raises(ValueError) as excinfo:
DatasetDict.from_json(data_dir=DATA_PATH)
assert (
str(excinfo.value)
== "document_type must be provided if it cannot be loaded from the metadata file"
)


def test_load_dataset():
dataset_dict = DatasetDict.load_dataset(
"pie/brat", base_dataset_kwargs=dict(data_dir=FIXTURES_ROOT / "datasets" / "brat")
Expand Down Expand Up @@ -90,6 +99,19 @@ def test_to_json_and_back(dataset_dict, tmp_path):
assert doc1 == doc2


def test_to_json_and_back_serialize_document_type(dataset_dict, tmp_path):
path = Path(tmp_path) / "dataset_dict"
dataset_dict.to_json(path)
dataset_dict_from_json = DatasetDict.from_json(
data_dir=path,
)
assert set(dataset_dict_from_json) == set(dataset_dict)
for split in dataset_dict:
assert len(dataset_dict_from_json[split]) == len(dataset_dict[split])
for doc1, doc2 in zip(dataset_dict_from_json[split], dataset_dict[split]):
assert doc1 == doc2


def test_document_type_empty_no_splits():
with pytest.raises(ValueError) as excinfo:
DatasetDict().document_type
Expand Down
9 changes: 9 additions & 0 deletions tests/utils/test_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
InstantiationException,
resolve_optional_document_type,
resolve_target,
serialize_document_type,
)
from tests.conftest import TestDocument


def test_resolve_target_string():
Expand Down Expand Up @@ -102,3 +104,10 @@ def test_resolve_optional_document_type_no_document():
str(excinfo.value)
== "(resolved) document_type must be a subclass of Document, but it is: <class 'tests.utils.test_hydra.NoDocument'>"
)


def test_serialize_document_type():
serialized_dt = serialize_document_type(TestDocument)
assert serialized_dt == "tests.conftest.TestDocument"
resolved_dt = resolve_optional_document_type(serialized_dt)
assert resolved_dt == TestDocument

0 comments on commit bc897a5

Please sign in to comment.