Skip to content

Commit

Permalink
try to resolve the converter if it is a string
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Sep 10, 2023
1 parent b333841 commit d333c1a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
10 changes: 8 additions & 2 deletions src/pytorch_ie/data/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self,
base_dataset_kwargs: Optional[Dict[str, Any]] = None,
document_converters: Optional[
Dict[Union[Type[Document], str], Union[Callable[..., Document], Dict[str, str]]]
Dict[Union[Type[Document], str], Union[Callable[..., Document], Dict[str, str], str]]
] = None,
**kwargs,
):
Expand Down Expand Up @@ -119,9 +119,15 @@ def __init__(

self.document_converters = dict(self.DOCUMENT_CONVERTERS)
if document_converters is not None:
for document_type_or_str, document_converter in document_converters.items():
for document_type_or_str, document_converter_or_str in document_converters.items():
document_type = resolve_target(document_type_or_str)
if isinstance(document_type, type) and issubclass(document_type, Document):
document_converter: Union[Callable[..., Any], dict[str, str]]
if isinstance(document_converter_or_str, str):
document_converter = resolve_target(document_converter_or_str)
else:
document_converter = document_converter_or_str

self.document_converters[document_type] = document_converter
else:
raise TypeError(
Expand Down
19 changes: 15 additions & 4 deletions src/pytorch_ie/data/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def dataset_type(self) -> Union[Type[Dataset], Type[IterableDataset]]:

def register_document_converter(
self,
converter: Union[Callable[..., D], Dict[str, str]],
converter: Union[Callable[..., D], Dict[str, str], str],
document_type: Optional[Union[Type[D], str]] = None,
) -> "DatasetDict":
"""Register a converter function or field mapping for a target document type.
Expand All @@ -172,8 +172,9 @@ def register_document_converter(
of Document or string that can be resolved to such a type. If `None`, the document type is tried to be
inferred from the converter function signature.
converter: Either a function that converts a document of the document type of this dataset to a document
of the target document_type, or a field mapping (dict[str, str]) that maps fields of the document type
of this dataset to fields of the target document_type.
of the target document_type, a string that can be resolved to such a function, or a field mapping
(dict[str, str]) that maps fields of the document type of this dataset to fields of the target
document_type.
"""
resolved_document_type: Optional[Union[Type[D], Callable]] = None
if document_type is not None:
Expand All @@ -189,9 +190,19 @@ def register_document_converter(
f"document_type must be or resolv to a subclass of Document, but is {document_type}"
)

resolved_converter: Union[Callable[..., Any], dict[str, str]]
if isinstance(converter, str):
resolved_converter = resolve_target(converter)
else:
resolved_converter = converter
if not (callable(resolved_converter) or isinstance(resolved_converter, dict)):
raise TypeError(
f"converter must be a callable or a dict, but is {type(resolved_converter)}"
)

for ds in self.values():
ds.register_document_converter(
document_type=resolved_document_type, converter=converter
document_type=resolved_document_type, converter=resolved_converter
)
return self

Expand Down

0 comments on commit d333c1a

Please sign in to comment.