diff --git a/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py b/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py index 41877225..81b9f0a2 100644 --- a/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py +++ b/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py @@ -55,7 +55,13 @@ def __init__( def save_jsonl(self): data = [] for example in self.examples: - data.append({self._image_column: example.image_path, self._caption_column: example.caption}) + data.append( + { + self._image_column: example.image_path, + self._caption_column: example.caption, + MASK_COLUMN_DEFAULT: example.mask_path, + } + ) save_jsonl(data, self._jsonl_path) def _get_image_path(self, idx: int) -> str: diff --git a/tests/invoke_training/_shared/data/datasets/test_image_caption_jsonl_dataset.py b/tests/invoke_training/_shared/data/datasets/test_image_caption_jsonl_dataset.py index a0aa2380..a6d27b9f 100644 --- a/tests/invoke_training/_shared/data/datasets/test_image_caption_jsonl_dataset.py +++ b/tests/invoke_training/_shared/data/datasets/test_image_caption_jsonl_dataset.py @@ -1,6 +1,10 @@ +import shutil +from pathlib import Path + import PIL.Image from invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ImageCaptionJsonlDataset +from invoke_training._shared.utils.jsonl import load_jsonl from ..dataset_fixtures import image_caption_jsonl # noqa: F401 @@ -52,3 +56,21 @@ def test_image_caption_jsonl_dataset_get_image_dimensions(image_caption_jsonl): image_dims = dataset.get_image_dimensions() assert len(image_dims) == len(dataset) + + +def test_image_caption_jsonl_dataset_save_jsonl(image_caption_jsonl, tmp_path: Path): # noqa: F811 + # Create a copy of the image_caption_jsonl file to avoid modifying the original file. + image_caption_jsonl_copy = tmp_path / "test.jsonl" + shutil.copy(image_caption_jsonl, image_caption_jsonl_copy) + + # Load the dataset from the copied jsonl file. + dataset = ImageCaptionJsonlDataset(str(image_caption_jsonl)) + + # Save the dataset to a new jsonl file. + dataset.save_jsonl() + + # Verify that the roundtrip was successful. + assert image_caption_jsonl != image_caption_jsonl_copy + original_jsonl = load_jsonl(image_caption_jsonl) + roundtrip_jsonl = load_jsonl(image_caption_jsonl_copy) + assert original_jsonl == roundtrip_jsonl