From 5a0e15bf1ec0501037d228d85025fdb80ab1576f Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 10 Sep 2023 23:41:32 +0200 Subject: [PATCH] try to resolve the converter if it is a string --- src/pytorch_ie/data/builder.py | 10 ++++++++-- src/pytorch_ie/data/dataset_dict.py | 19 +++++++++++++++---- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/pytorch_ie/data/builder.py b/src/pytorch_ie/data/builder.py index 7772bc26..635826f9 100644 --- a/src/pytorch_ie/data/builder.py +++ b/src/pytorch_ie/data/builder.py @@ -60,7 +60,7 @@ def __init__( self, base_dataset_kwargs: Optional[Dict[str, Any]] = None, document_converters: Optional[ - Dict[Union[Type[Document], str], Union[Callable[..., Document], Dict[str, str]]] + Dict[Union[Type[Document], str], Union[Callable[..., Document], Dict[str, str], str]] ] = None, **kwargs, ): @@ -119,9 +119,15 @@ def __init__( self.document_converters = dict(self.DOCUMENT_CONVERTERS) if document_converters is not None: - for document_type_or_str, document_converter in document_converters.items(): + for document_type_or_str, document_converter_or_str in document_converters.items(): document_type = resolve_target(document_type_or_str) if isinstance(document_type, type) and issubclass(document_type, Document): + document_converter: Union[Callable[..., Any], dict[str, str]] + if isinstance(document_converter_or_str, str): + document_converter = resolve_target(document_converter_or_str) + else: + document_converter = document_converter_or_str + self.document_converters[document_type] = document_converter else: raise TypeError( diff --git a/src/pytorch_ie/data/dataset_dict.py b/src/pytorch_ie/data/dataset_dict.py index 4296a7fb..556f1807 100644 --- a/src/pytorch_ie/data/dataset_dict.py +++ b/src/pytorch_ie/data/dataset_dict.py @@ -162,7 +162,7 @@ def dataset_type(self) -> Union[Type[Dataset], Type[IterableDataset]]: def register_document_converter( self, - converter: Union[Callable[..., D], Dict[str, str]], + converter: Union[Callable[..., D], Dict[str, str], str], document_type: Optional[Union[Type[D], str]] = None, ) -> "DatasetDict": """Register a converter function or field mapping for a target document type. @@ -172,8 +172,9 @@ def register_document_converter( of Document or string that can be resolved to such a type. If `None`, the document type is tried to be inferred from the converter function signature. converter: Either a function that converts a document of the document type of this dataset to a document - of the target document_type, or a field mapping (dict[str, str]) that maps fields of the document type - of this dataset to fields of the target document_type. + of the target document_type, a string that can be resolved to such a function, or a field mapping + (dict[str, str]) that maps fields of the document type of this dataset to fields of the target + document_type. """ resolved_document_type: Optional[Union[Type[D], Callable]] = None if document_type is not None: @@ -189,9 +190,19 @@ def register_document_converter( f"document_type must be or resolv to a subclass of Document, but is {document_type}" ) + resolved_converter: Union[Callable[..., Any], dict[str, str]] + if isinstance(converter, str): + resolved_converter = resolve_target(converter) + else: + resolved_converter = converter + if not (callable(resolved_converter) or isinstance(resolved_converter, dict)): + raise TypeError( + f"converter must be a callable or a dict, but is {type(resolved_converter)}" + ) + for ds in self.values(): ds.register_document_converter( - document_type=resolved_document_type, converter=converter + document_type=resolved_document_type, converter=resolved_converter ) return self