Skip to content

Commit

Permalink
feat: support for additional hf dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
neindochoh committed Sep 20, 2023
1 parent 8c1da4c commit c038b9a
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 6 deletions.
5 changes: 5 additions & 0 deletions renumics/spotlight_plugins/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def __register__() -> None:
"""
register data sources
"""
from . import (
pandas_data_source, # noqa: F401
hdf5_data_source, # noqa: F401
huggingface_datasource, # noqa: F401
)


def __activate__(app: SpotlightApp) -> None:
Expand Down
43 changes: 39 additions & 4 deletions renumics/spotlight_plugins/core/huggingface_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,13 @@ def __init__(self, source: datasets.Dataset):
super().__init__(source)
self._dataset = source
self._intermediate_dtypes = {
col: _dtype_from_feature(feat)
col: _dtype_from_feature(feat, True)
for col, feat in self._dataset.features.items()
}
self._guessed_dtypes = {
col: _dtype_from_feature(feat, False)
for col, feat in self._dataset.features.items()
}
self._guessed_dtypes = self._intermediate_dtypes

@property
def column_names(self) -> List[str]:
Expand Down Expand Up @@ -88,19 +91,31 @@ def get_column_values(
raw_values = self._dataset.data[column_name].take(indices)

feature = self._dataset.features[column_name]
if isinstance(feature, datasets.Audio):
if isinstance(feature, datasets.Audio) or isinstance(feature, datasets.Image):
# TODO: use path for name if available?
return np.array(
[
value["path"].as_py()
if value["bytes"].as_py() is None
else value["bytes"].as_py()
for value in raw_values
],
dtype=object,
)
if isinstance(feature, dict):
return np.array(
[value["bytes"].as_py() for value in raw_values], dtype=object
)
if isinstance(feature, datasets.Sequence):
return raw_values.to_numpy()
else:
return raw_values.to_numpy()

def get_column_metadata(self, _: str) -> ColumnMetadata:
return ColumnMetadata(nullable=True, editable=False)


def _dtype_from_feature(feature: _FeatureType) -> DType:
def _dtype_from_feature(feature: _FeatureType, intermediate: bool) -> DType:
if isinstance(feature, datasets.Value):
hf_dtype = cast(datasets.Value, feature).dtype
if hf_dtype == "bool":
Expand Down Expand Up @@ -146,8 +161,28 @@ def _dtype_from_feature(feature: _FeatureType) -> DType:
elif isinstance(feature, datasets.ClassLabel):
return dtypes.CategoryDType(categories=cast(datasets.ClassLabel, feature).names)
elif isinstance(feature, datasets.Audio):
if intermediate:
return dtypes.bytes_dtype
return dtypes.audio_dtype
elif isinstance(feature, datasets.Image):
if intermediate:
return dtypes.bytes_dtype
return dtypes.image_dtype
elif isinstance(feature, datasets.Sequence):
if isinstance(feature.feature, datasets.Value):
if feature.length != -1:
return dtypes.embedding_dtype
return dtypes.array_dtype
else:
raise UnsupportedFeature(feature)
elif isinstance(feature, dict):
if len(feature) == 2 and "bytes" in feature and "path" in feature:
if intermediate:
return dtypes.bytes_dtype
else:
# TODO: guess filetype from bytes
return dtypes.str_dtype
else:
raise UnsupportedFeature(feature)
else:
raise UnsupportedFeature(feature)
2 changes: 1 addition & 1 deletion src/datatypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export type IntegerDataType = BaseDataType<'int'>;
export type FloatDataType = BaseDataType<'float'>;
export type BooleanDataType = BaseDataType<'bool'>;
export type DateTimeDataType = BaseDataType<'datetime'>;
export type ArrayDataType = BaseDataType<'array'>;
export type ArrayDataType = BaseDataType<'array', true>;
export type WindowDataType = BaseDataType<'Window'>;
export type StringDataType = BaseDataType<'str', true>;
export type EmbeddingDataType = BaseDataType<'Embedding', true>;
Expand Down
2 changes: 1 addition & 1 deletion src/stores/dataset/columnFactory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ function makeDatatype(column: Column): DataType {
case 'float':
case 'bool':
case 'Window':
case 'array':
case 'datetime':
return {
kind,
Expand All @@ -20,6 +19,7 @@ function makeDatatype(column: Column): DataType {
optional: column.optional,
};
case 'str':
case 'array':
case 'Embedding':
return {
kind,
Expand Down

0 comments on commit c038b9a

Please sign in to comment.