From 700b5b06d511f0e679e4750994bb29808f74d2f8 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Thu, 21 Sep 2023 15:46:03 +0200 Subject: [PATCH] feat: add simple handling for datasets.Translation --- .../spotlight_plugins/core/huggingface_datasource.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/renumics/spotlight_plugins/core/huggingface_datasource.py b/renumics/spotlight_plugins/core/huggingface_datasource.py index abe2248c..de85d2e8 100644 --- a/renumics/spotlight_plugins/core/huggingface_datasource.py +++ b/renumics/spotlight_plugins/core/huggingface_datasource.py @@ -102,6 +102,7 @@ def get_column_values( raw_values = self._dataset.data[column_name].take(indices) feature = self._dataset.features[column_name] + if isinstance(feature, datasets.Audio) or isinstance(feature, datasets.Image): return np.array( [ @@ -112,18 +113,23 @@ def get_column_values( ], dtype=object, ) + if isinstance(feature, dict): return np.array( [value["bytes"].as_py() for value in raw_values], dtype=object ) + if isinstance(feature, datasets.Sequence): if is_array_dtype(intermediate_dtype): return raw_values.to_numpy() if is_embedding_dtype(intermediate_dtype): return raw_values.to_numpy() return np.array([str(value) for value in raw_values]) - else: - return raw_values.to_numpy() + + if isinstance(feature, datasets.Translation): + return np.array([str(value) for value in raw_values]) + + return raw_values.to_numpy() def get_column_metadata(self, _: str) -> ColumnMetadata: return ColumnMetadata(nullable=True, editable=False) @@ -201,5 +207,7 @@ def _get_intermediate_dtype(feature: _FeatureType) -> DType: return dtypes.file_dtype else: raise UnsupportedFeature(feature) + elif isinstance(feature, datasets.Translation): + return dtypes.str_dtype else: raise UnsupportedFeature(feature)