diff --git a/renumics/spotlight/data_store.py b/renumics/spotlight/data_store.py index e9f924ea..7d70f9bc 100644 --- a/renumics/spotlight/data_store.py +++ b/renumics/spotlight/data_store.py @@ -165,17 +165,15 @@ def _guess_dtype(self, col: str) -> DType: intermediate_dtype = self._data_source.intermediate_dtypes[col] fallback_dtype = _intermediate_to_semantic_dtype(intermediate_dtype) - sample_count = min(len(self._data_source), 10) - if sample_count == 0: - return fallback_dtype - - sample_values = self._data_source.get_column_values( - col, list(range(sample_count)) - ) + sample_values = self._data_source.get_column_values(col, slice(10)) sample_dtypes = [_guess_value_dtype(value) for value in sample_values] - guessed_dtype = statistics.mode(sample_dtypes) or fallback_dtype - return guessed_dtype + try: + mode_dtype = statistics.mode(sample_dtypes) + except statistics.StatisticsError: + return fallback_dtype + + return mode_dtype or fallback_dtype def _intermediate_to_semantic_dtype(intermediate_dtype: DType) -> DType: diff --git a/renumics/spotlight_plugins/core/huggingface_datasource.py b/renumics/spotlight_plugins/core/huggingface_datasource.py index 9002fdb1..abe2248c 100644 --- a/renumics/spotlight_plugins/core/huggingface_datasource.py +++ b/renumics/spotlight_plugins/core/huggingface_datasource.py @@ -6,7 +6,14 @@ from renumics.spotlight.data_source import DataSource from renumics.spotlight.data_source.decorator import datasource -from renumics.spotlight.dtypes import DType, DTypeMap +from renumics.spotlight.dtypes import ( + DType, + DTypeMap, + is_array_dtype, + is_embedding_dtype, + is_float_dtype, + is_int_dtype, +) from renumics.spotlight.data_source.data_source import ColumnMetadata @@ -83,12 +90,14 @@ def get_column_values( column_name: str, indices: Union[List[int], np.ndarray, slice] = slice(None), ) -> np.ndarray: + intermediate_dtype = self._intermediate_dtypes[column_name] + if isinstance(indices, slice): if indices == slice(None): raw_values = self._dataset.data[column_name] else: actual_indices = list(range(len(self._dataset)))[indices] - raw_values = self._dataset.data[column_name].take[actual_indices] + raw_values = self._dataset.data[column_name].take(actual_indices) else: raw_values = self._dataset.data[column_name].take(indices) @@ -108,7 +117,11 @@ def get_column_values( [value["bytes"].as_py() for value in raw_values], dtype=object ) if isinstance(feature, datasets.Sequence): - return raw_values.to_numpy() + if is_array_dtype(intermediate_dtype): + return raw_values.to_numpy() + if is_embedding_dtype(intermediate_dtype): + return raw_values.to_numpy() + return np.array([str(value) for value in raw_values]) else: return raw_values.to_numpy() @@ -178,10 +191,11 @@ def _get_intermediate_dtype(feature: _FeatureType) -> DType: elif isinstance(feature, datasets.Image): return dtypes.file_dtype elif isinstance(feature, datasets.Sequence): - if isinstance(feature.feature, datasets.Value): + inner_dtype = _get_intermediate_dtype(feature.feature) + if is_int_dtype(inner_dtype) or is_float_dtype(inner_dtype): return dtypes.array_dtype else: - raise UnsupportedFeature(feature) + return dtypes.str_dtype elif isinstance(feature, dict): if len(feature) == 2 and "bytes" in feature and "path" in feature: return dtypes.file_dtype