From c038b9a982d499f3aa809274fb3b3c785dd87f6f Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 20 Sep 2023 17:01:14 +0200 Subject: [PATCH] feat: support for additional hf dtypes --- renumics/spotlight_plugins/core/__init__.py | 5 +++ .../core/huggingface_datasource.py | 43 +++++++++++++++++-- src/datatypes.ts | 2 +- src/stores/dataset/columnFactory.ts | 2 +- 4 files changed, 46 insertions(+), 6 deletions(-) diff --git a/renumics/spotlight_plugins/core/__init__.py b/renumics/spotlight_plugins/core/__init__.py index d14dd27c..f741a4dd 100644 --- a/renumics/spotlight_plugins/core/__init__.py +++ b/renumics/spotlight_plugins/core/__init__.py @@ -21,6 +21,11 @@ def __register__() -> None: """ register data sources """ + from . import ( + pandas_data_source, # noqa: F401 + hdf5_data_source, # noqa: F401 + huggingface_datasource, # noqa: F401 + ) def __activate__(app: SpotlightApp) -> None: diff --git a/renumics/spotlight_plugins/core/huggingface_datasource.py b/renumics/spotlight_plugins/core/huggingface_datasource.py index 9a8e5fd2..fcd81e2b 100644 --- a/renumics/spotlight_plugins/core/huggingface_datasource.py +++ b/renumics/spotlight_plugins/core/huggingface_datasource.py @@ -45,10 +45,13 @@ def __init__(self, source: datasets.Dataset): super().__init__(source) self._dataset = source self._intermediate_dtypes = { - col: _dtype_from_feature(feat) + col: _dtype_from_feature(feat, True) + for col, feat in self._dataset.features.items() + } + self._guessed_dtypes = { + col: _dtype_from_feature(feat, False) for col, feat in self._dataset.features.items() } - self._guessed_dtypes = self._intermediate_dtypes @property def column_names(self) -> List[str]: @@ -88,11 +91,23 @@ def get_column_values( raw_values = self._dataset.data[column_name].take(indices) feature = self._dataset.features[column_name] - if isinstance(feature, datasets.Audio): + if isinstance(feature, datasets.Audio) or isinstance(feature, datasets.Image): # TODO: use path for name if available? + return np.array( + [ + value["path"].as_py() + if value["bytes"].as_py() is None + else value["bytes"].as_py() + for value in raw_values + ], + dtype=object, + ) + if isinstance(feature, dict): return np.array( [value["bytes"].as_py() for value in raw_values], dtype=object ) + if isinstance(feature, datasets.Sequence): + return raw_values.to_numpy() else: return raw_values.to_numpy() @@ -100,7 +115,7 @@ def get_column_metadata(self, _: str) -> ColumnMetadata: return ColumnMetadata(nullable=True, editable=False) -def _dtype_from_feature(feature: _FeatureType) -> DType: +def _dtype_from_feature(feature: _FeatureType, intermediate: bool) -> DType: if isinstance(feature, datasets.Value): hf_dtype = cast(datasets.Value, feature).dtype if hf_dtype == "bool": @@ -146,8 +161,28 @@ def _dtype_from_feature(feature: _FeatureType) -> DType: elif isinstance(feature, datasets.ClassLabel): return dtypes.CategoryDType(categories=cast(datasets.ClassLabel, feature).names) elif isinstance(feature, datasets.Audio): + if intermediate: + return dtypes.bytes_dtype return dtypes.audio_dtype elif isinstance(feature, datasets.Image): + if intermediate: + return dtypes.bytes_dtype return dtypes.image_dtype + elif isinstance(feature, datasets.Sequence): + if isinstance(feature.feature, datasets.Value): + if feature.length != -1: + return dtypes.embedding_dtype + return dtypes.array_dtype + else: + raise UnsupportedFeature(feature) + elif isinstance(feature, dict): + if len(feature) == 2 and "bytes" in feature and "path" in feature: + if intermediate: + return dtypes.bytes_dtype + else: + # TODO: guess filetype from bytes + return dtypes.str_dtype + else: + raise UnsupportedFeature(feature) else: raise UnsupportedFeature(feature) diff --git a/src/datatypes.ts b/src/datatypes.ts index 6ac98b87..29ed7bca 100644 --- a/src/datatypes.ts +++ b/src/datatypes.ts @@ -35,7 +35,7 @@ export type IntegerDataType = BaseDataType<'int'>; export type FloatDataType = BaseDataType<'float'>; export type BooleanDataType = BaseDataType<'bool'>; export type DateTimeDataType = BaseDataType<'datetime'>; -export type ArrayDataType = BaseDataType<'array'>; +export type ArrayDataType = BaseDataType<'array', true>; export type WindowDataType = BaseDataType<'Window'>; export type StringDataType = BaseDataType<'str', true>; export type EmbeddingDataType = BaseDataType<'Embedding', true>; diff --git a/src/stores/dataset/columnFactory.ts b/src/stores/dataset/columnFactory.ts index 3ec20678..26018d58 100644 --- a/src/stores/dataset/columnFactory.ts +++ b/src/stores/dataset/columnFactory.ts @@ -11,7 +11,6 @@ function makeDatatype(column: Column): DataType { case 'float': case 'bool': case 'Window': - case 'array': case 'datetime': return { kind, @@ -20,6 +19,7 @@ function makeDatatype(column: Column): DataType { optional: column.optional, }; case 'str': + case 'array': case 'Embedding': return { kind,