From 563b93d23166e16d70dbd3d3c0067241dba83101 Mon Sep 17 00:00:00 2001 From: ArneBinder Date: Sun, 26 Nov 2023 06:28:00 +0100 Subject: [PATCH] add dataset mixins (#382) --- src/pytorch_ie/core/__init__.py | 9 +++++++- src/pytorch_ie/core/module_mixins.py | 33 ++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/pytorch_ie/core/__init__.py b/src/pytorch_ie/core/__init__.py index b3462abd..41240eb2 100644 --- a/src/pytorch_ie/core/__init__.py +++ b/src/pytorch_ie/core/__init__.py @@ -1,7 +1,14 @@ from .document import Annotation, AnnotationLayer, Document, annotation_field from .metric import DocumentMetric from .model import PyTorchIEModel -from .module_mixins import WithDocumentTypeMixin +from .module_mixins import ( + EnterDatasetDictMixin, + EnterDatasetMixin, + ExitDatasetDictMixin, + ExitDatasetMixin, + PreparableMixin, + WithDocumentTypeMixin, +) from .statistic import DocumentStatistic from .taskmodule import TaskEncoding, TaskModule diff --git a/src/pytorch_ie/core/module_mixins.py b/src/pytorch_ie/core/module_mixins.py index 01c4bc3f..79c98ebe 100644 --- a/src/pytorch_ie/core/module_mixins.py +++ b/src/pytorch_ie/core/module_mixins.py @@ -1,4 +1,5 @@ import logging +from abc import ABC, abstractmethod from typing import List, Optional, Type from pytorch_ie.core.document import Document @@ -97,3 +98,35 @@ def prepare(self, *args, **kwargs) -> None: ) self._post_prepare() return None + + +class EnterDatasetMixin(ABC): + """Mixin for processors that enter a dataset context.""" + + @abstractmethod + def enter_dataset(self, dataset, name: Optional[str] = None) -> None: + """Enter dataset context.""" + + +class ExitDatasetMixin(ABC): + """Mixin for processors that exit a dataset context.""" + + @abstractmethod + def exit_dataset(self, dataset, name: Optional[str] = None) -> None: + """Exit dataset context.""" + + +class EnterDatasetDictMixin(ABC): + """Mixin for processors that enter a dataset dict context.""" + + @abstractmethod + def enter_dataset_dict(self, dataset_dict) -> None: + """Enter dataset dict context.""" + + +class ExitDatasetDictMixin(ABC): + """Mixin for processors that exit a dataset dict context.""" + + @abstractmethod + def exit_dataset_dict(self, dataset_dict) -> None: + """Exit dataset dict context."""