diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8c6ef304..691a85f0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -376,7 +376,7 @@ jobs: with: python-version: ${{ matrix.python-version }} install-dependencies: false - - name: Cache pip cache folder + - name: '♻️ Cache pip cache folder' uses: actions/cache@v3 with: path: ${{ steps.setup-poetry.outputs.pip-cache-dir }} diff --git a/renumics/spotlight/data_store.py b/renumics/spotlight/data_store.py index 7d70f9bc..42867e5a 100644 --- a/renumics/spotlight/data_store.py +++ b/renumics/spotlight/data_store.py @@ -14,9 +14,12 @@ from renumics.spotlight.data_source.data_source import ColumnMetadata from renumics.spotlight.io import audio from renumics.spotlight.dtypes import ( + ArrayDType, CategoryDType, DType, DTypeMap, + EmbeddingDType, + is_array_dtype, is_audio_dtype, is_category_dtype, is_file_dtype, @@ -29,7 +32,6 @@ video_dtype, mesh_dtype, embedding_dtype, - array_dtype, window_dtype, sequence_1d_dtype, ) @@ -163,7 +165,10 @@ def _update_dtypes(self) -> None: def _guess_dtype(self, col: str) -> DType: intermediate_dtype = self._data_source.intermediate_dtypes[col] - fallback_dtype = _intermediate_to_semantic_dtype(intermediate_dtype) + semantic_dtype = _intermediate_to_semantic_dtype(intermediate_dtype) + + if is_array_dtype(intermediate_dtype): + return semantic_dtype sample_values = self._data_source.get_column_values(col, slice(10)) sample_dtypes = [_guess_value_dtype(value) for value in sample_values] @@ -171,12 +176,28 @@ def _guess_dtype(self, col: str) -> DType: try: mode_dtype = statistics.mode(sample_dtypes) except statistics.StatisticsError: - return fallback_dtype + return semantic_dtype - return mode_dtype or fallback_dtype + return mode_dtype or semantic_dtype def _intermediate_to_semantic_dtype(intermediate_dtype: DType) -> DType: + if is_array_dtype(intermediate_dtype): + if intermediate_dtype.shape is None: + return intermediate_dtype + if intermediate_dtype.shape == (2,): + return window_dtype + if intermediate_dtype.ndim == 1 and intermediate_dtype.shape[0] is not None: + return EmbeddingDType(intermediate_dtype.shape[0]) + if intermediate_dtype.ndim == 1 and intermediate_dtype.shape[0] is None: + return sequence_1d_dtype + if intermediate_dtype.ndim == 2 and ( + intermediate_dtype.shape[0] == 2 or intermediate_dtype.shape[1] == 2 + ): + return sequence_1d_dtype + if intermediate_dtype.ndim == 3 and intermediate_dtype.shape[-1] in (1, 3, 4): + return image_dtype + return intermediate_dtype if is_file_dtype(intermediate_dtype): return str_dtype if is_mixed_dtype(intermediate_dtype): @@ -208,7 +229,7 @@ def _guess_value_dtype(value: Any) -> Optional[DType]: if isinstance(value, trimesh.Trimesh): return mesh_dtype if isinstance(value, np.ndarray): - return _infer_array_dtype(value) + return ArrayDType(value.shape) if isinstance(value, bytes) or (is_pathtype(value) and os.path.isfile(value)): kind = filetype.guess(value) @@ -227,22 +248,5 @@ def _guess_value_dtype(value: Any) -> Optional[DType]: except (TypeError, ValueError): pass else: - return _infer_array_dtype(value) + return ArrayDType(value.shape) return None - - -def _infer_array_dtype(value: np.ndarray) -> DType: - """ - Infer dtype of a numpy array - """ - if value.ndim == 3: - if value.shape[-1] in (1, 3, 4): - return image_dtype - elif value.ndim == 2: - if value.shape[0] == 2 or value.shape[1] == 2: - return sequence_1d_dtype - elif value.ndim == 1: - if len(value) == 2: - return window_dtype - return embedding_dtype - return array_dtype diff --git a/renumics/spotlight/dtypes/__init__.py b/renumics/spotlight/dtypes/__init__.py index 7348097c..0e24ea10 100644 --- a/renumics/spotlight/dtypes/__init__.py +++ b/renumics/spotlight/dtypes/__init__.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Dict, Iterable, Optional, Tuple, Union import numpy as np from typing_extensions import TypeGuard @@ -80,6 +80,38 @@ def inverted_categories(self) -> Optional[Dict[int, str]]: return self._inverted_categories +class ArrayDType(DType): + """ + Array dtype with optional shape. + """ + + shape: Optional[Tuple[Optional[int], ...]] + + def __init__(self, shape: Optional[Tuple[Optional[int], ...]] = None): + super().__init__("array") + self.shape = shape + + @property + def ndim(self) -> int: + if self.shape is None: + return 0 + return len(self.shape) + + +class EmbeddingDType(DType): + """ + Embedding dtype with optional length. + """ + + length: Optional[int] + + def __init__(self, length: Optional[int] = None): + super().__init__("Embedding") + if length is not None and length < 0: + raise ValueError(f"Length must be non-negative, but {length} received.") + self.length = length + + class Sequence1DDType(DType): """ 1D-sequence dtype with predefined axis labels. @@ -131,10 +163,10 @@ def register_dtype(dtype: DType, aliases: Optional[list] = None) -> None: window_dtype = DType("Window") """Window dtype""" register_dtype(window_dtype, [Window]) -embedding_dtype = DType("Embedding") +embedding_dtype = EmbeddingDType() """Embedding dtype""" register_dtype(embedding_dtype, [Embedding]) -array_dtype = DType("array") +array_dtype = ArrayDType() """numpy array dtype""" register_dtype(array_dtype, [np.ndarray]) image_dtype = DType("Image") @@ -195,7 +227,7 @@ def is_category_dtype(dtype: DType) -> TypeGuard[CategoryDType]: return dtype.name == "Category" -def is_array_dtype(dtype: DType) -> bool: +def is_array_dtype(dtype: DType) -> TypeGuard[ArrayDType]: return dtype.name == "array" @@ -203,7 +235,7 @@ def is_window_dtype(dtype: DType) -> bool: return dtype.name == "Window" -def is_embedding_dtype(dtype: DType) -> bool: +def is_embedding_dtype(dtype: DType) -> TypeGuard[EmbeddingDType]: return dtype.name == "Embedding" diff --git a/renumics/spotlight_plugins/core/huggingface_datasource.py b/renumics/spotlight_plugins/core/huggingface_datasource.py index 49a50f63..870245cb 100644 --- a/renumics/spotlight_plugins/core/huggingface_datasource.py +++ b/renumics/spotlight_plugins/core/huggingface_datasource.py @@ -2,20 +2,31 @@ import datasets import numpy as np -from renumics.spotlight import dtypes from renumics.spotlight.data_source import DataSource from renumics.spotlight.data_source.decorator import datasource from renumics.spotlight.dtypes import ( + ArrayDType, + CategoryDType, DType, DTypeMap, + audio_dtype, + bool_dtype, + datetime_dtype, + float_dtype, + image_dtype, + int_dtype, + file_dtype, + bytes_dtype, is_array_dtype, is_embedding_dtype, is_file_dtype, is_float_dtype, is_int_dtype, + str_dtype, ) from renumics.spotlight.data_source.data_source import ColumnMetadata +from renumics.spotlight.logging import logger _FeatureType = Union[ @@ -84,7 +95,7 @@ def get_uid(self) -> str: return self._dataset._fingerprint def get_name(self) -> str: - return self._dataset.builder_name + return f"🤗 Dataset {self._dataset.builder_name or ''}" def get_column_values( self, @@ -125,14 +136,41 @@ def get_column_values( if isinstance(feature, datasets.Sequence): if is_array_dtype(intermediate_dtype): - return raw_values.to_numpy() + values = [ + _convert_object_array(value) for value in raw_values.to_numpy() + ] + return_array = np.empty(len(values), dtype=object) + return_array[:] = values + return return_array if is_embedding_dtype(intermediate_dtype): return raw_values.to_numpy() return np.array([str(value) for value in raw_values]) + if isinstance( + feature, + (datasets.Array2D, datasets.Array3D, datasets.Array4D, datasets.Array5D), + ): + if is_array_dtype(intermediate_dtype): + values = [ + _convert_object_array(value) for value in raw_values.to_numpy() + ] + return_array = np.empty(len(values), dtype=object) + return_array[:] = values + return return_array + return np.array([str(value) for value in raw_values]) + if isinstance(feature, datasets.Translation): return np.array([str(value) for value in raw_values]) + if isinstance(feature, datasets.Value): + hf_dtype = feature.dtype + if hf_dtype.startswith("duration"): + return raw_values.to_numpy().astype(int) + if hf_dtype.startswith("time32") or hf_dtype.startswith("time64"): + return raw_values.to_numpy().astype(str) + if hf_dtype.startswith("timestamp[ns"): + return raw_values.to_numpy().astype(int) + return raw_values.to_numpy() def get_column_metadata(self, _: str) -> ColumnMetadata: @@ -141,69 +179,100 @@ def get_column_metadata(self, _: str) -> ColumnMetadata: def _guess_semantic_dtype(feature: _FeatureType) -> Optional[DType]: if isinstance(feature, datasets.Audio): - return dtypes.audio_dtype + return audio_dtype if isinstance(feature, datasets.Image): - return dtypes.image_dtype - if isinstance(feature, datasets.Sequence): - if isinstance(feature.feature, datasets.Value): - if feature.length != -1: - return dtypes.embedding_dtype + return image_dtype return None def _get_intermediate_dtype(feature: _FeatureType) -> DType: if isinstance(feature, datasets.Value): - hf_dtype = cast(datasets.Value, feature).dtype + hf_dtype = feature.dtype if hf_dtype == "bool": - return dtypes.bool_dtype + return bool_dtype elif hf_dtype.startswith("int"): - return dtypes.int_dtype + return int_dtype elif hf_dtype.startswith("uint"): - return dtypes.int_dtype + return int_dtype elif hf_dtype.startswith("float"): - return dtypes.float_dtype + return float_dtype elif hf_dtype.startswith("time32"): - return dtypes.datetime_dtype + return str_dtype elif hf_dtype.startswith("time64"): - return dtypes.datetime_dtype + return str_dtype elif hf_dtype.startswith("timestamp"): - return dtypes.datetime_dtype + if hf_dtype.startswith("timestamp[ns"): + return int_dtype + return datetime_dtype elif hf_dtype.startswith("date32"): - return dtypes.datetime_dtype + return datetime_dtype elif hf_dtype.startswith("date64"): - return dtypes.datetime_dtype + return datetime_dtype elif hf_dtype.startswith("duration"): - return dtypes.float_dtype + return int_dtype elif hf_dtype.startswith("decimal"): - return dtypes.float_dtype + return float_dtype elif hf_dtype == "binary": - return dtypes.bytes_dtype + return bytes_dtype elif hf_dtype == "large_binary": - return dtypes.bytes_dtype + return bytes_dtype elif hf_dtype == "string": - return dtypes.str_dtype + return str_dtype elif hf_dtype == "large_string": - return dtypes.str_dtype + return str_dtype else: - raise UnsupportedFeature(feature) + logger.warning(f"Unsupported Hugging Face value dtype: {hf_dtype}.") + return str_dtype elif isinstance(feature, datasets.ClassLabel): - return dtypes.CategoryDType(categories=cast(datasets.ClassLabel, feature).names) + return CategoryDType(categories=cast(datasets.ClassLabel, feature).names) elif isinstance(feature, datasets.Audio): - return dtypes.file_dtype + return file_dtype elif isinstance(feature, datasets.Image): - return dtypes.file_dtype + return file_dtype elif isinstance(feature, datasets.Sequence): inner_dtype = _get_intermediate_dtype(feature.feature) if is_int_dtype(inner_dtype) or is_float_dtype(inner_dtype): - return dtypes.array_dtype - else: - return dtypes.str_dtype + return ArrayDType((None if feature.length == -1 else feature.length,)) + if is_array_dtype(inner_dtype): + if inner_dtype.shape is None: + return str_dtype + shape = ( + None if feature.length == -1 else feature.length, + *inner_dtype.shape, + ) + if shape.count(None) > 1: + return str_dtype + return ArrayDType(shape) + return str_dtype + elif isinstance(feature, list): + inner_dtype = _get_intermediate_dtype(feature[0]) + if is_int_dtype(inner_dtype) or is_float_dtype(inner_dtype): + return ArrayDType((None,)) + if is_array_dtype(inner_dtype): + if inner_dtype.shape is None: + return str_dtype + shape = (None, *inner_dtype.shape) + if shape.count(None) > 1: + return str_dtype + return ArrayDType(shape) + return str_dtype + elif isinstance( + feature, + (datasets.Array2D, datasets.Array3D, datasets.Array4D, datasets.Array5D), + ): + return ArrayDType(feature.shape) elif isinstance(feature, dict): if len(feature) == 2 and "bytes" in feature and "path" in feature: - return dtypes.file_dtype + return file_dtype else: - return dtypes.str_dtype + return str_dtype elif isinstance(feature, datasets.Translation): - return dtypes.str_dtype - else: - raise UnsupportedFeature(feature) + return str_dtype + logger.warning(f"Unsupported Hugging Face feature: {feature}.") + return str_dtype + + +def _convert_object_array(value: np.ndarray) -> np.ndarray: + if value.dtype.type is np.object_: + return np.array([_convert_object_array(x) for x in value]) + return value diff --git a/tests/integration/huggingface/__init__.py b/tests/integration/huggingface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/huggingface/conftest.py b/tests/integration/huggingface/conftest.py new file mode 100644 index 00000000..5f6f089b --- /dev/null +++ b/tests/integration/huggingface/conftest.py @@ -0,0 +1,17 @@ +""" +Pytest Fixtures for Hugging Face tests +""" + +import datasets +import pytest + +from .dataset import create_hf_dataset + + +@pytest.fixture +def dataset() -> datasets.Dataset: + """ + H5 Dataset for tests + """ + + return create_hf_dataset() diff --git a/tests/integration/huggingface/dataset.py b/tests/integration/huggingface/dataset.py new file mode 100644 index 00000000..3b81abb6 --- /dev/null +++ b/tests/integration/huggingface/dataset.py @@ -0,0 +1,146 @@ +""" +Data for Hugging Face tests +""" + +import datetime + +import datasets + + +DATA = { + "bool": [True, False, False], + "int": [-1, 1, 100000], + "uint": [1, 1, 30000], + "float": [1.0, float("nan"), 1000], + "string": ["foo", "barbaz", ""], + "label": ["foo", "bar", "foo"], + "binary": [b"foo", b"bar", b""], + "duration": [-1, 2, 10], + "decimal": [1.0, 3.0, 1000], + "date": [datetime.date.min, datetime.date(2001, 2, 15), datetime.date.max], + "time": [ + datetime.time.min, + datetime.time(14, 24, 15, 2672), + datetime.time.max, + ], + "timestamp": [ + datetime.datetime(1970, 2, 15, 14, 24, 15, 2672), + datetime.datetime(2001, 2, 15, 14, 24, 15, 2672), + datetime.datetime(2170, 2, 15, 14, 24, 15, 2672), + ], + "timestamp_ns": [ + datetime.datetime(1970, 2, 15, 14, 24, 15, 2672), + datetime.datetime(2001, 2, 15, 14, 24, 15, 2672), + datetime.datetime(2170, 2, 15, 14, 24, 15, 2672), + ], + "embedding": [[1, 2, 3, 4], [1, 6, 3, 7], [-1, -2, -3, -4]], + "audio": [ + "data/audio/mono/gs-16b-1c-44100hz.mp3", + "data/audio/1.wav", + "data/audio/stereo/gs-16b-2c-44100hz.ogg", + ], + "image": [ + "data/images/nature-256p.ico", + "data/images/sea-360p.gif", + "data/images/nature-360p.jpg", + ], + # HF sequence as Spotlight sequence + "sequence_1d": [[1, 2, 3, 4], [1, 6, 3], [-1, -2, float("nan"), -4, 10]], + "sequence_2d": [ + [[1, 2, 3, 4], [-1, 3, 1, 6]], + [[1, -3, 10], [1, 6, 3]], + [[-10, 0, 10], [-1, -2, -3]], + ], + "sequence_2d_t": [[[5, 3], [2, 5], [10, 8]], [], [[-1, 1], [1, 10]]], + # HF sequence as Spotlight array + "sequence_2d_array": [ + [[1, 2, 3, 4], [-1, 3, 1, 6], [1, 2, 4, 4]], + [[1, -3, 10], [1, 6, 3], [1, float("nan"), 4]], + [[-10, 0, 10], [-1, -2, -3], [1, 2, 4]], + ], + "sequence_3d_array": [ + [[[1, 2, 3, 4], [-1, 3, 1, 6], [1, 2, 4, 4]]], + [[[1, -3, 10], [1, 6, 3], [1, float("nan"), 4]]], + [[[-10, 0, 10], [-1, -2, -3], [1, 2, 4]]], + ], + # HF 2D array as Spotlight sequence + "array_2d_sequence": [ + [[1, 2, 3], [-1, 3, 1]], + [[1, -3, 10], [1, 6, 3]], + [[-10, 0, 10], [-1, -2, -3]], + ], + "array_2d_t_sequence": [ + [[5, 3], [2, 5], [10, 8]], + [[float("nan"), 1], [1, 1], [2, 2]], + [[-1, 1], [1, 10], [10, 1]], + ], + "array_2d_vlen_sequence": [ + [[5, 3], [2, 5], [10, 8]], + [], + [[-1, 1], [1, 10]], + ], + # HF 4D array as Spotlight array + "array_4d": [ + [[[[1.0, 1.0, -10.0]]], [[[-1.0, 1.0, -1.0]]], [[[2.0, 1.0, 1.0]]]], + [ + [[[2.0, -3.0, 0.0]]], + [[[3.0, 6.0, -2.0]]], + [[[4.0, float("nan"), 2.0]]], + [[[4.0, float("nan"), 2.0]]], + ], + [[[[3.0, 10.0, 10.0]]], [[[6.0, 3.0, -3.0]]], [[[4.0, 4.0, 4.0]]]], + ], + # HF list as Spotlight embedding + "list_sequence": [[1, 2, 3], [1, 6, 3, 7, 8], [-1, -2, -3, -4]], +} + +FEATURES = { + "bool": datasets.Value("bool"), + "int": datasets.Value("int32"), + "uint": datasets.Value("uint16"), + "float": datasets.Value("float64"), + "string": datasets.Value("string"), + "label": datasets.ClassLabel(num_classes=4, names=["foo", "bar", "baz", "barbaz"]), + "binary": datasets.Value("binary"), + "duration": datasets.Value("duration[s]"), + "decimal": datasets.Value("decimal128(10, 2)"), + "date": datasets.Value("date32"), + "time": datasets.Value("time64[us]"), + "timestamp": datasets.Value("timestamp[us]"), + "timestamp_ns": datasets.Value("timestamp[ns]"), + "audio": datasets.Audio(), + "image": datasets.Image(), + "embedding": datasets.Sequence(feature=datasets.Value("float64"), length=4), + "sequence_1d": datasets.Sequence(feature=datasets.Value("float64")), + "sequence_2d": datasets.Sequence( + feature=datasets.Sequence(feature=datasets.Value("float64")), + length=2, + ), + "sequence_2d_t": datasets.Sequence( + feature=datasets.Sequence(feature=datasets.Value("float64"), length=2), + ), + "sequence_2d_array": datasets.Sequence( + feature=datasets.Sequence(feature=datasets.Value("float64")), + length=3, + ), + "sequence_3d_array": datasets.Sequence( + feature=datasets.Sequence( + feature=datasets.Sequence(feature=datasets.Value("float64")), + length=3, + ), + length=1, + ), + "array_2d_sequence": datasets.Array2D(shape=(2, 3), dtype="float64"), + "array_2d_t_sequence": datasets.Array2D(shape=(3, 2), dtype="float64"), + "array_2d_vlen_sequence": datasets.Array2D(shape=(None, 2), dtype="float64"), + "array_4d": datasets.Array4D(shape=(None, 1, 1, 3), dtype="float64"), + "list_sequence": [datasets.Value("float64")], +} + + +def create_hf_dataset() -> datasets.Dataset: + ds = datasets.Dataset.from_dict( + DATA, + features=datasets.Features(FEATURES), + ) + return ds diff --git a/tests/integration/huggingface/test_hf.py b/tests/integration/huggingface/test_hf.py new file mode 100644 index 00000000..9380380f --- /dev/null +++ b/tests/integration/huggingface/test_hf.py @@ -0,0 +1,37 @@ +""" +Integration Test on API level for h5 data sources +""" +import pytest +import httpx + +import datasets + +from renumics import spotlight + +from .dataset import DATA + + +def test_get_table_returns_http_ok(dataset: datasets.Dataset) -> None: + """ + Ensure /api/table/ returns a valid response + """ + viewer = spotlight.show(dataset, no_browser=True, wait=False) + response = httpx.Client(base_url=viewer.url).get("/api/table/") + viewer.close() + assert response.status_code == 200 + + +@pytest.mark.parametrize("col", DATA.keys()) +def test_get_cell_returns_http_ok(dataset: str, col: str) -> None: + """ + Serve h5 dataset and get cell data for dtype + """ + viewer = spotlight.show(dataset, no_browser=True, wait=False) + gen_id = ( + httpx.Client(base_url=viewer.url).get("/api/table/").json()["generation_id"] + ) + response = httpx.Client(base_url=viewer.url).get( + f"/api/table/{col}/0?generation_id={gen_id}" + ) + viewer.close() + assert response.status_code == 200