Skip to content

Commit

Permalink
feat: handle non-numeric sequences in huggingface datasource
Browse files Browse the repository at this point in the history
  • Loading branch information
neindochoh committed Sep 21, 2023
1 parent 50488f9 commit 21a8a3d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
16 changes: 7 additions & 9 deletions renumics/spotlight/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,17 +165,15 @@ def _guess_dtype(self, col: str) -> DType:
intermediate_dtype = self._data_source.intermediate_dtypes[col]
fallback_dtype = _intermediate_to_semantic_dtype(intermediate_dtype)

sample_count = min(len(self._data_source), 10)
if sample_count == 0:
return fallback_dtype

sample_values = self._data_source.get_column_values(
col, list(range(sample_count))
)
sample_values = self._data_source.get_column_values(col, slice(10))
sample_dtypes = [_guess_value_dtype(value) for value in sample_values]
guessed_dtype = statistics.mode(sample_dtypes) or fallback_dtype

return guessed_dtype
try:
mode_dtype = statistics.mode(sample_dtypes)
except statistics.StatisticsError:
return fallback_dtype

return mode_dtype or fallback_dtype


def _intermediate_to_semantic_dtype(intermediate_dtype: DType) -> DType:
Expand Down
24 changes: 19 additions & 5 deletions renumics/spotlight_plugins/core/huggingface_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@

from renumics.spotlight.data_source import DataSource
from renumics.spotlight.data_source.decorator import datasource
from renumics.spotlight.dtypes import DType, DTypeMap
from renumics.spotlight.dtypes import (
DType,
DTypeMap,
is_array_dtype,
is_embedding_dtype,
is_float_dtype,
is_int_dtype,
)
from renumics.spotlight.data_source.data_source import ColumnMetadata


Expand Down Expand Up @@ -83,12 +90,14 @@ def get_column_values(
column_name: str,
indices: Union[List[int], np.ndarray, slice] = slice(None),
) -> np.ndarray:
intermediate_dtype = self._intermediate_dtypes[column_name]

if isinstance(indices, slice):
if indices == slice(None):
raw_values = self._dataset.data[column_name]
else:
actual_indices = list(range(len(self._dataset)))[indices]
raw_values = self._dataset.data[column_name].take[actual_indices]
raw_values = self._dataset.data[column_name].take(actual_indices)
else:
raw_values = self._dataset.data[column_name].take(indices)

Expand All @@ -108,7 +117,11 @@ def get_column_values(
[value["bytes"].as_py() for value in raw_values], dtype=object
)
if isinstance(feature, datasets.Sequence):
return raw_values.to_numpy()
if is_array_dtype(intermediate_dtype):
return raw_values.to_numpy()
if is_embedding_dtype(intermediate_dtype):
return raw_values.to_numpy()
return np.array([str(value) for value in raw_values])
else:
return raw_values.to_numpy()

Expand Down Expand Up @@ -178,10 +191,11 @@ def _get_intermediate_dtype(feature: _FeatureType) -> DType:
elif isinstance(feature, datasets.Image):
return dtypes.file_dtype
elif isinstance(feature, datasets.Sequence):
if isinstance(feature.feature, datasets.Value):
inner_dtype = _get_intermediate_dtype(feature.feature)
if is_int_dtype(inner_dtype) or is_float_dtype(inner_dtype):
return dtypes.array_dtype
else:
raise UnsupportedFeature(feature)
return dtypes.str_dtype
elif isinstance(feature, dict):
if len(feature) == 2 and "bytes" in feature and "path" in feature:
return dtypes.file_dtype
Expand Down

0 comments on commit 21a8a3d

Please sign in to comment.