Skip to content

Commit

Permalink
feat: add simple handling for datasets.Translation
Browse files Browse the repository at this point in the history
  • Loading branch information
neindochoh committed Sep 21, 2023
1 parent 21a8a3d commit 700b5b0
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions renumics/spotlight_plugins/core/huggingface_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def get_column_values(
raw_values = self._dataset.data[column_name].take(indices)

feature = self._dataset.features[column_name]

if isinstance(feature, datasets.Audio) or isinstance(feature, datasets.Image):
return np.array(
[
Expand All @@ -112,18 +113,23 @@ def get_column_values(
],
dtype=object,
)

if isinstance(feature, dict):
return np.array(
[value["bytes"].as_py() for value in raw_values], dtype=object
)

if isinstance(feature, datasets.Sequence):
if is_array_dtype(intermediate_dtype):
return raw_values.to_numpy()
if is_embedding_dtype(intermediate_dtype):
return raw_values.to_numpy()
return np.array([str(value) for value in raw_values])
else:
return raw_values.to_numpy()

if isinstance(feature, datasets.Translation):
return np.array([str(value) for value in raw_values])

return raw_values.to_numpy()

def get_column_metadata(self, _: str) -> ColumnMetadata:
return ColumnMetadata(nullable=True, editable=False)
Expand Down Expand Up @@ -201,5 +207,7 @@ def _get_intermediate_dtype(feature: _FeatureType) -> DType:
return dtypes.file_dtype
else:
raise UnsupportedFeature(feature)
elif isinstance(feature, datasets.Translation):
return dtypes.str_dtype
else:
raise UnsupportedFeature(feature)

0 comments on commit 700b5b0

Please sign in to comment.