Skip to content

Commit

Permalink
Merge pull request #244 from Renumics/feature/huggingface-datasource
Browse files Browse the repository at this point in the history
Feature/huggingface datasource
  • Loading branch information
neindochoh authored Sep 22, 2023
2 parents 6b39a69 + c65411c commit cd16f7b
Show file tree
Hide file tree
Showing 13 changed files with 381 additions and 29 deletions.
5 changes: 2 additions & 3 deletions renumics/spotlight/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Union, Any

import pandas as pd

from renumics.spotlight.layout.nodes import Layout
from renumics.spotlight.analysis.typing import DataIssue
Expand All @@ -20,7 +19,7 @@ class AppConfig:
"""

# dataset
dataset: Optional[Union[Path, pd.DataFrame]] = None
dataset: Any = None
dtypes: Optional[DTypeMap] = None
project_root: Optional[Path] = None

Expand Down
5 changes: 3 additions & 2 deletions renumics/spotlight/data_source/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ def check_generation_id(self, generation_id: int) -> None:
if self.get_generation_id() != generation_id:
raise GenerationIDMismatch()

@property
@abstractmethod
def guess_dtypes(self) -> DTypeMap:
def semantic_dtypes(self) -> DTypeMap:
"""
Guess data source's dtypes.
Semantic dtypes for viewer.
"""

@abstractmethod
Expand Down
125 changes: 123 additions & 2 deletions renumics/spotlight/data_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import hashlib
import io
from typing import List, Optional, Set, Union, cast
import os
import statistics
from typing import Any, List, Optional, Set, Union, cast
import numpy as np
import filetype
import trimesh
import PIL.Image

from renumics.spotlight.cache import external_data_cache
from renumics.spotlight.data_source import DataSource
Expand All @@ -10,13 +15,33 @@
from renumics.spotlight.io import audio
from renumics.spotlight.dtypes import (
CategoryDType,
DType,
DTypeMap,
is_audio_dtype,
is_category_dtype,
is_file_dtype,
is_str_dtype,
is_mixed_dtype,
is_bytes_dtype,
str_dtype,
audio_dtype,
image_dtype,
video_dtype,
mesh_dtype,
embedding_dtype,
array_dtype,
window_dtype,
sequence_1d_dtype,
)

from renumics.spotlight.typing import is_iterable, is_pathtype
from renumics.spotlight.media.mesh import Mesh
from renumics.spotlight.media.video import Video
from renumics.spotlight.media.audio import Audio
from renumics.spotlight.media.image import Image
from renumics.spotlight.media.sequence_1d import Sequence1D
from renumics.spotlight.media.embedding import Embedding


class DataStore:
_data_source: DataSource
Expand Down Expand Up @@ -102,7 +127,14 @@ def get_waveform(self, column_name: str, index: int) -> Optional[np.ndarray]:
return waveform

def _update_dtypes(self) -> None:
guessed_dtypes = self._data_source.guess_dtypes()
guessed_dtypes = self._data_source.semantic_dtypes.copy()

# guess missing dtypes from intermediate dtypes
for col, dtype in self._data_source.intermediate_dtypes.items():
if col not in guessed_dtypes:
guessed_dtypes[col] = self._guess_dtype(col)

# merge guessed semantic dtypes with user dtypes
dtypes = {
**guessed_dtypes,
**{
Expand All @@ -111,6 +143,8 @@ def _update_dtypes(self) -> None:
if column_name in guessed_dtypes
},
}

# determine categories for _automatic_ CategoryDtypes
for column_name, dtype in dtypes.items():
if (
is_category_dtype(dtype)
Expand All @@ -124,4 +158,91 @@ def _update_dtypes(self) -> None:
]
category_names = sorted(cast(Set[str], set(converted_values)))
dtypes[column_name] = CategoryDType(category_names)

self._dtypes = dtypes

def _guess_dtype(self, col: str) -> DType:
intermediate_dtype = self._data_source.intermediate_dtypes[col]
fallback_dtype = _intermediate_to_semantic_dtype(intermediate_dtype)

sample_values = self._data_source.get_column_values(col, slice(10))
sample_dtypes = [_guess_value_dtype(value) for value in sample_values]

try:
mode_dtype = statistics.mode(sample_dtypes)
except statistics.StatisticsError:
return fallback_dtype

return mode_dtype or fallback_dtype


def _intermediate_to_semantic_dtype(intermediate_dtype: DType) -> DType:
if is_file_dtype(intermediate_dtype):
return str_dtype
if is_mixed_dtype(intermediate_dtype):
return str_dtype
if is_bytes_dtype(intermediate_dtype):
return str_dtype
else:
return intermediate_dtype


def _guess_value_dtype(value: Any) -> Optional[DType]:
"""
Infer dtype for value
"""
if isinstance(value, Embedding):
return embedding_dtype
if isinstance(value, Sequence1D):
return sequence_1d_dtype
if isinstance(value, Image):
return image_dtype
if isinstance(value, Audio):
return audio_dtype
if isinstance(value, Video):
return video_dtype
if isinstance(value, Mesh):
return mesh_dtype
if isinstance(value, PIL.Image.Image):
return image_dtype
if isinstance(value, trimesh.Trimesh):
return mesh_dtype
if isinstance(value, np.ndarray):
return _infer_array_dtype(value)

if isinstance(value, bytes) or (is_pathtype(value) and os.path.isfile(value)):
kind = filetype.guess(value)
if kind is not None:
mime_group = kind.mime.split("/")[0]
if mime_group == "image":
return image_dtype
if mime_group == "audio":
return audio_dtype
if mime_group == "video":
return video_dtype
return str_dtype
if is_iterable(value):
try:
value = np.asarray(value, dtype=float)
except (TypeError, ValueError):
pass
else:
return _infer_array_dtype(value)
return None


def _infer_array_dtype(value: np.ndarray) -> DType:
"""
Infer dtype of a numpy array
"""
if value.ndim == 3:
if value.shape[-1] in (1, 3, 4):
return image_dtype
elif value.ndim == 2:
if value.shape[0] == 2 or value.shape[1] == 2:
return sequence_1d_dtype
elif value.ndim == 1:
if len(value) == 2:
return window_dtype
return embedding_dtype
return array_dtype
10 changes: 5 additions & 5 deletions renumics/spotlight/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _user_column_attributes(dtype: spotlight_dtypes.DType) -> Dict[str, Type]:
if spotlight_dtypes.is_sequence_1d_dtype(dtype):
attribute_names["x_label"] = str
attribute_names["y_label"] = str
if spotlight_dtypes.is_file_dtype(dtype):
if spotlight_dtypes.is_filebased_dtype(dtype):
attribute_names["lookup"] = dict
attribute_names["external"] = bool
if spotlight_dtypes.is_audio_dtype(dtype):
Expand Down Expand Up @@ -758,7 +758,7 @@ def from_pandas(

column = prepare_column(column, dtype)

if workdir is not None and spotlight_dtypes.is_file_dtype(dtype):
if workdir is not None and spotlight_dtypes.is_filebased_dtype(dtype):
# For file-based data types, relative paths should be resolved.
str_mask = is_string_mask(column)
column[str_mask] = column[str_mask].apply(
Expand All @@ -777,7 +777,7 @@ def from_pandas(
else:
values = column.to_numpy()

if spotlight_dtypes.is_file_dtype(dtype):
if spotlight_dtypes.is_filebased_dtype(dtype):
attrs["external"] = False # type: ignore
attrs["lookup"] = False # type: ignore

Expand Down Expand Up @@ -2435,7 +2435,7 @@ def _append_column(
elif spotlight_dtypes.is_sequence_1d_dtype(dtype):
attrs["x_label"] = dtype.x_label
attrs["y_label"] = dtype.y_label
elif spotlight_dtypes.is_file_dtype(dtype):
elif spotlight_dtypes.is_filebased_dtype(dtype):
lookup = attrs.get("lookup", None)
if is_iterable(lookup) and not isinstance(lookup, dict):
# Assume that we can keep all the lookup values in memory.
Expand Down Expand Up @@ -3002,7 +3002,7 @@ def _encode_value(
if self._is_ref_column(column):
value = cast(RefColumnInputType, value)
self._assert_valid_value_type(value, dtype, column_name)
if spotlight_dtypes.is_file_dtype(dtype) and isinstance(value, str):
if spotlight_dtypes.is_filebased_dtype(dtype) and isinstance(value, str):
try:
return self._find_lookup_ref(value, column)
except KeyError:
Expand Down
32 changes: 25 additions & 7 deletions renumics/spotlight/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,14 @@ def __init__(self, x_label: str = "x", y_label: str = "y"):
ALIASES: Dict[Any, DType] = {}


def register_dtype(dtype: DType, aliases: list) -> None:
for alias in aliases:
assert dtype.name.lower() not in ALIASES
ALIASES[dtype.name.lower()] = dtype
assert alias not in ALIASES
ALIASES[alias] = dtype
def register_dtype(dtype: DType, aliases: Optional[list] = None) -> None:
assert dtype.name.lower() not in ALIASES
ALIASES[dtype.name.lower()] = dtype

if aliases is not None:
for alias in aliases:
assert alias not in ALIASES
ALIASES[alias] = dtype


bool_dtype = DType("bool")
Expand Down Expand Up @@ -150,9 +152,13 @@ def register_dtype(dtype: DType, aliases: list) -> None:
video_dtype = DType("Video")
"""Video dtype"""
register_dtype(video_dtype, [Video])

mixed_dtype = DType("mixed")
"""Unknown or mixed dtype"""

file_dtype = DType("file")
"""File Dtype (bytes or str(path))"""


DTypeMap = Dict[str, DType]

Expand Down Expand Up @@ -221,9 +227,21 @@ def is_video_dtype(dtype: DType) -> bool:
return dtype.name == "Video"


def is_bytes_dtype(dtype: DType) -> bool:
return dtype.name == "bytes"


def is_mixed_dtype(dtype: DType) -> bool:
return dtype.name == "mixed"


def is_scalar_dtype(dtype: DType) -> bool:
return dtype.name in ("bool", "int", "float")


def is_file_dtype(dtype: DType) -> bool:
return dtype.name in ("Audio", "Image", "Video", "Mesh")
return dtype.name == "file"


def is_filebased_dtype(dtype: DType) -> bool:
return dtype.name in ("Audio", "Image", "Video", "Mesh", "file")
2 changes: 1 addition & 1 deletion renumics/spotlight/io/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def prepare_column(column: pd.Series, dtype: dtypes.DType) -> pd.Series:
str_mask = is_string_mask(column)
column[str_mask] = column[str_mask].apply(try_literal_eval)

if dtypes.is_file_dtype(dtype):
if dtypes.is_filebased_dtype(dtype):
dict_mask = column.map(type) == dict
column[dict_mask] = column[dict_mask].apply(prepare_hugging_face_dict)

Expand Down
4 changes: 1 addition & 3 deletions renumics/spotlight/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,8 @@ def show(
project_root = dataset
else:
project_root = dataset.parent
elif isinstance(dataset, pd.DataFrame) or dataset is None:
project_root = None
else:
raise TypeError("Dataset has invalid type")
project_root = None

if folder:
project_root = Path(folder)
Expand Down
6 changes: 5 additions & 1 deletion renumics/spotlight_plugins/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ def __register__() -> None:
"""
register data sources
"""
from . import pandas_data_source, hdf5_data_source # noqa: F401
from . import (
pandas_data_source, # noqa: F401
hdf5_data_source, # noqa: F401
huggingface_datasource, # noqa: F401
)


def __activate__(app: SpotlightApp) -> None:
Expand Down
5 changes: 3 additions & 2 deletions renumics/spotlight_plugins/core/hdf5_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,13 @@ def column_names(self) -> List[str]:

@property
def intermediate_dtypes(self) -> DTypeMap:
return self.guess_dtypes()
return self.semantic_dtypes

def __len__(self) -> int:
return len(self._table)

def guess_dtypes(self) -> DTypeMap:
@property
def semantic_dtypes(self) -> DTypeMap:
return {
column_name: create_dtype(self._table.get_dtype(column_name))
for column_name in self.column_names
Expand Down
Loading

0 comments on commit cd16f7b

Please sign in to comment.