Skip to content

Commit

Permalink
Allow pickling in COCO, VOC, YOLO and ImageNet (#674)
Browse files Browse the repository at this point in the history
* Allow pickling for Environment

* Allow to reuse merged dataset storage

* Enable pickling in COCO, VOC, YOLO and ImageNet formats

* Add pickling tests
  • Loading branch information
Maxim Zhiltsov authored Feb 18, 2022
1 parent 6070d05 commit 128f546
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 29 deletions.
5 changes: 3 additions & 2 deletions datumaro/components/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

from enum import Enum, auto
from functools import partial
from itertools import zip_longest
from typing import (
Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union,
Expand Down Expand Up @@ -462,7 +463,7 @@ def get_instance_labels(self) -> Dict[int, int]:
+ self.instance_mask.astype(np.uint32)
keys = np.unique(m)
instance_labels = {
k & ((1 << class_shift) - 1): k >> class_shift
int(k & ((1 << class_shift) - 1)): int(k >> class_shift)
for k in keys
if k & ((1 << class_shift) - 1) != 0
}
Expand All @@ -476,7 +477,7 @@ def extract(self, instance_id: int) -> IndexMaskImage:
return self.instance_mask == instance_id

def lazy_extract(self, instance_id: int) -> Callable[[], IndexMaskImage]:
return lambda: self.extract(instance_id)
return partial(self.extract, instance_id)

@attrs(slots=True, order=False)
class _Shape(Annotation):
Expand Down
20 changes: 12 additions & 8 deletions datumaro/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,11 @@ def as_dataset(self) -> Dataset:


class DatasetStorage(IDataset):
def __init__(self, source: IDataset = None,
def __init__(self, source: Union[IDataset, DatasetItemStorage] = None,
categories: CategoriesInfo = None):
if source is None and categories is None:
categories = {}
elif source is not None and categories is not None:
elif isinstance(source, IDataset) and categories is not None:
raise ValueError("Can't use both source and categories")
self._categories = categories

Expand All @@ -280,16 +280,20 @@ def __init__(self, source: IDataset = None,
# 2. no source + storage
# - a dataset created from scratch
# - a dataset from a source or transform, which was cached
self._source = source
self._storage = DatasetItemStorage() # patch or cache
if isinstance(source, DatasetItemStorage):
self._source = None
self._storage = source
else:
self._source = source
self._storage = DatasetItemStorage() # patch or cache
self._transforms = [] # A stack of postponed transforms

# Describes changes in the dataset since initialization
self._updated_items = {} # (id, subset) -> ItemStatus

self._flush_changes = False # Deferred flush indicator

self._length = 0 if source is None else None
self._length = len(self._storage) if self._source is None else None

def is_cache_initialized(self) -> bool:
return self._source is None and not self._transforms
Expand Down Expand Up @@ -647,14 +651,14 @@ def from_extractors(*sources: IDataset,
env: Optional[Environment] = None) -> Dataset:
if len(sources) == 1:
source = sources[0]
dataset = Dataset(source=source, env=env)
else:
from datumaro.components.operations import ExactMerge
source = ExactMerge.merge(*sources)
categories = ExactMerge.merge_categories(
s.categories() for s in sources)
source = DatasetItemStorageDatasetView(source, categories)

return Dataset(source=source, env=env)
dataset = Dataset(source=source, categories=categories, env=env)
return dataset

def __init__(self, source: Optional[IDataset] = None, *,
categories: Optional[CategoriesInfo] = None,
Expand Down
28 changes: 17 additions & 11 deletions datumaro/components/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# SPDX-License-Identifier: MIT

from functools import partial
from inspect import isclass
from typing import (
Callable, Dict, Generic, Iterable, Iterator, Optional, Type, TypeVar,
)
import glob
import importlib
import inspect
import logging as log
import os.path as osp

Expand Down Expand Up @@ -46,31 +46,37 @@ class PluginRegistry(Registry[Type[CliPlugin]]):
def __init__(self, filter: Callable[[Type[CliPlugin]], bool] = None): \
#pylint: disable=redefined-builtin
super().__init__()
self.filter = filter
self._filter = filter

def batch_register(self, values: Iterable[CliPlugin]):
for v in values:
if self.filter and not self.filter(v):
if self._filter and not self._filter(v):
continue

self.register(v.NAME, v)

class Environment:
_builtin_plugins = None

def __init__(self):
def _filter(accept, skip=None):
accept = (accept, ) if inspect.isclass(accept) else tuple(accept)
skip = {skip} if inspect.isclass(skip) else set(skip or [])
skip = tuple(skip | set(accept))
return lambda t: issubclass(t, accept) and t not in skip
@classmethod
def _make_filter(cls, accept, skip=None):
accept = (accept, ) if isclass(accept) else tuple(accept)
skip = {skip} if isclass(skip) else set(skip or [])
skip = tuple(skip | set(accept))
return partial(cls._check_type, accept=accept, skip=skip)

@staticmethod
def _check_type(t, *, accept, skip):
return issubclass(t, accept) and t not in skip

def __init__(self):
from datumaro.components.converter import Converter
from datumaro.components.extractor import (
Extractor, Importer, ItemTransform, SourceExtractor, Transform,
)
from datumaro.components.launcher import Launcher
from datumaro.components.validator import Validator
_filter = self._make_filter
self._extractors = PluginRegistry(_filter(Extractor,
skip=SourceExtractor))
self._importers = PluginRegistry(_filter(Importer))
Expand Down Expand Up @@ -147,7 +153,7 @@ def _get_plugin_exports(cls, module, types):
exports.append(getattr(module, symbol))

exports = [s for s in exports
if inspect.isclass(s) and issubclass(s, types) and not s in types]
if isclass(s) and issubclass(s, types) and not s in types]

return exports

Expand Down Expand Up @@ -219,7 +225,7 @@ def make_launcher(self, name, *args, **kwargs):

def make_converter(self, name, *args, **kwargs):
result = self.converters.get(name)
if inspect.isclass(result):
if isclass(result):
result = result.convert
return partial(result, *args, **kwargs)

Expand Down
4 changes: 3 additions & 1 deletion datumaro/util/mask_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: MIT

from functools import partial
from itertools import chain
from typing import Tuple

Expand Down Expand Up @@ -136,7 +137,8 @@ def load_mask(path, inverse_colormap=None):
return mask

def lazy_mask(path, inverse_colormap=None):
return lazy_image(path, lambda path: load_mask(path, inverse_colormap))
return lazy_image(path,
partial(load_mask, inverse_colormap=inverse_colormap))

def mask_to_rle(binary_mask):
# walk in row-major order as COCO format specifies
Expand Down
1 change: 1 addition & 0 deletions tests/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Requirements:
DATUM_497 = "Support import for SYNTHIA dataset"
DATUM_542 = "Images missing after merging two datasets"
DATUM_580 = "Import for MPII Human Pose Dataset"
DATUM_673 = "Pickling for Dataset and Annotations"

# GitHub issues (bugs)
# https://github.com/openvinotoolkit/datumaro/issues
Expand Down
25 changes: 24 additions & 1 deletion tests/test_coco_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest import TestCase
import os
import os.path as osp
import pickle # nosec - disable B403:import_pickle check

import numpy as np

Expand All @@ -21,7 +22,7 @@
)
from datumaro.plugins.coco_format.importer import CocoImporter
from datumaro.util.test_utils import (
TestDir, check_save_and_load, compare_datasets,
TestDir, check_save_and_load, compare_datasets, compare_datasets_strict,
)

from .requirements import Requirements, mark_requirement
Expand Down Expand Up @@ -549,6 +550,28 @@ def test_can_detect(self):
detected_formats = env.detect_dataset(dataset_dir)
self.assertEqual([CocoImporter.NAME], detected_formats)

@mark_requirement(Requirements.DATUM_673)
def test_can_pickle(self):
subdirs = [
'coco',
'coco_captions',
'coco_image_info',
'coco_instances',
'coco_labels',
'coco_panoptic',
'coco_person_keypoints',
'coco_stuff',
]

for subdir in subdirs:
with self.subTest(fmt=subdir, subdir=subdir):
dataset_dir = osp.join(DUMMY_DATASET_DIR, subdir)
source = Dataset.import_from(dataset_dir, format=subdir)

parsed = pickle.loads(pickle.dumps(source)) # nosec

compare_datasets_strict(self, source, parsed)

class CocoConverterTest(TestCase):
def _test_save_and_load(self, source_dataset, converter, test_dir,
target_dataset=None, importer_args=None, **kwargs):
Expand Down
34 changes: 31 additions & 3 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from unittest import TestCase, mock
import os
import os.path as osp
import pickle # nosec - disable B403:import_pickle check

import numpy as np

from datumaro.components.annotation import (
AnnotationType, Bbox, Caption, Label, LabelCategories, Mask, Points,
Polygon, PolyLine,
AnnotationType, Bbox, Caption, Label, LabelCategories, Mask, MaskCategories,
Points, Polygon, PolyLine,
)
from datumaro.components.converter import Converter
from datumaro.components.dataset import (
Expand All @@ -30,7 +31,9 @@
from datumaro.components.launcher import Launcher
from datumaro.components.media import Image
from datumaro.components.progress_reporting import NullProgressReporter
from datumaro.util.test_utils import TestDir, compare_datasets
from datumaro.util.test_utils import (
TestDir, compare_datasets, compare_datasets_strict,
)
import datumaro.components.hl_ops as hl_ops

from .requirements import Requirements, mark_requirement
Expand Down Expand Up @@ -1685,6 +1688,31 @@ class TestErrorPolicy(ImportErrorPolicy):
error_policy.report_item_error.assert_called()
error_policy.report_annotation_error.assert_called()

@mark_requirement(Requirements.DATUM_673)
def test_can_pickle(self):
source = Dataset.from_iterable([
DatasetItem(id=1, subset='subset',
image=np.ones((5, 4, 3)),
annotations=[
Label(0, attributes={'a1': 1, 'a2': '2'}, id=1, group=2),
Caption('hello', id=1, group=5),
Label(2, id=3, group=2, attributes={ 'x': 1, 'y': '2' }),
Bbox(1, 2, 3, 4, label=4, id=4, attributes={ 'a': 1.0 }),
Points([1, 2, 2, 0, 1, 1], label=0, id=5, group=6),
Mask(label=3, id=5, image=np.ones((2, 3))),
PolyLine([1, 2, 3, 4, 5, 6, 7, 8], id=11),
Polygon([1, 2, 3, 4, 5, 6, 7, 8]),
])
], categories={
AnnotationType.label: LabelCategories.from_iterable(['a', 'b']),
AnnotationType.mask: MaskCategories.generate(2),
})
source.init_cache()

parsed = pickle.loads(pickle.dumps(source)) # nosec

compare_datasets_strict(self, source, parsed)

class DatasetItemTest(TestCase):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_ctor_requires_id(self):
Expand Down
13 changes: 12 additions & 1 deletion tests/test_imagenet_format.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest import TestCase
import os.path as osp
import pickle # nosec - disable B403:import_pickle check

import numpy as np

Expand All @@ -11,7 +12,9 @@
from datumaro.components.extractor import DatasetItem
from datumaro.components.media import Image
from datumaro.plugins.imagenet_format import ImagenetConverter, ImagenetImporter
from datumaro.util.test_utils import TestDir, compare_datasets
from datumaro.util.test_utils import (
TestDir, compare_datasets, compare_datasets_strict,
)

from .requirements import Requirements, mark_requirement

Expand Down Expand Up @@ -143,3 +146,11 @@ def test_can_import(self):
def test_can_detect_imagenet(self):
detected_formats = Environment().detect_dataset(DUMMY_DATASET_DIR)
self.assertIn(ImagenetImporter.NAME, detected_formats)

@mark_requirement(Requirements.DATUM_673)
def test_can_pickle(self):
source = Dataset.import_from(DUMMY_DATASET_DIR, format='imagenet')

parsed = pickle.loads(pickle.dumps(source)) # nosec

compare_datasets_strict(self, source, parsed)
21 changes: 20 additions & 1 deletion tests/test_voc_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest import TestCase
import os
import os.path as osp
import pickle # nosec - disable B403:import_pickle check

import numpy as np

Expand All @@ -20,7 +21,7 @@
from datumaro.plugins.voc_format.importer import VocImporter
from datumaro.util.mask_tools import load_mask
from datumaro.util.test_utils import (
TestDir, check_save_and_load, compare_datasets,
TestDir, check_save_and_load, compare_datasets, compare_datasets_strict,
)
import datumaro.plugins.voc_format.format as VOC

Expand Down Expand Up @@ -392,6 +393,24 @@ def test_can_import_voc_dataset_with_empty_lines_in_subset_lists(self):

compare_datasets(self, expected, actual, require_images=True)

@mark_requirement(Requirements.DATUM_673)
def test_can_pickle(self):
formats = [
'voc',
'voc_classification',
'voc_detection',
'voc_action',
'voc_layout',
'voc_segmentation'
]

for fmt in formats:
with self.subTest(fmt=fmt):
source = Dataset.import_from(DUMMY_DATASET_DIR, format=fmt)

parsed = pickle.loads(pickle.dumps(source)) # nosec

compare_datasets_strict(self, source, parsed)

class VocConverterTest(TestCase):
def _test_save_and_load(self, source_dataset, converter, test_dir,
Expand Down
13 changes: 12 additions & 1 deletion tests/test_yolo_format.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import TestCase
import os
import os.path as osp
import pickle # nosec - disable B403:import_pickle check

import numpy as np

Expand All @@ -12,7 +13,9 @@
from datumaro.plugins.yolo_format.converter import YoloConverter
from datumaro.plugins.yolo_format.extractor import YoloImporter
from datumaro.util.image import save_image
from datumaro.util.test_utils import TestDir, compare_datasets
from datumaro.util.test_utils import (
TestDir, compare_datasets, compare_datasets_strict,
)

from .requirements import Requirements, mark_requirement

Expand Down Expand Up @@ -239,3 +242,11 @@ def test_can_import(self):
dataset = Dataset.import_from(DUMMY_DATASET_DIR, 'yolo')

compare_datasets(self, expected_dataset, dataset)

@mark_requirement(Requirements.DATUM_673)
def test_can_pickle(self):
source = Dataset.import_from(DUMMY_DATASET_DIR, format='yolo')

parsed = pickle.loads(pickle.dumps(source)) # nosec

compare_datasets_strict(self, source, parsed)

0 comments on commit 128f546

Please sign in to comment.