diff --git a/Makefile b/Makefile index c4c4cf01..61fa4be3 100644 --- a/Makefile +++ b/Makefile @@ -45,6 +45,7 @@ typecheck: ## Typecheck all source files poetry run mypy -p renumics.spotlight poetry run mypy -p renumics.spotlight_plugins.core poetry run mypy scripts + poetry run mypy tests pnpm run typecheck .PHONY: lint diff --git a/pyproject.toml b/pyproject.toml index d762b2d6..dda36258 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,6 +140,9 @@ files = [ [tool.ruff] line-length = 100 +ignore = [ + "E501" +] [tool.mypy] ignore_missing_imports = false @@ -168,7 +171,8 @@ module = [ "cleanlab.*", "machineid", "filetype", - "datasets" + "datasets", + "diffimg" ] ignore_missing_imports = true diff --git a/renumics/spotlight/__init__.py b/renumics/spotlight/__init__.py index 64e3d471..de22c91b 100644 --- a/renumics/spotlight/__init__.py +++ b/renumics/spotlight/__init__.py @@ -4,16 +4,15 @@ from .__version__ import __version__ # noqa: F401 from .dataset import Dataset # noqa: F401 -from .dtypes import ( +from .media import ( Audio, # noqa: F401 - Category, # noqa: F401 Embedding, # noqa: F401 Image, # noqa: F401 Mesh, # noqa: F401 Sequence1D, # noqa: F401 Video, # noqa: F401 - Window, # noqa: F401 ) +from .dtypes.legacy import Category, Window # noqa: F401 from .viewer import Viewer, close, viewers, show from .plugin_loader import load_plugins from .settings import settings diff --git a/renumics/spotlight/analysis/analyzers/cleanlab.py b/renumics/spotlight/analysis/analyzers/cleanlab.py index c7a388a0..84f2deae 100644 --- a/renumics/spotlight/analysis/analyzers/cleanlab.py +++ b/renumics/spotlight/analysis/analyzers/cleanlab.py @@ -7,9 +7,9 @@ import numpy as np import cleanlab.outlier -from renumics.spotlight.dtypes import Embedding -from renumics.spotlight.data_store import DataStore +from renumics.spotlight.data_store import DataStore +from renumics.spotlight.dtypes import is_embedding_dtype from ..decorator import data_analyzer from ..typing import DataIssue @@ -23,7 +23,9 @@ def analyze_with_cleanlab( """ embedding_columns = ( - col for col in columns if data_store.dtypes.get(col) == Embedding + col + for col in columns + if col in data_store.dtypes and is_embedding_dtype(data_store.dtypes[col]) ) for column_name in embedding_columns: diff --git a/renumics/spotlight/app.py b/renumics/spotlight/app.py index 3cf42099..8b23e08d 100644 --- a/renumics/spotlight/app.py +++ b/renumics/spotlight/app.py @@ -48,13 +48,14 @@ from renumics.spotlight.plugin_loader import load_plugins from renumics.spotlight.develop.project import get_project_info from renumics.spotlight.backend.middlewares.timing import add_timing_middleware -from renumics.spotlight.dtypes.typing import ColumnTypeMapping from renumics.spotlight.app_config import AppConfig from renumics.spotlight.data_source import DataSource, create_datasource from renumics.spotlight.layout.default import DEFAULT_LAYOUT from renumics.spotlight.data_store import DataStore +from renumics.spotlight.dtypes import DTypeMap + class IssuesUpdatedMessage(Message): """ @@ -79,7 +80,7 @@ class SpotlightApp(FastAPI): # datasource _dataset: Optional[Union[PathType, pd.DataFrame]] - _user_dtypes: ColumnTypeMapping + _user_dtypes: DTypeMap _data_source: Optional[DataSource] _data_store: Optional[DataStore] diff --git a/renumics/spotlight/app_config.py b/renumics/spotlight/app_config.py index 779d60fe..0e62bf06 100644 --- a/renumics/spotlight/app_config.py +++ b/renumics/spotlight/app_config.py @@ -10,8 +10,7 @@ from renumics.spotlight.layout.nodes import Layout from renumics.spotlight.analysis.typing import DataIssue - -from renumics.spotlight.dtypes.typing import ColumnTypeMapping +from renumics.spotlight.dtypes import DTypeMap @dataclass(frozen=True) @@ -22,7 +21,7 @@ class AppConfig: # dataset dataset: Optional[Union[Path, pd.DataFrame]] = None - dtypes: Optional[ColumnTypeMapping] = None + dtypes: Optional[DTypeMap] = None project_root: Optional[Path] = None # data analysis diff --git a/renumics/spotlight/backend/exceptions.py b/renumics/spotlight/backend/exceptions.py index 05aabfbb..cb8ee678 100644 --- a/renumics/spotlight/backend/exceptions.py +++ b/renumics/spotlight/backend/exceptions.py @@ -2,12 +2,13 @@ Exceptions to be raised from backend. """ -from typing import Any, Optional, Type +from typing import Any, Optional from fastapi import status -from renumics.spotlight.dtypes.typing import ColumnType from renumics.spotlight.typing import IndexType, PathOrUrlType, PathType +from renumics.spotlight.dtypes import DType + class Problem(Exception): """ @@ -127,7 +128,7 @@ class ConversionFailed(Problem): Value cannot be converted to the desired dtype. """ - def __init__(self, dtype: Type[ColumnType], value: Any) -> None: + def __init__(self, dtype: DType, value: Any) -> None: self.dtype = dtype self.value = value super().__init__( diff --git a/renumics/spotlight/backend/tasks/reduction.py b/renumics/spotlight/backend/tasks/reduction.py index 381d48db..00207c38 100644 --- a/renumics/spotlight/backend/tasks/reduction.py +++ b/renumics/spotlight/backend/tasks/reduction.py @@ -9,9 +9,8 @@ from sklearn import preprocessing from renumics.spotlight.dataset.exceptions import ColumnNotExistsError -from renumics.spotlight.dtypes import Category, Embedding - from renumics.spotlight.data_store import DataStore +from renumics.spotlight.dtypes import is_category_dtype, is_embedding_dtype SEED = 42 @@ -36,7 +35,7 @@ def align_data( for column_name in column_names: dtype = data_store.dtypes[column_name] column_values = data_store.get_converted_values(column_name, indices) - if dtype is Embedding: + if is_embedding_dtype(dtype): embedding_length = max( 0 if x is None else len(cast(np.ndarray, x)) for x in column_values ) @@ -50,7 +49,7 @@ def align_data( ] ) ) - elif dtype is Category: + elif is_category_dtype(dtype): na_mask = np.array(column_values) == -1 one_hot_values = preprocessing.label_binarize( column_values, classes=sorted(set(column_values).difference({-1})) # type: ignore diff --git a/renumics/spotlight/cli.py b/renumics/spotlight/cli.py index 358db704..9f9c230e 100644 --- a/renumics/spotlight/cli.py +++ b/renumics/spotlight/cli.py @@ -6,44 +6,36 @@ import platform import signal import sys -from typing import Optional, Tuple, Union, List +from typing import Dict, Optional, Tuple, Union, List from pathlib import Path import click from renumics import spotlight -from renumics.spotlight.dtypes.typing import COLUMN_TYPES_BY_NAME, ColumnTypeMapping from renumics.spotlight import logging def cli_dtype_callback( _ctx: click.Context, _param: click.Option, value: Tuple[str, ...] -) -> Optional[ColumnTypeMapping]: +) -> Optional[Dict[str, str]]: """ Parse column types from multiple strings in format `COLUMN_NAME=DTYPE` to a dict. """ if not value: return None - dtype = {} + dtypes: Dict[str, str] = {} for mapping in value: try: - column_name, dtype_name = mapping.split("=") + column_name, dtype = mapping.split("=") except ValueError as e: raise click.BadParameter( "Column type setting separator '=' not specified or specified " "more than once." ) from e - try: - column_type = COLUMN_TYPES_BY_NAME[dtype_name] - except KeyError as e: - raise click.BadParameter( - f"Column types from {list(COLUMN_TYPES_BY_NAME.keys())} " - f"expected, but value '{dtype_name}' recived." - ) from e - dtype[column_name] = column_type - return dtype + dtypes[column_name] = dtype + return dtypes @click.command() # type: ignore @@ -84,9 +76,7 @@ def cli_dtype_callback( type=click.UNPROCESSED, callback=cli_dtype_callback, multiple=True, - help="Custom column types setting (use COLUMN_NAME={" - + "|".join(sorted(COLUMN_TYPES_BY_NAME.keys())) - + "} notation). Multiple settings allowed.", + help="Custom column types setting (use COLUMN_NAME=DTYPE notation). Multiple settings allowed.", ) @click.option( "--no-browser", @@ -119,7 +109,7 @@ def main( host: str, port: Union[int, str], layout: Optional[str], - dtype: Optional[ColumnTypeMapping], + dtype: Optional[Dict[str, str]], no_browser: bool, filebrowsing: bool, analyze: List[str], diff --git a/renumics/spotlight/data_source/data_source.py b/renumics/spotlight/data_source/data_source.py index 6437321f..d894d66a 100644 --- a/renumics/spotlight/data_source/data_source.py +++ b/renumics/spotlight/data_source/data_source.py @@ -2,7 +2,7 @@ import dataclasses from abc import ABC, abstractmethod -from typing import Dict, Optional, List, Any, Union +from typing import Optional, List, Any, Union import pandas as pd import numpy as np @@ -12,11 +12,10 @@ ColumnExistsError, ColumnNotExistsError, ) -from renumics.spotlight.dtypes.typing import ( - ColumnTypeMapping, -) from renumics.spotlight.backend.exceptions import GenerationIDMismatch, NoRowFound +from renumics.spotlight.dtypes import DTypeMap + @dataclasses.dataclass class ColumnMetadata: @@ -84,7 +83,7 @@ def check_generation_id(self, generation_id: int) -> None: raise GenerationIDMismatch() @abstractmethod - def guess_dtypes(self) -> ColumnTypeMapping: + def guess_dtypes(self) -> DTypeMap: """ Guess data source's dtypes. """ @@ -117,12 +116,6 @@ def get_column_metadata(self, column_name: str) -> ColumnMetadata: Get extra info of a column. """ - @abstractmethod - def get_column_categories(self, column_name: str) -> Dict[str, int]: - """ - Get column categories (for categorical dtype) - """ - def _assert_index_exists(self, index: int) -> None: if index < -len(self) or index >= len(self): raise NoRowFound(index) diff --git a/renumics/spotlight/data_store.py b/renumics/spotlight/data_store.py index ff71d4f7..b140b808 100644 --- a/renumics/spotlight/data_store.py +++ b/renumics/spotlight/data_store.py @@ -1,36 +1,51 @@ import hashlib import io -from typing import List, Optional, Union +from typing import List, Optional, Set, Union, cast import numpy as np from renumics.spotlight.cache import external_data_cache -from renumics.spotlight.dtypes.typing import ColumnTypeMapping from renumics.spotlight.data_source import DataSource -from renumics.spotlight.dtypes import Audio, Category -from renumics.spotlight.dtypes.conversion import ( - ConvertedValue, - DTypeOptions, - convert_to_dtype, -) +from renumics.spotlight.dtypes.conversion import ConvertedValue, convert_to_dtype from renumics.spotlight.data_source.data_source import ColumnMetadata - from renumics.spotlight.io import audio +from renumics.spotlight.dtypes import ( + CategoryDType, + DTypeMap, + is_audio_dtype, + is_category_dtype, + is_str_dtype, + str_dtype, +) class DataStore: - _dtypes: ColumnTypeMapping + _dtypes: DTypeMap _data_source: DataSource - def __init__(self, data_source: DataSource, user_dtypes: ColumnTypeMapping) -> None: + def __init__(self, data_source: DataSource, user_dtypes: DTypeMap) -> None: self._data_source = data_source - dtypes = self._data_source.guess_dtypes() - dtypes.update( - { - column_name: column_type - for column_name, column_type in user_dtypes.items() - if column_name in dtypes - } - ) + guessed_dtypes = self._data_source.guess_dtypes() + dtypes = { + **guessed_dtypes, + **{ + column_name: dtype + for column_name, dtype in user_dtypes.items() + if column_name in guessed_dtypes + }, + } + for column_name, dtype in dtypes.items(): + if ( + is_category_dtype(dtype) + and dtype.categories is None + and is_str_dtype(guessed_dtypes[column_name]) + ): + normalized_values = self._data_source.get_column_values(column_name) + converted_values = [ + convert_to_dtype(value, str_dtype, simple=True, check=True) + for value in normalized_values + ] + category_names = sorted(cast(Set[str], set(converted_values))) + dtypes[column_name] = CategoryDType(category_names) self._dtypes = dtypes def __len__(self) -> int: @@ -53,7 +68,7 @@ def column_names(self) -> List[str]: return self._data_source.column_names @property - def dtypes(self) -> ColumnTypeMapping: + def dtypes(self) -> DTypeMap: return self._dtypes def check_generation_id(self, generation_id: int) -> None: @@ -71,16 +86,8 @@ def get_converted_values( ) -> List[ConvertedValue]: dtype = self._dtypes[column_name] normalized_values = self._data_source.get_column_values(column_name, indices) - if dtype is Category: - dtype_options = DTypeOptions( - categories=self._data_source.get_column_categories(column_name) - ) - else: - dtype_options = DTypeOptions() converted_values = [ - convert_to_dtype( - value, dtype, dtype_options=dtype_options, simple=simple, check=check - ) + convert_to_dtype(value, dtype, simple=simple, check=check) for value in normalized_values ] return converted_values @@ -94,7 +101,7 @@ def get_waveform(self, column_name: str, index: int) -> Optional[np.ndarray]: """ return the waveform of an audio cell """ - assert self._dtypes[column_name] is Audio + assert is_audio_dtype(self._dtypes[column_name]) blob = self.get_converted_value(column_name, index, simple=False) if blob is None: diff --git a/renumics/spotlight/dataset/__init__.py b/renumics/spotlight/dataset/__init__.py index 7e837a85..4a4ae153 100644 --- a/renumics/spotlight/dataset/__init__.py +++ b/renumics/spotlight/dataset/__init__.py @@ -29,7 +29,7 @@ import trimesh import validators from loguru import logger -from typing_extensions import Literal, TypeGuard +from typing_extensions import TypeGuard from renumics.spotlight.__version__ import __version__ from renumics.spotlight.io.pandas import ( @@ -57,23 +57,49 @@ Video, Window, ) -from renumics.spotlight.dtypes.base import DType, FileBasedDType -from renumics.spotlight.dtypes.typing import ( - ColumnType, - ColumnTypeMapping, - FileBasedColumnType, - get_column_type, - get_column_type_name, - is_file_based_column_type, -) from renumics.spotlight.dtypes.conversion import prepare_path_or_url + +from renumics.spotlight.dtypes import ( + CategoryDType, + Sequence1DDType, + create_dtype, + is_bool_dtype, + is_int_dtype, + is_float_dtype, + is_str_dtype, + is_datetime_dtype, + is_category_dtype, + is_array_dtype, + is_window_dtype, + is_embedding_dtype, + is_sequence_1d_dtype, + is_audio_dtype, + is_image_dtype, + is_mesh_dtype, + is_video_dtype, + is_file_dtype, + is_scalar_dtype, + DType, + bool_dtype, + int_dtype, + float_dtype, + str_dtype, + datetime_dtype, + array_dtype, + window_dtype, + embedding_dtype, + audio_dtype, + image_dtype, + mesh_dtype, + video_dtype, +) + + from . import exceptions from .typing import ( - REF_COLUMN_TYPE_NAMES, - SimpleColumnType, - RefColumnType, - ExternalColumnType, + OutputType, + ExternalOutputType, BoolColumnInputType, IntColumnInputType, FloatColumnInputType, @@ -94,10 +120,29 @@ ) INTERNAL_COLUMN_NAMES = ["__last_edited_by__", "__last_edited_at__"] +INTERNAL_COLUMN_DTYPES = [str_dtype, datetime_dtype] _EncodedColumnType = Optional[Union[bool, int, float, str, np.ndarray, h5py.Reference]] +VALUE_TYPE_BY_DTYPE_NAME = { + "bool": bool, + "int": int, + "float": float, + "str": str, + "datetime": datetime, + "Category": Category, + "array": np.ndarray, + "Window": Window, + "Embedding": Embedding, + "Sequence1D": Sequence1D, + "Audio": Audio, + "Image": Image, + "Video": Video, + "Mesh": Mesh, +} + + def get_current_datetime() -> datetime: """ Get current datetime with timezone. @@ -135,50 +180,70 @@ def unescape_dataset_name(escaped_name: str) -> str: return name -_ALLOWED_COLUMN_TYPES: Dict[Type[ColumnType], Tuple[Type, ...]] = { - bool: (np.bool_,), - int: (np.integer,), - float: (np.floating,), - datetime: (np.datetime64,), +_ALLOWED_COLUMN_TYPES: Dict[str, Tuple[Type, ...]] = { + "bool": (bool, np.bool_), + "int": (int, np.integer), + "float": (float, np.floating), + "str": (str,), + "datetime": (datetime, np.datetime64), + "Category": (str,), + "array": ( + np.ndarray, + list, + tuple, + bool, + int, + float, + np.bool_, + np.integer, + np.floating, + ), + "Window": (np.ndarray, list, tuple), + "Embedding": (Embedding, np.ndarray, list, tuple), + "Sequence1D": (Sequence1D, np.ndarray, list, tuple), + "Audio": (Audio, bytes, str, os.PathLike), + "Image": (Image, bytes, str, os.PathLike, np.ndarray, list, tuple), + "Mesh": (Mesh, trimesh.Trimesh, str, os.PathLike), + "Video": (Video, bytes, str, os.PathLike), } -_ALLOWED_COLUMN_DTYPES: Dict[Type[ColumnType], Tuple[Type, ...]] = { - bool: (np.bool_,), - int: (np.integer,), - float: (np.floating,), - datetime: (np.datetime64,), - Window: (np.floating,), - Embedding: (np.floating,), +_ALLOWED_COLUMN_DTYPES: Dict[str, Tuple[Type, ...]] = { + "bool": (np.bool_,), + "int": (np.integer,), + "float": (np.floating,), + "datetime": (np.datetime64,), + "Window": (np.floating,), + "Embedding": (np.floating,), } -def _check_valid_value_type(value: Any, column_type: Type[ColumnType]) -> bool: +def _check_valid_value_type(value: Any, dtype: DType) -> bool: """ Check if a value is suitable for the given column type. Instances of the given type are always suitable for its type, but extra types from `_ALLOWED_COLUMN_TYPES` are also checked. """ - allowed_types = (column_type,) + _ALLOWED_COLUMN_TYPES.get(column_type, ()) + allowed_types = _ALLOWED_COLUMN_TYPES.get(dtype.name, ()) return isinstance(value, allowed_types) -def _check_valid_value_dtype(dtype: np.dtype, column_type: Type[ColumnType]) -> bool: +def _check_valid_value_dtype(value_dtype: np.dtype, dtype: DType) -> bool: """ Check if an array with the given dtype is suitable for the given column type. Only types from `_ALLOWED_COLUMN_DTYPES` are checked. All other column types are assumed to have no dtype equivalent. """ - allowed_dtypes = _ALLOWED_COLUMN_DTYPES.get(column_type, ()) - return any(np.issubdtype(dtype, allowed_dtype) for allowed_dtype in allowed_dtypes) + allowed_dtypes = _ALLOWED_COLUMN_DTYPES.get(dtype.name, ()) + return any( + np.issubdtype(value_dtype, allowed_dtype) for allowed_dtype in allowed_dtypes + ) -def _check_valid_array( - value: Any, column_type: Type[ColumnType] -) -> TypeGuard[np.ndarray]: +def _check_valid_array(value: Any, dtype: DType) -> TypeGuard[np.ndarray]: """ Check if a value is an array and its type is suitable for the given column type. """ return isinstance(value, np.ndarray) and _check_valid_value_dtype( - value.dtype, column_type + value.dtype, dtype ) @@ -196,7 +261,7 @@ class Dataset: _length: int @staticmethod - def _user_column_attributes(column_type: Type[ColumnType]) -> Dict[str, Type]: + def _user_column_attributes(dtype: DType) -> Dict[str, Type]: attribute_names = { "order": int, "hidden": bool, @@ -205,43 +270,40 @@ def _user_column_attributes(column_type: Type[ColumnType]) -> Dict[str, Type]: "description": str, "tags": list, } - if column_type in { - bool, - int, - float, - str, - Category, - Window, - }: + if ( + is_scalar_dtype(dtype) + or is_str_dtype(dtype) + or is_category_dtype(dtype) + or is_window_dtype(dtype) + ): attribute_names["editable"] = bool - if column_type is Category: + if is_category_dtype(dtype): attribute_names["categories"] = dict - if column_type is Sequence1D: + if is_sequence_1d_dtype(dtype): attribute_names["x_label"] = str attribute_names["y_label"] = str - if issubclass(column_type, FileBasedDType): + if is_file_dtype(dtype): attribute_names["lookup"] = dict attribute_names["external"] = bool - if column_type is Audio: + if is_audio_dtype(dtype): attribute_names["lossy"] = bool return attribute_names @classmethod - def _default_default(cls, column_type: Type[ColumnType]) -> Any: - if column_type is datetime: - return np.datetime64("NaT") - if column_type in (str, Category): - return "" - if column_type is float: + def _default_default(cls, dtype: DType) -> Any: + if is_bool_dtype(dtype): + return False + if is_int_dtype(dtype): + return 0 + if is_float_dtype(dtype): return float("nan") - if column_type is Window: + if is_str_dtype(dtype): + return "" + if is_datetime_dtype(dtype): + return np.datetime64("NaT") + if is_window_dtype(dtype): return [np.nan, np.nan] - if column_type is np.ndarray or issubclass(column_type, DType): - return None - raise exceptions.InvalidAttributeError( - f"`default` argument for optional column of type " - f"{get_column_type_name(column_type)} should be set, but `None` received." - ) + return None def __init__(self, filepath: PathType, mode: str): self._filepath = os.path.abspath(filepath) @@ -349,7 +411,7 @@ def __delitem__(self, item: Union[str, IndexType, Indices1dType]) -> None: for column_name in self.keys() + INTERNAL_COLUMN_NAMES: column = self._h5_file[column_name] raw_values = column[item + 1 :] - if self._get_column_type(column) is Embedding: + if is_embedding_dtype(self._get_dtype(column)): raw_values = list(raw_values) column[item:-1] = raw_values column.resize(self._length - 1, axis=0) @@ -368,13 +430,13 @@ def __getitem__( ... @overload - def __getitem__(self, item: IndexType) -> Dict[str, Optional[ColumnType]]: + def __getitem__(self, item: IndexType) -> Dict[str, Optional[OutputType]]: ... @overload def __getitem__( self, item: Union[Tuple[str, IndexType], Tuple[IndexType, str]] - ) -> Optional[ColumnType]: + ) -> Optional[OutputType]: ... def __getitem__( @@ -385,7 +447,7 @@ def __getitem__( Tuple[str, Union[IndexType, Indices1dType]], Tuple[Union[IndexType, Indices1dType], str], ], - ) -> Union[np.ndarray, Dict[str, Optional[ColumnType]], Optional[ColumnType],]: + ) -> Union[np.ndarray, Dict[str, Optional[OutputType]], Optional[OutputType]]: """ Get a dataset column, row or value. @@ -606,21 +668,21 @@ def keys(self) -> List[str]: return list(self._column_names) @overload - def iterrows(self) -> Iterable[Dict[str, Optional[ColumnType]]]: + def iterrows(self) -> Iterable[Dict[str, Optional[OutputType]]]: ... @overload def iterrows( self, column_names: Union[str, Iterable[str]] ) -> Union[ - Iterable[Dict[str, Optional[ColumnType]]], Iterable[Optional[ColumnType]] + Iterable[Dict[str, Optional[OutputType]]], Iterable[Optional[OutputType]] ]: ... def iterrows( self, column_names: Optional[Union[str, Iterable[str]]] = None ) -> Union[ - Iterable[Dict[str, Optional[ColumnType]]], Iterable[Optional[ColumnType]] + Iterable[Dict[str, Optional[OutputType]]], Iterable[Optional[OutputType]] ]: """ Iterate through dataset rows. @@ -629,19 +691,16 @@ def iterrows( if isinstance(column_names, str): self._assert_column_exists(column_names) column = self._h5_file[column_names] - column_type = self._get_column_type(column) + dtype = self._get_dtype(column) if column.attrs.get("external", False): for value in column: - column_type = cast(Type[ExternalColumnType], column_type) - yield self._decode_external_value(value, column_type) + yield self._decode_external_value(value, dtype) elif self._is_ref_column(column): for ref in column: - column_type = cast(Type[RefColumnType], column_type) - yield self._decode_ref_value(ref, column_type, column_names) + yield self._decode_ref_value(ref, dtype, column_names) else: for value in column: - column_type = cast(Type[SimpleColumnType], column_type) - yield self._decode_simple_value(value, column, column_type) + yield self._decode_simple_value(value, dtype) else: if column_names is None: column_names = self._column_names @@ -666,7 +725,7 @@ def from_pandas( self, df: pd.DataFrame, index: bool = False, - dtype: Optional[ColumnTypeMapping] = None, + dtypes: Optional[Dict[str, Any]] = None, workdir: Optional[PathType] = None, ) -> None: """ @@ -679,7 +738,7 @@ def from_pandas( df: `pandas.DataFrame` to import. index: Whether to import index of the dataframe as regular dataset column. - dtype: Optional dict with mapping `column name -> column type` with + dtypes: Optional dict with mapping `column name -> column type` with column types allowed by Spotlight. workdir: Optional folder where audio/images/meshes are stored. If `None`, current folder is used. @@ -718,16 +777,21 @@ def from_pandas( df = df.copy() df.columns = pd.Index(stringify_columns(df)) - inferred_dtype = infer_dtypes(df, dtype) + if dtypes is None: + dtypes = {} + + inferred_dtypes = infer_dtypes( + df, {col: create_dtype(dtype) for col, dtype in dtypes.items()} + ) for column_name in df.columns: try: column = df[column_name] - column_type = inferred_dtype[column_name] + dtype = inferred_dtypes[column_name] - column = prepare_column(column, column_type) + column = prepare_column(column, dtype) - if workdir is not None and is_file_based_column_type(dtype): + if workdir is not None and is_file_dtype(dtype): # For file-based data types, relative paths should be resolved. str_mask = is_string_mask(column) column[str_mask] = column[str_mask].apply( @@ -736,30 +800,30 @@ def from_pandas( attrs = {} - if column_type is Category: + if is_category_dtype(dtype): attrs["categories"] = column.cat.categories.to_list() values = column.to_numpy() # `pandas` uses `NaN`s for unknown values, we use `None`. values = np.where(pd.isna(values), np.array(None), values) - elif column_type is datetime: + elif is_datetime_dtype(dtype): values = column.to_numpy("datetime64[us]") else: values = column.to_numpy() - if is_file_based_column_type(column_type): + if is_file_dtype(dtype): attrs["external"] = False # type: ignore attrs["lookup"] = False # type: ignore self.append_column( column_name, - column_type, + dtype, values, hidden=column_name.startswith("_"), - optional=column_type not in (bool, int), + optional=True, **attrs, # type: ignore ) except Exception as e: - if column_name in (dtype or {}): + if column_name in (dtypes or {}): raise e logger.warning( f"Column '{column_name}' not imported from " @@ -769,7 +833,7 @@ def from_pandas( def from_csv( self, filepath: PathType, - dtype: Optional[ColumnTypeMapping] = None, + dtypes: Optional[Dict[str, Any]] = None, columns: Optional[Iterable[str]] = None, workdir: Optional[PathType] = None, ) -> None: @@ -788,7 +852,7 @@ def from_csv( df: pd.DataFrame = pd.read_csv(filepath, usecols=columns or None) if workdir is None: workdir = os.path.dirname(filepath) - self.from_pandas(df, index=False, dtype=dtype, workdir=workdir) + self.from_pandas(df, index=False, dtypes=dtypes, workdir=workdir) def to_pandas(self) -> pd.DataFrame: """ @@ -819,14 +883,24 @@ def to_pandas(self) -> pd.DataFrame: self._assert_is_opened() df = pd.DataFrame() for column_name in self._column_names: - column_type = self.get_column_type(column_name) - if column_type in (bool, int, float, str, datetime): + dtype = self.get_dtype(column_name) + if ( + is_scalar_dtype(dtype) + or is_str_dtype(dtype) + or is_datetime_dtype(dtype) + ): df[column_name] = self[column_name] - elif column_type is Category: - df[column_name] = pd.Categorical.from_codes( - self._h5_file[column_name], - self._h5_file[column_name].attrs["category_keys"], # type: ignore - ) + elif is_category_dtype(dtype): + if dtype.inverted_categories is None: + values: List[Optional[str]] = [None] * len( + self._h5_file[column_name] + ) + else: + values = [ + dtype.inverted_categories.get(code) + for code in self._h5_file[column_name] + ] + df[column_name] = pd.Categorical(values) not_exported_columns = self._column_names.difference(df.columns) if len(not_exported_columns) > 0: @@ -844,7 +918,7 @@ def append_bool_column( order: Optional[int] = None, hidden: bool = False, optional: bool = False, - default: BoolColumnInputType = None, + default: BoolColumnInputType = False, description: Optional[str] = None, tags: Optional[List[str]] = None, editable: bool = True, @@ -861,8 +935,7 @@ def append_bool_column( hidden: Whether column is hidden in Spotlight. optional: Whether column is optional. default: Value to use by default if column is optional and no value - or `None` is given. If `optional` is `True`, should be - explicitly set to `True` or `False`. + or `None` is given. description: Optional column description. tags: Optional tags for the column. editable: Whether column is editable in Spotlight. @@ -878,7 +951,7 @@ def append_bool_column( """ self._append_column( name, - bool, + bool_dtype, values, np.dtype(bool), order, @@ -897,7 +970,7 @@ def append_int_column( order: Optional[int] = None, hidden: bool = False, optional: bool = False, - default: IntColumnInputType = None, + default: IntColumnInputType = -1, description: Optional[str] = None, tags: Optional[List[str]] = None, editable: bool = True, @@ -912,11 +985,9 @@ def append_int_column( order: Optional Spotlight priority order value. `None` means the lowest priority. hidden: Whether column is hidden in Spotlight. - optional: Whether column is optional. If `default` other than `None` - is specified, `optional` is automatically set to `True`. + optional: Whether column is optional. default: Value to use by default if column is optional and no value - or `None` is given. If `optional` is `True`, should be - explicitly set. + or `None` is given. description: Optional column description. tags: Optional tags for the column. editable: Whether column is editable in Spotlight. @@ -927,7 +998,7 @@ def append_int_column( """ self._append_column( name, - int, + int_dtype, values, np.dtype(int), order, @@ -975,7 +1046,7 @@ def append_float_column( """ self._append_column( name, - float, + float_dtype, values, np.dtype(float), order, @@ -1023,7 +1094,7 @@ def append_string_column( """ self._append_column( name, - str, + str_dtype, values, h5py.string_dtype(), order, @@ -1078,7 +1149,7 @@ def append_datetime_column( """ self._append_column( name, - datetime, + datetime_dtype, values, h5py.string_dtype(), order, @@ -1129,7 +1200,7 @@ def append_array_column( """ self._append_column( name, - np.ndarray, + array_dtype, values, h5py.string_dtype(), order, @@ -1179,7 +1250,7 @@ def append_categorical_column( """ self._append_column( name, - Category, + CategoryDType(categories), values, np.dtype("int32"), order, @@ -1189,7 +1260,6 @@ def append_categorical_column( description, tags, editable=editable, - categories=categories, ) def append_embedding_column( @@ -1234,7 +1304,7 @@ def append_embedding_column( ) self._append_column( name, - Embedding, + embedding_dtype, values, h5py.vlen_dtype(np_dtype), order, @@ -1282,11 +1352,13 @@ def append_sequence_1d_column( Example: Find an example usage in :class:`renumics.spotlight.dtypes'.Sequence1D`. """ + if x_label is None: + x_label = "x" if y_label is None: y_label = name self._append_column( name, - Sequence1D, + Sequence1DDType(x_label, y_label), values, h5py.string_dtype(), order, @@ -1295,8 +1367,6 @@ def append_sequence_1d_column( default, description, tags, - x_label=x_label, - y_label=y_label, ) def append_mesh_column( @@ -1349,7 +1419,7 @@ def append_mesh_column( """ self._append_column( name, - Mesh, + mesh_dtype, values, h5py.string_dtype(), order, @@ -1410,7 +1480,7 @@ def append_image_column( """ self._append_column( name, - Image, + image_dtype, values, h5py.string_dtype(), order, @@ -1474,7 +1544,7 @@ def append_audio_column( slows down the execution. Example: - Find an example usage in :class:`renumics.spotlight.dtypes'.Audio`. + Find an example usage in :class:`renumics.spotlight.media.Audio`. """ attrs = {} if lossy is None and external is False: @@ -1483,7 +1553,7 @@ def append_audio_column( attrs["lossy"] = lossy self._append_column( name, - Audio, + audio_dtype, values, h5py.string_dtype(), order, @@ -1544,7 +1614,7 @@ def append_video_column( """ self._append_column( name, - Video, + video_dtype, values, h5py.string_dtype(), order, @@ -1594,7 +1664,7 @@ def append_window_column( """ self._append_column( name, - Window, + window_dtype, values, np.dtype("float32"), order, @@ -1609,7 +1679,7 @@ def append_window_column( def append_column( self, name: str, - column_type: Type[ColumnType], + dtype: Any, values: Union[ColumnInputType, Iterable[ColumnInputType]] = None, order: Optional[int] = None, hidden: bool = False, @@ -1624,7 +1694,7 @@ def append_column( Args: name: Column name. - column_type: Column type. + dtype: Column type. values: Optional column values. If a single value, the whole column filled with this value. order: Optional Spotlight priority order value. `None` means the @@ -1657,47 +1727,48 @@ def append_column( [ True True True True True] [1. 1. 1. 1. 1.] """ + dtype = create_dtype(dtype) - if column_type is bool: + if dtype.name == "bool": append_column_fn: Callable = self.append_bool_column - elif column_type is int: + elif dtype.name == "int": append_column_fn = self.append_int_column - elif column_type is float: + elif dtype.name == "float": append_column_fn = self.append_float_column - elif column_type is str: + elif dtype.name == "str": append_column_fn = self.append_string_column - elif column_type is datetime: + elif dtype.name == "datetime": append_column_fn = self.append_datetime_column - elif column_type is np.ndarray: + elif dtype.name == "array": append_column_fn = self.append_array_column - elif column_type is Embedding: + elif dtype.name == "Embedding": append_column_fn = self.append_embedding_column - elif column_type is Image: + elif dtype.name == "Image": append_column_fn = self.append_image_column - elif column_type is Mesh: + elif dtype.name == "Mesh": append_column_fn = self.append_mesh_column - elif column_type is Sequence1D: + elif dtype.name == "Sequence1D": append_column_fn = self.append_sequence_1d_column - elif column_type is Audio: + elif dtype.name == "Audio": append_column_fn = self.append_audio_column - elif column_type is Category: + elif dtype.name == "Category": append_column_fn = self.append_categorical_column - elif column_type is Video: + elif dtype.name == "Video": append_column_fn = self.append_video_column - elif column_type is Window: + elif dtype.name == "Window": append_column_fn = self.append_window_column else: - raise exceptions.InvalidDTypeError(f"Unknown column type: {column_type}.") + raise exceptions.InvalidDTypeError(f"Unknown column type: {dtype}.") append_column_fn( name=name, - values=values, + values=values, # type: ignore order=order, hidden=hidden, optional=optional, - default=default, + default=default, # type: ignore description=description, tags=tags, - **attrs, + **attrs, # type: ignore ) def append_row(self, **values: ColumnInputType) -> None: @@ -1791,7 +1862,7 @@ def insert_row(self, index: IndexType, values: Dict[str, ColumnInputType]) -> No column = self._h5_file[column_name] column.resize(length + 1, axis=0) raw_values = column[index:-1] - if self._get_column_type(column) is Embedding: + if self._get_dtype(column).name == "Embedding": raw_values = list(raw_values) column[index + 1 :] = raw_values self._length += 1 @@ -1807,12 +1878,12 @@ def pop(self, item: str) -> np.ndarray: ... @overload - def pop(self, item: IndexType) -> Dict[str, Optional[ColumnType]]: + def pop(self, item: IndexType) -> Dict[str, Optional[OutputType]]: ... def pop( self, item: Union[str, IndexType] - ) -> Union[np.ndarray, Dict[str, Optional[ColumnType]]]: + ) -> Union[np.ndarray, Dict[str, Optional[OutputType]]]: """ Delete a dataset column or row and return it. """ @@ -1834,18 +1905,18 @@ def isnull(self, column_name: str) -> np.ndarray: column = self._h5_file[column_name] raw_values = column[()] - column_type = self._get_column_type(column) + dtype = self._get_dtype(column) if self._is_ref_column(column): return ~raw_values.astype(bool) - if column_type is datetime: + if dtype.name == "datetime": return np.array([raw_value in ["", b""] for raw_value in raw_values]) - if column_type is float: + if dtype.name == "float": return np.isnan(raw_values) - if column_type is Category: + if dtype.name == "Category": return raw_values == -1 - if column_type is Window: + if dtype.name == "Window": return np.isnan(raw_values).all(axis=1) - if column_type is Embedding: + if dtype.name == "Embedding": return np.array([len(x) == 0 for x in raw_values]) return np.full(len(self), False) @@ -1934,32 +2005,19 @@ def prune(self) -> None: else: refs.append(None) raw_values = refs - if self._get_column_type(column) is Embedding: + if self._get_dtype(column).name == "Embedding": raw_values = list(raw_values) new_column[:] = raw_values self.close() shutil.move(new_dataset, os.path.realpath(self._filepath)) self.open() - @overload - def get_column_type( - self, name: str, as_string: Literal[False] = False - ) -> Type[ColumnType]: - ... - - @overload - def get_column_type(self, name: str, as_string: Literal[True]) -> str: - ... - - def get_column_type( - self, name: str, as_string: bool = False - ) -> Union[Type[ColumnType], str]: + def get_dtype(self, column_name: str) -> DType: """ Get type of dataset column. Args: - name: Column name. - as_string: Get internal name of the column type. + column_name: Column name. Example: >>> from renumics.spotlight import Dataset @@ -1970,46 +2028,22 @@ def get_column_type( ... dataset.append_mesh_column("mesh") >>> with Dataset("docs/example.h5", "r") as dataset: ... for column_name in sorted(dataset.keys()): - ... print(column_name, dataset.get_column_type(column_name)) - array - bool - datetime - mesh - >>> with Dataset("docs/example.h5", "r") as dataset: - ... for column_name in sorted(dataset.keys()): - ... print(column_name, dataset.get_column_type(column_name, True)) + ... print(column_name, dataset.get_dtype(column_name)) array array bool bool datetime datetime mesh Mesh """ self._assert_is_opened() - if not isinstance(name, str): + if not isinstance(column_name, str): raise TypeError( - f"`item` argument should be a string, but value {name} of type " - f"`{type(name)}` received.`" + f"`item` argument should be a string, but value {column_name} of type " + f"`{type(column_name)}` received.`" ) - self._assert_column_exists(name, internal=True) - type_name = self._h5_file[name].attrs["type"] - if as_string: - return type_name - return get_column_type(type_name) - - def get_column_attributes( - self, name: str - ) -> Dict[ - str, - Optional[ - Union[ - bool, - int, - str, - ColumnType, - Dict[str, int], - Dict[str, FileBasedColumnType], - ] - ], - ]: + self._assert_column_exists(column_name, internal=True) + return self._get_dtype(self._h5_file[column_name]) + + def get_column_attributes(self, name: str) -> Dict[str, Any]: """ Get attributes of a column. Available but unset attributes contain None. @@ -2032,11 +2066,11 @@ def get_column_attributes( ... attributes = dataset.get_column_attributes("int") ... for key in sorted(attributes.keys()): ... print(key, attributes[key]) - default None + default -1 description None editable True hidden False - optional False + optional True order None tags None >>> with Dataset("docs/example.h5", "r") as dataset: @@ -2061,27 +2095,17 @@ def get_column_attributes( column = self._h5_file[name] column_attrs = column.attrs - column_type = self._get_column_type(column_attrs) - allowed_attributes = self._user_column_attributes(column_type) + dtype = self._get_dtype(column_attrs) + allowed_attributes = self._user_column_attributes(dtype) - attrs: Dict[ - str, - Optional[ - Union[ - bool, - int, - str, - ColumnType, - Dict[str, int], - Dict[str, FileBasedColumnType], - ] - ], - ] = {attribute_name: None for attribute_name in allowed_attributes} + attrs: Dict[str, Any] = { + attribute_name: None for attribute_name in allowed_attributes + } attrs.update( { attribute_name: attribute_type(column_attrs[attribute_name]) - if not attribute_type == object + if attribute_type is not object else column_attrs[attribute_name] for attribute_name, attribute_type in allowed_attributes.items() if attribute_name in column_attrs @@ -2118,14 +2142,14 @@ def _assert_valid_attribute( self, attribute_name: str, attribute_value: ColumnInputType, column_name: str ) -> None: column = self._h5_file.get(column_name) - column_type = self._get_column_type(column) + dtype = self._get_dtype(column) - allowed_attributes = self._user_column_attributes(column_type) + allowed_attributes = self._user_column_attributes(dtype) if attribute_name not in allowed_attributes: raise exceptions.InvalidAttributeError( f'Setting an attribute with the name "{attribute_name}" for column ' f'"{column_name}" is not allowed. ' - f'Allowed attribute names for "{column_type}" ' + f'Allowed attribute names for "{dtype}" ' f'are: "{list(allowed_attributes.keys())}"' ) if not isinstance(attribute_value, allowed_attributes[attribute_name]): @@ -2141,7 +2165,7 @@ def _assert_valid_attribute( ): raise exceptions.InvalidAttributeError( f'Invalid `optional` argument for column "{column_name}" of ' - f"type {column_type}. Columns can not be changed from " + f"type {dtype}. Columns can not be changed from " f"`optional=False` to `optional=True`." ) if attribute_name == "tags" and not all( @@ -2149,7 +2173,7 @@ def _assert_valid_attribute( ): raise exceptions.InvalidAttributeError( f'Invalid `tags` argument for column "{column_name}" of type ' - f"{column_type}. Tags should be a `list of str`." + f"{dtype}. Tags should be a `list of str`." ) @staticmethod @@ -2197,15 +2221,10 @@ def set_column_attributes( or `None` is given. description: Optional column description. tags: Optional tags for the column. - attrs: Optional more ColumnType specific attributes . + attrs: Optional more DType specific attributes . """ self._assert_is_writable() - if not isinstance(name, str): - raise TypeError( - f"`name` argument should be a string, but value {name} of type " - f"`{type(name)}` received.`" - ) - self._assert_column_exists(name) + self._assert_column_exists(name, check_type=True) if default is not None: optional = True @@ -2219,7 +2238,7 @@ def set_column_attributes( attrs = {k: v for k, v in attrs.items() if v is not None} column = self._h5_file[name] - column_type = self._get_column_type(column) + dtype = self._get_dtype(column) if "lookup" in attrs: lookup = attrs["lookup"] @@ -2249,11 +2268,6 @@ def set_column_attributes( f'Attribute "categories" for column "{name}" contains ' "invalid dict - keys must be of type str." ) - if any(v == "" for v in attrs["categories"].keys()): - raise exceptions.InvalidAttributeError( - f'Attribute "categories" for column "{name}" contains ' - 'invalid dict - "" (empty string) is no allowed as category key.' - ) if len(attrs["categories"].values()) > len( set(attrs["categories"].values()) ): @@ -2261,6 +2275,10 @@ def set_column_attributes( f'Attribute "categories" for column "{name}" contains ' "invalid dict - keys and values must be unique" ) + if -1 in attrs["categories"].values(): + raise exceptions.InvalidAttributeError( + f'Invalid categories received for column "{name}". Code `-1` is reserved.' + ) if column.attrs.get("category_keys") is not None: values_must_include = column[:] if "default" in column.attrs: @@ -2324,28 +2342,27 @@ def set_column_attributes( # Set new default value. try: if default is None and old_default is None: - default = self._default_default(column_type) - if ( - default is None - and column_type is Embedding - and not self._is_ref_column(column) - ): - # For a non-ref `Embedding` column, replace `None` with an empty array. - default = np.empty(0, column.dtype.metadata["vlen"]) - if column_type is Category and default != "": - if default not in column.attrs["category_keys"]: + default = self._default_default(dtype) + if default is None: + if is_embedding_dtype(dtype) and not self._is_ref_column( + column + ): + # For a non-ref `Embedding` column, replace `None` with an empty array. + default = np.empty(0, column.dtype.metadata["vlen"]) + if is_category_dtype(dtype) and default is not None: + categories: List[str] = column.attrs["category_keys"].tolist() + if default not in categories: column.attrs["category_values"] = np.append( - column.attrs["category_values"], - max(column.attrs["category_values"] + 1), + column.attrs["category_values"], -1 ).astype(dtype=np.int32) - column.attrs["category_keys"] = np.append( - column.attrs["category_keys"], np.array(default) - ) + column.attrs["category_keys"] = categories + [default] if default is not None: encoded_value = self._encode_value(default, column) - if column_type is datetime and encoded_value is None: + if dtype.name == "datetime" and encoded_value is None: encoded_value = "" column.attrs["default"] = encoded_value + elif is_category_dtype(dtype): + column.attrs["default"] = -1 except Exception as e: # Rollback @@ -2389,10 +2406,10 @@ def _format(value: _EncodedColumnType, ref_type_name: str) -> str: for column in columns: attrs = column.attrs type_name = attrs["type"] - column_type = get_column_type(type_name) - optional_keys = set( - self._user_column_attributes(column_type).keys() - ).difference(required_keys) + dtype = create_dtype(type_name) + optional_keys = set(self._user_column_attributes(dtype).keys()).difference( + required_keys + ) column_reprs.append( [ key + ": " + _format(attrs.get(key), type_name) @@ -2413,9 +2430,9 @@ def _format(value: _EncodedColumnType, ref_type_name: str) -> str: def _append_column( self, name: str, - column_type: Type[ColumnType], + dtype: DType, values: Union[ColumnInputType, Iterable[ColumnInputType]], - dtype: np.dtype, + np_dtype: np.dtype, order: Optional[int] = None, hidden: bool = True, optional: bool = False, @@ -2431,34 +2448,33 @@ def _append_column( # `set_column_attributes` method. shape: Tuple[int, ...] = (0,) maxshape: Tuple[Optional[int], ...] = (None,) - if column_type is Category: - categories = attrs.get("categories", None) - if categories is None: - # Values are given, but no categories. + if is_category_dtype(dtype): + if dtype.categories is None: if is_iterable(values): - values = list(values) - categories = set(values) + categories: List[str] = sorted(set(values)) + elif values is None: + categories = [] else: - categories = {values} - categories.difference_update({"", None}) - if is_iterable(categories) and not isinstance(categories, dict): - # dict is forced to preserve the order. - categories = list(dict.fromkeys(categories, None).keys()) - attrs["categories"] = dict(zip(categories, range(len(categories)))) - # Otherwise, exception about type will be raised later in the - # `set_column_attributes` method. - elif column_type is Window: + categories = cast(List[str], [values]) + dtype = CategoryDType(categories) + attrs["categories"] = dtype.categories + elif is_window_dtype(dtype): shape = (0, 2) maxshape = (None, 2) - elif issubclass(column_type, FileBasedDType): + elif is_sequence_1d_dtype(dtype): + attrs["x_label"] = dtype.x_label + attrs["y_label"] = dtype.y_label + elif is_file_dtype(dtype): lookup = attrs.get("lookup", None) if is_iterable(lookup) and not isinstance(lookup, dict): # Assume that we can keep all the lookup values in memory. attrs["lookup"] = {str(i): v for i, v in enumerate(lookup)} try: - column = self._h5_file.create_dataset(name, shape, dtype, maxshape=maxshape) + column = self._h5_file.create_dataset( + name, shape, np_dtype, maxshape=maxshape + ) self._column_names.add(name) - column.attrs["type"] = get_column_type_name(column_type) + column.attrs["type"] = dtype.name self.set_column_attributes( name, order, @@ -2558,7 +2574,7 @@ def _set_column( else: # Reorder values according to the given indices. encoded_values = encoded_values[values_indices] - if self._get_column_type(column) is Embedding: + if self._get_dtype(column).name == "Embedding": encoded_values = list(encoded_values) elif values is not None: # A single value is given. `Window` and `Embedding` values should @@ -2661,70 +2677,63 @@ def _get_column( return self._decode_values(values, column) def _decode_values(self, values: np.ndarray, column: h5py.Dataset) -> np.ndarray: - column_type = self._get_column_type(column) + dtype = self._get_dtype(column) if column.attrs.get("external", False): - column_type = cast(Type[ExternalColumnType], column_type) - return self._decode_external_values(values, column_type) + return self._decode_external_values(values, dtype) if self._is_ref_column(column): - column_type = cast(Type[RefColumnType], column_type) - return self._decode_ref_values(values, column, column_type) - column_type = cast(Type[SimpleColumnType], column_type) - return self._decode_simple_values(values, column, column_type) + return self._decode_ref_values(values, column, dtype) + return self._decode_simple_values(values, column, dtype) @staticmethod def _decode_simple_values( - values: np.ndarray, column: h5py.Dataset, column_type: Type[SimpleColumnType] + values: np.ndarray, column: h5py.Dataset, dtype: DType ) -> np.ndarray: - if column_type is Category: - mapping = dict( - zip(column.attrs["category_values"], column.attrs["category_keys"]) - ) - mapping[-1] = "" - return np.array([mapping[x] for x in values], dtype=str) + if is_category_dtype(dtype): + if dtype.inverted_categories is None: + return np.full(len(values), None) + return np.array([dtype.inverted_categories.get(x) for x in values]) if h5py.check_string_dtype(column.dtype): - # `column_type` is `str` or `datetime`. + # `dtype` is `str` or `datetime`. values = np.array([x.decode("utf-8") for x in values]) - if column_type is str: + if is_str_dtype(dtype): return values # Decode datetimes. return np.array( [None if x == "" else datetime.fromisoformat(x) for x in values], dtype=object, ) - if column_type is Embedding: + if is_embedding_dtype(dtype): null_mask = [len(x) == 0 for x in values] values[null_mask] = None # For column types `bool`, `int`, `float` or `Window`, return the array as-is. return values def _decode_ref_values( - self, values: np.ndarray, column: h5py.Dataset, column_type: Type[RefColumnType] + self, values: np.ndarray, column: h5py.Dataset, dtype: DType ) -> np.ndarray: column_name = self._get_column_name(column) - if column_type in (np.ndarray, Embedding): + if dtype.name in ("array", "Embedding"): # `np.array([<...>], dtype=object)` creation does not work for # some cases and erases dtypes of sub-arrays, so we use assignment. decoded_values = np.empty(len(values), dtype=object) decoded_values[:] = [ - self._decode_ref_value(ref, column_type, column_name) for ref in values + self._decode_ref_value(ref, dtype, column_name) for ref in values ] return decoded_values return np.array( - [self._decode_ref_value(ref, column_type, column_name) for ref in values], + [self._decode_ref_value(ref, dtype, column_name) for ref in values], dtype=object, ) - def _decode_external_values( - self, values: np.ndarray, column_type: Type[ExternalColumnType] - ) -> np.ndarray: + def _decode_external_values(self, values: np.ndarray, dtype: DType) -> np.ndarray: return np.array( - [self._decode_external_value(value, column_type) for value in values], + [self._decode_external_value(value, dtype) for value in values], dtype=object, ) def _get_value( self, column: h5py.Dataset, index: IndexType, check_index: bool = False - ) -> Optional[ColumnType]: + ) -> Optional[OutputType]: if check_index: self._assert_index_exists(index) value = column[index] @@ -2750,7 +2759,7 @@ def _get_column_names_and_length(self) -> Tuple[Set[str], int]: h5_dataset = self._h5_file[name] if isinstance(h5_dataset, h5py.Dataset): try: - self._get_column_type(h5_dataset) + self._get_dtype(h5_dataset) except (KeyError, exceptions.InvalidDTypeError): continue else: @@ -2792,40 +2801,49 @@ def _encode_values( def _encode_simple_values( self, values: Iterable[SimpleColumnInputType], column: h5py.Dataset ) -> np.ndarray: - column_type = cast(Type[SimpleColumnType], self._get_column_type(column)) - if column_type is Category: - mapping = dict( - zip(column.attrs["category_keys"], column.attrs["category_values"]) - ) + dtype = self._get_dtype(column) + if is_category_dtype(dtype): + categories = cast(Dict[Optional[str], int], (dtype.categories or {}).copy()) if column.attrs.get("optional", False): default = column.attrs.get("default", -1) - mapping[None] = default - if default == -1: - mapping[""] = -1 + categories[None] = default try: # Map values and save as the right int type. - return np.array([mapping[x] for x in values], dtype=column.dtype) + return np.array( + [ + categories[value] + for value in cast(Iterable[Optional[str]], values) + ], + dtype=column.dtype, + ) except KeyError as e: column_name = self._get_column_name(column) + if dtype.categories: + categories_str = ", ".join(dtype.categories.keys()) + else: + categories_str = "" raise exceptions.InvalidValueError( - f'Values for the categorical column "{column_name}" ' - f"contain unknown categories." + f"Unknown value(s) received for categorical column " + f"'{column_name}'. Valid values for this column are: " + f"{categories_str}." ) from e - if column_type is datetime: - if _check_valid_array(values, column_type): + if is_datetime_dtype(dtype): + if _check_valid_array(values, dtype): encoded_values = np.array( [None if x is None else x.isoformat() for x in values.tolist()] ) else: encoded_values = np.array( - [self._encode_value(value, column) for value in values] + [ + self._encode_value(value, column) for value in values + ] # TODO: check for simple ) if np.issubdtype(encoded_values.dtype, str): # That means, we have all strings in array, no `None`s. return encoded_values return self._replace_none(encoded_values, column) - if column_type is Window: - encoded_values = self._asarray(values, column, column_type) + if is_window_dtype(dtype): + encoded_values = self._asarray(values, column, dtype) if encoded_values.ndim == 1: if len(encoded_values) == 2: # A single window, reshape it to an array. @@ -2842,8 +2860,8 @@ def _encode_simple_values( f"one of shapes (2,) (a single window) or (n, 2) (multiple " f"windows), but values with shape {encoded_values.shape} received." ) - if column_type is Embedding: - if _check_valid_array(values, column_type): + if is_embedding_dtype(dtype): + if _check_valid_array(values, dtype): # This is the only case we can handle fast and easily, otherwise # embedding should go through `_encode_value` element-wise. if values.ndim == 1: @@ -2866,12 +2884,12 @@ def _encode_simple_values( encoded_values = self._replace_none(encoded_values, column) return encoded_values # column type is `bool`, `int`, `float` or `str`. - encoded_values = self._asarray(values, column, column_type) + encoded_values = self._asarray(values, column, dtype) if encoded_values.ndim == 1: return encoded_values column_name = self._get_column_name(column) raise exceptions.InvalidShapeError( - f'Input values to `{column_type}` column "{column_name}" should ' + f'Input values to `{dtype}` column "{column_name}" should ' f"be 1-dimensional, but values with shape {encoded_values.shape} " f"received." ) @@ -2891,10 +2909,10 @@ def _asarray( self, values: Iterable[SimpleColumnInputType], column: h5py.Dataset, - column_type: Type[SimpleColumnType], + dtype: DType, ) -> np.ndarray: if isinstance(values, np.ndarray): - if _check_valid_value_dtype(values.dtype, column_type): + if _check_valid_value_dtype(values.dtype, dtype): return values elif not isinstance(values, (list, tuple, range)): # Make iterables, dicts etc. convertible to an array. @@ -2909,7 +2927,7 @@ def _asarray( except TypeError as e: column_name = self._get_column_name(column) raise exceptions.InvalidValueError( - f'Values for the column "{column_name}" of type {column_type} ' + f'Values for the column "{column_name}" of type {dtype} ' f"are not convertible to the dtype {column.dtype}." ) from e @@ -3007,42 +3025,62 @@ def _encode_value( if attrs.get("external", False): value = cast(PathOrUrlType, value) return self._encode_external_value(value, column) - column_type = self._get_column_type(attrs) + dtype = self._get_dtype(attrs) if self._is_ref_column(column): value = cast(RefColumnInputType, value) - return self._encode_ref_value(value, column, column_type, column_name) + self._assert_valid_value_type(value, dtype, column_name) + if is_file_dtype(dtype) and isinstance(value, str): + try: + return self._find_lookup_ref(value, column) + except KeyError: + pass # Don't need to search/update, so encode value as usual. + encoded_value = self._encode_ref_value(value, column, dtype) + ref = self._write_ref_value(encoded_value, column, column_name) + return ref value = cast(SimpleColumnInputType, value) - return self._encode_simple_value(value, column, column_type, column_name) + return self._encode_simple_value(value, column, dtype, column_name) + + @staticmethod + def _find_lookup_ref(key: str, column: h5py.Dataset) -> str: + lookup_keys = column.attrs["lookup_keys"].tolist() + try: + index = lookup_keys.index(key) + except ValueError as e: + raise KeyError from e + else: + # Return stored ref, do not process data again. + return column.attrs["lookup_values"][index] def _encode_simple_value( self, value: SimpleColumnInputType, column: h5py.Dataset, - column_type: Type[ColumnType], + dtype: DType, column_name: str, ) -> _EncodedColumnType: """ Encode a non-ref value, e.g. bool, int, float, str, datetime, Category, - Window and Embedding (in last versions). + Window and Embedding. Value *cannot* be `None` already. """ + self._assert_valid_value_type(value, dtype, column_name) attrs = column.attrs - if column_type is Category: - categories = dict( - zip(attrs.get("category_keys"), attrs.get("category_values")) + if is_category_dtype(dtype): + value = cast(str, value) + if dtype.categories: + try: + return dtype.categories[value] + except KeyError: + ... + categories_str = ", ".join(dtype.categories.keys()) + categories_str = "" + raise exceptions.InvalidValueError( + f"Unknown value '{value}' of type {type(value)} received for " + f"categorical column '{column_name}'. Valid values for this " + f"column are: {categories_str}." ) - if attrs.get("optional", False) and attrs.get("default", -1) == -1: - categories[""] = -1 - if value not in categories.keys(): - raise exceptions.InvalidValueError( - f"Values for {column_type} column " - f'"{column.name.lstrip("/")}" should be one of ' - f"{list(categories.keys())} " - f"but value '{value}' received." - ) - return categories[value] - if column_type is Window: + if is_window_dtype(dtype): value = np.asarray(value, dtype=column.dtype) if value.shape == (2,): return value @@ -3050,14 +3088,13 @@ def _encode_simple_value( f"Windows should consist of 2 values, but window of shape " f"{value.shape} received for column {column_name}." ) - if column_type is Embedding: + if is_embedding_dtype(dtype): # `Embedding` column is not a ref column. if isinstance(value, Embedding): value = value.encode(attrs.get("format", None)) value = np.asarray(value, dtype=column.dtype.metadata["vlen"]) self._assert_valid_or_set_embedding_shape(value.shape, column) return value - self._assert_valid_value_type(value, column_type, column_name) if isinstance(value, np.str_): return value.tolist() if isinstance(value, np.datetime64): @@ -3066,64 +3103,16 @@ def _encode_simple_value( return value.isoformat() return value - def _encode_ref_value( + def _write_ref_value( self, - value: RefColumnInputType, + value: Optional[Union[np.ndarray, np.void]], column: h5py.Dataset, - column_type: Type[ColumnType], column_name: str, - ) -> _EncodedColumnType: - """ - Encode a ref value, e.g. np.ndarray, Sequence1D, Image, Mesh, Audio, - Video, and Embedding (in old versions). - - Value *cannot* be `None` already. - """ - - attrs = column.attrs - key: Optional[str] = None - lookup_keys: List[str] = [] - if column_type is Mesh and isinstance(value, trimesh.Trimesh): - value = Mesh.from_trimesh(value) - elif issubclass(column_type, (Audio, Image, Video)) and isinstance( - value, bytes - ): - value = column_type.from_bytes(value) - elif is_file_based_column_type(column_type) and isinstance( - value, (str, os.PathLike) - ): - try: - lookup_keys = attrs["lookup_keys"].tolist() - except KeyError: - pass # Don't need to search/update, so encode value as usual. - else: - key = str(value) - try: - index = lookup_keys.index(key) - except ValueError: - pass # Index not found, so encode value as usual. - else: - # Return stored ref, do not process data again. - return attrs["lookup_values"][index] - try: - value = column_type.from_file(value) - except Exception: - return None - if issubclass(column_type, (Embedding, Image, Sequence1D)): - if not isinstance(value, column_type): - value = column_type(value) # type: ignore - value = value.encode(attrs.get("format", None)) # type: ignore - elif issubclass(column_type, (Mesh, Audio, Video)): - self._assert_valid_value_type(value, column_type, column_name) - value = value.encode(attrs.get("format", None)) # type: ignore - else: - value = np.asarray(value) - # `value` can be a `np.ndarray` or a `np.void`. - if isinstance(value, np.ndarray): - # Check dtype. - self._assert_valid_or_set_value_dtype(value.dtype, column) - if column_type is Embedding: - self._assert_valid_or_set_embedding_shape(value.shape, column) + key: Optional[str] = None, + ) -> Optional[Union[str, h5py.Reference]]: + if value is None: + return None + # Write value into H5 and return its reference. dataset_name = str(uuid.uuid4()) if key is None else escape_dataset_name(key) h5_dataset = self._h5_file.create_dataset( f"__group__/{column_name}/{dataset_name}", data=value @@ -3132,20 +3121,77 @@ def _encode_ref_value( ref = h5_dataset.ref # Legacy handling. else: ref = dataset_name - if key is not None: - # `lookup_keys` is not `None`, so `lookup_values` too. - self._write_lookup( - attrs, - lookup_keys + [key], - np.concatenate( - (attrs["lookup_values"], [ref]), - dtype=column.dtype, - ), - column_name, - ) - h5_dataset.attrs["key"] = key return ref + def _encode_ref_value( + self, value: RefColumnInputType, column: h5py.Dataset, dtype: DType + ) -> Optional[Union[np.ndarray, np.void]]: + """ + Encode a ref value, e.g. np.ndarray, Sequence1D, Image, Mesh, Audio, + Video, and Embedding (in old versions). + + Value *cannot* be `None` already. + """ + if is_array_dtype(dtype): + value = np.asarray(value) + self._assert_valid_or_set_value_dtype(value.dtype, column) + return value + if is_embedding_dtype(dtype): + if not isinstance(value, Embedding): + value = Embedding(value) # type: ignore + value = value.encode() + self._assert_valid_or_set_value_dtype(value.dtype, column) + self._assert_valid_or_set_embedding_shape(value.shape, column) + return value + if is_sequence_1d_dtype(dtype): + if not isinstance(value, Sequence1D): + value = Sequence1D(value) # type: ignore + value = value.encode() + self._assert_valid_or_set_value_dtype(value.dtype, column) + return value + if is_audio_dtype(dtype): + if isinstance(value, (str, os.PathLike)): + try: + value = Audio.from_file(value) + except Exception: + return None + if isinstance(value, bytes): + value = Audio.from_bytes(value) + assert isinstance(value, Audio) + return value.encode(column.attrs.get("format", None)) + if is_image_dtype(dtype): + if isinstance(value, (str, os.PathLike)): + try: + value = Image.from_file(value) + except Exception: + return None + if isinstance(value, bytes): + value = Image.from_bytes(value) + if not isinstance(value, Image): + value = Image(value) # type: ignore + return value.encode() + if is_mesh_dtype(dtype): + if isinstance(value, (str, os.PathLike)): + try: + value = Mesh.from_file(value) + except Exception: + return None + if isinstance(value, trimesh.Trimesh): + value = Mesh.from_trimesh(value) + assert isinstance(value, Mesh) + return value.encode() + if is_video_dtype(dtype): + if isinstance(value, (str, os.PathLike)): + try: + value = Video.from_file(value) + except Exception: + return None + if isinstance(value, bytes): + value = Video.from_bytes(value) + assert isinstance(value, Video) + return value.encode(column.attrs.get("format", None)) + assert False + def _encode_external_value(self, value: PathOrUrlType, column: h5py.Dataset) -> str: """ Encode an external value, i.e. an URL or a path. @@ -3198,12 +3244,12 @@ def _encode_external_value(self, value: PathOrUrlType, column: h5py.Dataset) -> @staticmethod def _assert_valid_value_type( - value: ColumnInputType, column_type: Type[ColumnType], column_name: str + value: ColumnInputType, dtype: DType, column_name: str ) -> None: - if not _check_valid_value_type(value, column_type): - allowed_types = (column_type,) + _ALLOWED_COLUMN_TYPES.get(column_type, ()) + if not _check_valid_value_type(value, dtype): + allowed_types = _ALLOWED_COLUMN_TYPES.get(dtype.name, ()) raise exceptions.InvalidDTypeError( - f'Values for non-optional {column_type} column "{column_name}" ' + f'Values for non-optional {dtype} column "{column_name}" ' f"should be one of {allowed_types} instances, but value " f"{value} of type `{type(value)}` received." ) @@ -3214,108 +3260,94 @@ def _decode_value( np.bool_, np.integer, np.floating, bytes, str, np.ndarray, h5py.Reference ], column: h5py.Dataset, - ) -> Optional[ColumnType]: - column_type = self._get_column_type(column) + ) -> Optional[OutputType]: + dtype = self._get_dtype(column) if column.attrs.get("external", False): value = cast(bytes, value) - column_type = cast(Type[ExternalColumnType], column_type) - return self._decode_external_value(value, column_type) + return self._decode_external_value(value, dtype) if self._is_ref_column(column): value = cast(Union[bytes, h5py.Reference], value) - column_type = cast(Type[RefColumnType], column_type) column_name = self._get_column_name(column) - return self._decode_ref_value(value, column_type, column_name) + return self._decode_ref_value(value, dtype, column_name) value = cast(Union[np.bool_, np.integer, np.floating, bytes, np.ndarray], value) - column_type = cast(Type[SimpleColumnType], column_type) - return self._decode_simple_value(value, column, column_type) + return self._decode_simple_value(value, dtype) @staticmethod def _decode_simple_value( value: Union[np.bool_, np.integer, np.floating, bytes, str, np.ndarray], - column: h5py.Dataset, - column_type: Type[SimpleColumnType], + dtype: DType, ) -> Optional[Union[bool, int, float, str, datetime, np.ndarray]]: - if column_type is Window: - value = cast(np.ndarray, value) - return value - if column_type is Embedding: + if is_window_dtype(dtype): + return value # type: ignore + if is_embedding_dtype(dtype): value = cast(np.ndarray, value) if len(value) == 0: return None return value - if column_type is Category: - mapping = dict( - zip(column.attrs["category_values"], column.attrs["category_keys"]) - ) - if column.attrs.get("optional", False) and column.attrs.get( - "default", None - ) in (-1, None): - mapping[-1] = "" - return mapping[value] + if is_category_dtype(dtype): + if dtype.inverted_categories is None: + return None + return dtype.inverted_categories.get(cast(int, value), None) if isinstance(value, bytes): value = value.decode("utf-8") - if column_type is datetime: + if is_datetime_dtype(dtype): value = cast(str, value) if value == "": return None return datetime.fromisoformat(value) - value = cast(Union[np.bool_, np.integer, np.floating, str], value) - column_type = cast(Type[Union[bool, int, float, str]], column_type) - return column_type(value) + return VALUE_TYPE_BY_DTYPE_NAME[dtype.name](value) # type: ignore def _decode_ref_value( self, ref: Union[bytes, str, h5py.Reference], - column_type: Type[RefColumnType], + dtype: DType, column_name: str, ) -> Optional[Union[np.ndarray, Audio, Image, Mesh, Sequence1D, Video]]: # Value can be a H5 reference or a string reference. if not ref: return None value = self._resolve_ref(ref, column_name)[()] - value = cast(Union[np.ndarray, np.void], value) - if column_type in (np.ndarray, Embedding): - return value - column_type = cast( - Type[Union[Audio, Image, Mesh, Sequence1D, Video]], column_type - ) - return column_type.decode(value) + if is_array_dtype(dtype) or is_embedding_dtype(dtype): + return np.asarray(value) + return VALUE_TYPE_BY_DTYPE_NAME[dtype.name].decode(value) # type: ignore def _decode_external_value( self, value: Union[str, bytes], - column_type: Type[ExternalColumnType], - ) -> Optional[ExternalColumnType]: + dtype: DType, + ) -> Optional[ExternalOutputType]: if not value: return None if isinstance(value, bytes): value = value.decode("utf-8") file = prepare_path_or_url(value, os.path.dirname(self._filepath)) try: - return column_type.from_file(file) + return VALUE_TYPE_BY_DTYPE_NAME[dtype.name].from_file(file) # type: ignore except Exception: # No matter what happens, we should not crash, but warn instead. logger.warning( f"File or URL {value} either does not exist or could not be " - f"loaded by the class `spotlight.{column_type.__name__}`." + f"loaded." f"Instead of script failure the value will be replaced with " f"`None`." ) - return None + return None def _append_internal_columns(self) -> None: """ Append internal columns to the first created or imported dataset. """ internal_column_values = [self._get_username(), get_current_datetime()] - for column_name, value in zip(INTERNAL_COLUMN_NAMES, internal_column_values): + for column_name, value, dtype in zip( + INTERNAL_COLUMN_NAMES, internal_column_values, INTERNAL_COLUMN_DTYPES + ): try: column = self._h5_file[column_name] except KeyError: # Internal column does not exist, create. value = cast(Union[str, datetime], value) self.append_column( - column_name, type(value), value if self._length > 0 else None + column_name, dtype, value if self._length > 0 else None ) else: # Internal column exists, check type. @@ -3327,12 +3359,11 @@ def _append_internal_columns(self) -> None: f"has no type stored in attributes. Remove or rename " f"the respective h5 dataset." ) from e - column_type = get_column_type(type_name) - if column_type is not type(value): + if create_dtype(type_name).name != dtype.name: raise exceptions.InconsistentDatasetError( f'Internal column "{column_name}" already exists, ' - f"but has invalid type `{column_type}` " - f"(`{type(value)}` expected). Remove or rename " + f"but has invalid type `{type_name}` " + f"(`{dtype}` expected). Remove or rename " f"the respective h5 dataset." ) @@ -3397,23 +3428,21 @@ def _resolve_ref( def _get_username() -> str: return "" - @staticmethod - def _get_column_type( - x: Union[str, h5py.Dataset, h5py.AttributeManager] - ) -> Type[ColumnType]: + def _get_dtype(self, x: Union[h5py.Dataset, h5py.AttributeManager]) -> DType: """ Get column type by its name, or extract it from `h5py` entities. """ - if isinstance(x, str): - return get_column_type(x) if isinstance(x, h5py.Dataset): - return get_column_type(x.attrs["type"]) - if isinstance(x, h5py.AttributeManager): - return get_column_type(x["type"]) - raise TypeError( - f"Argument is expected to ba an instance of type `str`, `h5py.Dataset` " - f"or `h5py.AttributeManager`, but `x` of type {type(x)} received." - ) + return self._get_dtype(x.attrs) + + type_name = x["type"] + if type_name == "Category": + return CategoryDType( + dict(zip(x.get("category_keys", []), x.get("category_values", []))) + ) + if type_name == "Sequence1D": + return Sequence1DDType(x.get("x_label", "x"), x.get("y_label", "y")) + return create_dtype(type_name) @staticmethod def _get_column_name(column: h5py.Dataset) -> str: @@ -3427,7 +3456,15 @@ def _is_ref_column(column: h5py.Dataset) -> bool: """ Check if a column is ref column. """ - return column.attrs["type"] in REF_COLUMN_TYPE_NAMES and ( + return column.attrs["type"] in [ + "array", + "Embedding", + "Sequence1D", + "Audio", + "Image", + "Mesh", + "Video", + ] and ( h5py.check_string_dtype(column.dtype) or h5py.check_ref_dtype(column.dtype) ) diff --git a/renumics/spotlight/dataset/descriptors/__init__.py b/renumics/spotlight/dataset/descriptors/__init__.py index cacaf663..e4dab9a2 100644 --- a/renumics/spotlight/dataset/descriptors/__init__.py +++ b/renumics/spotlight/dataset/descriptors/__init__.py @@ -9,7 +9,6 @@ from renumics.spotlight.dataset import Dataset from renumics.spotlight.dataset.exceptions import ColumnExistsError, InvalidDTypeError -from renumics.spotlight.dtypes import Audio, Sequence1D from .data_alignment import align_column_data @@ -77,11 +76,11 @@ def catch22( if suffix is None: suffix = "catch24" if catch24 else "catch22" - column_type = dataset.get_column_type(column) - if column_type not in (Audio, Sequence1D): + dtype = dataset.get_dtype(column) + if dtype.name == ("Audio", "Sequence1D"): raise InvalidDTypeError( f"catch22 is only applicable to columns of type `Audio` and " - f'`Sequence1D`, but column "{column}" of type {column_type} received.' + f'`Sequence1D`, but column "{column}" of type {dtype} received.' ) column_names = [] diff --git a/renumics/spotlight/dataset/descriptors/data_alignment.py b/renumics/spotlight/dataset/descriptors/data_alignment.py index 19968cc4..0c54fae9 100644 --- a/renumics/spotlight/dataset/descriptors/data_alignment.py +++ b/renumics/spotlight/dataset/descriptors/data_alignment.py @@ -10,12 +10,7 @@ from skimage.transform import resize_local_mean from renumics.spotlight import ( - Audio, Dataset, - Embedding, - Image, - Sequence1D, - Window, ) from renumics.spotlight.dataset import exceptions @@ -25,10 +20,10 @@ def align_audio_data(dataset: Dataset, column: str) -> Tuple[np.ndarray, np.ndar Align data from an audio column. """ - column_type = dataset.get_column_type(column) - if column_type is not Audio: + dtype = dataset.get_dtype(column) + if dtype.name != "Audio": raise exceptions.InvalidDTypeError( - f'An audio column expected, but column "{column}" of type {column_type} received.' + f'An audio column expected, but column "{column}" of type {dtype} received.' ) notnull_mask = dataset.notnull(column) if notnull_mask.sum() == 0: @@ -70,10 +65,10 @@ def align_embedding_data( """ Align data from an embedding column. """ - column_type = dataset.get_column_type(column) - if column_type is not Embedding: + dtype = dataset.get_dtype(column) + if dtype.name != "Embedding": raise exceptions.InvalidDTypeError( - f'An embedding column expected, but column "{column}" of type {column_type} received.' + f'An embedding column expected, but column "{column}" of type {dtype} received.' ) notnull_mask = dataset.notnull(column) if notnull_mask.sum() == 0: @@ -86,10 +81,10 @@ def align_image_data(dataset: Dataset, column: str) -> Tuple[np.ndarray, np.ndar """ Align data from an image column. """ - column_type = dataset.get_column_type(column) - if column_type is not Image: + dtype = dataset.get_dtype(column) + if dtype.name != "Image": raise exceptions.InvalidDTypeError( - f'An image column expected, but column "{column}" of type {column_type} received.' + f'An image column expected, but column "{column}" of type {dtype} received.' ) notnull_mask = dataset.notnull(column) if notnull_mask.sum() == 0: @@ -128,10 +123,10 @@ def align_sequence_1d_data( """ Align data from an sequence 1D column. """ - column_type = dataset.get_column_type(column) - if column_type is not Sequence1D: + dtype = dataset.get_dtype(column) + if dtype.name != "Sequence1D": raise exceptions.InvalidDTypeError( - f'A sequence 1D column expected, but column "{column}" of type {column_type} received.' + f'A sequence 1D column expected, but column "{column}" of type {dtype} received.' ) notnull_mask = dataset.notnull(column) if notnull_mask.sum() == 0: @@ -165,20 +160,20 @@ def align_column_data( """ Align data from an Spotlight dataset column if possible. """ - column_type = dataset.get_column_type(column) - if column_type is Audio: + dtype = dataset.get_dtype(column) + if dtype.name == "Audio": data, mask = align_audio_data(dataset, column) - elif column_type is Embedding: + elif dtype.name == "Embedding": data, mask = align_embedding_data(dataset, column) - elif column_type is Image: + elif dtype.name == "Image": data, mask = align_image_data(dataset, column) - elif column_type is Sequence1D: + elif dtype.name == "Sequence1D": data, mask = align_sequence_1d_data(dataset, column) - elif column_type in (bool, int, float, Window): + elif dtype.name in ("bool", "int", "float", "Window"): data = dataset[column].astype(np.float64).reshape((len(dataset), -1)) mask = np.full(len(dataset), True) else: - raise NotImplementedError(f"{column_type} column currently not supported.") + raise NotImplementedError(f"{dtype} column currently not supported.") if not allow_nan: # Remove "rows" with `NaN`s. diff --git a/renumics/spotlight/dataset/typing.py b/renumics/spotlight/dataset/typing.py index 0d015fd7..e6a9c91a 100644 --- a/renumics/spotlight/dataset/typing.py +++ b/renumics/spotlight/dataset/typing.py @@ -3,38 +3,33 @@ """ from datetime import datetime -from typing import ( - List, - Optional, - Sequence, - Tuple, - Union, -) +from typing import List, Optional, Sequence, Tuple, Union import numpy as np import trimesh -from typing_extensions import get_args from renumics.spotlight.typing import BoolType, IntType, NumberType, PathOrUrlType -from renumics.spotlight.dtypes import ( +from renumics.spotlight.media import ( Array1dLike, + ImageLike, Embedding, Mesh, Sequence1D, Image, - ImageLike, Audio, - Category, Video, - Window, ) -from renumics.spotlight.dtypes.typing import FileBasedColumnType, NAME_BY_COLUMN_TYPE +OutputType = Union[ + bool, int, float, str, datetime, np.ndarray, Sequence1D, Audio, Image, Mesh, Video +] # Only pure types. -SimpleColumnType = Union[bool, int, float, str, datetime, Category, Window, Embedding] -RefColumnType = Union[np.ndarray, Embedding, Mesh, Sequence1D, Image, Audio, Video] -ExternalColumnType = FileBasedColumnType +SimpleOutputType = Union[bool, int, float, str, datetime, Embedding] +RefOutputType = Union[np.ndarray, Embedding, Mesh, Sequence1D, Image, Audio, Video] +FileBasedOutputType = Union[Audio, Image, Mesh, Video] +ExternalOutputType = FileBasedOutputType +ArrayBasedOutputType = Union[Embedding, Image, Sequence1D] # Pure types, compatible types and `None`. BoolColumnInputType = Optional[BoolType] IntColumnInputType = Optional[IntType] @@ -66,18 +61,17 @@ RefColumnInputType = Union[ ArrayColumnInputType, EmbeddingColumnInputType, + Sequence1DColumnInputType, AudioColumnInputType, ImageColumnInputType, MeshColumnInputType, - Sequence1DColumnInputType, VideoColumnInputType, ] ColumnInputType = Union[SimpleColumnInputType, RefColumnInputType] -ExternalColumnInputType = Optional[PathOrUrlType] - -REF_COLUMN_TYPE_NAMES = [ - NAME_BY_COLUMN_TYPE[column_type] for column_type in get_args(RefColumnType) -] -SIMPLE_COLUMN_TYPE_NAMES = [ - NAME_BY_COLUMN_TYPE[column_type] for column_type in get_args(SimpleColumnType) +FileColumnInputType = Union[ + AudioColumnInputType, + ImageColumnInputType, + MeshColumnInputType, + VideoColumnInputType, ] +ExternalColumnInputType = Optional[PathOrUrlType] diff --git a/renumics/spotlight/dtypes/__init__.py b/renumics/spotlight/dtypes/__init__.py index ff0105ea..671bb74c 100644 --- a/renumics/spotlight/dtypes/__init__.py +++ b/renumics/spotlight/dtypes/__init__.py @@ -1,979 +1,185 @@ -""" -This module provides custom data types for Spotlight dataset. -""" +from datetime import datetime +from typing import Any, Dict, Iterable, Optional, Union -import io -import math -import os -from typing import Dict, IO, List, Optional, Sequence, Tuple, Union -from urllib.parse import urlparse - -import imageio.v3 as iio import numpy as np -import pygltflib -import requests -import trimesh -import validators -from loguru import logger - -from renumics.spotlight.requests import headers -from renumics.spotlight.typing import FileType, NumberType, PathType -from . import exceptions, triangulation -from .base import DType, FileBasedDType -from ..io import audio, gltf, file as file_io - -Array1dLike = Union[Sequence[NumberType], np.ndarray] -Array2dLike = Union[Sequence[Sequence[NumberType]], np.ndarray] -ImageLike = Union[ - Sequence[Sequence[Union[NumberType, Sequence[NumberType]]]], np.ndarray -] - - -class Embedding(DType): - """ - Data sample projected onto a new space. - - Attributes: - data: 1-dimensional array-like. Sample embedding. - dtype: Optional data type of embedding. If `None`, data type inferred - from data. - - Example: - >>> import numpy as np - >>> from renumics.spotlight import Dataset, Embedding - >>> value = np.array(np.random.rand(2)) - >>> embedding = Embedding(value) - >>> with Dataset("docs/example.h5", "w") as dataset: - ... dataset.append_embedding_column("embeddings", 5*[embedding]) - >>> with Dataset("docs/example.h5", "r") as dataset: - ... print(len(dataset["embeddings", 3].data)) - 2 - """ - - data: np.ndarray +from typing_extensions import TypeGuard - def __init__( - self, data: Array1dLike, dtype: Optional[Union[str, np.dtype]] = None - ) -> None: - data_array = np.asarray(data, dtype) - if data_array.ndim != 1 or data_array.size == 0: - raise ValueError( - f"`data` argument should an array-like with shape " - f"`(num_features,)` with `num_features > 0`, but shape " - f"{data_array.shape} received." - ) - if data_array.dtype.str[1] not in "fiu": - raise ValueError( - f"`data` argument should be an array-like with integer or " - f"float dtypes, but dtype {data_array.dtype.name} received." - ) - self.data = data_array - - @classmethod - def decode(cls, value: Union[np.ndarray, np.void]) -> "Embedding": - if not isinstance(value, np.ndarray): - raise TypeError( - f"`value` argument should be a numpy array, but {type(value)} " - f"received." - ) - return cls(value) - - def encode(self, _target: Optional[str] = None) -> np.ndarray: - return self.data - - -class Sequence1D(DType): - """ - One-dimensional ndarray with optional index values. - - Attributes: - index: 1-dimensional array-like of length `num_steps`. Index values (x-axis). - value: 1-dimensional array-like of length `num_steps`. Respective values (y-axis). - dtype: Optional data type of sequence. If `None`, data type inferred - from data. - - Example: - >>> import numpy as np - >>> from renumics.spotlight import Dataset, Sequence1D - >>> index = np.arange(100) - >>> value = np.array(np.random.rand(100)) - >>> sequence = Sequence1D(index, value) - >>> with Dataset("docs/example.h5", "w") as dataset: - ... dataset.append_sequence_1d_column("sequences", 5*[sequence]) - >>> with Dataset("docs/example.h5", "r") as dataset: - ... print(len(dataset["sequences", 2].value)) - 100 - """ - - index: np.ndarray - value: np.ndarray +from .legacy import Audio, Category, Embedding, Image, Mesh, Sequence1D, Video, Window - def __init__( - self, - index: Optional[Array1dLike], - value: Optional[Array1dLike] = None, - dtype: Optional[Union[str, np.dtype]] = None, - ) -> None: - if value is None: - if index is None: - raise ValueError( - "At least one of arguments `index` or `value` should be " - "set, but both `None` values received." - ) - value = index - index = None - - value_array = np.asarray(value, dtype) - if value_array.dtype.str[1] not in "fiu": - raise ValueError( - f"Input values should be array-likes with integer or float " - f"dtype, but dtype {value_array.dtype.name} received." - ) - if index is None: - if value_array.ndim == 2: - if value_array.shape[0] == 2: - self.index = value_array[0] - self.value = value_array[1] - elif value_array.shape[1] == 2: - self.index = value_array[:, 0] - self.value = value_array[:, 1] - else: - raise ValueError( - f"A single 2-dimensional input value should have one " - f"dimension of length 2, but shape {value_array.shape} received." - ) - elif value_array.ndim == 1: - self.value = value_array - if dtype is None: - dtype = self.value.dtype - self.index = np.arange(len(self.value), dtype=dtype) - else: - raise ValueError( - f"A single input value should be 1- or 2-dimensional, but " - f"shape {value_array.shape} received." - ) - else: - if value_array.ndim != 1: - raise ValueError( - f"Value should be 1-dimensional, but shape {value_array.shape} received." - ) - index_array = np.asarray(index, dtype) - if index_array.ndim != 1: - raise ValueError( - f"INdex should be 1-dimensional array-like, but shape " - f"{index_array.shape} received." - ) - if index_array.dtype.str[1] not in "fiu": - raise ValueError( - f"Index should be array-like with integer or float " - f"dtype, but dtype {index_array.dtype.name} received." - ) - self.value = value_array - self.index = index_array - if len(self.value) != len(self.index): - raise ValueError( - f"Lengths of `index` and `value` should match, but lengths " - f"{len(self.index)} and {len(self.value)} received." - ) - - @classmethod - def decode(cls, value: Union[np.ndarray, np.void]) -> "Sequence1D": - if not isinstance(value, np.ndarray): - raise TypeError( - f"`value` argument should be a numpy array, but {type(value)} " - f"received." - ) - if value.ndim != 2 or value.shape[1] != 2: - raise ValueError( - f"`value` argument should be a numpy array with shape " - f"`(num_steps, 2)`, but shape {value.shape} received." - ) - return cls(value[:, 0], value[:, 1]) - - def encode(self, _target: Optional[str] = None) -> np.ndarray: - return np.stack((self.index, self.value), axis=1) - - @classmethod - def empty(cls) -> "Sequence1D": - """ - Create an empty sequence. - """ - return cls(np.empty(0), np.empty(0)) - - -class Mesh(FileBasedDType): - """ - Triangular 3D mesh with optional per-point and per-triangle attributes and - optional per-point displacements over time. - - Example: - >>> import numpy as np - >>> from renumics.spotlight import Dataset, Mesh - >>> points = np.array([[0,0,0],[1,1,1],[0,1,0],[-1,0,1]]) - >>> triangles = np.array([[0,1,2],[2,3,0]]) - >>> mesh = Mesh(points, triangles) - >>> with Dataset("docs/example.h5", "w") as dataset: - ... dataset.append_mesh_column("meshes", 5*[mesh]) - >>> with Dataset("docs/example.h5", "r") as dataset: - ... print(dataset["meshes", 2].triangles) - [[0 1 2] - [2 3 0]] - """ - - _points: np.ndarray - _triangles: np.ndarray - _point_attributes: Dict[str, np.ndarray] - _point_displacements: List[np.ndarray] - - _point_indices: np.ndarray - _triangle_indices: np.ndarray - _triangle_attribute_indices: np.ndarray + +class DType: + _name: str + + def __init__(self, name: str): + self._name = name + + def __str__(self) -> str: + return self.name + + @property + def name(self) -> str: + return self._name + + +class CategoryDType(DType): + _categories: Optional[Dict[str, int]] + _inverted_categories: Optional[Dict[int, str]] def __init__( - self, - points: Array2dLike, - triangles: Array2dLike, - point_attributes: Optional[Dict[str, np.ndarray]] = None, - triangle_attributes: Optional[Dict[str, np.ndarray]] = None, - point_displacements: Optional[Union[np.ndarray, List[np.ndarray]]] = None, + self, categories: Optional[Union[Iterable[str], Dict[str, int]]] = None ): - self._point_attributes = {} - self._point_displacements = [] - self._set_points_triangles(points, triangles) + super().__init__("Category") + if isinstance(categories, dict) or categories is None: + self._categories = categories + else: + self._categories = { + category: code for code, category in enumerate(categories) + } - if point_displacements is None: - point_displacements = [] - self.point_displacements = point_displacements # type: ignore - self.update_attributes(point_attributes, triangle_attributes) + if self._categories is None: + self._inverted_categories = None + else: + self._inverted_categories = { + code: category for category, code in self._categories.items() + } + # invert again to remove duplicate codes + self._categories = { + category: code for code, category in self._inverted_categories.items() + } @property - def points(self) -> np.ndarray: - """ - :code:`np.array` with shape `(num_points, 3)`. Mesh points. - """ - return self._points + def categories(self) -> Optional[Dict[str, int]]: + return self._categories @property - def triangles(self) -> np.ndarray: - """ - :code:`np.array` with shape `(num_triangles, 3)`. Mesh triangles stored as their - CCW nodes referring to the `points` indices. - """ - return self._triangles + def inverted_categories(self) -> Optional[Dict[int, str]]: + return self._inverted_categories - @property - def point_attributes(self) -> Dict[str, np.ndarray]: - """ - Mapping str -> :code:`np.array` with shape `(num_points, ...)`. Point-wise - attributes corresponding to `points`. All possible shapes of a single - attribute can be found in - `renumics.spotlight.mesh_proc.gltf.GLTF_SHAPES`. - """ - return self._point_attributes - @property - def point_displacements(self) -> List[np.ndarray]: - """ - List of arrays with shape `(num_points, 3)`. Point-wise relative - displacements (offsets) over the time corresponding to `points`. - Timestep 0 is omitted since it is explicit stored as absolute values in - `points`. - """ - return self._point_displacements - - @point_displacements.setter - def point_displacements(self, value: Union[np.ndarray, List[np.ndarray]]) -> None: - array = triangulation.attribute_to_array(value) - if array.size == 0: - self._point_displacements = [] - else: - array = array.astype(np.float32) - if array.shape[1] != len(self._points): - array = array[:, self._point_indices] - if array.shape[1:] != (len(self._points), 3): - raise ValueError( - f"Point displacements should have the same shape as points " - f"({self._points.shape}), but shape {array.shape[1:]} " - f"received." - ) - self._point_displacements = list(array) - - @classmethod - def from_trimesh(cls, mesh: trimesh.Trimesh) -> "Mesh": - """ - Import a `trimesh.Trimesh` mesh. - """ - return cls( - mesh.vertices, mesh.faces, mesh.vertex_attributes, mesh.face_attributes - ) - - @classmethod - def from_file(cls, filepath: PathType) -> "Mesh": - """ - Read mesh from a filepath or an URL. - - `trimesh` is used inside, so only supported formats are allowed. - """ - file: Union[str, IO] = ( - str(filepath) if isinstance(filepath, os.PathLike) else filepath - ) - extension = None - if isinstance(file, str): - if validators.url(file): - response = requests.get(file, headers=headers, timeout=30) - if not response.ok: - raise exceptions.InvalidFile(f"URL {file} does not exist.") - extension = os.path.splitext(urlparse(file).path)[1] - if extension == "": - raise exceptions.InvalidFile(f"URL {file} has no file extension.") - file = io.BytesIO(response.content) - elif not os.path.isfile(file): - raise exceptions.InvalidFile( - f"File {file} is neither an existing file nor an existing URL." - ) - try: - mesh = trimesh.load(file, file_type=extension, force="mesh") - except Exception as e: - raise exceptions.InvalidFile( - f"Mesh {filepath} does not exist or could not be read." - ) from e - return cls.from_trimesh(mesh) - - @classmethod - def empty(cls) -> "Mesh": - """ - Create an empty mesh. - """ - return cls(np.empty((0, 3)), np.empty((0, 3), np.int64)) - - @classmethod - def decode(cls, value: Union[np.ndarray, np.void]) -> "Mesh": - gltf_mesh = pygltflib.GLTF2.load_from_bytes(value.tobytes()) - gltf.check_gltf(gltf_mesh) - arrays = gltf.decode_gltf_arrays(gltf_mesh) - primitive = gltf_mesh.meshes[0].primitives[0] - points = arrays[primitive.attributes.POSITION] - triangles = arrays[primitive.indices].reshape((-1, 3)) - point_attributes = { - k[1:]: arrays[v] - for k, v in primitive.attributes.__dict__.items() - if k.startswith("_") - } - point_displacements = [ - arrays[target["POSITION"]] for target in primitive.targets - ] - return cls( - points, triangles, point_attributes, point_displacements=point_displacements - ) - - def encode(self, _target: Optional[str] = None) -> np.void: - bin_data, buffer_views, accessors = gltf.encode_gltf_array( - self._triangles.flatten(), b"", [], [], pygltflib.ELEMENT_ARRAY_BUFFER - ) - mesh_primitive_attributes_kwargs = {"POSITION": 1} - bin_data, buffer_views, accessors = gltf.encode_gltf_array( - self._points, bin_data, buffer_views, accessors - ) - for attr_name, point_attr in self._point_attributes.items(): - mesh_primitive_attributes_kwargs["_" + attr_name] = len(buffer_views) - bin_data, buffer_views, accessors = gltf.encode_gltf_array( - point_attr, bin_data, buffer_views, accessors - ) - morph_targets = [] - for point_displacement in self._point_displacements: - morph_targets.append(pygltflib.Attributes(POSITION=len(buffer_views))) - bin_data, buffer_views, accessors = gltf.encode_gltf_array( - point_displacement, bin_data, buffer_views, accessors - ) - gltf_mesh = pygltflib.GLTF2( - asset=pygltflib.Asset(), - scene=0, - scenes=[pygltflib.Scene(nodes=[0])], - nodes=[pygltflib.Node(mesh=0)], - meshes=[ - pygltflib.Mesh( - primitives=[ - pygltflib.Primitive( - attributes=pygltflib.Attributes( - **mesh_primitive_attributes_kwargs - ), - indices=0, - mode=pygltflib.TRIANGLES, - targets=morph_targets, - ) - ], - ) - ], - accessors=accessors, - bufferViews=buffer_views, - buffers=[pygltflib.Buffer(byteLength=len(bin_data))], - ) - gltf_mesh.set_binary_blob(bin_data) - return np.void(b"".join(gltf_mesh.save_to_bytes())) - - def update_attributes( - self, - point_attributes: Optional[Dict[str, np.ndarray]] = None, - triangle_attributes: Optional[Dict[str, np.ndarray]] = None, - ) -> None: - """ - Update point and/or triangle attributes dict-like. - """ - if point_attributes: - point_attributes = self._sanitize_point_attributes(point_attributes) - self._point_attributes.update(point_attributes) - if triangle_attributes: - triangle_attributes = self._sanitize_triangle_attributes( - triangle_attributes - ) - logger.info("Triangle attributes will be converted to point attributes.") - self._point_attributes.update( - self._triangle_attributes_to_point_attributes(triangle_attributes) - ) - - def interpolate_point_displacements(self, num_timesteps: int) -> None: - """subsample time dependent attributes with new time step count""" - if num_timesteps < 1: - raise ValueError( - f"`num_timesteps` argument should be non-negative, but " - f"{num_timesteps} received." - ) - current_num_timesteps = len(self._point_displacements) - if current_num_timesteps == 0: - logger.info("No displacements found, so cannot interpolate.") - return - if current_num_timesteps == num_timesteps: - return - - def _interpolated_list_access( - arrays: List[np.ndarray], index_float: float - ) -> np.ndarray: - """access a list equally sized numpy arrays with interpolation between two neighbors""" - array_left = arrays[math.floor(index_float)] - array_right = arrays[math.ceil(index_float)] - weight_right = index_float - math.floor(index_float) - return (array_left * (1 - weight_right)) + (array_right * weight_right) - - # simplification assumption : timesteps are equally sized - timesteps = np.linspace(0, current_num_timesteps, num_timesteps + 1)[1:] - - # add implicit 0 displacement for t=0 - displacements = [ - np.zeros_like(self._point_displacements[0]) - ] + self._point_displacements - self._point_displacements = [ - _interpolated_list_access(displacements, t) for t in timesteps - ] - - def _set_points_triangles( - self, points: Array2dLike, triangles: Array2dLike - ) -> None: - # Check points. - points_array = np.asarray(points, np.float32) - if points_array.ndim != 2 or points_array.shape[1] != 3: - raise ValueError( - f"`points` argument should be a numpy array with shape " - f"`(num_points, 3)`, but shape {points_array.shape} received." - ) - # Check triangles. - triangles_array = np.asarray(triangles, np.uint32) - if triangles_array.ndim != 2 or triangles_array.shape[1] != 3: - raise ValueError( - f"`triangles` argument should be a numpy array with shape " - f"`(num_triangles, 3)`, but shape {triangles_array.shape} received." - ) - # Subsample only valid points and triangles. - point_ids = np.arange(len(points_array)) - valid_triangles_mask = ( - (triangles_array[:, 0] != triangles_array[:, 1]) - & (triangles_array[:, 0] != triangles_array[:, 2]) - & (triangles_array[:, 1] != triangles_array[:, 2]) - & np.isin(triangles_array, point_ids).all(axis=1) - ) - self._triangle_indices = np.nonzero(valid_triangles_mask)[0] - self._triangles = triangles_array[self._triangle_indices] - valid_points_mask = np.isin(point_ids, self._triangles) - self._point_indices = np.nonzero(valid_points_mask)[0] - self._points = points_array[self._point_indices] - # Reindex triangles since there can be fewer points than before. - point_ids = point_ids[self._point_indices] - self._triangles, *_ = triangulation.reindex(point_ids, self._triangles) - # Set indices for conversion of triangle attributes to point attributes. - self._triangle_attribute_indices = np.full((len(self._points)), 0, np.uint32) - triangle_indices = np.arange(len(self._triangles), dtype=np.uint32) - self._triangle_attribute_indices[self._triangles[:, 0]] = triangle_indices - self._triangle_attribute_indices[self._triangles[:, 1]] = triangle_indices - self._triangle_attribute_indices[self._triangles[:, 2]] = triangle_indices - - def _triangle_attributes_to_point_attributes( - self, triangle_attributes: Dict[str, np.ndarray] - ) -> Dict[str, np.ndarray]: - return { - f"element_{attr_name}": triangle_attr[self._triangle_attribute_indices] - for attr_name, triangle_attr in triangle_attributes.items() - } - - def _sanitize_point_attributes( - self, point_attributes: Dict[str, np.ndarray] - ) -> Dict[str, np.ndarray]: - if not isinstance(point_attributes, dict): - raise TypeError( - f"`point_attributes` argument should be a dict, but " - f"{type(point_attributes)} received." - ) - valid_point_attributes = {} - for attr_name, point_attr in point_attributes.items(): - point_attr = np.asarray(point_attr) - if len(point_attr) != len(self._points): - point_attr = point_attr[self._point_indices] - if point_attr.dtype.str[1] not in "fiu": - raise ValueError( - f"Point attributes should have one of integer or float " - f'dtypes, but attribute "{attr_name}" of dtype ' - f"{point_attr.dtype.name} received." - ) - point_attr = point_attr.squeeze() - if point_attr.shape[1:] not in gltf.GLTF_SHAPES_LOOKUP.keys(): - logger.warning( - f"Element shape {point_attr.shape[1:]} of the point " - f'attribute "{attr_name}" not supported, attribute will be ' - f"removed." - ) - continue - valid_point_attributes[attr_name] = point_attr.astype( - gltf.GLTF_DTYPES_CONVERSION[point_attr.dtype.str[1:]] - ) - return valid_point_attributes - - def _sanitize_triangle_attributes( - self, triangle_attributes: Dict[str, np.ndarray] - ) -> Dict[str, np.ndarray]: - if not isinstance(triangle_attributes, dict): - raise TypeError( - f"`triangle_attributes` argument should be a dict, but " - f"{type(triangle_attributes)} received." - ) - valid_triangle_attributes = {} - for attr_name, triangle_attr in triangle_attributes.items(): - triangle_attr = np.asarray(triangle_attr) - if len(triangle_attr) != len(self._triangles): - triangle_attr = triangle_attr[self._triangle_indices] - if triangle_attr.dtype.str[1] not in "fiu": - raise ValueError( - f"Triangle attributes should have one of integer or float " - f'dtypes, but attribute "{attr_name}" of dtype ' - f"{triangle_attr.dtype.name} received." - ) - triangle_attr = triangle_attr.squeeze() - if triangle_attr.shape[1:] not in gltf.GLTF_SHAPES_LOOKUP.keys(): - logger.warning( - f"Element shape {triangle_attr.shape[1:]} of the triangle " - f'attribute "{attr_name}" not supported, attribute will be ' - f"removed." - ) - continue - valid_triangle_attributes[attr_name] = triangle_attr.astype( - gltf.GLTF_DTYPES_CONVERSION[triangle_attr.dtype.str[1:]] - ) - return valid_triangle_attributes - - -class Image(FileBasedDType): - """ - An RGB(A) or grayscale image that will be saved in encoded form. - - Attributes: - data: Array-like with shape `(num_rows, num_columns)` or - `(num_rows, num_columns, num_channels)` with `num_channels` equal to - 3, or 4; with dtype "uint8". - - Example: - >>> import numpy as np - >>> from renumics.spotlight import Dataset, Image - >>> data = np.full([100,100,3], 255, dtype=np.uint8) # white uint8 image - >>> image = Image(data) - >>> float_data = np.random.uniform(0, 1, (100, 100)) # random grayscale float image - >>> float_image = Image(float_data) - >>> with Dataset("docs/example.h5", "w") as dataset: - ... dataset.append_image_column("images", [image, float_image, data, float_data]) - >>> with Dataset("docs/example.h5", "r") as dataset: - ... print(dataset["images", 0].data[50][50]) - ... print(dataset["images", 3].data.dtype) - [255 255 255] - uint8 - """ - - data: np.ndarray - - def __init__(self, data: ImageLike) -> None: - data_array = np.asarray(data) - if ( - data_array.size == 0 - or data_array.ndim != 2 - and (data_array.ndim != 3 or data_array.shape[-1] not in (1, 3, 4)) - ): - raise ValueError( - f"`data` argument should be a numpy array with shape " - f"`(num_rows, num_columns, num_channels)` or " - f"`(num_rows, num_columns)` or with `num_rows > 0`, " - f"`num_cols > 0` and `num_channels` equal to 1, 3, or 4, but " - f"shape {data_array.shape} received." - ) - if data_array.dtype.str[1] not in "fiu": - raise ValueError( - f"`data` argument should be a numpy array with integer or " - f"float dtypes, but dtype {data_array.dtype.name} received." - ) - if data_array.ndim == 3 and data_array.shape[2] == 1: - data_array = data_array.squeeze(axis=2) - if data_array.dtype.str[1] == "f": - logger.info( - 'Image data converted to "uint8" dtype by multiplication with ' - "255 and rounding." - ) - data_array = (255 * data_array).round() - self.data = data_array.astype("uint8") - - @classmethod - def from_file(cls, filepath: FileType) -> "Image": - """ - Read image from a filepath, an URL, or a file-like object. - - `imageio` is used inside, so only supported formats are allowed. - """ - with file_io.as_file(filepath) as file: - try: - image_array = iio.imread(file, index=False) # type: ignore - except Exception as e: - raise exceptions.InvalidFile( - f"Image {filepath} does not exist or could not be read." - ) from e - return cls(image_array) - - @classmethod - def from_bytes(cls, blob: bytes) -> "Image": - """ - Read image from raw bytes. - - `imageio` is used inside, so only supported formats are allowed. - """ - try: - image_array = iio.imread(blob, index=False) # type: ignore - except Exception as e: - raise exceptions.InvalidFile( - "Image could not be read from the given bytes." - ) from e - return cls(image_array) - - @classmethod - def empty(cls) -> "Image": - """ - Create a transparent 1 x 1 image. - """ - return cls(np.zeros((1, 1, 4), np.uint8)) - - @classmethod - def decode(cls, value: Union[np.ndarray, np.void]) -> "Image": - if isinstance(value, np.void): - buffer = io.BytesIO(value.tolist()) - return cls(iio.imread(buffer, extension=".png", index=False)) - raise TypeError( - f"`value` should be a `numpy.void` instance, but {type(value)} " - f"received." - ) - - def encode(self, _target: Optional[str] = None) -> np.void: - buf = io.BytesIO() - iio.imwrite(buf, self.data, extension=".png") - return np.void(buf.getvalue()) - - -class Audio(FileBasedDType): - """ - An Audio Signal that will be saved in encoded form. - - All formats and codecs supported by AV are supported for read. - - Attributes: - data: Array-like with shape `(num_samples, num_channels)` - with `num_channels` <= 5. - If `data` has a float dtype, its values should be between -1 and 1. - If `data` has an int dtype, its values should be between minimum and - maximum possible values for the particular int dtype. - If `data` has an unsigned int dtype, ist values should be between 0 - and maximum possible values for the particular unsigned int dtype. - sampling_rate: Sampling rate (samples per seconds) - - Example: - >>> import numpy as np - >>> from renumics.spotlight import Dataset, Audio - >>> samplerate = 44100 - >>> fs = 100 # 100 Hz audio signal - >>> time = np.linspace(0.0, 1.0, samplerate) - >>> amplitude = np.iinfo(np.int16).max * 0.4 - >>> data = np.array(amplitude * np.sin(2.0 * np.pi * fs * time), dtype=np.int16) - >>> audio = Audio(samplerate, np.array([data, data]).T) # int16 stereo signal - >>> float_data = 0.5 * np.cos(2.0 * np.pi * fs * time).astype(np.float32) - >>> float_audio = Audio(samplerate, float_data) # float32 mono signal - >>> with Dataset("docs/example.h5", "w") as dataset: - ... dataset.append_audio_column("audio", [audio, float_audio]) - ... dataset.append_audio_column("lossy_audio", [audio, float_audio], lossy=True) - >>> with Dataset("docs/example.h5", "r") as dataset: - ... print(dataset["audio", 0].data[100]) - ... print(f"{dataset['lossy_audio', 1].data[0, 0]:.5g}") - [12967 12967] - 0.4596 - """ - - data: np.ndarray - sampling_rate: int - - def __init__(self, sampling_rate: int, data: Array2dLike) -> None: - data_array = np.asarray(data) - is_valid_multi_channel = ( - data_array.size > 0 and data_array.ndim == 2 and data_array.shape[1] <= 5 - ) - is_valid_mono = data_array.size > 0 and data_array.ndim == 1 - if not (is_valid_multi_channel or is_valid_mono): - raise ValueError( - f"`data` argument should be a 1D array for mono data" - f" or a 2D numpy array with shape " - f"`(num_samples, num_channels)` and with num_channels <= 5, " - f"but shape {data_array.shape} received." - ) - if data_array.dtype not in [np.float32, np.int32, np.int16, np.uint8]: - raise ValueError( - f"`data` argument should be a numpy array with " - f"dtype np.float32, np.int32, np.int16 or np.uint8, " - f"but dtype {data_array.dtype.name} received." - ) - self.data = data_array - self.sampling_rate = sampling_rate - - @classmethod - def from_file(cls, filepath: FileType) -> "Audio": - """ - Read audio file from a filepath, an URL, or a file-like object. - - `pyav` is used inside, so only supported formats are allowed. - """ - try: - data, sampling_rate = audio.read_audio(filepath) - except Exception as e: - raise exceptions.InvalidFile( - f"Audio file {filepath} does not exist or could not be read." - ) from e - return cls(sampling_rate, data) - - @classmethod - def from_bytes(cls, blob: bytes) -> "Audio": - """ - Read audio from raw bytes. - - `pyav` is used inside, so only supported formats are allowed. - """ - try: - data, sampling_rate = audio.read_audio(io.BytesIO(blob)) - except Exception as e: - raise exceptions.InvalidFile( - "Audio could not be read from the given bytes." - ) from e - return cls(sampling_rate, data) - - @classmethod - def empty(cls) -> "Audio": - """ - Create a single zero-value sample stereo audio signal. - """ - return cls(1, np.zeros((1, 2), np.int16)) - - @classmethod - def decode(cls, value: Union[np.ndarray, np.void]) -> "Audio": - if isinstance(value, np.void): - buffer = io.BytesIO(value.tolist()) - data, sampling_rate = audio.read_audio(buffer) - return cls(sampling_rate, data) - raise TypeError( - f"`value` should be a `numpy.void` instance, but {type(value)} " - f"received." - ) - - def encode(self, target: Optional[str] = None) -> np.void: - format_, codec = self.get_format_codec(target) - buffer = io.BytesIO() - audio.write_audio(buffer, self.data, self.sampling_rate, format_, codec) - return np.void(buffer.getvalue()) - - @staticmethod - def get_format_codec(target: Optional[str] = None) -> Tuple[str, str]: - """ - Get an audio format and an audio codec by an `target`. - """ - format_ = "wav" if target is None else target.lstrip(".").lower() - codec = {"wav": "pcm_s16le", "ogg": "libvorbis", "mp3": "libmp3lame"}.get( - format_, format_ - ) - return format_, codec - - -class Category(str): - """ - A string value that takes only a limited number of possible values (categories). - - The corresponding categories can be got and set with get/set_column_attributes['categories']. - - Dummy class for window column creation, should not be explicitly used as - input data. - - Example: - >>> import numpy as np - >>> from renumics.spotlight import Dataset - >>> with Dataset("docs/example.h5", "w") as dataset: - ... dataset.append_categorical_column("my_new_cat", - ... categories=["red", "green", "blue"],) - ... dataset.append_row(my_new_cat="blue") - ... dataset.append_row(my_new_cat="green") - >>> with Dataset("docs/example.h5", "r") as dataset: - ... print(dataset["my_new_cat", 1]) - green - - Example: - >>> import numpy as np - >>> import datetime - >>> from renumics.spotlight import Dataset - >>> with Dataset("docs/example.h5", "w") as dataset: - ... dataset.append_categorical_column("my_new_cat", - ... categories=["red", "green", "blue"],) - ... current_categories = dataset.get_column_attributes("my_new_cat")["categories"] - ... dataset.set_column_attributes("my_new_cat", categories={**current_categories, - ... "black":100}) - ... dataset.append_row(my_new_cat="black") - >>> with Dataset("docs/example.h5", "r") as dataset: - ... print(dataset["my_new_cat", 0]) - black - """ - - -class Video(FileBasedDType): - """ - A video object. No encoding or decoding is currently performed on the python - side, so all formats will be saved into dataset without compatibility check, - but only the formats supported by your browser (apparently .mp4, .ogg, - .webm, .mov etc.) can be played in Spotlight. - """ - - data: bytes - - def __init__(self, data: bytes) -> None: - if not isinstance(data, bytes): - raise TypeError( - f"`data` argument should be video bytes, but type {type(data)} " - f"received." - ) - self.data = data - - @classmethod - def from_file(cls, filepath: PathType) -> "Video": - """ - Read video from a filepath or an URL. - """ - prepared_file = str(filepath) if isinstance(filepath, os.PathLike) else filepath - if not isinstance(prepared_file, str): - raise TypeError( - "`filepath` should be a string or an `os.PathLike` instance, " - f"but value {prepared_file} or type {type(prepared_file)} " - f"received." - ) - if validators.url(prepared_file): - response = requests.get( - prepared_file, headers=headers, stream=True, timeout=10 - ) - if not response.ok: - raise exceptions.InvalidFile(f"URL {prepared_file} does not exist.") - return cls(response.raw.data) - if os.path.isfile(prepared_file): - with open(filepath, "rb") as f: - return cls(f.read()) - raise exceptions.InvalidFile( - f"File {prepared_file} is neither an existing file nor an existing URL." - ) - - @classmethod - def from_bytes(cls, blob: bytes) -> "Video": - """ - Read video from raw bytes. - """ - return cls(blob) - - @classmethod - def empty(cls) -> "Video": - """ - Create an empty video instance. - """ - return cls(b"\x00") - - @classmethod - def decode(cls, value: Union[np.ndarray, np.void]) -> "Video": - if isinstance(value, np.void): - return cls(value.tolist()) - raise TypeError( - f"`value` should be a `numpy.void` instance, but {type(value)} " - f"received." - ) - - def encode(self, _target: Optional[str] = None) -> np.void: - return np.void(self.data) - - -class Window: - - """ - A pair of two timestamps in seconds which can be later projected onto - continuous data (only :class:`Audio ` - is currently supported). - - Dummy class for window column creation - (see :func:`Dataset.append_column `), - should not be explicitly used as input data. - - To create a window column, use - :func:`Dataset.append_window_column ` - method. - - Examples: - - >>> import numpy as np - >>> from renumics.spotlight import Dataset - >>> with Dataset("docs/example.h5", "w") as dataset: - ... dataset.append_window_column("window", [[1, 2]] * 4) - ... dataset.append_row(window=(0, 1)) - ... dataset.append_row(window=np.array([-1, 0])) - >>> with Dataset("docs/example.h5", "r") as dataset: - ... print(dataset["window"]) - [[ 1. 2.] - [ 1. 2.] - [ 1. 2.] - [ 1. 2.] - [ 0. 1.] - [-1. 0.]] - - - >>> import numpy as np - >>> from renumics.spotlight import Dataset - >>> with Dataset("docs/example.h5", "w") as dataset: - ... dataset.append_int_column("start", range(5)) - ... dataset.append_float_column("end", dataset["start"] + 2) - ... print(dataset["start"]) - ... print(dataset["end"]) - [0 1 2 3 4] - [2. 3. 4. 5. 6.] - >>> with Dataset("docs/example.h5", "a") as dataset: - ... dataset.append_window_column("window", zip(dataset["start"], dataset["end"])) - >>> with Dataset("docs/example.h5", "r") as dataset: - ... print(dataset["window"]) - [[0. 2.] - [1. 3.] - [2. 4.] - [3. 5.] - [4. 6.]] - """ +class Sequence1DDType(DType): + x_label: str + y_label: str + + def __init__(self, x_label: str = "x", y_label: str = "y"): + super().__init__("Sequence1D") + self.x_label = x_label + self.y_label = y_label + + +ALIASES: Dict[Any, DType] = {} + + +def register_dtype(dtype: DType, aliases: list) -> None: + for alias in aliases: + assert dtype.name.lower() not in ALIASES + ALIASES[dtype.name.lower()] = dtype + assert alias not in ALIASES + ALIASES[alias] = dtype + + +bool_dtype = DType("bool") +register_dtype(bool_dtype, [bool]) +int_dtype = DType("int") +register_dtype(int_dtype, [int]) +float_dtype = DType("float") +register_dtype(float_dtype, [float]) +bytes_dtype = DType("bytes") +register_dtype(bytes_dtype, [bytes]) +str_dtype = DType("str") +register_dtype(str_dtype, [str]) +datetime_dtype = DType("datetime") +register_dtype(datetime_dtype, [datetime]) +category_dtype = CategoryDType() +register_dtype(category_dtype, [Category]) +window_dtype = DType("Window") +register_dtype(window_dtype, [Window]) +embedding_dtype = DType("Embedding") +register_dtype(embedding_dtype, [Embedding]) +array_dtype = DType("array") +register_dtype(array_dtype, [np.ndarray]) +image_dtype = DType("Image") +register_dtype(image_dtype, [Image]) +audio_dtype = DType("Audio") +register_dtype(audio_dtype, [Audio]) +mesh_dtype = DType("Mesh") +register_dtype(mesh_dtype, [Mesh]) +sequence_1d_dtype = Sequence1DDType() +register_dtype(sequence_1d_dtype, [Sequence1D]) +video_dtype = DType("Video") +register_dtype(video_dtype, [Video]) +mixed_dtype = DType("mixed") + + +DTypeMap = Dict[str, DType] + + +def create_dtype(x: Any) -> DType: + if isinstance(x, DType): + return x + if isinstance(x, str): + return ALIASES[x.lower()] + return ALIASES[x] + + +def is_bool_dtype(dtype: DType) -> bool: + return dtype.name == "bool" + + +def is_int_dtype(dtype: DType) -> bool: + return dtype.name == "int" + + +def is_float_dtype(dtype: DType) -> bool: + return dtype.name == "float" + + +def is_str_dtype(dtype: DType) -> bool: + return dtype.name == "str" + + +def is_datetime_dtype(dtype: DType) -> bool: + return dtype.name == "datetime" + + +def is_category_dtype(dtype: DType) -> TypeGuard[CategoryDType]: + return dtype.name == "Category" + + +def is_array_dtype(dtype: DType) -> bool: + return dtype.name == "array" + + +def is_window_dtype(dtype: DType) -> bool: + return dtype.name == "Window" + + +def is_embedding_dtype(dtype: DType) -> bool: + return dtype.name == "Embedding" + + +def is_sequence_1d_dtype(dtype: DType) -> TypeGuard[Sequence1DDType]: + return dtype.name == "Sequence1D" + + +def is_audio_dtype(dtype: DType) -> bool: + return dtype.name == "Audio" + + +def is_image_dtype(dtype: DType) -> bool: + return dtype.name == "Image" + + +def is_mesh_dtype(dtype: DType) -> bool: + return dtype.name == "Mesh" + + +def is_video_dtype(dtype: DType) -> bool: + return dtype.name == "Video" + + +def is_scalar_dtype(dtype: DType) -> bool: + return dtype.name in ("bool", "int", "float") + + +def is_file_dtype(dtype: DType) -> bool: + return dtype.name in ("Audio", "Image", "Video", "Mesh") diff --git a/renumics/spotlight/dtypes/conversion.py b/renumics/spotlight/dtypes/conversion.py index d699eb5b..2e6421fe 100644 --- a/renumics/spotlight/dtypes/conversion.py +++ b/renumics/spotlight/dtypes/conversion.py @@ -16,11 +16,9 @@ from abc import ABCMeta import ast from collections import defaultdict -from dataclasses import dataclass import inspect import io import os -from inspect import signature from typing import ( Callable, List, @@ -31,7 +29,6 @@ Dict, get_args, get_origin, - cast, ) import datetime from filetype import filetype @@ -44,24 +41,17 @@ from renumics.spotlight.cache import external_data_cache from renumics.spotlight.io import audio from renumics.spotlight.io.file import as_file - -from renumics.spotlight.dtypes.exceptions import InvalidFile - +from renumics.spotlight.media.exceptions import InvalidFile from renumics.spotlight.backend.exceptions import Problem - -from .typing import ( - ColumnType, - Category, - FileBasedColumnType, - Window, - Sequence1D, - Embedding, - Image, - Audio, - Video, - Mesh, - get_column_type_name, +from renumics.spotlight.dtypes import ( + CategoryDType, + DType, + audio_dtype, + image_dtype, + mesh_dtype, + video_dtype, ) +from renumics.spotlight.media import Sequence1D, Image, Audio, Video, Mesh NormalizedValue = Union[ @@ -86,15 +76,6 @@ ] -@dataclass(frozen=True) -class DTypeOptions: - """ - All possible dtype options - """ - - categories: Optional[Dict[str, int]] = None - - class ConversionError(Exception): """ Conversion Error @@ -113,7 +94,7 @@ class ConversionFailed(Problem): def __init__( self, value: NormalizedValue, - dtype: Type[ColumnType], + dtype: DType, reason: Optional[str] = None, ) -> None: super().__init__( @@ -133,58 +114,46 @@ class NoConverterAvailable(Problem): No matching converter could be applied """ - def __init__(self, value: NormalizedValue, dtype: Type[ColumnType]) -> None: + def __init__(self, value: NormalizedValue, dtype: DType) -> None: msg = f"No Converter for {type(value)} -> {dtype}" super().__init__(title="No matching converter", detail=msg, status_code=422) N = TypeVar("N", bound=NormalizedValue) -Converter = Callable[[N, DTypeOptions], ConvertedValue] -ConverterWithoutOptions = Callable[[N], ConvertedValue] +Converter = Callable[[N, DType], ConvertedValue] _converters_table: Dict[ - Type[NormalizedValue], Dict[Type[ColumnType], List[Converter]] + Type[NormalizedValue], Dict[str, List[Converter]] ] = defaultdict(lambda: defaultdict(list)) _simple_converters_table: Dict[ - Type[NormalizedValue], Dict[Type[ColumnType], List[Converter]] + Type[NormalizedValue], Dict[str, List[Converter]] ] = defaultdict(lambda: defaultdict(list)) def register_converter( from_type: Type[N], - to_type: Type[ColumnType], - converter: Union[Converter[N], ConverterWithoutOptions[N]], + to_type: str, + converter: Converter[N], simple: Optional[bool] = None, ) -> None: """ register a converter from NormalizedType to ColumnType """ - - parameter_count = len(signature(converter).parameters) - if parameter_count == 2: - converter_with_options = cast(Converter[N], converter) - else: - - def converter_with_options(value: N, _: DTypeOptions, /) -> ConvertedValue: - return cast(Converter[N], converter)(value) # type: ignore - if simple is None: - _simple_converters_table[from_type][to_type].append(converter_with_options) # type: ignore - _converters_table[from_type][to_type].append(converter_with_options) # type: ignore + _simple_converters_table[from_type][to_type].append(converter) # type: ignore + _converters_table[from_type][to_type].append(converter) # type: ignore elif simple: - _simple_converters_table[from_type][to_type].append(converter_with_options) # type: ignore + _simple_converters_table[from_type][to_type].append(converter) # type: ignore else: - _converters_table[from_type][to_type].append(converter_with_options) # type: ignore + _converters_table[from_type][to_type].append(converter) # type: ignore -def convert(to_type: Type[ColumnType], simple: Optional[bool] = None) -> Callable: +def convert(to_type: str, simple: Optional[bool] = None) -> Callable: """ Decorator for simplified registration of converters """ - def _decorate( - func: Union[Converter[N], ConverterWithoutOptions[N]] - ) -> Union[Converter[N], ConverterWithoutOptions[N]]: + def _decorate(func: Converter[N]) -> Converter[N]: value_annotation = next(iter(func.__annotations__.values())) if origin := get_origin(value_annotation): @@ -205,45 +174,41 @@ def _decorate( def convert_to_dtype( - value: NormalizedValue, - dtype: Type[ColumnType], - dtype_options: DTypeOptions = DTypeOptions(), - simple: bool = False, - check: bool = True, + value: NormalizedValue, dtype: DType, simple: bool = False, check: bool = True ) -> ConvertedValue: """ Convert normalized type from data source to internal Spotlight DType """ registered_converters = ( - _simple_converters_table[type(value)][dtype] + _simple_converters_table[type(value)][dtype.name] if simple - else _converters_table[type(value)][dtype] + else _converters_table[type(value)][dtype.name] ) last_conversion_error: Optional[ConversionError] = None for converter in registered_converters: try: - return converter(value, dtype_options) + return converter(value, dtype) except ConversionError as e: last_conversion_error = e try: if value is None: return None - if dtype is bool: + if dtype.name == "bool": return bool(value) # type: ignore - if dtype is int: + if dtype.name == "int": return int(value) # type: ignore - if dtype is float: + if dtype.name == "float": return float(value) # type: ignore - if dtype is str: + if dtype.name == "str": str_value = str(value) if simple and len(str_value) > 100: return str_value[:97] + "..." return str_value - if dtype is np.ndarray: + if dtype.name == "array": if simple: return "[...]" if isinstance(value, list): @@ -251,13 +216,6 @@ def convert_to_dtype( if isinstance(value, np.ndarray): return value - if dtype is Category and np.issubdtype(np.dtype(type(value)), np.integer): - assert dtype_options.categories is not None - value_int = int(value) # type: ignore - if value_int != -1 and value_int not in dtype_options.categories.values(): - raise ConversionFailed(value, dtype) - return value_int - except (TypeError, ValueError) as e: if check: raise ConversionFailed(value, dtype) from e @@ -273,52 +231,66 @@ def convert_to_dtype( return None -@convert(datetime.datetime) -def _(value: datetime.datetime) -> datetime.datetime: +@convert("datetime") +def _(value: datetime.datetime, _: DType) -> datetime.datetime: return value -@convert(datetime.datetime) -def _(value: Union[str, np.str_]) -> Optional[datetime.datetime]: +@convert("datetime") +def _(value: Union[str, np.str_], _: DType) -> Optional[datetime.datetime]: if value == "": return None return datetime.datetime.fromisoformat(value) -@convert(datetime.datetime) # type: ignore -def _(value: np.datetime64) -> datetime.datetime: +@convert("datetime") +def _(value: np.datetime64, _: DType) -> Optional[datetime.datetime]: return value.tolist() -@convert(Category) -def _(value: Union[str, np.str_], options: DTypeOptions) -> int: - if not options.categories: +@convert("Category") +def _(value: Union[str, np.str_], dtype: CategoryDType) -> int: + categories = dtype.categories + if not categories: return -1 - return options.categories[value] + return categories[value] -@convert(Category) -def _(_: None) -> int: +@convert("Category") +def _(_: None, _dtype: CategoryDType) -> int: return -1 -@convert(Category) -def _(value: int) -> int: - return value - - -@convert(Window) -def _(value: list) -> np.ndarray: +@convert("Category") +def _( + value: Union[ + int, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ], + _: CategoryDType, +) -> int: + return int(value) + + +@convert("Window") +def _(value: list, _: DType) -> np.ndarray: return np.array(value, dtype=np.float64) -@convert(Window) -def _(value: np.ndarray) -> np.ndarray: +@convert("Window") +def _(value: np.ndarray, _: DType) -> np.ndarray: return value.astype(np.float64) -@convert(Window) -def _(value: Union[str, np.str_]) -> np.ndarray: +@convert("Window") +def _(value: Union[str, np.str_], _: DType) -> np.ndarray: try: obj = ast.literal_eval(value) array = np.array(obj, dtype=np.float64) @@ -329,18 +301,18 @@ def _(value: Union[str, np.str_]) -> np.ndarray: raise ConversionError("Cannot interpret string as a window") -@convert(Embedding, simple=False) -def _(value: list) -> np.ndarray: +@convert("Embedding", simple=False) +def _(value: list, _: DType) -> np.ndarray: return np.array(value, dtype=np.float64) -@convert(Embedding, simple=False) -def _(value: np.ndarray) -> np.ndarray: +@convert("Embedding", simple=False) +def _(value: np.ndarray, _: DType) -> np.ndarray: return value.astype(np.float64) -@convert(Embedding, simple=False) -def _(value: Union[str, np.str_]) -> np.ndarray: +@convert("Embedding", simple=False) +def _(value: Union[str, np.str_], _: DType) -> np.ndarray: try: obj = ast.literal_eval(value) array = np.array(obj, dtype=np.float64) @@ -351,13 +323,13 @@ def _(value: Union[str, np.str_]) -> np.ndarray: raise ConversionError("Cannot interpret string as an embedding") -@convert(Sequence1D, simple=False) -def _(value: Union[np.ndarray, list], _: DTypeOptions) -> np.ndarray: +@convert("Sequence1D", simple=False) +def _(value: Union[np.ndarray, list], _: DType) -> np.ndarray: return Sequence1D(value).encode() -@convert(Sequence1D, simple=False) -def _(value: Union[str, np.str_]) -> np.ndarray: +@convert("Sequence1D", simple=False) +def _(value: Union[str, np.str_], _: DType) -> np.ndarray: try: obj = ast.literal_eval(value) return Sequence1D(obj).encode() @@ -365,113 +337,113 @@ def _(value: Union[str, np.str_]) -> np.ndarray: raise ConversionError("Cannot interpret string as a 1D sequence") -@convert(Image, simple=False) -def _(value: Union[str, np.str_]) -> bytes: +@convert("Image", simple=False) +def _(value: Union[str, np.str_], _: DType) -> bytes: try: - if data := read_external_value(value, Image): + if data := read_external_value(value, image_dtype): return data.tolist() except InvalidFile: raise ConversionError() raise ConversionError() -@convert(Image, simple=False) -def _(value: Union[bytes, np.bytes_]) -> bytes: +@convert("Image", simple=False) +def _(value: Union[bytes, np.bytes_], _: DType) -> bytes: return Image.from_bytes(value).encode().tolist() -@convert(Image, simple=False) -def _(value: np.ndarray) -> bytes: +@convert("Image", simple=False) +def _(value: np.ndarray, _: DType) -> bytes: return Image(value).encode().tolist() -@convert(Audio, simple=False) -def _(value: Union[str, np.str_]) -> bytes: +@convert("Audio", simple=False) +def _(value: Union[str, np.str_], _: DType) -> bytes: try: - if data := read_external_value(value, Audio): + if data := read_external_value(value, audio_dtype): return data.tolist() except (InvalidFile, IndexError, ValueError): raise ConversionError() raise ConversionError() -@convert(Audio, simple=False) -def _(value: Union[bytes, np.bytes_]) -> bytes: +@convert("Audio", simple=False) +def _(value: Union[bytes, np.bytes_], _: DType) -> bytes: return Audio.from_bytes(value).encode().tolist() -@convert(Video, simple=False) -def _(value: Union[str, np.str_]) -> bytes: +@convert("Video", simple=False) +def _(value: Union[str, np.str_], _: DType) -> bytes: try: - if data := read_external_value(value, Video): + if data := read_external_value(value, video_dtype): return data.tolist() except InvalidFile: raise ConversionError() raise ConversionError() -@convert(Video, simple=False) -def _(value: Union[bytes, np.bytes_]) -> bytes: +@convert("Video", simple=False) +def _(value: Union[bytes, np.bytes_], _: DType) -> bytes: return Video.from_bytes(value).encode().tolist() -@convert(Mesh, simple=False) -def _(value: Union[str, np.str_]) -> bytes: +@convert("Mesh", simple=False) +def _(value: Union[str, np.str_], _: DType) -> bytes: try: - if data := read_external_value(value, Mesh): + if data := read_external_value(value, mesh_dtype): return data.tolist() except InvalidFile: raise ConversionError() raise ConversionError() -@convert(Mesh, simple=False) -def _(value: Union[bytes, np.bytes_]) -> bytes: +@convert("Mesh", simple=False) +def _(value: Union[bytes, np.bytes_], _: DType) -> bytes: return value # this should not be necessary -@convert(Mesh, simple=False) # type: ignore -def _(value: trimesh.Trimesh) -> bytes: +@convert("Mesh", simple=False) # type: ignore +def _(value: trimesh.Trimesh, _: DType) -> bytes: return Mesh.from_trimesh(value).encode().tolist() -@convert(Embedding, simple=True) -@convert(Sequence1D, simple=True) -def _(_: Union[np.ndarray, list, str, np.str_]) -> str: +@convert("Embedding", simple=True) +@convert("Sequence1D", simple=True) +def _(_: Union[np.ndarray, list, str, np.str_], _dtype: DType) -> str: return "[...]" -@convert(Image, simple=True) -def _(_: np.ndarray) -> str: +@convert("Image", simple=True) +def _(_: np.ndarray, _dtype: DType) -> str: return "[...]" -@convert(Image, simple=True) -@convert(Audio, simple=True) -@convert(Video, simple=True) -@convert(Mesh, simple=True) -def _(value: Union[str, np.str_]) -> str: +@convert("Image", simple=True) +@convert("Audio", simple=True) +@convert("Video", simple=True) +@convert("Mesh", simple=True) +def _(value: Union[str, np.str_], _: DType) -> str: return str(value) -@convert(Image, simple=True) -@convert(Audio, simple=True) -@convert(Video, simple=True) -@convert(Mesh, simple=True) -def _(_: Union[bytes, np.bytes_]) -> str: +@convert("Image", simple=True) +@convert("Audio", simple=True) +@convert("Video", simple=True) +@convert("Mesh", simple=True) +def _(_: Union[bytes, np.bytes_], _dtype: DType) -> str: return "" # this should not be necessary -@convert(Mesh, simple=True) # type: ignore -def _(_: trimesh.Trimesh) -> str: +@convert("Mesh", simple=True) # type: ignore +def _(_: trimesh.Trimesh, _dtype: DType) -> str: return "" def read_external_value( path_or_url: Optional[str], - column_type: Type[FileBasedColumnType], + dtype: DType, target_format: Optional[str] = None, workdir: PathType = ".", ) -> Optional[np.void]: @@ -481,7 +453,7 @@ def read_external_value( """ if not path_or_url: return None - cache_key = f"external:{path_or_url},{get_column_type_name(column_type)}" + cache_key = f"external:{path_or_url},{dtype}" if target_format is not None: cache_key += f"/{target_format}" try: @@ -490,7 +462,7 @@ def read_external_value( except KeyError: ... - value = _decode_external_value(path_or_url, column_type, target_format, workdir) + value = _decode_external_value(path_or_url, dtype, target_format, workdir) external_data_cache[cache_key] = value.tolist() return value @@ -508,7 +480,7 @@ def prepare_path_or_url(path_or_url: PathOrUrlType, workdir: PathType) -> str: def _decode_external_value( path_or_url: PathOrUrlType, - column_type: Type[FileBasedColumnType], + dtype: DType, target_format: Optional[str] = None, workdir: PathType = ".", ) -> np.void: @@ -517,7 +489,7 @@ def _decode_external_value( """ path_or_url = prepare_path_or_url(path_or_url, workdir) - if column_type is Audio: + if dtype.name == "Audio": file = audio.prepare_input_file(path_or_url, reusable=True) # `file` is a filepath of type `str` or an URL downloaded as `io.BytesIO`. input_format, input_codec = audio.get_format_codec(file) @@ -549,7 +521,7 @@ def _decode_external_value( audio.transcode_audio(file, buffer, output_format, output_codec) return np.void(buffer.getvalue()) - if column_type is Image: + if dtype.name == "Image": with as_file(path_or_url) as file: kind = filetype.guess(file) if kind is not None and kind.mime.split("/")[1] in ( @@ -566,5 +538,8 @@ def _decode_external_value( # `image/tiff`s become blank in frontend, so convert them too. return Image.from_file(file).encode(target_format) - data_obj = column_type.from_file(path_or_url) - return data_obj.encode(target_format) + if dtype.name == "Mesh": + return Mesh.from_file(path_or_url).encode(target_format) + if dtype.name == "Video": + return Video.from_file(path_or_url).encode(target_format) + assert False diff --git a/renumics/spotlight/dtypes/legacy.py b/renumics/spotlight/dtypes/legacy.py new file mode 100644 index 00000000..86a7a557 --- /dev/null +++ b/renumics/spotlight/dtypes/legacy.py @@ -0,0 +1,100 @@ +from renumics.spotlight.media import ( + Embedding, # noqa: F401 + Sequence1D, # noqa: F401 + Audio, # noqa: F401 + Image, # noqa: F401 + Mesh, # noqa: F401 + Video, # noqa: F401 +) + + +class Category: + """ + A string value that takes only a limited number of possible values (categories). + + The corresponding categories can be got and set with get/set_column_attributes['categories']. + + Dummy class for window column creation, should not be explicitly used as + input data. + + Example: + >>> import numpy as np + >>> from renumics.spotlight import Dataset + >>> with Dataset("docs/example.h5", "w") as dataset: + ... dataset.append_categorical_column("my_new_cat", + ... categories=["red", "green", "blue"],) + ... dataset.append_row(my_new_cat="blue") + ... dataset.append_row(my_new_cat="green") + >>> with Dataset("docs/example.h5", "r") as dataset: + ... print(dataset["my_new_cat", 1]) + green + + Example: + >>> import numpy as np + >>> import datetime + >>> from renumics.spotlight import Dataset + >>> with Dataset("docs/example.h5", "w") as dataset: + ... dataset.append_categorical_column("my_new_cat", + ... categories=["red", "green", "blue"],) + ... current_categories = dataset.get_column_attributes("my_new_cat")["categories"] + ... dataset.set_column_attributes("my_new_cat", categories={**current_categories, + ... "black":100}) + ... dataset.append_row(my_new_cat="black") + >>> with Dataset("docs/example.h5", "r") as dataset: + ... print(dataset["my_new_cat", 0]) + black + """ + + +class Window: + + """ + A pair of two timestamps in seconds which can be later projected onto + continuous data (only :class:`Audio ` + is currently supported). + + Dummy class for window column creation + (see :func:`Dataset.append_column `), + should not be explicitly used as input data. + + To create a window column, use + :func:`Dataset.append_window_column ` + method. + + Examples: + + >>> import numpy as np + >>> from renumics.spotlight import Dataset + >>> with Dataset("docs/example.h5", "w") as dataset: + ... dataset.append_window_column("window", [[1, 2]] * 4) + ... dataset.append_row(window=(0, 1)) + ... dataset.append_row(window=np.array([-1, 0])) + >>> with Dataset("docs/example.h5", "r") as dataset: + ... print(dataset["window"]) + [[ 1. 2.] + [ 1. 2.] + [ 1. 2.] + [ 1. 2.] + [ 0. 1.] + [-1. 0.]] + + + >>> import numpy as np + >>> from renumics.spotlight import Dataset + >>> with Dataset("docs/example.h5", "w") as dataset: + ... dataset.append_int_column("start", range(5)) + ... dataset.append_float_column("end", dataset["start"] + 2) + ... print(dataset["start"]) + ... print(dataset["end"]) + [0 1 2 3 4] + [2. 3. 4. 5. 6.] + >>> with Dataset("docs/example.h5", "a") as dataset: + ... dataset.append_window_column("window", zip(dataset["start"], dataset["end"])) + >>> with Dataset("docs/example.h5", "r") as dataset: + ... print(dataset["window"]) + [[0. 2.] + [1. 3.] + [2. 4.] + [3. 5.] + [4. 6.]] + """ diff --git a/renumics/spotlight/dtypes/triangulation.py b/renumics/spotlight/dtypes/triangulation.py deleted file mode 100644 index b99c4b1a..00000000 --- a/renumics/spotlight/dtypes/triangulation.py +++ /dev/null @@ -1,162 +0,0 @@ -""" -This module provides common utilities for handling meshes. -""" - -from typing import Dict, List, Optional, Tuple, Union - -import numpy as np - - -def reindex(point_ids: np.ndarray, *elements: np.ndarray) -> Tuple[np.ndarray, ...]: - """ - Reindex elements which refer to the non-negative unique point ids so that - they refer to the indices of point ids (`np.arange(len(point_ids))`) in the - same way. - """ - if elements: - inverse_point_ids = np.full( - max((x.max(initial=0) for x in (point_ids, *elements))) + 1, - -1, - np.int64, - ) - inverse_point_ids[point_ids] = np.arange(len(point_ids)) - return tuple(inverse_point_ids[x].astype(x.dtype) for x in elements) - return () - - -def attribute_to_array(attribute: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: - """ - Encode a single attribute or a list of attribute steps as array with shape - `(num_steps, n, ...)` and squeeze all dimensions except for 0 and 1. - """ - if not isinstance(attribute, list): - attribute = [attribute] - attribute = np.asarray(attribute) - attribute = attribute.reshape( - (*attribute.shape[:2], *(x for x in attribute.shape[2:] if x != 1)) - ) - return attribute - - -def triangulate( - triangles: Optional[np.ndarray] = None, - triangle_attributes: Optional[ - Dict[str, Union[np.ndarray, List[np.ndarray]]] - ] = None, - quadrangles: Optional[np.ndarray] = None, - quadrangle_attributes: Optional[ - Dict[str, Union[np.ndarray, List[np.ndarray]]] - ] = None, -) -> Tuple[np.ndarray, Dict[str, Union[np.ndarray, List[np.ndarray]]]]: - """ - Triangulate quadrangles and respective attributes and append them to the - given triangles/attributes. - """ - - attrs = {} - if triangles is None: - trias = np.empty((0, 3), np.uint32) - if triangle_attributes: - raise ValueError( - f"`triangles` not given, but `triangle_attributes` have " - f"{len(triangle_attributes)} items." - ) - else: - trias = triangles - if triangle_attributes is not None: - for attr_name, triangle_attr in triangle_attributes.items(): - attrs[attr_name] = attribute_to_array(triangle_attr) - - if quadrangles is None: - if quadrangle_attributes is not None and len(quadrangle_attributes) != 0: - raise ValueError( - f"`quadrangles` not given, but `quadrangle_attributes` have " - f"{len(quadrangle_attributes)} items." - ) - else: - trias = np.concatenate( - (trias, quadrangles[:, [0, 1, 2]], quadrangles[:, [0, 2, 3]]) - ) - for attr_name, quadrangle_attr in (quadrangle_attributes or {}).items(): - quadrangle_attr = attribute_to_array(quadrangle_attr) - try: - attr = attrs[attr_name] - except KeyError: - attrs[attr_name] = np.concatenate( - (quadrangle_attr, quadrangle_attr), axis=1 - ) - else: - attrs[attr_name] = np.concatenate( - (attr, quadrangle_attr, quadrangle_attr), axis=1 - ) - - for attr_name, attr in attrs.items(): - if attr.shape[1] != len(trias): - raise ValueError( - f"Values of attributes should have the same length as " - f"triangles ({len(trias)}), but length {attr.shape[1]} " - f"received." - ) - return ( - trias, - { - attr_name: attr[0] if len(attr) == 1 else list(attr) - for attr_name, attr in attrs.items() - }, - ) - - -def clean( - points: np.ndarray, - triangles: np.ndarray, - point_attributes: Optional[Dict[str, np.ndarray]] = None, - triangle_attributes: Optional[Dict[str, np.ndarray]] = None, - point_displacements: Optional[List[np.ndarray]] = None, -) -> Tuple[ - np.ndarray, - np.ndarray, - Dict[str, np.ndarray], - Dict[str, np.ndarray], - List[np.ndarray], -]: - """ - Remove: - degenerated triangles and respective attributes; - invalid triangles and respective attributes; - invalid points and respective attributes; - empty attributes and point displacements. - """ - point_ids = np.arange(len(points)) - valid_triangles_mask = ( - (triangles[:, 0] != triangles[:, 1]) - & (triangles[:, 0] != triangles[:, 2]) - & (triangles[:, 1] != triangles[:, 2]) - & np.isin(triangles, point_ids).all(axis=1) - ) - triangles = triangles[valid_triangles_mask] - triangle_attributes = { - k: [x[valid_triangles_mask] for x in v] - if isinstance(v, list) - else v[valid_triangles_mask] - for k, v in (triangle_attributes or {}).items() - } - - valid_points_mask = np.isin(point_ids, triangles) - points = points[valid_points_mask] - point_attributes = { - k: [x[valid_points_mask] for x in v] - if isinstance(v, list) - else v[valid_points_mask] - for k, v in (point_attributes or {}).items() - } - point_displacements = [x[valid_points_mask] for x in point_displacements or []] - point_ids = point_ids[valid_points_mask] - triangles, *_ = reindex(point_ids, triangles) - - return ( - points, - triangles, - {k: v for k, v in point_attributes.items() if len(v) > 0}, - {k: v for k, v in triangle_attributes.items() if len(v) > 0}, - [x for x in point_displacements if len(x) > 0], - ) diff --git a/renumics/spotlight/dtypes/typing.py b/renumics/spotlight/dtypes/typing.py deleted file mode 100644 index e525502b..00000000 --- a/renumics/spotlight/dtypes/typing.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -Spotlight data types' typing. -""" - -from datetime import datetime -from typing import Any, Dict, Union, Type - -import numpy as np -from typing_extensions import TypeGuard, get_args - -from . import Audio, Category, Embedding, Image, Mesh, Sequence1D, Video, Window, DType -from .exceptions import NotADType - -ColumnType = Union[bool, int, float, str, datetime, Category, Window, np.ndarray, DType] -ScalarColumnType = Union[bool, int, float, str, datetime, Category] -FileBasedColumnType = Union[Audio, Image, Mesh, Video] -ArrayBasedColumnType = Union[Embedding, Image, Sequence1D] - -ColumnTypeMapping = Dict[str, Type[ColumnType]] - -COLUMN_TYPES_BY_NAME: Dict[str, Type[ColumnType]] = { - "bool": bool, - "int": int, - "float": float, - "str": str, - "datetime": datetime, - "Category": Category, - "Window": Window, - "array": np.ndarray, - "Image": Image, - "Audio": Audio, - "Video": Video, - "Mesh": Mesh, - "Embedding": Embedding, - "Sequence1D": Sequence1D, -} -NAME_BY_COLUMN_TYPE: Dict[Type[ColumnType], str] = { - v: k for k, v in COLUMN_TYPES_BY_NAME.items() -} - - -def get_column_type_name(column_type: Type[ColumnType]) -> str: - """ - Get name of a column type as string. - """ - try: - return NAME_BY_COLUMN_TYPE[column_type] - except KeyError as e: - raise NotADType(f"Unknown column type: {column_type}.") from e - - -def get_column_type(x: str) -> Type[ColumnType]: - """ - Get column type by its name. - """ - try: - return COLUMN_TYPES_BY_NAME[x] - except KeyError as e: - raise NotADType(f"Unknown column type: {x}.") from e - - -def is_column_type(x: Any) -> TypeGuard[Type[ColumnType]]: - """ - Check whether `x` is a Spotlight data type class. - """ - return x in COLUMN_TYPES_BY_NAME.values() - - -def is_scalar_column_type(x: Any) -> TypeGuard[Type[ScalarColumnType]]: - """ - Check whether `x` is a scalar Spotlight data type class. - """ - return x in get_args(ScalarColumnType) - - -def is_file_based_column_type(x: Any) -> TypeGuard[Type[FileBasedColumnType]]: - """ - Check whether `x` is a Spotlight column type class whose instances - can be read from/saved into a file. - """ - return x in get_args(FileBasedColumnType) - - -def is_array_based_column_type(x: Any) -> TypeGuard[Type[ArrayBasedColumnType]]: - """ - Check whether `x` is a Spotlight column type class which can be instantiated - from a single array-like argument. - """ - return x in get_args(ArrayBasedColumnType) diff --git a/renumics/spotlight/io/file.py b/renumics/spotlight/io/file.py index 022c36dd..ffeb2c73 100644 --- a/renumics/spotlight/io/file.py +++ b/renumics/spotlight/io/file.py @@ -12,7 +12,7 @@ from renumics.spotlight.requests import headers from renumics.spotlight.typing import FileType -from renumics.spotlight.dtypes import exceptions +from renumics.spotlight.media import exceptions @contextlib.contextmanager diff --git a/renumics/spotlight/io/pandas.py b/renumics/spotlight/io/pandas.py index 5f66eb20..c09f32f3 100644 --- a/renumics/spotlight/io/pandas.py +++ b/renumics/spotlight/io/pandas.py @@ -6,8 +6,7 @@ import os.path import statistics from contextlib import suppress -from datetime import datetime -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional import PIL.Image import filetype @@ -17,25 +16,15 @@ from renumics.spotlight.dtypes import ( Audio, - Category, Embedding, Image, Mesh, Sequence1D, Video, - Window, -) -from renumics.spotlight.dtypes.exceptions import NotADType, UnsupportedDType -from renumics.spotlight.dtypes.typing import ( - COLUMN_TYPES_BY_NAME, - ColumnType, - ColumnTypeMapping, - is_column_type, - is_file_based_column_type, - is_scalar_column_type, ) +from renumics.spotlight.media.exceptions import UnsupportedDType from renumics.spotlight.typing import is_iterable, is_pathtype -from renumics.spotlight.dtypes.base import DType +from renumics.spotlight import dtypes def is_empty(value: Any) -> bool: @@ -66,7 +55,7 @@ def stringify_columns(df: pd.DataFrame) -> List[str]: return [str(column_name) for column_name in df.columns] -def infer_dtype(column: pd.Series) -> Type[ColumnType]: +def infer_dtype(column: pd.Series) -> dtypes.DType: """ Get an equivalent Spotlight data type for a `pandas` column, if possible. @@ -83,20 +72,25 @@ def infer_dtype(column: pd.Series) -> Type[ColumnType]: Returns: Inferred dtype. - Reises: + Raises: ValueError: If dtype cannot be inferred automatically. """ if pd.api.types.is_bool_dtype(column) and not column.hasnans: - return bool + return dtypes.bool_dtype if pd.api.types.is_categorical_dtype(column): - return Category + return dtypes.CategoryDType( + { + category: code + for code, category in zip(column.cat.codes, column.cat.categories) + } + ) if pd.api.types.is_integer_dtype(column) and not column.hasnans: - return int + return dtypes.int_dtype if pd.api.types.is_float_dtype(column): - return float + return dtypes.float_dtype if pd.api.types.is_datetime64_any_dtype(column): - return datetime + return dtypes.datetime_dtype column = column.copy() str_mask = is_string_mask(column) @@ -104,15 +98,15 @@ def infer_dtype(column: pd.Series) -> Type[ColumnType]: column = column[~column.isna()] if len(column) == 0: - return str + return dtypes.str_dtype column_head = column.iloc[:10] head_dtypes = column_head.apply(infer_value_dtype).to_list() # type: ignore dtype_mode = statistics.mode(head_dtypes) if dtype_mode is None: - return str - if issubclass(dtype_mode, (Window, Embedding)): + return dtypes.str_dtype + if dtype_mode in [dtypes.window_dtype, dtypes.embedding_dtype]: column = column.astype(object) str_mask = is_string_mask(column) x = column[str_mask].apply(try_literal_eval) @@ -122,39 +116,31 @@ def infer_dtype(column: pd.Series) -> Type[ColumnType]: try: np.asarray(column.to_list(), dtype=float) except (TypeError, ValueError): - return Sequence1D + return dtypes.sequence_1d_dtype return dtype_mode return dtype_mode -def infer_array_dtype(value: np.ndarray) -> Type[ColumnType]: - """ - Infer dtype of a numpy array - """ - if value.ndim == 3: - if value.shape[-1] in (1, 3, 4): - return Image - elif value.ndim == 2: - if value.shape[0] == 2 or value.shape[1] == 2: - return Sequence1D - elif value.ndim == 1: - if len(value) == 2: - return Window - return Embedding - return np.ndarray - - -def infer_value_dtype(value: Any) -> Optional[Type[ColumnType]]: +def infer_value_dtype(value: Any) -> Optional[dtypes.DType]: """ Infer dtype for value """ - - if isinstance(value, DType): - return type(value) + if isinstance(value, Embedding): + return dtypes.embedding_dtype + if isinstance(value, Sequence1D): + return dtypes.sequence_1d_dtype + if isinstance(value, Image): + return dtypes.image_dtype + if isinstance(value, Audio): + return dtypes.audio_dtype + if isinstance(value, Video): + return dtypes.video_dtype + if isinstance(value, Mesh): + return dtypes.mesh_dtype if isinstance(value, PIL.Image.Image): - return Image + return dtypes.image_dtype if isinstance(value, trimesh.Trimesh): - return Mesh + return dtypes.mesh_dtype if isinstance(value, np.ndarray): return infer_array_dtype(value) @@ -168,11 +154,11 @@ def infer_value_dtype(value: Any) -> Optional[Type[ColumnType]]: if kind is not None: mime_group = kind.mime.split("/")[0] if mime_group == "image": - return Image + return dtypes.image_dtype if mime_group == "audio": - return Audio + return dtypes.audio_dtype if mime_group == "video": - return Video + return dtypes.video_dtype return None if is_iterable(value): try: @@ -184,26 +170,35 @@ def infer_value_dtype(value: Any) -> Optional[Type[ColumnType]]: return None -def infer_dtypes( - df: pd.DataFrame, dtype: Optional[ColumnTypeMapping] -) -> ColumnTypeMapping: +def infer_array_dtype(value: np.ndarray) -> dtypes.DType: + """ + Infer dtype of a numpy array + """ + if value.ndim == 3: + if value.shape[-1] in (1, 3, 4): + return dtypes.image_dtype + elif value.ndim == 2: + if value.shape[0] == 2 or value.shape[1] == 2: + return dtypes.sequence_1d_dtype + elif value.ndim == 1: + if len(value) == 2: + return dtypes.window_dtype + return dtypes.embedding_dtype + return dtypes.array_dtype + + +def infer_dtypes(df: pd.DataFrame, dtype: Optional[dtypes.DTypeMap]) -> dtypes.DTypeMap: """ Check column types from the given `dtype` and complete it with auto inferred column types for the given `pandas.DataFrame`. """ inferred_dtype = dtype or {} - for column_name, column_type in inferred_dtype.items(): - if not is_column_type(column_type): - raise NotADType( - f"Given column type {column_type} for column '{column_name}' " - f"is not a valid Spotlight column type." - ) for column_index in df: if column_index not in inferred_dtype: try: column_type = infer_dtype(df[column_index]) except UnsupportedDType: - column_type = str + column_type = dtypes.str_dtype inferred_dtype[str(column_index)] = column_type return inferred_dtype @@ -244,7 +239,7 @@ def prepare_hugging_face_dict(x: Dict) -> Any: return x["path"] -def prepare_column(column: pd.Series, dtype: Type[ColumnType]) -> pd.Series: +def prepare_column(column: pd.Series, dtype: dtypes.DType) -> pd.Series: """ Convert a `pandas` column to the desired `dtype` and prepare some values, but still as `pandas` column. @@ -261,29 +256,27 @@ def prepare_column(column: pd.Series, dtype: Type[ColumnType]) -> pd.Series: """ column = column.copy() - if dtype is Category: + if dtypes.is_category_dtype(dtype): # We only support string/`NA` categories, but `pandas` can more, so # force categories to be strings (does not affect `NA`s). return to_categorical(column, str_categories=True) - if dtype is datetime: + if dtypes.is_datetime_dtype(dtype): # `errors="coerce"` will produce `NaT`s instead of fail. return pd.to_datetime(column, errors="coerce") - if dtype is str: + if dtypes.is_str_dtype(dtype): # Allow `NA`s, convert all other elements to strings. return column.astype(str).mask(column.isna(), None) # type: ignore - if is_scalar_column_type(dtype): - # `dtype` is `bool`, `int` or `float`. - return column.astype(dtype) + if dtypes.is_bool_dtype(dtype): + return column.astype(bool) - if not is_column_type(dtype): - raise NotADType( - "`dtype` should be one of Spotlight data types (" - + ", ".join(COLUMN_TYPES_BY_NAME.keys()) - + f"), but {dtype} received." - ) + if dtypes.is_int_dtype(dtype): + return column.astype(int) + + if dtypes.is_float_dtype(dtype): + return column.astype(float) # We explicitely don't want to change the original `DataFrame`. with pd.option_context("mode.chained_assignment", None): @@ -297,7 +290,7 @@ def prepare_column(column: pd.Series, dtype: Type[ColumnType]) -> pd.Series: str_mask = is_string_mask(column) column[str_mask] = column[str_mask].apply(try_literal_eval) - if is_file_based_column_type(dtype): + if dtypes.is_file_dtype(dtype): dict_mask = column.map(type) == dict column[dict_mask] = column[dict_mask].apply(prepare_hugging_face_dict) diff --git a/renumics/spotlight/media/__init__.py b/renumics/spotlight/media/__init__.py new file mode 100644 index 00000000..8e489843 --- /dev/null +++ b/renumics/spotlight/media/__init__.py @@ -0,0 +1,22 @@ +from .base import Array1dLike, Array2dLike, ImageLike, MediaType, FileMediaType +from .embedding import Embedding +from .sequence_1d import Sequence1D +from .audio import Audio +from .image import Image +from .mesh import Mesh +from .video import Video + + +__all__ = [ + "Array1dLike", + "Array2dLike", + "ImageLike", + "MediaType", + "FileMediaType", + "Embedding", + "Sequence1D", + "Audio", + "Image", + "Mesh", + "Video", +] diff --git a/renumics/spotlight/media/audio.py b/renumics/spotlight/media/audio.py new file mode 100644 index 00000000..754e88f1 --- /dev/null +++ b/renumics/spotlight/media/audio.py @@ -0,0 +1,138 @@ +import io +from typing import Optional, Tuple, Union + +import numpy as np + +from renumics.spotlight.typing import FileType +from renumics.spotlight.media.base import Array2dLike, FileMediaType + +from . import exceptions +from ..io import audio + + +class Audio(FileMediaType): + """ + An Audio Signal that will be saved in encoded form. + + All formats and codecs supported by AV are supported for read. + + Attributes: + data: Array-like with shape `(num_samples, num_channels)` + with `num_channels` <= 5. + If `data` has a float dtype, its values should be between -1 and 1. + If `data` has an int dtype, its values should be between minimum and + maximum possible values for the particular int dtype. + If `data` has an unsigned int dtype, ist values should be between 0 + and maximum possible values for the particular unsigned int dtype. + sampling_rate: Sampling rate (samples per seconds) + + Example: + >>> import numpy as np + >>> from renumics.spotlight import Dataset, Audio + >>> samplerate = 44100 + >>> fs = 100 # 100 Hz audio signal + >>> time = np.linspace(0.0, 1.0, samplerate) + >>> amplitude = np.iinfo(np.int16).max * 0.4 + >>> data = np.array(amplitude * np.sin(2.0 * np.pi * fs * time), dtype=np.int16) + >>> audio = Audio(samplerate, np.array([data, data]).T) # int16 stereo signal + >>> float_data = 0.5 * np.cos(2.0 * np.pi * fs * time).astype(np.float32) + >>> float_audio = Audio(samplerate, float_data) # float32 mono signal + >>> with Dataset("docs/example.h5", "w") as dataset: + ... dataset.append_audio_column("audio", [audio, float_audio]) + ... dataset.append_audio_column("lossy_audio", [audio, float_audio], lossy=True) + >>> with Dataset("docs/example.h5", "r") as dataset: + ... print(dataset["audio", 0].data[100]) + ... print(f"{dataset['lossy_audio', 1].data[0, 0]:.5g}") + [12967 12967] + 0.4596 + """ + + data: np.ndarray + sampling_rate: int + + def __init__(self, sampling_rate: int, data: Array2dLike) -> None: + data_array = np.asarray(data) + is_valid_multi_channel = ( + data_array.size > 0 and data_array.ndim == 2 and data_array.shape[1] <= 5 + ) + is_valid_mono = data_array.size > 0 and data_array.ndim == 1 + if not (is_valid_multi_channel or is_valid_mono): + raise ValueError( + f"`data` argument should be a 1D array for mono data" + f" or a 2D numpy array with shape " + f"`(num_samples, num_channels)` and with num_channels <= 5, " + f"but shape {data_array.shape} received." + ) + if data_array.dtype not in [np.float32, np.int32, np.int16, np.uint8]: + raise ValueError( + f"`data` argument should be a numpy array with " + f"dtype np.float32, np.int32, np.int16 or np.uint8, " + f"but dtype {data_array.dtype.name} received." + ) + self.data = data_array + self.sampling_rate = sampling_rate + + @classmethod + def from_file(cls, filepath: FileType) -> "Audio": + """ + Read audio file from a filepath, an URL, or a file-like object. + + `pyav` is used inside, so only supported formats are allowed. + """ + try: + data, sampling_rate = audio.read_audio(filepath) + except Exception as e: + raise exceptions.InvalidFile( + f"Audio file {filepath} does not exist or could not be read." + ) from e + return cls(sampling_rate, data) + + @classmethod + def from_bytes(cls, blob: bytes) -> "Audio": + """ + Read audio from raw bytes. + + `pyav` is used inside, so only supported formats are allowed. + """ + try: + data, sampling_rate = audio.read_audio(io.BytesIO(blob)) + except Exception as e: + raise exceptions.InvalidFile( + "Audio could not be read from the given bytes." + ) from e + return cls(sampling_rate, data) + + @classmethod + def empty(cls) -> "Audio": + """ + Create a single zero-value sample stereo audio signal. + """ + return cls(1, np.zeros((1, 2), np.int16)) + + @classmethod + def decode(cls, value: Union[np.ndarray, np.void]) -> "Audio": + if isinstance(value, np.void): + buffer = io.BytesIO(value.tolist()) + data, sampling_rate = audio.read_audio(buffer) + return cls(sampling_rate, data) + raise TypeError( + f"`value` should be a `numpy.void` instance, but {type(value)} " + f"received." + ) + + def encode(self, target: Optional[str] = None) -> np.void: + format_, codec = self.get_format_codec(target) + buffer = io.BytesIO() + audio.write_audio(buffer, self.data, self.sampling_rate, format_, codec) + return np.void(buffer.getvalue()) + + @staticmethod + def get_format_codec(target: Optional[str] = None) -> Tuple[str, str]: + """ + Get an audio format and an audio codec by an `target`. + """ + format_ = "wav" if target is None else target.lstrip(".").lower() + codec = {"wav": "pcm_s16le", "ogg": "libvorbis", "mp3": "libmp3lame"}.get( + format_, format_ + ) + return format_, codec diff --git a/renumics/spotlight/dtypes/base.py b/renumics/spotlight/media/base.py similarity index 59% rename from renumics/spotlight/dtypes/base.py rename to renumics/spotlight/media/base.py index 69ac7d27..e5ba66b9 100644 --- a/renumics/spotlight/dtypes/base.py +++ b/renumics/spotlight/media/base.py @@ -2,21 +2,28 @@ Base classes for dtypes. """ from abc import ABC, abstractmethod -from typing import Optional, Union +from typing import Optional, Sequence, Union import numpy as np -from renumics.spotlight.typing import PathType +from renumics.spotlight.typing import NumberType, PathType -class DType(ABC): +Array1dLike = Union[Sequence[NumberType], np.ndarray] +Array2dLike = Union[Sequence[Sequence[NumberType]], Sequence[np.ndarray], np.ndarray] +ImageLike = Union[ + Sequence[Sequence[Union[NumberType, Sequence[NumberType]]]], np.ndarray +] + + +class MediaType(ABC): """ Base Spotlight dataset field data. """ @classmethod @abstractmethod - def decode(cls, value: Union[np.ndarray, np.void]) -> "DType": + def decode(cls, value: Union[np.ndarray, np.void]) -> "MediaType": """ Restore class from its numpy representation. """ @@ -33,14 +40,14 @@ def encode(self, target: Optional[str] = None) -> Union[np.ndarray, np.void]: raise NotImplementedError -class FileBasedDType(DType): +class FileMediaType(MediaType): """ Spotlight dataset field data which can be read from a file. """ @classmethod @abstractmethod - def from_file(cls, filepath: PathType) -> "FileBasedDType": + def from_file(cls, filepath: PathType) -> "FileMediaType": """ Read data from a file. """ diff --git a/renumics/spotlight/media/embedding.py b/renumics/spotlight/media/embedding.py new file mode 100644 index 00000000..bca2b4c2 --- /dev/null +++ b/renumics/spotlight/media/embedding.py @@ -0,0 +1,60 @@ +from typing import Optional, Type, Union + +import numpy as np + +from renumics.spotlight.media.base import Array1dLike, MediaType + + +class Embedding(MediaType): + """ + Data sample projected onto a new space. + + Attributes: + data: 1-dimensional array-like. Sample embedding. + dtype: Optional data type of embedding. If `None`, data type inferred + from data. + + Example: + >>> import numpy as np + >>> from renumics.spotlight import Dataset, Embedding + >>> value = np.array(np.random.rand(2)) + >>> embedding = Embedding(value) + >>> with Dataset("docs/example.h5", "w") as dataset: + ... dataset.append_embedding_column("embeddings", 5*[embedding]) + >>> with Dataset("docs/example.h5", "r") as dataset: + ... print(len(dataset["embeddings", 3].data)) + 2 + """ + + data: np.ndarray + + def __init__( + self, + data: Array1dLike, + dtype: Optional[Union[str, np.dtype, Type[np.number]]] = None, + ) -> None: + data_array = np.asarray(data, dtype) + if data_array.ndim != 1 or data_array.size == 0: + raise ValueError( + f"`data` argument should an array-like with shape " + f"`(num_features,)` with `num_features > 0`, but shape " + f"{data_array.shape} received." + ) + if data_array.dtype.str[1] not in "fiu": + raise ValueError( + f"`data` argument should be an array-like with integer or " + f"float dtypes, but dtype {data_array.dtype.name} received." + ) + self.data = data_array + + @classmethod + def decode(cls, value: Union[np.ndarray, np.void]) -> "Embedding": + if not isinstance(value, np.ndarray): + raise TypeError( + f"`value` argument should be a numpy array, but {type(value)} " + f"received." + ) + return cls(value) + + def encode(self, _target: Optional[str] = None) -> np.ndarray: + return self.data diff --git a/renumics/spotlight/dtypes/exceptions.py b/renumics/spotlight/media/exceptions.py similarity index 100% rename from renumics/spotlight/dtypes/exceptions.py rename to renumics/spotlight/media/exceptions.py diff --git a/renumics/spotlight/media/image.py b/renumics/spotlight/media/image.py new file mode 100644 index 00000000..cfdd35a0 --- /dev/null +++ b/renumics/spotlight/media/image.py @@ -0,0 +1,123 @@ +import io +from typing import Optional, Union + +import imageio.v3 as iio +import numpy as np +from loguru import logger + +from renumics.spotlight.typing import FileType + +from renumics.spotlight.media.base import FileMediaType, ImageLike + +from . import exceptions +from ..io import file as file_io + + +class Image(FileMediaType): + """ + An RGB(A) or grayscale image that will be saved in encoded form. + + Attributes: + data: Array-like with shape `(num_rows, num_columns)` or + `(num_rows, num_columns, num_channels)` with `num_channels` equal to + 3, or 4; with dtype "uint8". + + Example: + >>> import numpy as np + >>> from renumics.spotlight import Dataset, Image + >>> data = np.full([100,100,3], 255, dtype=np.uint8) # white uint8 image + >>> image = Image(data) + >>> float_data = np.random.uniform(0, 1, (100, 100)) # random grayscale float image + >>> float_image = Image(float_data) + >>> with Dataset("docs/example.h5", "w") as dataset: + ... dataset.append_image_column("images", [image, float_image, data, float_data]) + >>> with Dataset("docs/example.h5", "r") as dataset: + ... print(dataset["images", 0].data[50][50]) + ... print(dataset["images", 3].data.dtype) + [255 255 255] + uint8 + """ + + data: np.ndarray + + def __init__(self, data: ImageLike) -> None: + data_array = np.asarray(data) + if ( + data_array.size == 0 + or data_array.ndim != 2 + and (data_array.ndim != 3 or data_array.shape[-1] not in (1, 3, 4)) + ): + raise ValueError( + f"`data` argument should be a numpy array with shape " + f"`(num_rows, num_columns, num_channels)` or " + f"`(num_rows, num_columns)` or with `num_rows > 0`, " + f"`num_cols > 0` and `num_channels` equal to 1, 3, or 4, but " + f"shape {data_array.shape} received." + ) + if data_array.dtype.str[1] not in "fiu": + raise ValueError( + f"`data` argument should be a numpy array with integer or " + f"float dtypes, but dtype {data_array.dtype.name} received." + ) + if data_array.ndim == 3 and data_array.shape[2] == 1: + data_array = data_array.squeeze(axis=2) + if data_array.dtype.str[1] == "f": + logger.info( + 'Image data converted to "uint8" dtype by multiplication with ' + "255 and rounding." + ) + data_array = (255 * data_array).round() + self.data = data_array.astype("uint8") + + @classmethod + def from_file(cls, filepath: FileType) -> "Image": + """ + Read image from a filepath, an URL, or a file-like object. + + `imageio` is used inside, so only supported formats are allowed. + """ + with file_io.as_file(filepath) as file: + try: + image_array = iio.imread(file, index=False) # type: ignore + except Exception as e: + raise exceptions.InvalidFile( + f"Image {filepath} does not exist or could not be read." + ) from e + return cls(image_array) + + @classmethod + def from_bytes(cls, blob: bytes) -> "Image": + """ + Read image from raw bytes. + + `imageio` is used inside, so only supported formats are allowed. + """ + try: + image_array = iio.imread(blob, index=False) # type: ignore + except Exception as e: + raise exceptions.InvalidFile( + "Image could not be read from the given bytes." + ) from e + return cls(image_array) + + @classmethod + def empty(cls) -> "Image": + """ + Create a transparent 1 x 1 image. + """ + return cls(np.zeros((1, 1, 4), np.uint8)) + + @classmethod + def decode(cls, value: Union[np.ndarray, np.void]) -> "Image": + if isinstance(value, np.void): + buffer = io.BytesIO(value.tolist()) + return cls(iio.imread(buffer, extension=".png", index=False)) + raise TypeError( + f"`value` should be a `numpy.void` instance, but {type(value)} " + f"received." + ) + + def encode(self, _target: Optional[str] = None) -> np.void: + buf = io.BytesIO() + iio.imwrite(buf, self.data, extension=".png") + return np.void(buf.getvalue()) diff --git a/renumics/spotlight/media/mesh.py b/renumics/spotlight/media/mesh.py new file mode 100644 index 00000000..c5638cba --- /dev/null +++ b/renumics/spotlight/media/mesh.py @@ -0,0 +1,552 @@ +import io +import math +import os +from typing import Dict, IO, List, Optional, Tuple, Union +from urllib.parse import urlparse + +import numpy as np +import pygltflib +import requests +import trimesh +import validators +from loguru import logger + +from renumics.spotlight.requests import headers +from renumics.spotlight.typing import PathType + +from renumics.spotlight.media.base import Array2dLike, FileMediaType + +from . import exceptions +from ..io import gltf + + +class Mesh(FileMediaType): + """ + Triangular 3D mesh with optional per-point and per-triangle attributes and + optional per-point displacements over time. + + Example: + >>> import numpy as np + >>> from renumics.spotlight import Dataset, Mesh + >>> points = np.array([[0,0,0],[1,1,1],[0,1,0],[-1,0,1]]) + >>> triangles = np.array([[0,1,2],[2,3,0]]) + >>> mesh = Mesh(points, triangles) + >>> with Dataset("docs/example.h5", "w") as dataset: + ... dataset.append_mesh_column("meshes", 5*[mesh]) + >>> with Dataset("docs/example.h5", "r") as dataset: + ... print(dataset["meshes", 2].triangles) + [[0 1 2] + [2 3 0]] + """ + + _points: np.ndarray + _triangles: np.ndarray + _point_attributes: Dict[str, np.ndarray] + _point_displacements: List[np.ndarray] + + _point_indices: np.ndarray + _triangle_indices: np.ndarray + _triangle_attribute_indices: np.ndarray + + def __init__( + self, + points: Array2dLike, + triangles: Array2dLike, + point_attributes: Optional[Dict[str, np.ndarray]] = None, + triangle_attributes: Optional[Dict[str, np.ndarray]] = None, + point_displacements: Optional[Union[np.ndarray, List[np.ndarray]]] = None, + ): + self._point_attributes = {} + self._point_displacements = [] + self._set_points_triangles(points, triangles) + + if point_displacements is None: + point_displacements = [] + self.point_displacements = point_displacements # type: ignore + self.update_attributes(point_attributes, triangle_attributes) + + @property + def points(self) -> np.ndarray: + """ + :code:`np.array` with shape `(num_points, 3)`. Mesh points. + """ + return self._points + + @property + def triangles(self) -> np.ndarray: + """ + :code:`np.array` with shape `(num_triangles, 3)`. Mesh triangles stored as their + CCW nodes referring to the `points` indices. + """ + return self._triangles + + @property + def point_attributes(self) -> Dict[str, np.ndarray]: + """ + Mapping str -> :code:`np.array` with shape `(num_points, ...)`. Point-wise + attributes corresponding to `points`. All possible shapes of a single + attribute can be found in + `renumics.spotlight.mesh_proc.gltf.GLTF_SHAPES`. + """ + return self._point_attributes + + @property + def point_displacements(self) -> List[np.ndarray]: + """ + List of arrays with shape `(num_points, 3)`. Point-wise relative + displacements (offsets) over the time corresponding to `points`. + Timestep 0 is omitted since it is explicit stored as absolute values in + `points`. + """ + return self._point_displacements + + @point_displacements.setter + def point_displacements(self, value: Union[np.ndarray, List[np.ndarray]]) -> None: + array = attribute_to_array(value) + if array.size == 0: + self._point_displacements = [] + else: + array = array.astype(np.float32) + if array.shape[1] != len(self._points): + array = array[:, self._point_indices] + if array.shape[1:] != (len(self._points), 3): + raise ValueError( + f"Point displacements should have the same shape as points " + f"({self._points.shape}), but shape {array.shape[1:]} " + f"received." + ) + self._point_displacements = list(array) + + @classmethod + def from_trimesh(cls, mesh: trimesh.Trimesh) -> "Mesh": + """ + Import a `trimesh.Trimesh` mesh. + """ + return cls( + mesh.vertices, mesh.faces, mesh.vertex_attributes, mesh.face_attributes + ) + + @classmethod + def from_file(cls, filepath: PathType) -> "Mesh": + """ + Read mesh from a filepath or an URL. + + `trimesh` is used inside, so only supported formats are allowed. + """ + file: Union[str, IO] = ( + str(filepath) if isinstance(filepath, os.PathLike) else filepath + ) + extension = None + if isinstance(file, str): + if validators.url(file): + response = requests.get(file, headers=headers, timeout=30) + if not response.ok: + raise exceptions.InvalidFile(f"URL {file} does not exist.") + extension = os.path.splitext(urlparse(file).path)[1] + if extension == "": + raise exceptions.InvalidFile(f"URL {file} has no file extension.") + file = io.BytesIO(response.content) + elif not os.path.isfile(file): + raise exceptions.InvalidFile( + f"File {file} is neither an existing file nor an existing URL." + ) + try: + mesh = trimesh.load(file, file_type=extension, force="mesh") + except Exception as e: + raise exceptions.InvalidFile( + f"Mesh {filepath} does not exist or could not be read." + ) from e + return cls.from_trimesh(mesh) + + @classmethod + def empty(cls) -> "Mesh": + """ + Create an empty mesh. + """ + return cls(np.empty((0, 3)), np.empty((0, 3), np.int64)) + + @classmethod + def decode(cls, value: Union[np.ndarray, np.void]) -> "Mesh": + gltf_mesh = pygltflib.GLTF2.load_from_bytes(value.tobytes()) + gltf.check_gltf(gltf_mesh) + arrays = gltf.decode_gltf_arrays(gltf_mesh) + primitive = gltf_mesh.meshes[0].primitives[0] + points = arrays[primitive.attributes.POSITION] + triangles = arrays[primitive.indices].reshape((-1, 3)) + point_attributes = { + k[1:]: arrays[v] + for k, v in primitive.attributes.__dict__.items() + if k.startswith("_") + } + point_displacements = [ + arrays[target["POSITION"]] for target in primitive.targets + ] + return cls( + points, triangles, point_attributes, point_displacements=point_displacements + ) + + def encode(self, _target: Optional[str] = None) -> np.void: + bin_data, buffer_views, accessors = gltf.encode_gltf_array( + self._triangles.flatten(), b"", [], [], pygltflib.ELEMENT_ARRAY_BUFFER + ) + mesh_primitive_attributes_kwargs = {"POSITION": 1} + bin_data, buffer_views, accessors = gltf.encode_gltf_array( + self._points, bin_data, buffer_views, accessors + ) + for attr_name, point_attr in self._point_attributes.items(): + mesh_primitive_attributes_kwargs["_" + attr_name] = len(buffer_views) + bin_data, buffer_views, accessors = gltf.encode_gltf_array( + point_attr, bin_data, buffer_views, accessors + ) + morph_targets = [] + for point_displacement in self._point_displacements: + morph_targets.append(pygltflib.Attributes(POSITION=len(buffer_views))) + bin_data, buffer_views, accessors = gltf.encode_gltf_array( + point_displacement, bin_data, buffer_views, accessors + ) + gltf_mesh = pygltflib.GLTF2( + asset=pygltflib.Asset(), + scene=0, + scenes=[pygltflib.Scene(nodes=[0])], + nodes=[pygltflib.Node(mesh=0)], + meshes=[ + pygltflib.Mesh( + primitives=[ + pygltflib.Primitive( + attributes=pygltflib.Attributes( + **mesh_primitive_attributes_kwargs + ), + indices=0, + mode=pygltflib.TRIANGLES, + targets=morph_targets, + ) + ], + ) + ], + accessors=accessors, + bufferViews=buffer_views, + buffers=[pygltflib.Buffer(byteLength=len(bin_data))], + ) + gltf_mesh.set_binary_blob(bin_data) + return np.void(b"".join(gltf_mesh.save_to_bytes())) + + def update_attributes( + self, + point_attributes: Optional[Dict[str, np.ndarray]] = None, + triangle_attributes: Optional[Dict[str, np.ndarray]] = None, + ) -> None: + """ + Update point and/or triangle attributes dict-like. + """ + if point_attributes: + point_attributes = self._sanitize_point_attributes(point_attributes) + self._point_attributes.update(point_attributes) + if triangle_attributes: + triangle_attributes = self._sanitize_triangle_attributes( + triangle_attributes + ) + logger.info("Triangle attributes will be converted to point attributes.") + self._point_attributes.update( + self._triangle_attributes_to_point_attributes(triangle_attributes) + ) + + def interpolate_point_displacements(self, num_timesteps: int) -> None: + """subsample time dependent attributes with new time step count""" + if num_timesteps < 1: + raise ValueError( + f"`num_timesteps` argument should be non-negative, but " + f"{num_timesteps} received." + ) + current_num_timesteps = len(self._point_displacements) + if current_num_timesteps == 0: + logger.info("No displacements found, so cannot interpolate.") + return + if current_num_timesteps == num_timesteps: + return + + def _interpolated_list_access( + arrays: List[np.ndarray], index_float: float + ) -> np.ndarray: + """access a list equally sized numpy arrays with interpolation between two neighbors""" + array_left = arrays[math.floor(index_float)] + array_right = arrays[math.ceil(index_float)] + weight_right = index_float - math.floor(index_float) + return (array_left * (1 - weight_right)) + (array_right * weight_right) + + # simplification assumption : timesteps are equally sized + timesteps = np.linspace(0, current_num_timesteps, num_timesteps + 1)[1:] + + # add implicit 0 displacement for t=0 + displacements = [ + np.zeros_like(self._point_displacements[0]) + ] + self._point_displacements + self._point_displacements = [ + _interpolated_list_access(displacements, t) for t in timesteps + ] + + def _set_points_triangles( + self, points: Array2dLike, triangles: Array2dLike + ) -> None: + # Check points. + points_array = np.asarray(points, np.float32) + if points_array.ndim != 2 or points_array.shape[1] != 3: + raise ValueError( + f"`points` argument should be a numpy array with shape " + f"`(num_points, 3)`, but shape {points_array.shape} received." + ) + # Check triangles. + triangles_array = np.asarray(triangles, np.uint32) + if triangles_array.ndim != 2 or triangles_array.shape[1] != 3: + raise ValueError( + f"`triangles` argument should be a numpy array with shape " + f"`(num_triangles, 3)`, but shape {triangles_array.shape} received." + ) + # Subsample only valid points and triangles. + point_ids = np.arange(len(points_array)) + valid_triangles_mask = ( + (triangles_array[:, 0] != triangles_array[:, 1]) + & (triangles_array[:, 0] != triangles_array[:, 2]) + & (triangles_array[:, 1] != triangles_array[:, 2]) + & np.isin(triangles_array, point_ids).all(axis=1) + ) + self._triangle_indices = np.nonzero(valid_triangles_mask)[0] + self._triangles = triangles_array[self._triangle_indices] + valid_points_mask = np.isin(point_ids, self._triangles) + self._point_indices = np.nonzero(valid_points_mask)[0] + self._points = points_array[self._point_indices] + # Reindex triangles since there can be fewer points than before. + point_ids = point_ids[self._point_indices] + self._triangles, *_ = reindex(point_ids, self._triangles) + # Set indices for conversion of triangle attributes to point attributes. + self._triangle_attribute_indices = np.full((len(self._points)), 0, np.uint32) + triangle_indices = np.arange(len(self._triangles), dtype=np.uint32) + self._triangle_attribute_indices[self._triangles[:, 0]] = triangle_indices + self._triangle_attribute_indices[self._triangles[:, 1]] = triangle_indices + self._triangle_attribute_indices[self._triangles[:, 2]] = triangle_indices + + def _triangle_attributes_to_point_attributes( + self, triangle_attributes: Dict[str, np.ndarray] + ) -> Dict[str, np.ndarray]: + return { + f"element_{attr_name}": triangle_attr[self._triangle_attribute_indices] + for attr_name, triangle_attr in triangle_attributes.items() + } + + def _sanitize_point_attributes( + self, point_attributes: Dict[str, np.ndarray] + ) -> Dict[str, np.ndarray]: + if not isinstance(point_attributes, dict): + raise TypeError( + f"`point_attributes` argument should be a dict, but " + f"{type(point_attributes)} received." + ) + valid_point_attributes = {} + for attr_name, point_attr in point_attributes.items(): + point_attr = np.asarray(point_attr) + if len(point_attr) != len(self._points): + point_attr = point_attr[self._point_indices] + if point_attr.dtype.str[1] not in "fiu": + raise ValueError( + f"Point attributes should have one of integer or float " + f'dtypes, but attribute "{attr_name}" of dtype ' + f"{point_attr.dtype.name} received." + ) + point_attr = point_attr.squeeze() + if point_attr.shape[1:] not in gltf.GLTF_SHAPES_LOOKUP.keys(): + logger.warning( + f"Element shape {point_attr.shape[1:]} of the point " + f'attribute "{attr_name}" not supported, attribute will be ' + f"removed." + ) + continue + valid_point_attributes[attr_name] = point_attr.astype( + gltf.GLTF_DTYPES_CONVERSION[point_attr.dtype.str[1:]] + ) + return valid_point_attributes + + def _sanitize_triangle_attributes( + self, triangle_attributes: Dict[str, np.ndarray] + ) -> Dict[str, np.ndarray]: + if not isinstance(triangle_attributes, dict): + raise TypeError( + f"`triangle_attributes` argument should be a dict, but " + f"{type(triangle_attributes)} received." + ) + valid_triangle_attributes = {} + for attr_name, triangle_attr in triangle_attributes.items(): + triangle_attr = np.asarray(triangle_attr) + if len(triangle_attr) != len(self._triangles): + triangle_attr = triangle_attr[self._triangle_indices] + if triangle_attr.dtype.str[1] not in "fiu": + raise ValueError( + f"Triangle attributes should have one of integer or float " + f'dtypes, but attribute "{attr_name}" of dtype ' + f"{triangle_attr.dtype.name} received." + ) + triangle_attr = triangle_attr.squeeze() + if triangle_attr.shape[1:] not in gltf.GLTF_SHAPES_LOOKUP.keys(): + logger.warning( + f"Element shape {triangle_attr.shape[1:]} of the triangle " + f'attribute "{attr_name}" not supported, attribute will be ' + f"removed." + ) + continue + valid_triangle_attributes[attr_name] = triangle_attr.astype( + gltf.GLTF_DTYPES_CONVERSION[triangle_attr.dtype.str[1:]] + ) + return valid_triangle_attributes + + +def reindex(point_ids: np.ndarray, *elements: np.ndarray) -> Tuple[np.ndarray, ...]: + """ + Reindex elements which refer to the non-negative unique point ids so that + they refer to the indices of point ids (`np.arange(len(point_ids))`) in the + same way. + """ + if elements: + inverse_point_ids = np.full( + max((x.max(initial=0) for x in (point_ids, *elements))) + 1, + -1, + np.int64, + ) + inverse_point_ids[point_ids] = np.arange(len(point_ids)) + return tuple(inverse_point_ids[x].astype(x.dtype) for x in elements) + return () + + +def attribute_to_array(attribute: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: + """ + Encode a single attribute or a list of attribute steps as array with shape + `(num_steps, n, ...)` and squeeze all dimensions except for 0 and 1. + """ + if not isinstance(attribute, list): + attribute = [attribute] + attribute = np.asarray(attribute) + attribute = attribute.reshape( + (*attribute.shape[:2], *(x for x in attribute.shape[2:] if x != 1)) + ) + return attribute + + +def triangulate( + triangles: Optional[np.ndarray] = None, + triangle_attributes: Optional[ + Dict[str, Union[np.ndarray, List[np.ndarray]]] + ] = None, + quadrangles: Optional[np.ndarray] = None, + quadrangle_attributes: Optional[ + Dict[str, Union[np.ndarray, List[np.ndarray]]] + ] = None, +) -> Tuple[np.ndarray, Dict[str, Union[np.ndarray, List[np.ndarray]]]]: + """ + Triangulate quadrangles and respective attributes and append them to the + given triangles/attributes. + """ + + attrs = {} + if triangles is None: + trias = np.empty((0, 3), np.uint32) + if triangle_attributes: + raise ValueError( + f"`triangles` not given, but `triangle_attributes` have " + f"{len(triangle_attributes)} items." + ) + else: + trias = triangles + if triangle_attributes is not None: + for attr_name, triangle_attr in triangle_attributes.items(): + attrs[attr_name] = attribute_to_array(triangle_attr) + + if quadrangles is None: + if quadrangle_attributes is not None and len(quadrangle_attributes) != 0: + raise ValueError( + f"`quadrangles` not given, but `quadrangle_attributes` have " + f"{len(quadrangle_attributes)} items." + ) + else: + trias = np.concatenate( + (trias, quadrangles[:, [0, 1, 2]], quadrangles[:, [0, 2, 3]]) + ) + for attr_name, quadrangle_attr in (quadrangle_attributes or {}).items(): + quadrangle_attr = attribute_to_array(quadrangle_attr) + try: + attr = attrs[attr_name] + except KeyError: + attrs[attr_name] = np.concatenate( + (quadrangle_attr, quadrangle_attr), axis=1 + ) + else: + attrs[attr_name] = np.concatenate( + (attr, quadrangle_attr, quadrangle_attr), axis=1 + ) + + for attr_name, attr in attrs.items(): + if attr.shape[1] != len(trias): + raise ValueError( + f"Values of attributes should have the same length as " + f"triangles ({len(trias)}), but length {attr.shape[1]} " + f"received." + ) + return ( + trias, + { + attr_name: attr[0] if len(attr) == 1 else list(attr) + for attr_name, attr in attrs.items() + }, + ) + + +def clean( + points: np.ndarray, + triangles: np.ndarray, + point_attributes: Optional[Dict[str, np.ndarray]] = None, + triangle_attributes: Optional[Dict[str, np.ndarray]] = None, + point_displacements: Optional[List[np.ndarray]] = None, +) -> Tuple[ + np.ndarray, + np.ndarray, + Dict[str, np.ndarray], + Dict[str, np.ndarray], + List[np.ndarray], +]: + """ + Remove: + degenerated triangles and respective attributes; + invalid triangles and respective attributes; + invalid points and respective attributes; + empty attributes and point displacements. + """ + point_ids = np.arange(len(points)) + valid_triangles_mask = ( + (triangles[:, 0] != triangles[:, 1]) + & (triangles[:, 0] != triangles[:, 2]) + & (triangles[:, 1] != triangles[:, 2]) + & np.isin(triangles, point_ids).all(axis=1) + ) + triangles = triangles[valid_triangles_mask] + triangle_attributes = { + k: [x[valid_triangles_mask] for x in v] + if isinstance(v, list) + else v[valid_triangles_mask] + for k, v in (triangle_attributes or {}).items() + } + + valid_points_mask = np.isin(point_ids, triangles) + points = points[valid_points_mask] + point_attributes = { + k: [x[valid_points_mask] for x in v] + if isinstance(v, list) + else v[valid_points_mask] + for k, v in (point_attributes or {}).items() + } + point_displacements = [x[valid_points_mask] for x in point_displacements or []] + point_ids = point_ids[valid_points_mask] + triangles, *_ = reindex(point_ids, triangles) + + return ( + points, + triangles, + {k: v for k, v in point_attributes.items() if len(v) > 0}, + {k: v for k, v in triangle_attributes.items() if len(v) > 0}, + [x for x in point_displacements if len(x) > 0], + ) diff --git a/renumics/spotlight/media/sequence_1d.py b/renumics/spotlight/media/sequence_1d.py new file mode 100644 index 00000000..70c49e2d --- /dev/null +++ b/renumics/spotlight/media/sequence_1d.py @@ -0,0 +1,125 @@ +from typing import Optional, Type, Union + +import numpy as np + + +from renumics.spotlight.media.base import Array1dLike, MediaType + + +class Sequence1D(MediaType): + """ + One-dimensional ndarray with optional index values. + + Attributes: + index: 1-dimensional array-like of length `num_steps`. Index values (x-axis). + value: 1-dimensional array-like of length `num_steps`. Respective values (y-axis). + dtype: Optional data type of sequence. If `None`, data type inferred + from data. + + Example: + >>> import numpy as np + >>> from renumics.spotlight import Dataset, Sequence1D + >>> index = np.arange(100) + >>> value = np.array(np.random.rand(100)) + >>> sequence = Sequence1D(index, value) + >>> with Dataset("docs/example.h5", "w") as dataset: + ... dataset.append_sequence_1d_column("sequences", 5*[sequence]) + >>> with Dataset("docs/example.h5", "r") as dataset: + ... print(len(dataset["sequences", 2].value)) + 100 + """ + + index: np.ndarray + value: np.ndarray + + def __init__( + self, + index: Optional[Array1dLike], + value: Optional[Array1dLike] = None, + dtype: Optional[Union[str, np.dtype, Type[np.number]]] = None, + ) -> None: + if value is None: + if index is None: + raise ValueError( + "At least one of arguments `index` or `value` should be " + "set, but both `None` values received." + ) + value = index + index = None + + value_array = np.asarray(value, dtype) + if value_array.dtype.str[1] not in "fiu": + raise ValueError( + f"Input values should be array-likes with integer or float " + f"dtype, but dtype {value_array.dtype.name} received." + ) + if index is None: + if value_array.ndim == 2: + if value_array.shape[0] == 2: + self.index = value_array[0] + self.value = value_array[1] + elif value_array.shape[1] == 2: + self.index = value_array[:, 0] + self.value = value_array[:, 1] + else: + raise ValueError( + f"A single 2-dimensional input value should have one " + f"dimension of length 2, but shape {value_array.shape} received." + ) + elif value_array.ndim == 1: + self.value = value_array + if dtype is None: + dtype = self.value.dtype + self.index = np.arange(len(self.value), dtype=dtype) + else: + raise ValueError( + f"A single input value should be 1- or 2-dimensional, but " + f"shape {value_array.shape} received." + ) + else: + if value_array.ndim != 1: + raise ValueError( + f"Value should be 1-dimensional, but shape {value_array.shape} received." + ) + index_array = np.asarray(index, dtype) + if index_array.ndim != 1: + raise ValueError( + f"INdex should be 1-dimensional array-like, but shape " + f"{index_array.shape} received." + ) + if index_array.dtype.str[1] not in "fiu": + raise ValueError( + f"Index should be array-like with integer or float " + f"dtype, but dtype {index_array.dtype.name} received." + ) + self.value = value_array + self.index = index_array + if len(self.value) != len(self.index): + raise ValueError( + f"Lengths of `index` and `value` should match, but lengths " + f"{len(self.index)} and {len(self.value)} received." + ) + + @classmethod + def decode(cls, value: Union[np.ndarray, np.void]) -> "Sequence1D": + if not isinstance(value, np.ndarray): + raise TypeError( + f"`value` argument should be a numpy array, but {type(value)} " + f"received." + ) + if value.ndim != 2 or value.shape[1] != 2: + raise ValueError( + f"`value` argument should be a numpy array with shape " + f"`(num_steps, 2)`, but shape {value.shape} received." + ) + return cls(value[:, 0], value[:, 1]) + + def encode(self, _target: Optional[str] = None) -> np.ndarray: + return np.stack((self.index, self.value), axis=1) + + @classmethod + def empty(cls) -> "Sequence1D": + """ + Create an empty sequence. + """ + return cls(np.empty(0), np.empty(0)) diff --git a/renumics/spotlight/media/video.py b/renumics/spotlight/media/video.py new file mode 100644 index 00000000..dc1b2b6d --- /dev/null +++ b/renumics/spotlight/media/video.py @@ -0,0 +1,84 @@ +import os +from typing import Optional, Union + +import numpy as np +import requests +import validators + +from renumics.spotlight.requests import headers +from renumics.spotlight.typing import PathType + +from renumics.spotlight.media.base import FileMediaType + +from ..media import exceptions + + +class Video(FileMediaType): + """ + A video object. No encoding or decoding is currently performed on the python + side, so all formats will be saved into dataset without compatibility check, + but only the formats supported by your browser (apparently .mp4, .ogg, + .webm, .mov etc.) can be played in Spotlight. + """ + + data: bytes + + def __init__(self, data: bytes) -> None: + if not isinstance(data, bytes): + raise TypeError( + f"`data` argument should be video bytes, but type {type(data)} " + f"received." + ) + self.data = data + + @classmethod + def from_file(cls, filepath: PathType) -> "Video": + """ + Read video from a filepath or an URL. + """ + prepared_file = str(filepath) if isinstance(filepath, os.PathLike) else filepath + if not isinstance(prepared_file, str): + raise TypeError( + "`filepath` should be a string or an `os.PathLike` instance, " + f"but value {prepared_file} or type {type(prepared_file)} " + f"received." + ) + if validators.url(prepared_file): + response = requests.get( + prepared_file, headers=headers, stream=True, timeout=10 + ) + if not response.ok: + raise exceptions.InvalidFile(f"URL {prepared_file} does not exist.") + return cls(response.raw.data) + if os.path.isfile(prepared_file): + with open(filepath, "rb") as f: + return cls(f.read()) + raise exceptions.InvalidFile( + f"File {prepared_file} is neither an existing file nor an existing URL." + ) + + @classmethod + def from_bytes(cls, blob: bytes) -> "Video": + """ + Read video from raw bytes. + """ + return cls(blob) + + @classmethod + def empty(cls) -> "Video": + """ + Create an empty video instance. + """ + return cls(b"\x00") + + @classmethod + def decode(cls, value: Union[np.ndarray, np.void]) -> "Video": + if isinstance(value, np.void): + return cls(value.tolist()) + raise TypeError( + f"`value` should be a `numpy.void` instance, but {type(value)} " + f"received." + ) + + def encode(self, _target: Optional[str] = None) -> np.void: + return np.void(self.data) diff --git a/renumics/spotlight/viewer.py b/renumics/spotlight/viewer.py index da84ea1b..97afec4e 100644 --- a/renumics/spotlight/viewer.py +++ b/renumics/spotlight/viewer.py @@ -50,7 +50,7 @@ from pathlib import Path import time -from typing import Collection, List, Union, Optional +from typing import Any, Collection, Dict, List, Union, Optional import pandas as pd from typing_extensions import Literal @@ -59,7 +59,6 @@ import __main__ from renumics.spotlight.settings import settings -from renumics.spotlight.dtypes.typing import ColumnTypeMapping from renumics.spotlight.layout import _LayoutLike, parse from renumics.spotlight.typing import PathType, is_pathtype from renumics.spotlight.webbrowser import launch_browser_in_thread @@ -67,6 +66,8 @@ from renumics.spotlight.analysis.typing import DataIssue from renumics.spotlight.app_config import AppConfig +from renumics.spotlight.dtypes import create_dtype + class ViewerNotFoundError(Exception): """ @@ -106,7 +107,7 @@ def show( no_browser: bool = False, allow_filebrowsing: Union[bool, Literal["auto"]] = "auto", wait: Union[bool, Literal["auto", "forever"]] = "auto", - dtype: Optional[ColumnTypeMapping] = None, + dtype: Optional[Dict[str, Any]] = None, analyze: Optional[Union[bool, List[str]]] = None, issues: Optional[Collection[DataIssue]] = None, ) -> None: @@ -153,10 +154,15 @@ def show( layout = layout or settings.layout parsed_layout = parse(layout) if layout else None + converted_dtypes = ( + {column_name: create_dtype(d) for column_name, d in dtype.items()} + if dtype + else None + ) config = AppConfig( dataset=dataset, - dtypes=dtype, + dtypes=converted_dtypes, project_root=project_root, analyze=analyze, custom_issues=list(issues) if issues else None, @@ -320,7 +326,7 @@ def show( no_browser: bool = False, allow_filebrowsing: Union[bool, Literal["auto"]] = "auto", wait: Union[bool, Literal["auto", "forever"]] = "auto", - dtype: Optional[ColumnTypeMapping] = None, + dtype: Optional[Dict[str, Any]] = None, analyze: Optional[Union[bool, List[str]]] = None, issues: Optional[Collection[DataIssue]] = None, ) -> Viewer: diff --git a/renumics/spotlight_plugins/core/api/table.py b/renumics/spotlight_plugins/core/api/table.py index a1efcf7c..e6112da2 100644 --- a/renumics/spotlight_plugins/core/api/table.py +++ b/renumics/spotlight_plugins/core/api/table.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional -import numpy as np from fastapi import APIRouter, Request from fastapi.responses import ORJSONResponse, Response from pydantic import BaseModel @@ -12,24 +11,10 @@ from renumics.spotlight.backend.exceptions import FilebrowsingNotAllowed, InvalidPath from renumics.spotlight.app import SpotlightApp from renumics.spotlight.app_config import AppConfig -from renumics.spotlight.dtypes.typing import get_column_type_name from renumics.spotlight.io.path import is_path_relative_to from renumics.spotlight.reporting import emit_timed_event -from renumics.spotlight.dtypes import ( - Audio, - Category, - Embedding, - Image, - Mesh, - Sequence1D, - Video, -) - - -# for now specify all lazy dtypes right here -# we should probably move closer to the actual dtype definition for easier extensibility -LAZY_DTYPES = [Embedding, Mesh, Image, Video, Sequence1D, np.ndarray, Audio, str] +from renumics.spotlight.dtypes import is_category_dtype class Column(BaseModel): @@ -90,17 +75,13 @@ def get_table(request: Request) -> ORJSONResponse: dtype = data_store.dtypes[column_name] values = data_store.get_converted_values(column_name, simple=True) meta = data_store.get_column_metadata(column_name) - if dtype == Category: - categories = data_store._data_source.get_column_categories(column_name) - else: - categories = {} column = Column( name=column_name, values=values, editable=meta.editable, optional=meta.nullable, - role=get_column_type_name(dtype), - categories=categories, + role=dtype.name, + categories=dtype.categories if is_category_dtype(dtype) else None, description=meta.description, tags=meta.tags, ) diff --git a/renumics/spotlight_plugins/core/hdf5_data_source.py b/renumics/spotlight_plugins/core/hdf5_data_source.py index b002280a..69de9bd3 100644 --- a/renumics/spotlight_plugins/core/hdf5_data_source.py +++ b/renumics/spotlight_plugins/core/hdf5_data_source.py @@ -1,37 +1,23 @@ """ access h5 table data """ -from functools import lru_cache from hashlib import sha1 from pathlib import Path -from typing import Dict, List, Union, cast +from typing import List, Union, cast import h5py import numpy as np -from renumics.spotlight.dtypes import Embedding -from renumics.spotlight.dtypes.typing import ( - ColumnTypeMapping, -) from renumics.spotlight.typing import IndexType -from renumics.spotlight.dataset import ( - Dataset, - INTERNAL_COLUMN_NAMES, -) +from renumics.spotlight.dataset import Dataset, INTERNAL_COLUMN_NAMES from renumics.spotlight.data_source import DataSource, datasource from renumics.spotlight.backend.exceptions import ( NoTableFileFound, CouldNotOpenTableFile, ) - - -from renumics.spotlight.dtypes.conversion import ( - NormalizedValue, - convert_to_dtype, -) - from renumics.spotlight.data_source.data_source import ColumnMetadata +from renumics.spotlight.dtypes import DTypeMap, create_dtype, is_embedding_dtype class H5Dataset(Dataset): @@ -45,26 +31,6 @@ def get_generation_id(self) -> int: """ return int(self._h5_file.attrs.get("spotlight_generation_id", 0)) - def read_value(self, column_name: str, index: IndexType) -> NormalizedValue: - """ - Get a dataset value as it is stored in the H5 dataset, resolve references. - """ - - self._assert_column_exists(column_name, internal=True) - self._assert_index_exists(index) - column = cast(h5py.Dataset, self._h5_file[column_name]) - value = column[index] - if isinstance(value, bytes): - value = value.decode("utf-8") - if self._is_ref_column(column): - if value: - value = self._resolve_ref(value, column_name)[()] - return value.tolist() if isinstance(value, np.void) else value - return None - if self._get_column_type(column.attrs) is Embedding and len(value) == 0: - return None - return value - def read_column( self, column_name: str, @@ -91,7 +57,7 @@ def read_column( for value in self._resolve_refs(raw_values, column_name) ] return normalized_values - if self._get_column_type(column.attrs) is Embedding: + if is_embedding_dtype(self._get_dtype(column)): normalized_values = np.empty(len(raw_values), dtype=object) normalized_values[:] = [None if len(x) == 0 else x for x in raw_values] return normalized_values @@ -116,7 +82,7 @@ def duplicate_row(self, from_index: IndexType, to_index: IndexType) -> None: if to_index != length: # Shift all values after the insertion position by one. raw_values = column[int(to_index) : -1] - if self._get_column_type(column) is Embedding: + if is_embedding_dtype(self._get_dtype(column)): raw_values = list(raw_values) column[int(to_index) + 1 :] = raw_values column[int(to_index)] = column[from_index] @@ -169,9 +135,9 @@ def column_names(self) -> List[str]: def __len__(self) -> int: return len(self._table) - def guess_dtypes(self) -> ColumnTypeMapping: + def guess_dtypes(self) -> DTypeMap: return { - column_name: self._table.get_column_type(column_name) + column_name: create_dtype(self._table.get_dtype(column_name)) for column_name in self.column_names } @@ -199,19 +165,3 @@ def get_column_values( indices: Union[List[int], np.ndarray, slice] = slice(None), ) -> np.ndarray: return self._table.read_column(column_name, indices=indices) - - @lru_cache(maxsize=128) - def get_column_categories(self, column_name: str) -> Dict[str, int]: - attrs = self._table.get_column_attributes(column_name) - try: - return cast(Dict[str, int], attrs["categories"]) - except KeyError: - normalized_values = cast( - List[str], - [ - convert_to_dtype(value, str, simple=True) - for value in self._table.read_column(column_name) - ], - ) - category_names = sorted(set(normalized_values)) - return {category_name: i for i, category_name in enumerate(category_names)} diff --git a/renumics/spotlight_plugins/core/pandas_data_source.py b/renumics/spotlight_plugins/core/pandas_data_source.py index 3c31ab61..c4c8a3b9 100644 --- a/renumics/spotlight_plugins/core/pandas_data_source.py +++ b/renumics/spotlight_plugins/core/pandas_data_source.py @@ -2,21 +2,16 @@ access pandas DataFrame table data """ from pathlib import Path -from typing import Any, Dict, List, Union, cast -from functools import lru_cache +from typing import Any, List, Union, cast import numpy as np import pandas as pd import datasets -from renumics.spotlight.dtypes.typing import ( - ColumnTypeMapping, -) from renumics.spotlight.io.pandas import ( infer_dtype, prepare_hugging_face_dict, stringify_columns, - to_categorical, try_literal_eval, ) from renumics.spotlight.data_source import ( @@ -27,6 +22,7 @@ from renumics.spotlight.backend.exceptions import DatasetColumnsNotUnique from renumics.spotlight.dataset.exceptions import ColumnNotExistsError from renumics.spotlight.data_source.exceptions import InvalidDataSource +from renumics.spotlight.dtypes import DTypeMap @datasource(pd.DataFrame) @@ -121,12 +117,11 @@ def df(self) -> pd.DataFrame: def __len__(self) -> int: return len(self._df) - def guess_dtypes(self) -> ColumnTypeMapping: - dtype_map = { + def guess_dtypes(self) -> DTypeMap: + return { str(column_name): infer_dtype(self.df[column_name]) for column_name in self.df } - return dtype_map def get_generation_id(self) -> int: return self._generation_id @@ -202,18 +197,3 @@ def _parse_column_index(self, column_name: str) -> Any: f"Column '{column_name}' doesn't exist in the dataset." ) from e return self._df.columns[index] - - @lru_cache(maxsize=128) - def get_column_categories(self, column_name: str) -> Dict[str, int]: - """ - Get categories of a categorical column. - - If `as_string` is True, convert the categories to their string - representation. - - At the moment, there is no way to add a new category in Spotlight, so we - rely on the previously cached ones. - """ - column_index = self._parse_column_index(column_name) - column = to_categorical(self._df[column_index], str_categories=True) - return {category: i for i, category in enumerate(column.cat.categories)} diff --git a/tests/integration/backend/test_cache.py b/tests/integration/backend/test_cache.py index 8c15d9dd..d6f99d3b 100644 --- a/tests/integration/backend/test_cache.py +++ b/tests/integration/backend/test_cache.py @@ -14,6 +14,7 @@ def test_external_data_cache(non_existing_image_df_viewer: spotlight.Viewer) -> """ Test loading non-existing external data, cache it and clear cache. """ + assert non_existing_image_df_viewer.df is not None image_path = non_existing_image_df_viewer.df["image"][0] app_url = str(non_existing_image_df_viewer) diff --git a/tests/integration/backend/test_reporting.py b/tests/integration/backend/test_reporting.py index 857e6b0b..0929204c 100644 --- a/tests/integration/backend/test_reporting.py +++ b/tests/integration/backend/test_reporting.py @@ -1,11 +1,13 @@ """ test reporting module """ +import pytest + from renumics.spotlight.reporting import skip_analytics from renumics.spotlight.settings import settings -def test_opt_out(monkeypatch) -> None: +def test_opt_out(monkeypatch: pytest.MonkeyPatch) -> None: """test opt_out is true""" monkeypatch.delenv("CI", raising=False) @@ -13,7 +15,7 @@ def test_opt_out(monkeypatch) -> None: assert skip_analytics() is True -def test_opt_in(monkeypatch): +def test_opt_in(monkeypatch: pytest.MonkeyPatch) -> None: """ test opt_in is true opt_out also as opt_in is False by default @@ -26,7 +28,7 @@ def test_opt_in(monkeypatch): assert skip_analytics() is False -def test_opt_in_and_opt_out(monkeypatch): +def test_opt_in_and_opt_out(monkeypatch: pytest.MonkeyPatch) -> None: """if opt_out is true and opt_in is false skip analytics""" @@ -36,7 +38,7 @@ def test_opt_in_and_opt_out(monkeypatch): assert skip_analytics() is True -def test_opt_in_and_ci(monkeypatch): +def test_opt_in_and_ci(monkeypatch: pytest.MonkeyPatch) -> None: """when ci is true always skip analytics""" settings.opt_out = False diff --git a/tests/integration/dataset/conftest.py b/tests/integration/dataset/conftest.py index 1351a9f2..8ddae68d 100644 --- a/tests/integration/dataset/conftest.py +++ b/tests/integration/dataset/conftest.py @@ -4,78 +4,27 @@ import os.path import tempfile -from dataclasses import dataclass, field from datetime import datetime -from typing import List, Optional, Sequence, Type, Union, Dict +from typing import Iterator, List import numpy as np import pytest from _pytest.fixtures import SubRequest -from renumics.spotlight import ( - Embedding, - Mesh, - Sequence1D, - Image, - Audio, - Category, - Video, - Dataset, - Window, +from renumics.spotlight import Embedding, Mesh, Sequence1D, Image, Audio, Dataset + +from .data import ( + ColumnData, + categorical_data, + array_data, + window_data, + embedding_data, + sequence_1d_data, + audio_data, + image_data, + mesh_data, + video_data, ) -from renumics.spotlight.dataset.typing import ColumnInputType -from renumics.spotlight.dtypes.typing import ColumnType - - -@dataclass -class ColumnData: - """ - Data for a dataset column. - """ - - name: str - column_type: Type[ColumnInputType] - values: Union[Sequence[ColumnInputType], np.ndarray] - optional: bool = False - default: ColumnInputType = None - description: Optional[str] = None - attrs: Dict = field(default_factory=dict) - - -def get_append_column_fn_name(column_type: Type[ColumnType]) -> str: - """ - Get name of the `append_column` dataset method for the given column type. - """ - - if column_type is bool: - return "append_bool_column" - if column_type is int: - return "append_int_column" - if column_type is float: - return "append_float_column" - if column_type is str: - return "append_string_column" - if column_type is datetime: - return "append_datetime_column" - if column_type is np.ndarray: - return "append_array_column" - if column_type is Embedding: - return "append_embedding_column" - if column_type is Image: - return "append_image_column" - if column_type is Sequence1D: - return "append_sequence_1d_column" - if column_type is Mesh: - return "append_mesh_column" - if column_type is Audio: - return "append_audio_column" - if column_type is Category: - return "append_categorical_column" - if column_type is Video: - return "append_video_column" - if column_type is Window: - return "append_window_column" - raise TypeError @pytest.fixture @@ -84,44 +33,44 @@ def optional_data() -> List[ColumnData]: Get a list of optional column data. """ return [ - ColumnData("bool", bool, [], default=True), - ColumnData("bool1", bool, [], default=False), - ColumnData("bool2", bool, [], default=np.bool_(True)), - ColumnData("int", int, [], default=-5), - ColumnData("int1", int, [], default=5), - ColumnData("int2", int, [], default=np.int16(1000)), - ColumnData("float", float, [], optional=True), - ColumnData("float1", float, [], default=5.0), - ColumnData("float2", float, [], default=np.float16(1000.0)), - ColumnData("string", str, [], optional=True), - ColumnData("string1", str, [], default="a"), - ColumnData("string2", str, [], default=np.str_("b")), - ColumnData("datetime", datetime, [], optional=True), - ColumnData("datetime1", datetime, [], default=datetime.now().astimezone()), - ColumnData("datetime2", datetime, [], default=np.datetime64("NaT")), - ColumnData("array", np.ndarray, [], optional=True), - ColumnData("array1", np.ndarray, [], default=np.empty(0)), - ColumnData("embedding", Embedding, [], optional=True), - ColumnData( - "embedding1", Embedding, [], default=Embedding(np.array([np.nan, np.nan])) - ), - ColumnData("image", Image, [], optional=True), - ColumnData("image1", Image, [], default=Image.empty()), - ColumnData("sequence_1d", Sequence1D, [], optional=True), - ColumnData("sequence_1d1", Sequence1D, [], default=Sequence1D.empty()), - ColumnData("mesh", Mesh, [], optional=True), - ColumnData("mesh1", Mesh, [], default=Mesh.empty()), - ColumnData("audio", Audio, [], optional=True), - ColumnData("audio1", Audio, [], default=Audio.empty()), + ColumnData("bool", "bool", [], default=True), + ColumnData("bool1", "bool", [], default=False), + ColumnData("bool2", "bool", [], default=np.bool_(True)), + ColumnData("int", "int", [], default=-5), + ColumnData("int1", "int", [], default=5), + ColumnData("int2", "int", [], default=np.int16(1000)), + ColumnData("float", "float", [], optional=True), + ColumnData("float1", "float", [], default=5.0), + ColumnData("float2", "float", [], default=np.float16(1000.0)), + ColumnData("string", "str", [], optional=True), + ColumnData("string1", "str", [], default="a"), + ColumnData("string2", "str", [], default=np.str_("b")), + ColumnData("datetime", "datetime", [], optional=True), + ColumnData("datetime1", "datetime", [], default=datetime.now().astimezone()), + ColumnData("datetime2", "datetime", [], default=np.datetime64("NaT")), + ColumnData("array", "array", [], optional=True), + ColumnData("array1", "array", [], default=np.empty(0)), + ColumnData("embedding", "Embedding", [], optional=True), + ColumnData( + "embedding1", "Embedding", [], default=Embedding(np.array([np.nan, np.nan])) + ), + ColumnData("image", "Image", [], optional=True), + ColumnData("image1", "Image", [], default=Image.empty()), + ColumnData("sequence_1d", "Sequence1D", [], optional=True), + ColumnData("sequence_1d1", "Sequence1D", [], default=Sequence1D.empty()), + ColumnData("mesh", "Mesh", [], optional=True), + ColumnData("mesh1", "Mesh", [], default=Mesh.empty()), + ColumnData("audio", "Audio", [], optional=True), + ColumnData("audio1", "Audio", [], default=Audio.empty()), ColumnData( "category1", - Category, + "Category", [], default="red", attrs={"categories": ["red", "green"]}, ), - ColumnData("window", Window, [], optional=True), - ColumnData("window1", Window, [], default=[-1, np.nan]), + ColumnData("window", "Window", [], optional=True), + ColumnData("window1", "Window", [], default=[-1, np.nan]), ] @@ -131,62 +80,62 @@ def simple_data() -> List[ColumnData]: Get a list of scalar column data. """ return [ - ColumnData("bool", bool, [True, False, True, True, False, True]), + ColumnData("bool", "bool", [True, False, True, True, False, True]), ColumnData( "bool1", - bool, + "bool", np.array([True, False, True, True, False, True]), description="np.bool_ column", ), - ColumnData("int", int, [-1000, -1, 0, 4, 5, 6]), + ColumnData("int", "int", [-1000, -1, 0, 4, 5, 6]), ColumnData( "int1", - int, + "int", np.array([-1000, -1, 0, 4, 5, 6]), description="numpy.int64 column", ), ColumnData( "int2", - int, + "int", np.array([-1000, -1, 0, 4, 5, 6], np.int16), description="np.int16 column", ), ColumnData( "int3", - int, + "int", np.array([1, 10, 0, 4, 5, 6], np.int16), description="np.uint16 column", ), ColumnData( "float", - float, + "float", [-float("inf"), -0.1, float("nan"), float("inf"), 0.1, 1000.0], ), ColumnData( "float1", - float, + "float", np.array([-float("inf"), -0.1, float("nan"), float("inf"), 0.1, 1000.0]), description="np.float64 column", ), ColumnData( "float2", - float, + "float", np.array( [-float("inf"), -0.1, float("nan"), float("inf"), 0.1, 1000.0], np.float16, ), description="np.float16 column", ), - ColumnData("string", str, ["", "a", "bc", "def", "ghijш", "klmnoü"]), + ColumnData("string", "str", ["", "a", "bc", "def", "ghijш", "klmnoü"]), ColumnData( "string1", - str, + "str", np.array(["", "a", "bc", "def", "ghij", "klmno"]), description="", ), ColumnData( "datetime", - datetime, + "datetime", [ datetime.now(), datetime.now().astimezone(), @@ -198,7 +147,7 @@ def simple_data() -> List[ColumnData]: ), ColumnData( "datetime1", - datetime, + "datetime", np.arange("2002-10-27T04:30", 6 * 60, 60, np.datetime64), description="np.datetime64 column", ), @@ -223,333 +172,8 @@ def complex_data() -> List[ColumnData]: ) -def array_data() -> List[ColumnData]: - """ - Get a list of array column data. - """ - return [ - ColumnData( - "array", - np.ndarray, - [ - np.array([1, 2]), - np.array([3, 4]), - np.array([5, 6]), - np.array([7, 8]), - np.array([9, 10]), - np.array([11, 12]), - ], - description="list of np.ndarray of fixed shape", - ), - ColumnData( - "array1", - np.ndarray, - [ - np.zeros(0, np.int64), - np.array([1], np.int64), - np.array([2, 3], np.int64), - np.array([4, 5, 6], np.int64), - np.array([7, 8, 9, 10], np.int64), - np.array([11, 12, 13, 14, 15], np.int64), - ], - description="list of np.ndarray of variable shape", - ), - ColumnData( - "array2", - np.ndarray, - [ - 1.0, - [], - [[[[]]]], - [[1.0, 2, 3], [4, 5, 6]], - (7.0, 8, 9, 10), - np.array([11.0, 12, 13, 14, 15]), - ], - description="mixed types", - ), - ColumnData( - "array3", np.ndarray, np.random.rand(6, 2, 2), description="batch array" - ), - ] - - -def embedding_data() -> List[ColumnData]: - """ - Get a list of embedding column data. - """ - return [ - ColumnData( - "embedding", - Embedding, - [ - Embedding(np.array([1.0, 2.0])), - Embedding(np.array([3.0, 4.0])), - Embedding(np.array([5.0, 6.0])), - Embedding(np.array([7.0, np.nan])), - Embedding(np.array([np.nan, 8.0])), - Embedding(np.array([np.nan, np.nan])), - ], - ), - ColumnData( - "embedding1", - Embedding, - [ - [1.0, 2.0], - (3.0, 4.0), - np.array([5.0, 6.0]), - [7.0, float("nan")], - (float("nan"), 8.0), - np.array([np.nan, np.nan]), - ], - description="mixed types", - ), - ColumnData( - "embedding2", Embedding, np.random.rand(6, 2), description="batch array" - ), - ] - - -def image_data() -> List[ColumnData]: - """ - Get a list of image column data. - """ - return [ - ColumnData( - "image", - Image, - [ - Image.empty(), - Image.empty(), - Image(np.zeros((10, 10), dtype=np.uint8)), - Image(np.zeros((10, 20, 1), dtype=np.int64)), - Image(np.zeros((20, 10, 3), dtype=np.float64)), - Image(np.zeros((20, 20, 4), dtype=np.uint8)), - ], - ), - ColumnData( - "image1", - Image, - [ - [[0]], - [[[1.0]]], - [[[127, 127, 127]]], - np.zeros((10, 10), dtype=np.uint8), - np.zeros((20, 10), dtype=np.int64), - np.zeros((20, 10, 3), dtype=np.float64), - ], - description="mixed types", - ), - ColumnData( - "image2", - Image, - np.random.randint(0, 256, (6, 10, 20, 3), "uint8"), - description="batch array", - ), - ] - - -def sequence_1d_data() -> List[ColumnData]: - """ - Get a list of 1d-sequence column data. - """ - return [ - ColumnData( - "sequence_1d", - Sequence1D, - [ - Sequence1D(np.array([0.0, 1.0]), np.array([1.0, 2.0])), - Sequence1D(np.array([0.0, 1.0]), np.array([3.0, 4.0])), - Sequence1D(np.array([0.0, 1.0]), np.array([5.0, 6.0])), - Sequence1D(np.array([0.0, 1.0]), np.array([7.0, 8.0])), - Sequence1D(np.array([np.inf, 1.0]), np.array([-np.inf, 1.0])), - Sequence1D(np.array([np.nan, np.nan]), np.array([np.nan, np.nan])), - ], - description="fixed shape", - ), - ColumnData( - "sequence_1d1", - Sequence1D, - [ - Sequence1D.empty(), - Sequence1D(np.array([0.0]), np.array([1.0])), - Sequence1D(np.array([0.0, 1.0, 2.0]), np.array([2.0, 3.0, 4.0])), - Sequence1D( - np.array([0.0, 1.0, 2.0, 3.0]), np.array([5.0, 6.0, 7.0, 8.0]) - ), - Sequence1D(np.array([np.inf, 1.0]), np.array([-np.inf, 1.0])), - Sequence1D(np.array([np.nan, np.nan]), np.array([np.nan, np.nan])), - ], - description="variable shape", - ), - ColumnData( - "sequence_1d2", - Sequence1D, - [ - [], - [1.0], - (float("nan"), float("inf"), -float("inf")), - np.array([5.0, 6.0, 7.0, 8.0]), - np.array([-np.inf, 1.0]), - np.array([np.nan, np.nan]), - ], - description="variable shape, mixed types", - ), - ] - - -def mesh_data() -> List[ColumnData]: - """ - Get a list of mesh column data. - """ - return [ - ColumnData( - "mesh", - Mesh, - [ - Mesh.empty(), - Mesh.empty(), - Mesh( - np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]]), - np.array([[0, 1, 2]]), - ), - Mesh( - np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]]), - np.array([[0, 1, 2]]), - ), - Mesh( - np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]]), - np.array([[0, 1, 2]]), - ), - Mesh( - np.array( - [ - [0.0, 0.0, 0.0], - [0.0, 0.0, 1.0], - [0.0, 1.0, 0.0], - [1.0, 0.0, 0.0], - ] - ), - np.array([[0, 1, 2], [0, 1, 3], [1, 2, 3]]), - ), - ], - description="", - ), - ] - - -def audio_data() -> List[ColumnData]: - """ - Get a list of audio column data. - """ - - samplerate = 44100 - time = np.linspace(0.0, 1.0, samplerate) - amplitude = np.iinfo(np.int16).max * 0.4 - data = np.array(amplitude * np.sin(2.0 * np.pi * 1000 * time), dtype=np.int16) - audio_data_left, audio_data_right = data, data - return [ - ColumnData( - "audio", - Audio, - 6 - * [ - Audio( - samplerate, - np.array([audio_data_left, audio_data_right]).transpose(), - ), - ], - description="List of 3 stereo Audio Signals", - ), - ColumnData( - "audio1", - Audio, - 6 - * [ - Audio( - samplerate, - np.array([audio_data_right, audio_data_right]).transpose(), - ), - ], - description="List of 2 stereo Audio Signals", - ), - ] - - -def categorical_data() -> List[ColumnData]: - """ - Get a list of categorical column data. - """ - return [ - ColumnData( - "category", - Category, - 2 * ["red", "blue", "red"], - description="strings of three categories", - attrs={"categories": ["red", "green", "blue"]}, - ), - ColumnData( - "category1", - Category, - 2 * ["red", "blue", "red"], - description="strings of three categories", - attrs={"categories": ["red", "green", "blue"]}, - ), - ] - - -def video_data() -> List[ColumnData]: - """ - Get a list of video column data. - """ - return [ - ColumnData( - "video", - Video, - [ - Video.empty(), - Video.empty(), - Video.from_file("data/videos/sea-360p.avi"), - Video.from_file("data/videos/sea-360p.mp4"), - Video.from_file("data/videos/sea-360p.wmv"), - Video.empty(), - ], - description="", - ), - ] - - -def window_data() -> List[ColumnData]: - """ - Get a list of window column data. - """ - return [ - ColumnData( - "window", - Window, - np.random.uniform(-1, 1, (6, 2)), - ), - ColumnData( - "window1", - Window, - np.random.randint(-1000, 1000, (6, 2)), - ), - ColumnData( - "window2", - Window, - [ - (1, 2), - (1.0, np.nan), - [np.nan, np.inf], - [1.0, -1.0], - np.array([1, 2]), - np.array([np.nan, -np.inf]), - ], - ), - ] - - @pytest.fixture -def empty_dataset() -> Dataset: +def empty_dataset() -> Iterator[Dataset]: """ An empty dataset. """ @@ -560,7 +184,7 @@ def empty_dataset() -> Dataset: @pytest.fixture -def categorical_color_dataset(request: SubRequest) -> Dataset: +def categorical_color_dataset(request: SubRequest) -> Iterator[Dataset]: """a dataset with category column""" with tempfile.TemporaryDirectory() as output_folder: output_h5_file = os.path.join(output_folder, "dataset.h5") @@ -576,7 +200,7 @@ def categorical_color_dataset(request: SubRequest) -> Dataset: @pytest.fixture -def fancy_indexing_dataset() -> Dataset: +def fancy_indexing_dataset() -> Iterator[Dataset]: """a dataset with a single range column""" with tempfile.TemporaryDirectory() as output_folder: output_h5_file = os.path.join(output_folder, "dataset.h5") @@ -589,7 +213,7 @@ def fancy_indexing_dataset() -> Dataset: @pytest.fixture -def descriptors_dataset_for_compress_dataset() -> Dataset: +def descriptors_dataset_for_compress_dataset() -> Iterator[Dataset]: """ dataset without faulty/ problematic columns """ @@ -619,7 +243,7 @@ def descriptors_dataset_for_compress_dataset() -> Dataset: @pytest.fixture -def descriptors_dataset() -> Dataset: +def descriptors_dataset() -> Iterator[Dataset]: """ builds descriptor testing dataset """ @@ -708,76 +332,3 @@ def descriptors_dataset() -> Dataset: data.append_audio_column("audio_none", [audio_1, audio_6, audio_1]) yield data - - -def approx( - expected: Optional[ColumnType], actual: ColumnInputType, type_: Type[ColumnType] -) -> bool: - """ - Check whether expected and actual dataset values are almost equal. - """ - - if expected is None and actual is None: - return True - if issubclass(type_, (bool, int, float, str)): - # Cast and compare scalars. - expected = np.array(expected, dtype=type_) - actual = np.array(actual, dtype=type_) - return approx(expected, actual, np.ndarray) - if issubclass(type_, datetime): - # Cast and compare datetimes. - expected = np.array(expected, dtype="datetime64") - actual = np.array(actual, dtype="datetime64") - return approx(expected, actual, np.ndarray) - if issubclass(type_, Window): - return approx( - np.array(expected, dtype=float), np.array(actual), type_=np.ndarray - ) - if issubclass(type_, np.ndarray): - # Cast and compare arrays. - expected = np.asarray(expected) - actual = np.asarray(actual) - if actual.shape != expected.shape: - return False - if issubclass(expected.dtype.type, np.inexact): - return np.allclose(actual, expected, equal_nan=True) - return actual.tolist() == expected.tolist() - if issubclass(type_, (Embedding, Image, Mesh, Sequence1D, Audio, Video)): - # Cast and compare custom types. - if not isinstance(expected, type_): - expected = type_(expected) - if not isinstance(actual, type_): - actual = type_(actual) - if issubclass(type_, (Embedding, Image, Video)): - return approx(actual.data, expected.data, np.ndarray) - if issubclass(type_, Audio): - return ( - approx(actual.data, expected.data, np.ndarray) - and actual.sampling_rate == expected.sampling_rate - ) - if issubclass(type_, Mesh): - return ( - approx(actual.points, expected.points, np.ndarray) - and approx(actual.triangles, expected.triangles, np.ndarray) - and len(actual.point_displacements) == len(expected.point_displacements) - and all( - approx(actual_displacement, expected_displacement, np.ndarray) - for actual_displacement, expected_displacement in zip( - actual.point_displacements, expected.point_displacements - ) - ) - and actual.point_attributes.keys() == expected.point_attributes.keys() - and all( - approx( - actual.point_attributes[attribute_name], - point_attribute, - np.ndarray, - ) - for attribute_name, point_attribute in expected.point_attributes.items() - ) - ) - if issubclass(type_, Sequence1D): - return approx(actual.index, expected.index, np.ndarray) and approx( - actual.value, expected.value, np.ndarray - ) - raise TypeError(f"Invalid `type_` received: value {type_} of type {type(type_)}.") diff --git a/tests/integration/dataset/data.py b/tests/integration/dataset/data.py new file mode 100644 index 00000000..85bca0fa --- /dev/null +++ b/tests/integration/dataset/data.py @@ -0,0 +1,347 @@ +from dataclasses import dataclass, field +from typing import List, Optional, Sequence, Union, Dict + +import numpy as np + +from renumics.spotlight import Embedding, Mesh, Sequence1D, Image, Audio, Video +from renumics.spotlight.dataset.typing import ColumnInputType + + +@dataclass +class ColumnData: + """ + Data for a dataset column. + """ + + name: str + dtype_name: str + values: Union[Sequence[ColumnInputType], np.ndarray] + optional: bool = False + default: ColumnInputType = None + description: Optional[str] = None + attrs: Dict = field(default_factory=dict) + + +def array_data() -> List[ColumnData]: + """ + Get a list of array column data. + """ + return [ + ColumnData( + "array", + "array", + [ + np.array([1, 2]), + np.array([3, 4]), + np.array([5, 6]), + np.array([7, 8]), + np.array([9, 10]), + np.array([11, 12]), + ], + description="list of np.ndarray of fixed shape", + ), + ColumnData( + "array1", + "array", + [ + np.zeros(0, np.int64), + np.array([1], np.int64), + np.array([2, 3], np.int64), + np.array([4, 5, 6], np.int64), + np.array([7, 8, 9, 10], np.int64), + np.array([11, 12, 13, 14, 15], np.int64), + ], + description="list of np.ndarray of variable shape", + ), + ColumnData( + "array2", + "array", + [ + 1.0, + [], + [[[[]]]], + [[1.0, 2, 3], [4, 5, 6]], + (7.0, 8, 9, 10), + np.array([11.0, 12, 13, 14, 15]), + ], + description="mixed types", + ), + ColumnData( + "array3", "array", np.random.rand(6, 2, 2), description="batch array" + ), + ] + + +def embedding_data() -> List[ColumnData]: + """ + Get a list of embedding column data. + """ + return [ + ColumnData( + "embedding", + "Embedding", + [ + Embedding(np.array([1.0, 2.0])), + Embedding(np.array([3.0, 4.0])), + Embedding(np.array([5.0, 6.0])), + Embedding(np.array([7.0, np.nan])), + Embedding(np.array([np.nan, 8.0])), + Embedding(np.array([np.nan, np.nan])), + ], + ), + ColumnData( + "embedding1", + "Embedding", + [ + [1.0, 2.0], + (3.0, 4.0), + np.array([5.0, 6.0]), + [7.0, float("nan")], + (float("nan"), 8.0), + np.array([np.nan, np.nan]), + ], + description="mixed types", + ), + ColumnData( + "embedding2", "Embedding", np.random.rand(6, 2), description="batch array" + ), + ] + + +def image_data() -> List[ColumnData]: + """ + Get a list of image column data. + """ + return [ + ColumnData( + "image", + "Image", + [ + Image.empty(), + Image.empty(), + Image(np.zeros((10, 10), dtype=np.uint8)), + Image(np.zeros((10, 20, 1), dtype=np.int64)), + Image(np.zeros((20, 10, 3), dtype=np.float64)), + Image(np.zeros((20, 20, 4), dtype=np.uint8)), + ], + ), + ColumnData( + "image1", + "Image", + [ + [[0]], + [[[1.0]]], + [[[127, 127, 127]]], + np.zeros((10, 10), dtype=np.uint8), + np.zeros((20, 10), dtype=np.int64), + np.zeros((20, 10, 3), dtype=np.float64), + ], + description="mixed types", + ), + ColumnData( + "image2", + "Image", + np.random.randint(0, 256, (6, 10, 20, 3), "uint8"), + description="batch array", + ), + ] + + +def sequence_1d_data() -> List[ColumnData]: + """ + Get a list of 1d-sequence column data. + """ + return [ + ColumnData( + "sequence_1d", + "Sequence1D", + [ + Sequence1D(np.array([0.0, 1.0]), np.array([1.0, 2.0])), + Sequence1D(np.array([0.0, 1.0]), np.array([3.0, 4.0])), + Sequence1D(np.array([0.0, 1.0]), np.array([5.0, 6.0])), + Sequence1D(np.array([0.0, 1.0]), np.array([7.0, 8.0])), + Sequence1D(np.array([np.inf, 1.0]), np.array([-np.inf, 1.0])), + Sequence1D(np.array([np.nan, np.nan]), np.array([np.nan, np.nan])), + ], + description="fixed shape", + ), + ColumnData( + "sequence_1d1", + "Sequence1D", + [ + Sequence1D.empty(), + Sequence1D(np.array([0.0]), np.array([1.0])), + Sequence1D(np.array([0.0, 1.0, 2.0]), np.array([2.0, 3.0, 4.0])), + Sequence1D( + np.array([0.0, 1.0, 2.0, 3.0]), np.array([5.0, 6.0, 7.0, 8.0]) + ), + Sequence1D(np.array([np.inf, 1.0]), np.array([-np.inf, 1.0])), + Sequence1D(np.array([np.nan, np.nan]), np.array([np.nan, np.nan])), + ], + description="variable shape", + ), + ColumnData( + "sequence_1d2", + "Sequence1D", + [ + [], + [1.0], + (float("nan"), float("inf"), -float("inf")), + np.array([5.0, 6.0, 7.0, 8.0]), + np.array([-np.inf, 1.0]), + np.array([np.nan, np.nan]), + ], + description="variable shape, mixed types", + ), + ] + + +def mesh_data() -> List[ColumnData]: + """ + Get a list of mesh column data. + """ + return [ + ColumnData( + "mesh", + "Mesh", + [ + Mesh.empty(), + Mesh.empty(), + Mesh( + np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]]), + np.array([[0, 1, 2]]), + ), + Mesh( + np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]]), + np.array([[0, 1, 2]]), + ), + Mesh( + np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]]), + np.array([[0, 1, 2]]), + ), + Mesh( + np.array( + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + ] + ), + np.array([[0, 1, 2], [0, 1, 3], [1, 2, 3]]), + ), + ], + description="", + ), + ] + + +def audio_data() -> List[ColumnData]: + """ + Get a list of audio column data. + """ + + samplerate = 44100 + time = np.linspace(0.0, 1.0, samplerate) + amplitude = np.iinfo(np.int16).max * 0.4 + data = np.array(amplitude * np.sin(2.0 * np.pi * 1000 * time), dtype=np.int16) + audio_data_left, audio_data_right = data, data + return [ + ColumnData( + "audio", + "Audio", + 6 + * [ + Audio( + samplerate, + np.array([audio_data_left, audio_data_right]).transpose(), + ), + ], + description="List of 3 stereo Audio Signals", + ), + ColumnData( + "audio1", + "Audio", + 6 + * [ + Audio( + samplerate, + np.array([audio_data_right, audio_data_right]).transpose(), + ), + ], + description="List of 2 stereo Audio Signals", + ), + ] + + +def categorical_data() -> List[ColumnData]: + """ + Get a list of categorical column data. + """ + return [ + ColumnData( + "category", + "Category", + 2 * ["red", "blue", "red"], + description="strings of three categories", + attrs={"categories": ["red", "green", "blue"]}, + ), + ColumnData( + "category1", + "Category", + 2 * ["red", "blue", "red"], + description="strings of three categories", + attrs={"categories": ["red", "green", "blue"]}, + ), + ] + + +def video_data() -> List[ColumnData]: + """ + Get a list of video column data. + """ + return [ + ColumnData( + "video", + "Video", + [ + Video.empty(), + Video.empty(), + Video.from_file("data/videos/sea-360p.avi"), + Video.from_file("data/videos/sea-360p.mp4"), + Video.from_file("data/videos/sea-360p.wmv"), + Video.empty(), + ], + description="", + ), + ] + + +def window_data() -> List[ColumnData]: + """ + Get a list of window column data. + """ + return [ + ColumnData( + "window", + "Window", + np.random.uniform(-1, 1, (6, 2)), + ), + ColumnData( + "window1", + "Window", + np.random.randint(-1000, 1000, (6, 2)), + ), + ColumnData( + "window2", + "Window", + [ + (1, 2), + (1.0, np.nan), + [np.nan, np.inf], + [1.0, -1.0], + np.array([1, 2]), + np.array([np.nan, -np.inf]), + ], + ), + ] diff --git a/tests/integration/dataset/helpers.py b/tests/integration/dataset/helpers.py new file mode 100644 index 00000000..54ad28d4 --- /dev/null +++ b/tests/integration/dataset/helpers.py @@ -0,0 +1,34 @@ +def get_append_column_fn_name(dtype_name: str) -> str: + """ + Get name of the `append_column` dataset method for the given column type. + """ + + if dtype_name == "bool": + return "append_bool_column" + if dtype_name == "int": + return "append_int_column" + if dtype_name == "float": + return "append_float_column" + if dtype_name == "str": + return "append_string_column" + if dtype_name == "datetime": + return "append_datetime_column" + if dtype_name == "array": + return "append_array_column" + if dtype_name == "Embedding": + return "append_embedding_column" + if dtype_name == "Image": + return "append_image_column" + if dtype_name == "Sequence1D": + return "append_sequence_1d_column" + if dtype_name == "Mesh": + return "append_mesh_column" + if dtype_name == "Audio": + return "append_audio_column" + if dtype_name == "Category": + return "append_categorical_column" + if dtype_name == "Video": + return "append_video_column" + if dtype_name == "Window": + return "append_window_column" + raise TypeError(dtype_name) diff --git a/tests/integration/dataset/test_categorical.py b/tests/integration/dataset/test_categorical.py index 01802987..ecc0ad82 100644 --- a/tests/integration/dataset/test_categorical.py +++ b/tests/integration/dataset/test_categorical.py @@ -10,16 +10,19 @@ @pytest.mark.parametrize( "categorical_color_dataset", ["red", "green", None], indirect=True ) -def test_categorical_disallows(categorical_color_dataset: Dataset) -> None: - """adding not existing category value should raise""" +@pytest.mark.parametrize("value", ["invalid_color", ""]) +def test_categorical_unknown_category( + categorical_color_dataset: Dataset, value: str +) -> None: + """adding not existing category value should add `None`""" with pytest.raises(exceptions.InvalidValueError): - categorical_color_dataset.append_row(my_new_cat="invalid_color") + categorical_color_dataset.append_row(my_new_cat=value) @pytest.mark.parametrize( "categorical_color_dataset", ["red", "green", None], indirect=True ) -def test_categorical_new_catagory(categorical_color_dataset: Dataset) -> None: +def test_categorical_new_category(categorical_color_dataset: Dataset) -> None: """adding new category with existing int value should raise""" old_categories = categorical_color_dataset.get_column_attributes("my_new_cat")[ "categories" @@ -37,7 +40,7 @@ def test_categorical_new_catagory(categorical_color_dataset: Dataset) -> None: @pytest.mark.parametrize( "categorical_color_dataset", ["red", "green", None], indirect=True ) -def test_categorical_remvove_used_raises(categorical_color_dataset: Dataset) -> None: +def test_categorical_remove_used_raises(categorical_color_dataset: Dataset) -> None: """removing used categories should raise""" with pytest.raises(exceptions.InvalidAttributeError): categorical_color_dataset.set_column_attributes( @@ -48,7 +51,7 @@ def test_categorical_remvove_used_raises(categorical_color_dataset: Dataset) -> @pytest.mark.parametrize( "categorical_color_dataset", ["red", "green", None], indirect=True ) -def test_categorical_remvove_unused(categorical_color_dataset: Dataset) -> None: +def test_categorical_remove_unused(categorical_color_dataset: Dataset) -> None: """removing unused categories should work""" old_categories = categorical_color_dataset.get_column_attributes("my_new_cat")[ "categories" diff --git a/tests/integration/dataset/test_data_types.py b/tests/integration/dataset/test_data_types.py deleted file mode 100644 index 6f3b6f47..00000000 --- a/tests/integration/dataset/test_data_types.py +++ /dev/null @@ -1,663 +0,0 @@ -""" -Test custom data types. -""" -import io -from pathlib import Path -from typing import Optional -from urllib.parse import urljoin - -import numpy as np -import pytest - -from renumics.spotlight import Audio, Embedding, Mesh, Sequence1D, Image, Video -from .conftest import approx - - -SEED = 42 -BASE_URL = "https://spotlightpublic.blob.core.windows.net/internal-test-data/" - - -class TestEmbedding: - """ - Test `renumics.spotlight.Embedding` class. - """ - - @pytest.mark.parametrize("length", [1, 10, 100]) - @pytest.mark.parametrize("dtype", ["float32", "float64", "uint16", "int32"]) - @pytest.mark.parametrize("input_type", ["array", "list", "tuple"]) - def test_embedding(self, length: int, dtype: str, input_type: str) -> None: - """ - Test Embedding class initialization and to/from array conversion. - """ - np_dtype = np.dtype(dtype) - if np_dtype.str[1] == "f": - np.random.seed(SEED) - array = np.random.uniform(0, 1, length).astype(dtype) - else: - np.random.seed(SEED) - array = np.random.randint( # type: ignore - np.iinfo(np_dtype).min, np.iinfo(np_dtype).max, length, dtype - ) - if input_type == "list": - array = array.tolist() - elif input_type == "tuple": - array = tuple(array.tolist()) # type: ignore - embedding = Embedding(array) - - assert approx(array, embedding.data, Embedding) - encoded_image = embedding.encode() - decoded_image = Embedding.decode(encoded_image) - assert approx(embedding, decoded_image, Embedding) - - @pytest.mark.parametrize("input_type", ["array", "list", "tuple"]) - def test_zero_length_embedding(self, input_type: str) -> None: - """ - Test if `Embedding` fails with zero-length data. - """ - array = np.empty(0) - if input_type == "list": - array = array.tolist() - elif input_type == "tuple": - array = tuple(array.tolist()) # type: ignore - with pytest.raises(ValueError): - _ = Embedding(array) - - @pytest.mark.parametrize("num_dims", [0, 2, 3, 4]) - def test_multidimensional_embedding(self, num_dims: int) -> None: - """ - Test if `Embedding` fails with not one-dimensional data. - """ - np.random.seed(SEED) - dims = np.random.randint(1, 20, num_dims) - np.random.seed(SEED) - array = np.random.uniform(0, 1, dims) - with pytest.raises(ValueError): - _ = Embedding(array) - - -@pytest.mark.parametrize("length", [1, 10, 100]) -@pytest.mark.parametrize("input_type", ["array", "list", "tuple"]) -def test_sequence_1d(length: int, input_type: str) -> None: - """ - Test `renumics.spotlight.Sequence1D` class. - """ - index = np.random.rand(length) - value = np.random.rand(length) - if input_type == "list": - index = index.tolist() - value = value.tolist() - elif input_type == "tuple": - index = tuple(index.tolist()) # type: ignore - value = tuple(value.tolist()) # type: ignore - # Initialization with index and value. - sequence_1d = Sequence1D(index, value) - assert approx(index, sequence_1d.index, np.ndarray) - assert approx(value, sequence_1d.value, np.ndarray) - encoded_sequence_1d = sequence_1d.encode() - decoded_sequence_1d = Sequence1D.decode(encoded_sequence_1d) - assert approx(sequence_1d, decoded_sequence_1d, Sequence1D) - # Initialization with value only. - sequence_1d = Sequence1D(value) - assert approx(np.arange(len(value)), sequence_1d.index, np.ndarray) - assert approx(value, sequence_1d.value, np.ndarray) - encoded_sequence_1d = sequence_1d.encode() - decoded_sequence_1d = Sequence1D.decode(encoded_sequence_1d) - assert approx(sequence_1d, decoded_sequence_1d, Sequence1D) - - -class TestImage: - """ - Test `renumics.spotlight.Image` class. - """ - - data_folder = Path("data/images") - - @pytest.mark.parametrize("size", [1, 10, 100]) - @pytest.mark.parametrize("num_channels", [None, 1, 3, 4]) - @pytest.mark.parametrize("dtype", ["float32", "float64", "uint8", "int32"]) - @pytest.mark.parametrize("input_type", ["array", "list", "tuple"]) - def test_image_from_array( - self, size: int, num_channels: Optional[int], dtype: str, input_type: str - ) -> None: - """ - Test image creation. - """ - shape = (size, size) if num_channels is None else (size, size, num_channels) - np_dtype = np.dtype(dtype) - if np_dtype.str[1] == "f": - np.random.seed(SEED) - array = np.random.uniform(0, 1, shape).astype(dtype) - target = (255 * array).round().astype("uint8") - else: - np.random.seed(SEED) - array = np.random.randint(0, 256, shape, dtype) # type: ignore - target = array.astype("uint8") - if num_channels == 1: - target = target.squeeze(axis=2) - if input_type == "list": - array = array.tolist() - elif input_type == "tuple": - array = tuple(array.tolist()) # type: ignore - image = Image(array) - - assert approx(target, image.data, Image) - encoded_image = image.encode() - decoded_image = Image.decode(encoded_image) - assert approx(image, decoded_image, Image) - - @pytest.mark.parametrize( - "filename", - [ - "nature-256p.ico", - "nature-360p.bmp", - "nature-360p.gif", - "nature-360p.jpg", - "nature-360p.png", - "nature-360p.tif", - "nature-360p.webp", - "nature-720p.jpg", - "nature-1080p.jpg", - "sea-360p.gif", - "sea-360p.apng", - ], - ) - def test_image_from_filepath(self, filename: str) -> None: - """ - Test reading image from an existing file. - """ - filepath = self.data_folder / filename - assert filepath.is_file() - _ = Image.from_file(str(filepath)) - _ = Image.from_file(filepath) - - @pytest.mark.parametrize( - "filename", - [ - "nature-256p.ico", - "nature-360p.bmp", - "nature-360p.gif", - "nature-360p.jpg", - "nature-360p.png", - "nature-360p.tif", - "nature-360p.webp", - "nature-720p.jpg", - "nature-1080p.jpg", - "sea-360p.gif", - "sea-360p.apng", - ], - ) - def test_image_from_file(self, filename: str) -> None: - """ - Test reading image from a file descriptor. - """ - filepath = self.data_folder / filename - assert filepath.is_file() - with filepath.open("rb") as file: - _ = Image.from_file(file) - - @pytest.mark.parametrize( - "filename", - [ - "nature-256p.ico", - "nature-360p.bmp", - "nature-360p.gif", - "nature-360p.jpg", - "nature-360p.png", - "nature-360p.tif", - "nature-360p.webp", - "nature-720p.jpg", - "nature-1080p.jpg", - "sea-360p.gif", - "sea-360p.apng", - ], - ) - def test_image_from_io(self, filename: str) -> None: - """ - Test reading image from an IO object. - """ - filepath = self.data_folder / filename - assert filepath.is_file() - with filepath.open("rb") as file: - blob = file.read() - buffer = io.BytesIO(blob) - _ = Image.from_file(buffer) - - @pytest.mark.parametrize( - "filename", - [ - "nature-256p.ico", - "nature-360p.bmp", - "nature-360p.gif", - "nature-360p.jpg", - "nature-360p.png", - "nature-360p.tif", - "nature-360p.webp", - "nature-720p.jpg", - "nature-1080p.jpg", - "sea-360p.gif", - "sea-360p.apng", - ], - ) - def test_image_from_bytes(self, filename: str) -> None: - """ - Test reading image from bytes. - """ - filepath = self.data_folder / filename - assert filepath.is_file() - with filepath.open("rb") as file: - blob = file.read() - _ = Image.from_file(blob) - - @pytest.mark.parametrize( - "filename", - [ - "nature-256p.ico", - "nature-360p.bmp", - "nature-360p.gif", - "nature-360p.jpg", - "nature-360p.png", - "nature-360p.tif", - "nature-360p.webp", - "nature-720p.jpg", - "nature-1080p.jpg", - "sea-360p.gif", - "sea-360p.apng", - ], - ) - def test_image_from_url(self, filename: str) -> None: - """ - Test reading image from a URL. - """ - url = urljoin(BASE_URL, filename) - _ = Image.from_file(url) - - -class TestMesh: - """ - Test `renumics.spotlight.Mesh` class. - """ - - seed = 42 - data_folder = Path("data/meshes") - - @pytest.mark.parametrize("num_points", [10, 100, 10000]) - @pytest.mark.parametrize("num_triangles", [10, 100, 10000]) - def test_mesh_from_array(self, num_points: int, num_triangles: int) -> None: - """ - Test mesh creation. - """ - points = np.random.random((num_points, 3)) - triangles = np.random.randint(0, num_points, (num_triangles, 3)) - mesh = Mesh( - points, - triangles, - { - "float32": np.random.rand(num_points).astype("float32"), - "float64_3": np.random.rand(num_points, 3).astype("float64"), - "float16_4": np.random.rand(num_points, 4).astype("float16"), - "uint64_1": np.random.randint(0, num_points, (num_points, 1), "uint64"), - "int16": np.random.randint(-10000, 10000, num_points, "int16"), - }, - { - "float32_1": np.random.rand(num_triangles, 1).astype("float32"), - "float64_3": np.random.rand(num_triangles, 3).astype("float64"), - "float32_4": np.random.rand(num_triangles, 4).astype("float32"), - "uint32": np.random.randint( - 0, np.iinfo("uint32").max, num_triangles, "uint32" - ), - "int64_1": np.random.randint( - -10000, 10000, (num_triangles, 1), "int64" - ), - }, - [np.random.random((num_points, 3)) for _ in range(10)], - ) - assert len(mesh.points) <= len(triangles) * 3 - assert len(mesh.triangles) <= len(triangles) - encoded_mesh = mesh.encode() - decoded_mesh = Mesh.decode(encoded_mesh) - assert approx(mesh, decoded_mesh, Mesh) - - @pytest.mark.parametrize( - "filename", - [ - "tree.ascii.stl", - "tree.glb", - "tree.gltf", - "tree.obj", - "tree.off", - "tree.ply", - "tree.stl", - ], - ) - def test_mesh_from_filepath(self, filename: str) -> None: - """ - Test reading mesh from an existing file. - """ - filepath = self.data_folder / filename - assert filepath.is_file() - _ = Mesh.from_file(str(filepath)) - _ = Mesh.from_file(filepath) - - @pytest.mark.parametrize( - "filename", - [ - "tree.ascii.stl", - "tree.glb", - "tree.gltf", - "tree.obj", - "tree.off", - "tree.ply", - "tree.stl", - ], - ) - def test_mesh_from_url(self, filename: str) -> None: - """ - Test reading mesh from a URL. - """ - url = urljoin(BASE_URL, filename) - _ = Mesh.from_file(url) - - -class TestAudio: - """ - Test `renumics.spotlight.Audio` class. - """ - - data_folder = Path("data/audio") - - @pytest.mark.parametrize("sampling_rate", [100, 8000, 44100, 48000, 96000]) - @pytest.mark.parametrize("channels", [1, 2, 5]) - @pytest.mark.parametrize("length", [0.5, 1.0, 2.0]) - @pytest.mark.parametrize("dtype", ["f4", "i2", "u1"]) - @pytest.mark.parametrize("target", ["wav", "flac"]) - def test_lossless_audio( - self, - sampling_rate: int, - channels: int, - length: float, - dtype: str, - target: Optional[str], - ) -> None: - """ - Test audio creation and lossless saving. - """ - - time = np.linspace(0.0, length, round(sampling_rate * length)) - y = 0.4 * np.sin(2.0 * np.pi * 100 * time) - if dtype.startswith("f"): - data = y.astype(dtype) - elif dtype.startswith("i"): - data = (y * np.iinfo(dtype).max).astype(dtype) - elif dtype.startswith("u"): - data = ((y + 1) * np.iinfo(dtype).max / 2).astype(dtype) - else: - assert False - if channels > 1: - data = np.broadcast_to(data[:, np.newaxis], (len(data), channels)) - audio = Audio(sampling_rate, data) - - encoded_audio = audio.encode(target) - decoded_audio = Audio.decode(encoded_audio) - - decoded_sr = decoded_audio.sampling_rate - decoded_data = decoded_audio.data - - assert decoded_sr == sampling_rate - assert decoded_data.shape == (len(y), channels) - - @pytest.mark.parametrize("sampling_rate", [32000, 44100, 48000]) - @pytest.mark.parametrize("channels", [1, 2]) - @pytest.mark.parametrize("length", [0.5, 1.0, 2.0]) - @pytest.mark.parametrize("dtype", ["f4", "i2", "u1"]) - def test_lossy_audio( - self, sampling_rate: int, channels: int, length: float, dtype: str - ) -> None: - """ - Test audio creation and lossy saving. - """ - time = np.linspace(0.0, length, round(sampling_rate * length)) - y = 0.4 * np.sin(2.0 * np.pi * 100 * time) - if dtype.startswith("f"): - data = y.astype(dtype) - elif dtype.startswith("i"): - data = (y * np.iinfo(dtype).max).astype(dtype) - elif dtype.startswith("u"): - data = ((y + 1) * np.iinfo(dtype).max / 2).astype(dtype) - else: - assert False - if channels > 1: - data = np.broadcast_to(data[:, np.newaxis], (len(data), channels)) - audio = Audio(sampling_rate, data) - - encoded_audio = audio.encode("ogg") - decoded_audio = Audio.decode(encoded_audio) - - decoded_sr = decoded_audio.sampling_rate - _decoded_data = decoded_audio.data - - assert decoded_sr == sampling_rate - - @pytest.mark.parametrize( - "filename", - [ - "gs-16b-1c-44100hz.aac", - "gs-16b-1c-44100hz.ac3", - "gs-16b-1c-44100hz.aiff", - "gs-16b-1c-44100hz.flac", - "gs-16b-1c-44100hz.m4a", - "gs-16b-1c-44100hz.mp3", - "gs-16b-1c-44100hz.ogg", - "gs-16b-1c-44100hz.ogx", - "gs-16b-1c-44100hz.wav", - "gs-16b-1c-44100hz.wma", - ], - ) - def test_audio_from_filepath_mono(self, filename: str) -> None: - """ - Test `Audio.from_file` method on mono data. - """ - filepath = self.data_folder / "mono" / filename - assert filepath.is_file() - _ = Audio.from_file(str(filepath)) - _ = Audio.from_file(filepath) - - @pytest.mark.parametrize( - "filename", - [ - "gs-16b-2c-44100hz.aac", - "gs-16b-2c-44100hz.ac3", - "gs-16b-2c-44100hz.aiff", - "gs-16b-2c-44100hz.flac", - "gs-16b-2c-44100hz.m4a", - "gs-16b-2c-44100hz.mp3", - "gs-16b-2c-44100hz.mp4", - "gs-16b-2c-44100hz.ogg", - "gs-16b-2c-44100hz.ogx", - "gs-16b-2c-44100hz.wav", - "gs-16b-2c-44100hz.wma", - ], - ) - def test_audio_from_filepath_stereo(self, filename: str) -> None: - """ - Test `Audio.from_file` method on stereo data. - """ - filepath = self.data_folder / "stereo" / filename - assert filepath.is_file() - _ = Audio.from_file(str(filepath)) - _ = Audio.from_file(filepath) - - @pytest.mark.parametrize( - "filename", - [ - "gs-16b-2c-44100hz.aac", - "gs-16b-2c-44100hz.ac3", - "gs-16b-2c-44100hz.aiff", - "gs-16b-2c-44100hz.flac", - "gs-16b-2c-44100hz.m4a", - "gs-16b-2c-44100hz.mp3", - "gs-16b-2c-44100hz.mp4", - "gs-16b-2c-44100hz.ogg", - "gs-16b-2c-44100hz.ogx", - "gs-16b-2c-44100hz.wav", - "gs-16b-2c-44100hz.wma", - ], - ) - def test_audio_from_file(self, filename: str) -> None: - """ - Test reading audio from a file descriptor. - """ - filepath = self.data_folder / "stereo" / filename - assert filepath.is_file() - with filepath.open("rb") as file: - _ = Audio.from_file(file) - - @pytest.mark.parametrize( - "filename", - [ - "gs-16b-2c-44100hz.aac", - "gs-16b-2c-44100hz.ac3", - "gs-16b-2c-44100hz.aiff", - "gs-16b-2c-44100hz.flac", - "gs-16b-2c-44100hz.m4a", - "gs-16b-2c-44100hz.mp3", - "gs-16b-2c-44100hz.mp4", - "gs-16b-2c-44100hz.ogg", - "gs-16b-2c-44100hz.ogx", - "gs-16b-2c-44100hz.wav", - "gs-16b-2c-44100hz.wma", - ], - ) - def test_audio_from_io(self, filename: str) -> None: - """ - Test reading audio from an IO object. - """ - filepath = self.data_folder / "stereo" / filename - assert filepath.is_file() - with filepath.open("rb") as file: - blob = file.read() - buffer = io.BytesIO(blob) - _ = Audio.from_file(buffer) - - @pytest.mark.parametrize( - "filename", - [ - "gs-16b-2c-44100hz.aac", - "gs-16b-2c-44100hz.ac3", - "gs-16b-2c-44100hz.aiff", - "gs-16b-2c-44100hz.flac", - "gs-16b-2c-44100hz.m4a", - "gs-16b-2c-44100hz.mp3", - "gs-16b-2c-44100hz.mp4", - "gs-16b-2c-44100hz.ogg", - "gs-16b-2c-44100hz.ogx", - "gs-16b-2c-44100hz.wav", - "gs-16b-2c-44100hz.wma", - ], - ) - def test_audio_from_bytes(self, filename: str) -> None: - """ - Test reading audio from bytes. - """ - filepath = self.data_folder / "stereo" / filename - assert filepath.is_file() - with filepath.open("rb") as file: - blob = file.read() - _ = Audio.from_bytes(blob) - - @pytest.mark.parametrize( - "filename", - [ - "gs-16b-2c-44100hz.aac", - "gs-16b-2c-44100hz.ac3", - "gs-16b-2c-44100hz.aiff", - "gs-16b-2c-44100hz.flac", - "gs-16b-2c-44100hz.mp3", - "gs-16b-2c-44100hz.ogg", - "gs-16b-2c-44100hz.ogx", - "gs-16b-2c-44100hz.ts", - "gs-16b-2c-44100hz.wav", - "gs-16b-2c-44100hz.wma", - ], - ) - def test_audio_from_url(self, filename: str) -> None: - """ - Test reading audio from a URL. - """ - url = urljoin(BASE_URL, filename) - _ = Audio.from_file(url) - - -class TestVideo: - """ - Test `renumics.spotlight.Mesh` class. - """ - - data_folder = Path("data/videos") - - @pytest.mark.parametrize( - "filename", - [ - "sea-360p.avi", - "sea-360p.mkv", - "sea-360p.mov", - "sea-360p.mp4", - "sea-360p.mpg", - "sea-360p.ogg", - "sea-360p.webm", - "sea-360p.wmv", - "sea-360p-10s.mp4", - ], - ) - def test_video_from_filepath(self, filename: str) -> None: - """ - Test reading video from an existing file. - """ - filepath = self.data_folder / filename - assert filepath.is_file() - _ = Video.from_file(str(filepath)) - _ = Video.from_file(filepath) - - @pytest.mark.parametrize( - "filename", - [ - "sea-360p.avi", - "sea-360p.mkv", - "sea-360p.mov", - "sea-360p.mp4", - "sea-360p.mpg", - "sea-360p.ogg", - "sea-360p.webm", - "sea-360p.wmv", - "sea-360p-10s.mp4", - ], - ) - def test_video_from_bytes(self, filename: str) -> None: - """ - Test reading video from bytes. - """ - filepath = self.data_folder / filename - assert filepath.is_file() - with filepath.open("rb") as file: - blob = file.read() - _ = Video.from_bytes(blob) - - @pytest.mark.parametrize( - "filename", - [ - "sea-360p.avi", - "sea-360p.mkv", - "sea-360p.mov", - "sea-360p.mp4", - "sea-360p.mpg", - "sea-360p.ogg", - "sea-360p.webm", - "sea-360p.wmv", - "sea-360p-10s.mp4", - ], - ) - def test_video_from_url(self, filename: str) -> None: - """ - Test reading video from an URL. - """ - url = urljoin(BASE_URL, filename) - _ = Video.from_file(url) diff --git a/tests/integration/dataset/test_dataset.py b/tests/integration/dataset/test_dataset.py index cfec2624..42554003 100644 --- a/tests/integration/dataset/test_dataset.py +++ b/tests/integration/dataset/test_dataset.py @@ -8,7 +8,7 @@ import tempfile from datetime import datetime from glob import glob -from typing import List +from typing import List, Optional, cast import numpy as np import pandas as pd @@ -16,17 +16,19 @@ from renumics.spotlight import ( Audio, - Category, Dataset, Embedding, Image, Mesh, Sequence1D, Video, - Window, ) from renumics.spotlight.dataset import escape_dataset_name, unescape_dataset_name -from .conftest import approx, get_append_column_fn_name, ColumnData + +from renumics.spotlight.dataset.typing import OutputType +from .conftest import ColumnData +from .helpers import get_append_column_fn_name +from ..helpers import approx @pytest.mark.parametrize( @@ -84,21 +86,21 @@ def test_initialized_dataset(optional_data: List[ColumnData]) -> None: with Dataset(output_h5_file, "w") as dataset: for sample in optional_data: append_fn = getattr( - dataset, get_append_column_fn_name(sample.column_type) + dataset, get_append_column_fn_name(sample.dtype_name) ) append_fn(sample.name, **sample.attrs) assert len(list(dataset.iterrows())) == 0 assert len(dataset) == 0 assert set(dataset.keys()) == column_names for sample in optional_data: - assert dataset.get_column_type(sample.name) is sample.column_type + assert dataset.get_dtype(sample.name).name is sample.dtype_name print(dataset) with Dataset(output_h5_file, "r") as dataset: assert len(list(dataset.iterrows())) == 0 assert len(dataset) == 0 assert set(dataset.keys()) == column_names for sample in optional_data: - assert dataset.get_column_type(sample.name) is sample.column_type + assert dataset.get_dtype(sample.name).name is sample.dtype_name print(dataset) @@ -113,14 +115,14 @@ def test_optional_columns(optional_data: List[ColumnData]) -> None: with Dataset(output_h5_file, "w") as dataset: sample = optional_data[0] values = [None] * 10 - append_fn = getattr(dataset, get_append_column_fn_name(sample.column_type)) + append_fn = getattr(dataset, get_append_column_fn_name(sample.dtype_name)) if sample.optional: append_fn(sample.name, values, optional=sample.optional, **sample.attrs) else: append_fn(sample.name, values, default=sample.default, **sample.attrs) for sample in optional_data[1:]: append_fn = getattr( - dataset, get_append_column_fn_name(sample.column_type) + dataset, get_append_column_fn_name(sample.dtype_name) ) if sample.optional: append_fn(sample.name, optional=sample.optional, **sample.attrs) @@ -129,39 +131,39 @@ def test_optional_columns(optional_data: List[ColumnData]) -> None: assert len(dataset) == 10 assert set(dataset.keys()) == column_names for sample in optional_data: - column_type = sample.column_type + dtype_name = sample.dtype_name default = sample.default if default is None: - if column_type is str: + if dtype_name == "str": default = "" - elif column_type is float: + elif dtype_name == "float": default = float("nan") - elif column_type is Window: + elif dtype_name == "Window": default = [float("nan"), float("nan")] - assert dataset.get_column_type(sample.name) is column_type + assert dataset.get_dtype(sample.name).name == dtype_name assert approx( default, dataset.get_column_attributes(sample.name)["default"], - dataset.get_column_type(sample.name), + dataset.get_dtype(sample.name).name, ) for _ in range(10): dataset.append_row() assert len(dataset) == 20 assert set(dataset.keys()) == column_names for sample in optional_data: - column_type = sample.column_type + dtype_name = sample.dtype_name default = sample.default if default is None: - if column_type is str: + if dtype_name == "str": default = "" - elif column_type is float: + elif dtype_name == "float": default = float("nan") - elif column_type is Window: + elif dtype_name == "Window": default = [float("nan"), float("nan")] - assert dataset.get_column_type(sample.name) is sample.column_type + assert dataset.get_dtype(sample.name).name == sample.dtype_name for dataset_value in dataset[sample.name]: assert approx( - default, dataset_value, dataset.get_column_type(sample.name) + default, dataset_value, dataset.get_dtype(sample.name).name ) print(dataset) @@ -181,7 +183,7 @@ def test_append_row( with Dataset(output_h5_file, "w") as dataset: for sample in data: append_fn = getattr( - dataset, get_append_column_fn_name(sample.column_type) + dataset, get_append_column_fn_name(sample.dtype_name) ) append_fn(sample.name, **sample.attrs) for i in range(length): @@ -202,7 +204,8 @@ def test_append_row( for name in names: value = data_dict[name] dataset_value = dataset_row[name] - assert approx(value, dataset_value, dataset.get_column_type(name)) + dtype = dataset.get_dtype(name) + assert approx(value, dataset_value, dtype.name) with Dataset(output_h5_file, "a") as dataset: dataset_length = len(dataset) for i in range(length): @@ -227,7 +230,7 @@ def test_insert_row( with Dataset(output_h5_file, "w") as dataset: for sample in data: append_fn = getattr( - dataset, get_append_column_fn_name(sample.column_type) + dataset, get_append_column_fn_name(sample.dtype_name) ) append_fn(sample.name, **sample.attrs) for i in range(length): @@ -250,7 +253,7 @@ def test_insert_row( for name in names: value = data_dict[name] dataset_value = dataset_row[name] - assert approx(value, dataset_value, dataset.get_column_type(name)) + assert approx(value, dataset_value, dataset.get_dtype(name).name) def test_append_common_column( @@ -268,7 +271,7 @@ def test_append_common_column( for sample in data: dataset.append_column( sample.name, - sample.column_type, + sample.dtype_name, sample.values, description=sample.description, **sample.attrs, @@ -283,11 +286,11 @@ def test_append_common_column( assert set(dataset.keys()) == names print(dataset) for sample in data: - column_type = sample.column_type - assert dataset.get_column_type(sample.name) is column_type + dtype_name = sample.dtype_name + assert dataset.get_dtype(sample.name).name == dtype_name dataset_values = dataset[sample.name] for value, dataset_value in zip(sample.values, dataset_values): - assert approx(value, dataset_value, column_type) + assert approx(value, dataset_value, dtype_name) def test_append_delete_column( @@ -303,7 +306,7 @@ def test_append_delete_column( with Dataset(output_h5_file, "w") as dataset: for sample in data: dataset.append_column( - sample.name, sample.column_type, sample.values, **sample.attrs + sample.name, sample.dtype_name, sample.values, **sample.attrs ) with Dataset(output_h5_file, "a") as dataset: @@ -329,14 +332,14 @@ def test_isnull_notnull(empty_dataset: Dataset) -> None: # Test non-nullable data types. empty_dataset.append_bool_column("bool", [True, False] * 5) null_mask = np.full(len(empty_dataset), False) - assert approx(null_mask, empty_dataset.isnull("bool"), np.ndarray) - assert approx(~null_mask, empty_dataset.notnull("bool"), np.ndarray) + assert approx(null_mask, empty_dataset.isnull("bool"), "array") + assert approx(~null_mask, empty_dataset.notnull("bool"), "array") empty_dataset.append_int_column("int", range(len(empty_dataset))) - assert approx(null_mask, empty_dataset.isnull("int"), np.ndarray) - assert approx(~null_mask, empty_dataset.notnull("int"), np.ndarray) + assert approx(null_mask, empty_dataset.isnull("int"), "array") + assert approx(~null_mask, empty_dataset.notnull("int"), "array") empty_dataset.append_string_column("string", ["", "foo", "barbaz", "", ""] * 2) - assert approx(null_mask, empty_dataset.isnull("string"), np.ndarray) - assert approx(~null_mask, empty_dataset.notnull("string"), np.ndarray) + assert approx(null_mask, empty_dataset.isnull("string"), "array") + assert approx(~null_mask, empty_dataset.notnull("string"), "array") # Test simple nullable data types. empty_dataset.append_float_column( "float", [0, 1, np.nan, np.nan, np.nan, -1000, np.inf, -np.inf, 8, np.nan] @@ -344,8 +347,8 @@ def test_isnull_notnull(empty_dataset: Dataset) -> None: null_mask = np.array( [False, False, True, True, True, False, False, False, False, True] ) - assert approx(null_mask, empty_dataset.isnull("float"), np.ndarray) - assert approx(~null_mask, empty_dataset.notnull("float"), np.ndarray) + assert approx(null_mask, empty_dataset.isnull("float"), "array") + assert approx(~null_mask, empty_dataset.notnull("float"), "array") now = datetime.now() empty_dataset.append_datetime_column( "datetime", @@ -363,41 +366,42 @@ def test_isnull_notnull(empty_dataset: Dataset) -> None: ], optional=True, ) - assert approx(null_mask, empty_dataset.isnull("datetime"), np.ndarray) - assert approx(~null_mask, empty_dataset.notnull("datetime"), np.ndarray) + assert approx(null_mask, empty_dataset.isnull("datetime"), "array") + assert approx(~null_mask, empty_dataset.notnull("datetime"), "array") empty_dataset.append_categorical_column( "category", - ["foo", "foo", "", "", "", "barbaz", "barbaz", "barbaz", "foo", ""], + ["foo", "foo", None, None, None, "barbaz", "barbaz", "barbaz", "foo", None], optional=True, + categories=["barbaz", "foo"], ) - assert approx(null_mask, empty_dataset.isnull("category"), np.ndarray) - assert approx(~null_mask, empty_dataset.notnull("category"), np.ndarray) + assert approx(null_mask, empty_dataset.isnull("category"), "array") + assert approx(~null_mask, empty_dataset.notnull("category"), "array") # Test complex nullable data types. values = np.full(len(empty_dataset), None) values[~null_mask] = Embedding([0, 1, 2, 3]) empty_dataset.append_embedding_column("embedding", values, optional=True) - assert approx(null_mask, empty_dataset.isnull("embedding"), np.ndarray) - assert approx(~null_mask, empty_dataset.notnull("embedding"), np.ndarray) + assert approx(null_mask, empty_dataset.isnull("embedding"), "array") + assert approx(~null_mask, empty_dataset.notnull("embedding"), "array") values[~null_mask] = Sequence1D([0, 1, 2, 3]) empty_dataset.append_sequence_1d_column("sequence_1d", values, optional=True) - assert approx(null_mask, empty_dataset.isnull("sequence_1d"), np.ndarray) - assert approx(~null_mask, empty_dataset.notnull("sequence_1d"), np.ndarray) + assert approx(null_mask, empty_dataset.isnull("sequence_1d"), "array") + assert approx(~null_mask, empty_dataset.notnull("sequence_1d"), "array") values[~null_mask] = Mesh.empty() empty_dataset.append_mesh_column("mesh", values, optional=True) - assert approx(null_mask, empty_dataset.isnull("mesh"), np.ndarray) - assert approx(~null_mask, empty_dataset.notnull("mesh"), np.ndarray) + assert approx(null_mask, empty_dataset.isnull("mesh"), "array") + assert approx(~null_mask, empty_dataset.notnull("mesh"), "array") values[~null_mask] = Image.empty() empty_dataset.append_image_column("image", values, optional=True) - assert approx(null_mask, empty_dataset.isnull("image"), np.ndarray) - assert approx(~null_mask, empty_dataset.notnull("image"), np.ndarray) + assert approx(null_mask, empty_dataset.isnull("image"), "array") + assert approx(~null_mask, empty_dataset.notnull("image"), "array") values[~null_mask] = Audio.empty() empty_dataset.append_audio_column("audio", values, optional=True) - assert approx(null_mask, empty_dataset.isnull("audio"), np.ndarray) - assert approx(~null_mask, empty_dataset.notnull("audio"), np.ndarray) + assert approx(null_mask, empty_dataset.isnull("audio"), "array") + assert approx(~null_mask, empty_dataset.notnull("audio"), "array") values[~null_mask] = Video.empty() empty_dataset.append_video_column("video", values, optional=True) - assert approx(null_mask, empty_dataset.isnull("video"), np.ndarray) - assert approx(~null_mask, empty_dataset.notnull("video"), np.ndarray) + assert approx(null_mask, empty_dataset.isnull("video"), "array") + assert approx(~null_mask, empty_dataset.notnull("video"), "array") # Test window data type. windows = np.random.random((len(empty_dataset), 2)) null_mask = np.full(len(windows), False) @@ -409,8 +413,8 @@ def test_isnull_notnull(empty_dataset: Dataset) -> None: null_mask[0] = True null_mask[-1] = True empty_dataset.append_window_column("window", windows) - assert approx(null_mask, empty_dataset.isnull("window"), np.ndarray) - assert approx(~null_mask, empty_dataset.notnull("window"), np.ndarray) + assert approx(null_mask, empty_dataset.isnull("window"), "array") + assert approx(~null_mask, empty_dataset.notnull("window"), "array") def test_rename_column( @@ -426,7 +430,7 @@ def test_rename_column( with Dataset(output_h5_file, "w") as dataset: for sample in data: dataset.append_column( - sample.name, sample.column_type, sample.values, **sample.attrs + sample.name, sample.dtype_name, sample.values, **sample.attrs ) with Dataset(output_h5_file, "a") as dataset: @@ -442,7 +446,7 @@ def test_rename_column( name = f"{sample.name}_" assert name in dataset.keys() dataset_values = dataset[name] - column_type = sample.column_type + column_type = sample.dtype_name for value, dataset_value in zip(sample.values, dataset_values): assert approx(value, dataset_value, column_type) @@ -459,7 +463,7 @@ def test_getitem(simple_data: List[ColumnData], complex_data: List[ColumnData]) for sample in data: dataset.append_column( sample.name, - sample.column_type, + sample.dtype_name, sample.values, description=sample.description, **sample.attrs, @@ -468,12 +472,12 @@ def test_getitem(simple_data: List[ColumnData], complex_data: List[ColumnData]) with Dataset(output_h5_file, "r") as dataset: for sample in data: column_name, values = sample.name, sample.values - column_type = dataset.get_column_type(column_name) + dtype = dataset.get_dtype(column_name) # Test `dataset[column_name]` getter. dataset_values = list(dataset[column_name]) assert len(dataset_values) == len(values) for dataset_value, value in zip(dataset_values, values): - approx(value, dataset_value, column_type) + approx(value, dataset_value, dtype.name) for i in range(-len(dataset), len(dataset)): data_dict = {sample.name: sample.values[i] for sample in data} @@ -483,15 +487,15 @@ def test_getitem(simple_data: List[ColumnData], complex_data: List[ColumnData]) assert dataset_row.keys() == data_dict.keys() for key, value in data_dict.items(): dataset_value = dataset_row[key] - assert approx(value, dataset_value, dataset.get_column_type(key)) + assert approx(value, dataset_value, dataset.get_dtype(key).name) for key, value in data_dict.items(): - column_type = dataset.get_column_type(key) + dtype = dataset.get_dtype(key) # Test `dataset[column_name, row_index]` getter. dataset_value = dataset[key, i] - assert approx(value, dataset_value, column_type) + assert approx(value, dataset_value, dtype.name) # Test `dataset[row_index, column_name]` getter. dataset_value = dataset[i, key] - assert approx(value, dataset_value, column_type) + assert approx(value, dataset_value, dtype.name) def test_setitem(simple_data: List[ColumnData], complex_data: List[ColumnData]) -> None: @@ -505,7 +509,7 @@ def test_setitem(simple_data: List[ColumnData], complex_data: List[ColumnData]) for sample in data: dataset.append_column( sample.name, - sample.column_type, + sample.dtype_name, sample.values, description=sample.description, **sample.attrs, @@ -514,18 +518,18 @@ def test_setitem(simple_data: List[ColumnData], complex_data: List[ColumnData]) with Dataset(output_h5_file, "a") as dataset: for key in dataset.keys(): values = dataset[key] - if dataset.get_column_type(key) is not np.ndarray: + if dataset.get_dtype(key).name != "array": for value in values: dataset[key] = value dataset[key] = values for i in range(-len(dataset), len(dataset)): - dataset[i] = dataset[i] - dataset[0] = dataset[-3] - dataset[-3] = dataset[5] - dataset[2] = dataset[-1] + dataset[i] = dataset[i] # type: ignore + dataset[0] = dataset[-3] # type: ignore + dataset[-3] = dataset[5] # type: ignore + dataset[2] = dataset[-1] # type: ignore dataset.append_row(**dataset[-2]) dataset.append_row(**dataset[1]) - dataset[-7] = dataset[7] + dataset[-7] = dataset[7] # type: ignore for key in dataset.keys(): for i in range(-len(dataset), len(dataset)): dataset[key, i] = dataset[key, i] @@ -550,7 +554,7 @@ def test_delitem(simple_data: List[ColumnData], complex_data: List[ColumnData]) for sample in data: dataset.append_column( sample.name, - sample.column_type, + sample.dtype_name, sample.values, description=sample.description, **sample.attrs, @@ -603,7 +607,7 @@ def test_iterrows( for sample in data: dataset.append_column( sample.name, - sample.column_type, + sample.dtype_name, sample.values, description=sample.description, **sample.attrs, @@ -611,16 +615,17 @@ def test_iterrows( for dataset_row in dataset.iterrows(): assert dataset_row.keys() == set(dataset.keys()) for keys in (keys1, keys2, keys3, keys4, keys5, keys6): - for dataset_row in dataset.iterrows(keys): - assert dataset_row.keys() == set(keys) - for dataset_row in dataset.iterrows(sample.name for sample in data): - assert dataset_row.keys() == set(sample.name for sample in data) + for dataset_row1 in dataset.iterrows(keys): + assert dataset_row1.keys() == set(keys) # type: ignore + for dataset_row2 in dataset.iterrows(sample.name for sample in data): + assert dataset_row2.keys() == set(sample.name for sample in data) # type: ignore for sample in data: values = sample.values - column_type = dataset.get_column_type(sample.name) + dtype_name = dataset.get_dtype(sample.name).name dataset_values = dataset.iterrows(sample.name) for value, dataset_value in zip(values, dataset_values): - assert approx(value, dataset_value, column_type) + dataset_value = cast(Optional[OutputType], dataset_value) + assert approx(value, dataset_value, dtype_name) def test_append_dataset( @@ -636,7 +641,7 @@ def test_append_dataset( for sample in data: dataset.append_column( sample.name, - sample.column_type, + sample.dtype_name, sample.values, description=sample.description, **sample.attrs, @@ -655,7 +660,7 @@ def test_append_dataset( for sample in data: dataset.append_column( sample.name, - sample.column_type, + sample.dtype_name, sample.values, description=sample.description, **sample.attrs, @@ -693,7 +698,7 @@ def test_copy_column( for sample in data: dataset.append_column( sample.name, - sample.column_type, + sample.dtype_name, sample.values, description=sample.description, **sample.attrs, @@ -701,10 +706,10 @@ def test_copy_column( with Dataset(output_h5_file, "a") as dataset: for name in dataset.keys(): - column_type = dataset.get_column_type(name) + dtype = dataset.get_dtype(name) kwargs = dataset.get_column_attributes(name) values = dataset[name] - dataset.append_column(f"new {name}", column_type, values, **kwargs) + dataset.append_column(f"new {name}", dtype, values, **kwargs) def test_pop(simple_data: List[ColumnData], complex_data: List[ColumnData]) -> None: @@ -721,7 +726,7 @@ def test_pop(simple_data: List[ColumnData], complex_data: List[ColumnData]) -> N for sample in data: dataset.append_column( sample.name, - sample.column_type, + sample.dtype_name, sample.values, description=sample.description, **sample.attrs, @@ -730,12 +735,12 @@ def test_pop(simple_data: List[ColumnData], complex_data: List[ColumnData]) -> N with Dataset(output_h5_file, "a") as dataset: for name in np.random.choice(list(names), 5, replace=False): - type_ = dataset.get_column_type(name) + dtype = dataset.get_dtype(name) actual_values = dataset.pop(name) for target, actual in zip( next((x.values for x in data if x.name == name)), actual_values ): - assert approx(target, actual, type_) + assert approx(target, actual, dtype.name) assert set(dataset.keys()) == names.difference({name}) names.discard(name) @@ -745,7 +750,7 @@ def test_pop(simple_data: List[ColumnData], complex_data: List[ColumnData]) -> N actual = dataset.pop(index) assert target.keys() == actual.keys() for key, value in target.items(): - assert approx(value, actual[key], dataset.get_column_type(key)) + assert approx(value, actual[key], dataset.get_dtype(key).name) assert len(dataset) == length - 1 length -= 1 assert len(dataset) == 0 @@ -765,7 +770,7 @@ def test_set_attributes_column( for sample in data: dataset.append_column( sample.name, - sample.column_type, + sample.dtype_name, sample.values, default=sample.values[0], description=sample.description, @@ -774,13 +779,13 @@ def test_set_attributes_column( with Dataset(output_h5_file, "a") as dataset: for name in dataset.keys(): - column_type = dataset.get_column_type(name) + dtype = dataset.get_dtype(name) kwargs = dataset.get_column_attributes(name) values = dataset[name] name = f"new_{name}" dataset.append_column( name, - column_type, + dtype, values, **( {"categories": kwargs["categories"]} @@ -803,7 +808,7 @@ def test_set_attributes_column( approx( column_attribute, column_attributes_new[attribute_key], - dataset.get_column_type(column_name), + dataset.get_dtype(column_name).name, ) else: assert ( @@ -839,25 +844,25 @@ def test_import_pandas_with_dtype() -> None: Test `Dataset.import_pandas` with defined `dtype` argument. """ df = pd.read_csv("build/datasets/multimodal-random-1000.csv") - dtype = { - "audio": str, - "image": str, - "mesh": str, - "video": str, - "embedding": Embedding, - "window": Window, - "bool": bool, - "int": int, - "float": float, - "str": str, - "datetime": datetime, - "category": Category, + dtypes = { + "audio": "str", + "image": "str", + "mesh": "str", + "video": "str", + "embedding": "Embedding", + "window": "Window", + "bool": "bool", + "int": "int", + "float": "float", + "str": "str", + "datetime": "datetime", + "category": "Category", } with tempfile.TemporaryDirectory() as output_folder: output_h5_file = os.path.join(output_folder, "dataset.h5") with Dataset(output_h5_file, "w") as dataset: - dataset.from_pandas(df, dtype=dtype) - assert dtype == {key: dataset.get_column_type(key) for key in dtype} + dataset.from_pandas(df, dtypes=dtypes) + assert dtypes == {key: dataset.get_dtype(key).name for key in dtypes} def test_import_csv() -> None: @@ -893,8 +898,8 @@ def test_import_csv_with_dtype() -> None: df["int"] = df["int1"] = df["int2"] = np.random.randint(-1000, 1000, 10) df["float"] = df["float1"] = np.random.random(10) indices = np.random.choice(len(df), 3, replace=False) - df.loc[indices, "float"] = [np.nan, np.inf, -np.inf] - df.loc[indices, "float1"] = [np.nan, np.inf, -np.inf] + df.loc[indices, "float"] = [np.nan, np.inf, -np.inf] # type: ignore + df.loc[indices, "float1"] = [np.nan, np.inf, -np.inf] # type: ignore df["string"] = df["string1"] = [ "".join( np.random.choice(list(string.ascii_letters), np.random.randint(1, 22)) @@ -968,37 +973,37 @@ def test_import_csv_with_dtype() -> None: df["image3"] = df["image2"] df["mesh3"] = df["mesh2"] df["video3"] = df["video2"] - optional_or_nan_columns = list(df.columns.difference(set(columns))) + optional_or_nan_columns: List[str] = list(df.columns.difference(columns)) - df.loc[indices, optional_or_nan_columns] = "" + df.loc[indices, optional_or_nan_columns] = "" # type: ignore df.to_csv(csv_file, index=False) dtypes = { - "string1": Category, - "datetime1": datetime, + "string1": "Category", + "datetime1": "datetime", } with Dataset(output_h5_file, "w") as dataset: dataset.from_csv(csv_file, dtypes) assert set(dataset.keys()) == set(df.keys()) - assert {key: dataset.get_column_type(key) for key in dtypes} == dtypes + assert {key: dataset.get_dtype(key).name for key in dtypes} == dtypes columns += optional_or_nan_columns dtypes.update( { - "string2": Category, - "datetime2": datetime, - "array5": np.ndarray, - "array6": Embedding, - "array7": Sequence1D, - "array8": np.ndarray, - "audio3": Audio, - "image3": Image, - "mesh3": Mesh, - "video3": Video, + "string2": "Category", + "datetime2": "datetime", + "array5": "array", + "array6": "Embedding", + "array7": "Sequence1D", + "array8": "array", + "audio3": "Audio", + "image3": "Image", + "mesh3": "Mesh", + "video3": "Video", } ) with Dataset(output_h5_file, "w") as dataset: dataset.from_csv(csv_file, dtypes) assert set(dataset.keys()) == set(columns) - assert {key: dataset.get_column_type(key) for key in dtypes} == dtypes + assert {key: dataset.get_dtype(key).name for key in dtypes} == dtypes diff --git a/tests/integration/dataset/test_fancy_indexing.py b/tests/integration/dataset/test_fancy_indexing.py index 0dd458cf..30a58772 100644 --- a/tests/integration/dataset/test_fancy_indexing.py +++ b/tests/integration/dataset/test_fancy_indexing.py @@ -118,13 +118,13 @@ def test_delitem(fancy_indexing_dataset: Dataset) -> None: length = len(fancy_indexing_dataset) column_name = fancy_indexing_dataset.keys()[0] - column_type = fancy_indexing_dataset.get_column_type(column_name) + dtype = fancy_indexing_dataset.get_dtype(column_name) target = np.array(fancy_indexing_dataset[column_name]) def _restore_column() -> None: del fancy_indexing_dataset[column_name] assert column_name not in fancy_indexing_dataset.keys() - fancy_indexing_dataset.append_column(column_name, column_type, target) + fancy_indexing_dataset.append_column(column_name, dtype, target) assert column_name in fancy_indexing_dataset.keys() assert (target == fancy_indexing_dataset[column_name]).all() diff --git a/tests/integration/dataset/test_prune.py b/tests/integration/dataset/test_prune.py index 5d3a15ec..a180788e 100644 --- a/tests/integration/dataset/test_prune.py +++ b/tests/integration/dataset/test_prune.py @@ -31,7 +31,7 @@ def test_prune( for sample in data: dataset.append_column( sample.name, - sample.column_type, + sample.dtype_name, description=sample.description, **sample.attrs, ) diff --git a/tests/integration/formats/test_formats.py b/tests/integration/formats/test_formats.py index e89fc7be..592f8745 100644 --- a/tests/integration/formats/test_formats.py +++ b/tests/integration/formats/test_formats.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize("extension", ["csv", "feather", "parquet", "orc"]) -def test_successful_load(extension: str, project_root: Path): +def test_successful_load(extension: str, project_root: Path) -> None: """ Check if the file loads without error when calling spotlight.show """ diff --git a/tests/integration/h5/conftest.py b/tests/integration/h5/conftest.py index ac704cc6..0c3515eb 100644 --- a/tests/integration/h5/conftest.py +++ b/tests/integration/h5/conftest.py @@ -24,7 +24,7 @@ def dataset_path() -> Iterator[str]: for col, (dtype, values) in COLUMNS.items(): dataset.append_column( col, - column_type=dtype, + dtype=dtype, values=values, optional=dtype not in (int, bool), ) diff --git a/tests/integration/h5/data.py b/tests/integration/h5/data.py index 3ac8befa..6c5a3f18 100644 --- a/tests/integration/h5/data.py +++ b/tests/integration/h5/data.py @@ -9,25 +9,19 @@ COLUMNS = { - "bool": (bool, [True, False]), - "int": (int, [0, 1]), - "float": (float, [0.0, np.nan]), - "str": (str, ["foobar", ""]), - "datetime": (datetime.datetime, [datetime.datetime.min, np.datetime64("NaT")]), - "categorical": (spotlight.Category, ["foo", "bar"]), - "array": (np.ndarray, [[[0]], [1, 2, 3]]), - "window": (spotlight.Window, [[0, 1], [-np.inf, np.nan]]), - "embedding": (spotlight.Embedding, [[1, 2, 3], [4, np.nan, 5]]), - "sequence": ( - spotlight.Sequence1D, - [[[1, 2, 3], [2, 3, 4]], [[2, 3], [5, 5]]], - ), - "optional_sequence": ( - spotlight.Sequence1D, - [[[1, 2, 3], [2, 3, 4]], None], - ), - "image": (spotlight.Image, [spotlight.Image.empty(), None]), - "audio": (spotlight.Audio, [spotlight.Audio.empty(), None]), - "video": (spotlight.Video, [spotlight.Video.empty(), None]), - "mesh": (spotlight.Mesh, [spotlight.Mesh.empty(), None]), + "bool": ("bool", [True, False]), + "int": ("int", [0, 1]), + "float": ("float", [0.0, np.nan]), + "str": ("str", ["foobar", ""]), + "datetime": ("datetime", [datetime.datetime.min, np.datetime64("NaT")]), + "categorical": ("Category", ["foo", "bar"]), + "array": ("array", [[[0]], [1, 2, 3]]), + "window": ("Window", [[0, 1], [-np.inf, np.nan]]), + "embedding": ("Embedding", [[1, 2, 3], [4, np.nan, 5]]), + "sequence": ("Sequence1D", [[[1, 2, 3], [2, 3, 4]], [[2, 3], [5, 5]]]), + "optional_sequence": ("Sequence1D", [[[1, 2, 3], [2, 3, 4]], None]), + "image": ("Image", [spotlight.Image.empty(), None]), + "audio": ("Audio", [spotlight.Audio.empty(), None]), + "video": ("Video", [spotlight.Video.empty(), None]), + "mesh": ("Mesh", [spotlight.Mesh.empty(), None]), } diff --git a/tests/integration/h5/test_h5.py b/tests/integration/h5/test_h5.py index 4981897a..b691d401 100644 --- a/tests/integration/h5/test_h5.py +++ b/tests/integration/h5/test_h5.py @@ -6,7 +6,7 @@ import pytest import httpx from renumics import spotlight -from renumics.spotlight.dtypes.typing import ColumnType +from renumics.spotlight.dataset.typing import OutputType from .data import COLUMNS @@ -31,7 +31,7 @@ def test_get_table_returns_http_ok(dataset_path: str) -> None: ("embedding", np.ndarray), ], ) -def test_custom_dtypes(dataset_path: str, col: str, dtype: Type[ColumnType]) -> None: +def test_custom_dtypes(dataset_path: str, col: str, dtype: Type[OutputType]) -> None: """ Test h5 data source with custom dtypes. """ diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py new file mode 100644 index 00000000..c2578f3a --- /dev/null +++ b/tests/integration/helpers.py @@ -0,0 +1,106 @@ +from typing import Optional + +import numpy as np + +from renumics.spotlight.dataset import VALUE_TYPE_BY_DTYPE_NAME +from renumics.spotlight.dataset.typing import ColumnInputType, OutputType +from renumics.spotlight.media import Embedding, Sequence1D, Audio, Image, Mesh, Video +from renumics.spotlight import dtypes + + +def approx( + expected: ColumnInputType, actual: Optional[OutputType], dtype_name: str +) -> bool: + """ + Check whether expected and actual dataset values are almost equal. + """ + if actual is None and expected is None: + return True + + dtype = dtypes.create_dtype(dtype_name) + if dtypes.is_scalar_dtype(dtype) or dtypes.is_str_dtype(dtype): + value_type = VALUE_TYPE_BY_DTYPE_NAME[dtype_name] + expected_value: np.ndarray = np.array(expected, dtype=value_type) + actual_value: np.ndarray = np.array(actual, dtype=value_type) + return approx(expected_value, actual_value, "array") + if dtypes.is_datetime_dtype(dtype): + expected_datetime = np.array(expected, dtype="datetime64") + actual_datetime = np.array(actual, dtype="datetime64") + return approx(expected_datetime, actual_datetime, "array") + if dtypes.is_category_dtype(dtype): + expected_category = np.array(expected, dtype="str") + actual_category = np.array(actual, dtype="str") + return approx(expected_category, actual_category, "array") + if dtypes.is_array_dtype(dtype): + expected_array = np.asarray(expected) + assert isinstance(actual, np.ndarray) + if actual.shape != expected_array.shape: + return False + if issubclass(expected_array.dtype.type, np.inexact): + return np.allclose(actual, expected_array, equal_nan=True) + return actual.tolist() == expected_array.tolist() + if dtypes.is_window_dtype(dtype): + expected_window = np.asarray(expected, dtype=float) + assert isinstance(actual, np.ndarray) + return approx(expected_window, actual, "array") + if dtypes.is_embedding_dtype(dtype): + if isinstance(expected, Embedding): + expected_embedding = expected + else: + expected_embedding = Embedding(expected) # type: ignore + assert isinstance(actual, np.ndarray) + return approx(expected_embedding.data, actual, "array") + if dtypes.is_sequence_1d_dtype(dtype): + if isinstance(expected, Sequence1D): + expected_sequence_1d = expected + else: + expected_sequence_1d = Sequence1D(expected) # type: ignore + assert isinstance(actual, Sequence1D) + return approx(expected_sequence_1d.index, actual.index, "array") and approx( + expected_sequence_1d.value, actual.value, "array" + ) + if dtypes.is_audio_dtype(dtype): + assert isinstance(expected, Audio) + assert isinstance(actual, Audio) + return ( + approx(expected.data, actual.data, "array") + and actual.sampling_rate == expected.sampling_rate + ) + if dtypes.is_image_dtype(dtype): + if isinstance(expected, Image): + expected_image = expected + else: + expected_image = Image(expected) # type: ignore + assert isinstance(actual, Image) + return approx(expected_image.data, actual.data, "array") + if dtypes.is_mesh_dtype(dtype): + assert isinstance(expected, Mesh) + assert isinstance(actual, Mesh) + return ( + approx(expected.points, actual.points, "array") + and approx(expected.triangles, actual.triangles, "array") + and len(actual.point_displacements) == len(expected.point_displacements) + and all( + approx(expected_displacement, actual_displacement, "array") + for actual_displacement, expected_displacement in zip( + actual.point_displacements, expected.point_displacements + ) + ) + and actual.point_attributes.keys() == expected.point_attributes.keys() + and all( + approx( + point_attribute, + actual.point_attributes[attribute_name], + "array", + ) + for attribute_name, point_attribute in expected.point_attributes.items() + ) + ) + if dtypes.is_video_dtype(dtype): + if isinstance(expected, Video): + expected_video = expected + else: + expected_video = Video(expected) # type: ignore + assert isinstance(actual, Video) + return actual.data == expected_video.data + raise TypeError(f"Invalid type name {dtype_name} received.") diff --git a/tests/integration/layout/test_setting.py b/tests/integration/layout/test_setting.py index b0ed1a9c..ad6f1570 100644 --- a/tests/integration/layout/test_setting.py +++ b/tests/integration/layout/test_setting.py @@ -11,7 +11,7 @@ from renumics.spotlight import settings -def test_settings_layout_is_used(monkeypatch: MonkeyPatch): +def test_settings_layout_is_used(monkeypatch: MonkeyPatch) -> None: """ Test if the layout set via env var is actually used in the frontend. """ @@ -32,7 +32,7 @@ def test_settings_layout_is_used(monkeypatch: MonkeyPatch): assert "env_layout_table" in app_layout -def test_layout_from_params_has_priority(monkeypatch: MonkeyPatch): +def test_layout_from_params_has_priority(monkeypatch: MonkeyPatch) -> None: """ Test if the layout set via layout= param in spotlight.show is preferred over the layout set via env. diff --git a/tests/integration/dataset/io/__init__.py b/tests/integration/media/__init__.py similarity index 100% rename from tests/integration/dataset/io/__init__.py rename to tests/integration/media/__init__.py diff --git a/tests/integration/media/data.py b/tests/integration/media/data.py new file mode 100644 index 00000000..c571b1c5 --- /dev/null +++ b/tests/integration/media/data.py @@ -0,0 +1 @@ +BASE_URL = "https://spotlightpublic.blob.core.windows.net/internal-test-data/" diff --git a/tests/integration/media/test_audio.py b/tests/integration/media/test_audio.py new file mode 100644 index 00000000..98f152f3 --- /dev/null +++ b/tests/integration/media/test_audio.py @@ -0,0 +1,33 @@ +""" +Test `renumics.spotlight.media.Audio` class. +""" +from urllib.parse import urljoin + +import pytest + +from renumics.spotlight.media import Audio + +from .data import BASE_URL + + +@pytest.mark.parametrize( + "filename", + [ + "gs-16b-2c-44100hz.aac", + "gs-16b-2c-44100hz.ac3", + "gs-16b-2c-44100hz.aiff", + "gs-16b-2c-44100hz.flac", + "gs-16b-2c-44100hz.mp3", + "gs-16b-2c-44100hz.ogg", + "gs-16b-2c-44100hz.ogx", + "gs-16b-2c-44100hz.ts", + "gs-16b-2c-44100hz.wav", + "gs-16b-2c-44100hz.wma", + ], +) +def test_audio_from_url(filename: str) -> None: + """ + Test reading audio from a URL. + """ + url = urljoin(BASE_URL, filename) + _ = Audio.from_file(url) diff --git a/tests/integration/media/test_image.py b/tests/integration/media/test_image.py new file mode 100644 index 00000000..34088192 --- /dev/null +++ b/tests/integration/media/test_image.py @@ -0,0 +1,34 @@ +""" +Test `renumics.spotlight.media.Image` class. +""" +from urllib.parse import urljoin + +import pytest + +from renumics.spotlight.media import Image + +from .data import BASE_URL + + +@pytest.mark.parametrize( + "filename", + [ + "nature-256p.ico", + "nature-360p.bmp", + "nature-360p.gif", + "nature-360p.jpg", + "nature-360p.png", + "nature-360p.tif", + "nature-360p.webp", + "nature-720p.jpg", + "nature-1080p.jpg", + "sea-360p.gif", + "sea-360p.apng", + ], +) +def test_image_from_url(filename: str) -> None: + """ + Test reading image from a URL. + """ + url = urljoin(BASE_URL, filename) + _ = Image.from_file(url) diff --git a/tests/integration/media/test_mesh.py b/tests/integration/media/test_mesh.py new file mode 100644 index 00000000..2de1cf06 --- /dev/null +++ b/tests/integration/media/test_mesh.py @@ -0,0 +1,30 @@ +""" +Test `renumics.spotlight.media.Mesh` class. +""" +from urllib.parse import urljoin + +import pytest + +from renumics.spotlight.media import Mesh + +from .data import BASE_URL + + +@pytest.mark.parametrize( + "filename", + [ + "tree.ascii.stl", + "tree.glb", + "tree.gltf", + "tree.obj", + "tree.off", + "tree.ply", + "tree.stl", + ], +) +def test_mesh_from_url(filename: str) -> None: + """ + Test reading mesh from a URL. + """ + url = urljoin(BASE_URL, filename) + _ = Mesh.from_file(url) diff --git a/tests/integration/media/test_video.py b/tests/integration/media/test_video.py new file mode 100644 index 00000000..7d05c5c1 --- /dev/null +++ b/tests/integration/media/test_video.py @@ -0,0 +1,32 @@ +""" +Test `renumics.spotlight.media.Mesh` class. +""" +from urllib.parse import urljoin + +import pytest + +from renumics.spotlight.media import Video + +from .data import BASE_URL + + +@pytest.mark.parametrize( + "filename", + [ + "sea-360p.avi", + "sea-360p.mkv", + "sea-360p.mov", + "sea-360p.mp4", + "sea-360p.mpg", + "sea-360p.ogg", + "sea-360p.webm", + "sea-360p.wmv", + "sea-360p-10s.mp4", + ], +) +def test_video_from_url(filename: str) -> None: + """ + Test reading video from an URL. + """ + url = urljoin(BASE_URL, filename) + _ = Video.from_file(url) diff --git a/tests/unit/dtypes/test_conversion.py b/tests/unit/dtypes/test_conversion.py index 633a2dec..391f772a 100644 --- a/tests/unit/dtypes/test_conversion.py +++ b/tests/unit/dtypes/test_conversion.py @@ -2,17 +2,15 @@ Tests for conversions from source to internal types """ -from typing import Any, Dict, Type, Union +from typing import Any, Dict, Union from pathlib import Path import io import datetime import pytest import numpy as np import PIL.Image +from renumics.spotlight.dtypes.conversion import convert_to_dtype from renumics.spotlight import dtypes -from renumics.spotlight.dtypes.conversion import convert_to_dtype, DTypeOptions - -from renumics.spotlight.dtypes.typing import ColumnType @pytest.mark.parametrize( @@ -33,7 +31,7 @@ def test_conversion_to_int(value: Any, target_value: int) -> None: """ Convert values to int """ - assert convert_to_dtype(value, int) == target_value + assert convert_to_dtype(value, dtypes.int_dtype) == target_value @pytest.mark.parametrize( @@ -58,7 +56,7 @@ def test_conversion_to_float(value: Any, target_value: int) -> None: """ Convert values to float """ - assert convert_to_dtype(value, float) == target_value + assert convert_to_dtype(value, dtypes.float_dtype) == target_value @pytest.mark.parametrize( @@ -79,7 +77,7 @@ def test_conversion_to_bool(value: Any, target_value: int) -> None: """ Convert values to bool """ - assert convert_to_dtype(value, bool) == target_value + assert convert_to_dtype(value, dtypes.bool_dtype) == target_value @pytest.mark.parametrize( @@ -100,7 +98,7 @@ def test_conversion_to_datetime(value: Any, target_value: datetime.datetime) -> """ Convert values to datetime """ - assert convert_to_dtype(value, datetime.datetime) == target_value + assert convert_to_dtype(value, dtypes.datetime_dtype) == target_value @pytest.mark.parametrize( @@ -121,7 +119,7 @@ def test_conversion_to_category( Convert values to category """ assert ( - convert_to_dtype(value, dtypes.Category, DTypeOptions(categories=categories)) + convert_to_dtype(value, dtypes.CategoryDType(categories=categories)) == target_value ) @@ -139,7 +137,7 @@ def test_conversion_to_array(value: Any, target_value: np.ndarray) -> None: """ Convert values to array """ - assert np.array_equal(convert_to_dtype(value, np.ndarray), target_value) # type: ignore + assert np.array_equal(convert_to_dtype(value, dtypes.array_dtype), target_value) # type: ignore @pytest.mark.parametrize( @@ -154,7 +152,7 @@ def test_conversion_to_window(value: Any, target_value: np.ndarray) -> None: Convert values to window """ assert np.array_equal( - convert_to_dtype(value, dtypes.Window), target_value, equal_nan=True # type: ignore + convert_to_dtype(value, dtypes.window_dtype), target_value, equal_nan=True # type: ignore ) @@ -170,7 +168,7 @@ def test_conversion_to_embedding(value: Any, target_value: np.ndarray) -> None: """ Convert values to embedding """ - assert np.array_equal(convert_to_dtype(value, dtypes.Embedding), target_value) # type: ignore + assert np.array_equal(convert_to_dtype(value, dtypes.embedding_dtype), target_value) # type: ignore @pytest.mark.parametrize( @@ -190,7 +188,9 @@ def test_conversion_to_sequence(value: Any, target_value: np.ndarray) -> None: """ Convert values to sequence """ - assert np.array_equal(convert_to_dtype(value, dtypes.Sequence1D), target_value) # type: ignore + assert np.array_equal( + convert_to_dtype(value, dtypes.sequence_1d_dtype), target_value # type: ignore + ) @pytest.mark.parametrize( @@ -206,7 +206,7 @@ def test_conversion_to_image(value: Union[str, bytes]) -> None: """ Convert values to image """ - image_bytes = convert_to_dtype(value, dtypes.Image) + image_bytes = convert_to_dtype(value, dtypes.image_dtype) image = PIL.Image.open(io.BytesIO(image_bytes)) # type: ignore assert image.width > 0 @@ -223,7 +223,7 @@ def test_conversion_to_audio(value: Union[str, bytes]) -> None: """ Convert values to audio """ - audio_bytes = convert_to_dtype(value, dtypes.Audio) + audio_bytes = convert_to_dtype(value, dtypes.audio_dtype) assert len(audio_bytes) > 0 # type: ignore @@ -237,7 +237,7 @@ def test_conversion_to_video(value: Union[str, bytes]) -> None: """ Convert values to video """ - video_bytes = convert_to_dtype(value, dtypes.Video) + video_bytes = convert_to_dtype(value, dtypes.video_dtype) assert len(video_bytes) > 0 # type: ignore @@ -251,36 +251,48 @@ def test_conversion_to_mesh(value: Union[str, bytes]) -> None: """ Convert values to mesh """ - mesh_bytes = convert_to_dtype(value, dtypes.Mesh) + mesh_bytes = convert_to_dtype(value, dtypes.mesh_dtype) assert len(mesh_bytes) > 0 # type: ignore @pytest.mark.parametrize( "dtype,value,target_value", [ - (bool, True, True), - (int, 42, 42), - (float, 1.0, 1.0), - (str, "foobar", "foobar"), - (str, "foobar" * 20, ("foobar" * 20)[:97] + "..."), - (datetime.datetime, datetime.datetime.min, datetime.datetime.min), - (np.ndarray, np.array([1, 2, 3]), "[...]"), - (np.ndarray, [], "[...]"), - (np.ndarray, None, None), - (dtypes.Embedding, np.array([1, 2, 3]), "[...]"), - (dtypes.Sequence1D, np.array([1, 2, 3]), "[...]"), - (dtypes.Image, np.array([[0.5, 0.7], [0.5, 0.7]]), "[...]"), + (dtypes.bool_dtype, True, True), + (dtypes.int_dtype, 42, 42), + (dtypes.float_dtype, 1.0, 1.0), + (dtypes.str_dtype, "foobar", "foobar"), + (dtypes.str_dtype, "foobar" * 20, ("foobar" * 20)[:97] + "..."), + (dtypes.datetime_dtype, datetime.datetime.min, datetime.datetime.min), + (dtypes.array_dtype, np.array([1, 2, 3]), "[...]"), + (dtypes.array_dtype, [], "[...]"), + (dtypes.array_dtype, None, None), + (dtypes.embedding_dtype, np.array([1, 2, 3]), "[...]"), + (dtypes.sequence_1d_dtype, np.array([1, 2, 3]), "[...]"), + (dtypes.image_dtype, np.array([[0.5, 0.7], [0.5, 0.7]]), "[...]"), ( - dtypes.Image, + dtypes.image_dtype, "./data/images/nature-360p.jpg", "./data/images/nature-360p.jpg", ), - (dtypes.Audio, "./data/audio/1.wav", "./data/audio/1.wav"), - (dtypes.Video, "./data/videos/sea-360p.ogg", "./data/videos/sea-360p.ogg"), - (dtypes.Mesh, "./data/meshes/tree.glb", "./data/meshes/tree.glb"), - (dtypes.Image, Path("./data/images/nature-360p.jpg").read_bytes(), ""), - (dtypes.Audio, Path("./data/audio/1.wav").read_bytes(), ""), - (dtypes.Video, Path("./data/videos/sea-360p.ogg").read_bytes(), ""), + (dtypes.audio_dtype, "./data/audio/1.wav", "./data/audio/1.wav"), + ( + dtypes.audio_dtype, + "./data/videos/sea-360p.ogg", + "./data/videos/sea-360p.ogg", + ), + (dtypes.mesh_dtype, "./data/meshes/tree.glb", "./data/meshes/tree.glb"), + ( + dtypes.image_dtype, + Path("./data/images/nature-360p.jpg").read_bytes(), + "", + ), + (dtypes.audio_dtype, Path("./data/audio/1.wav").read_bytes(), ""), + ( + dtypes.video_dtype, + Path("./data/videos/sea-360p.ogg").read_bytes(), + "", + ), ], ids=[ "bool", @@ -304,7 +316,7 @@ def test_conversion_to_mesh(value: Union[str, bytes]) -> None: "video-bytes", ], ) -def test_simple_conversion(dtype: Type[ColumnType], value: Any, target_value: Any): +def test_simple_conversion(dtype: dtypes.DType, value: Any, target_value: Any) -> None: """ Convert values for simple view. """ diff --git a/tests/unit/io/__init__.py b/tests/unit/io/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/dataset/io/test_audio.py b/tests/unit/io/test_audio.py similarity index 100% rename from tests/integration/dataset/io/test_audio.py rename to tests/unit/io/test_audio.py diff --git a/tests/unit/media/__init__.py b/tests/unit/media/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/media/data.py b/tests/unit/media/data.py new file mode 100644 index 00000000..d9fb1e22 --- /dev/null +++ b/tests/unit/media/data.py @@ -0,0 +1,10 @@ +from pathlib import Path + + +DATA_FOLDER = Path("data") +AUDIO_FOLDER = DATA_FOLDER / "audio" +IMAGES_FOLDER = DATA_FOLDER / "images" +MESHES_FOLDER = DATA_FOLDER / "meshes" +VIDEOS_FOLDER = DATA_FOLDER / "videos" + +SEED = 42 diff --git a/tests/unit/media/test_audio.py b/tests/unit/media/test_audio.py new file mode 100644 index 00000000..356d6a8c --- /dev/null +++ b/tests/unit/media/test_audio.py @@ -0,0 +1,213 @@ +""" +Test `renumics.spotlight.media.Audio` class. +""" +import io +from typing import Optional + +import numpy as np +import pytest + +from renumics.spotlight.media import Audio + +from .data import AUDIO_FOLDER + + +@pytest.mark.parametrize("sampling_rate", [100, 8000, 44100, 48000, 96000]) +@pytest.mark.parametrize("channels", [1, 2, 5]) +@pytest.mark.parametrize("length", [0.5, 1.0, 2.0]) +@pytest.mark.parametrize("dtype", ["f4", "i2", "u1"]) +@pytest.mark.parametrize("target", ["wav", "flac"]) +def test_lossless_audio( + sampling_rate: int, channels: int, length: float, dtype: str, target: Optional[str] +) -> None: + """ + Test audio creation and lossless saving. + """ + + time = np.linspace(0.0, length, round(sampling_rate * length)) + y = 0.4 * np.sin(2.0 * np.pi * 100 * time) + if dtype.startswith("f"): + data = y.astype(dtype) + elif dtype.startswith("i"): + data = (y * np.iinfo(dtype).max).astype(dtype) + elif dtype.startswith("u"): + data = ((y + 1) * np.iinfo(dtype).max / 2).astype(dtype) + else: + assert False + if channels > 1: + data = np.broadcast_to(data[:, np.newaxis], (len(data), channels)) + audio = Audio(sampling_rate, data) + + encoded_audio = audio.encode(target) + decoded_audio = Audio.decode(encoded_audio) + + decoded_sr = decoded_audio.sampling_rate + decoded_data = decoded_audio.data + + assert decoded_sr == sampling_rate + assert decoded_data.shape == (len(y), channels) + + +@pytest.mark.parametrize("sampling_rate", [32000, 44100, 48000]) +@pytest.mark.parametrize("channels", [1, 2]) +@pytest.mark.parametrize("length", [0.5, 1.0, 2.0]) +@pytest.mark.parametrize("dtype", ["f4", "i2", "u1"]) +def test_lossy_audio( + sampling_rate: int, channels: int, length: float, dtype: str +) -> None: + """ + Test audio creation and lossy saving. + """ + time = np.linspace(0.0, length, round(sampling_rate * length)) + y = 0.4 * np.sin(2.0 * np.pi * 100 * time) + if dtype.startswith("f"): + data = y.astype(dtype) + elif dtype.startswith("i"): + data = (y * np.iinfo(dtype).max).astype(dtype) + elif dtype.startswith("u"): + data = ((y + 1) * np.iinfo(dtype).max / 2).astype(dtype) + else: + assert False + if channels > 1: + data = np.broadcast_to(data[:, np.newaxis], (len(data), channels)) + audio = Audio(sampling_rate, data) + + encoded_audio = audio.encode("ogg") + decoded_audio = Audio.decode(encoded_audio) + + decoded_sr = decoded_audio.sampling_rate + _decoded_data = decoded_audio.data + + assert decoded_sr == sampling_rate + + +@pytest.mark.parametrize( + "filename", + [ + "gs-16b-1c-44100hz.aac", + "gs-16b-1c-44100hz.ac3", + "gs-16b-1c-44100hz.aiff", + "gs-16b-1c-44100hz.flac", + "gs-16b-1c-44100hz.m4a", + "gs-16b-1c-44100hz.mp3", + "gs-16b-1c-44100hz.ogg", + "gs-16b-1c-44100hz.ogx", + "gs-16b-1c-44100hz.wav", + "gs-16b-1c-44100hz.wma", + ], +) +def test_audio_from_filepath_mono(filename: str) -> None: + """ + Test `Audio.from_file` method on mono data. + """ + filepath = AUDIO_FOLDER / "mono" / filename + assert filepath.is_file() + _ = Audio.from_file(str(filepath)) + _ = Audio.from_file(filepath) + + +@pytest.mark.parametrize( + "filename", + [ + "gs-16b-2c-44100hz.aac", + "gs-16b-2c-44100hz.ac3", + "gs-16b-2c-44100hz.aiff", + "gs-16b-2c-44100hz.flac", + "gs-16b-2c-44100hz.m4a", + "gs-16b-2c-44100hz.mp3", + "gs-16b-2c-44100hz.mp4", + "gs-16b-2c-44100hz.ogg", + "gs-16b-2c-44100hz.ogx", + "gs-16b-2c-44100hz.wav", + "gs-16b-2c-44100hz.wma", + ], +) +def test_audio_from_filepath_stereo(filename: str) -> None: + """ + Test `Audio.from_file` method on stereo data. + """ + filepath = AUDIO_FOLDER / "stereo" / filename + assert filepath.is_file() + _ = Audio.from_file(str(filepath)) + _ = Audio.from_file(filepath) + + +@pytest.mark.parametrize( + "filename", + [ + "gs-16b-2c-44100hz.aac", + "gs-16b-2c-44100hz.ac3", + "gs-16b-2c-44100hz.aiff", + "gs-16b-2c-44100hz.flac", + "gs-16b-2c-44100hz.m4a", + "gs-16b-2c-44100hz.mp3", + "gs-16b-2c-44100hz.mp4", + "gs-16b-2c-44100hz.ogg", + "gs-16b-2c-44100hz.ogx", + "gs-16b-2c-44100hz.wav", + "gs-16b-2c-44100hz.wma", + ], +) +def test_audio_from_file(filename: str) -> None: + """ + Test reading audio from a file descriptor. + """ + filepath = AUDIO_FOLDER / "stereo" / filename + assert filepath.is_file() + with filepath.open("rb") as file: + _ = Audio.from_file(file) + + +@pytest.mark.parametrize( + "filename", + [ + "gs-16b-2c-44100hz.aac", + "gs-16b-2c-44100hz.ac3", + "gs-16b-2c-44100hz.aiff", + "gs-16b-2c-44100hz.flac", + "gs-16b-2c-44100hz.m4a", + "gs-16b-2c-44100hz.mp3", + "gs-16b-2c-44100hz.mp4", + "gs-16b-2c-44100hz.ogg", + "gs-16b-2c-44100hz.ogx", + "gs-16b-2c-44100hz.wav", + "gs-16b-2c-44100hz.wma", + ], +) +def test_audio_from_io(filename: str) -> None: + """ + Test reading audio from an IO object. + """ + filepath = AUDIO_FOLDER / "stereo" / filename + assert filepath.is_file() + with filepath.open("rb") as file: + blob = file.read() + buffer = io.BytesIO(blob) + _ = Audio.from_file(buffer) + + +@pytest.mark.parametrize( + "filename", + [ + "gs-16b-2c-44100hz.aac", + "gs-16b-2c-44100hz.ac3", + "gs-16b-2c-44100hz.aiff", + "gs-16b-2c-44100hz.flac", + "gs-16b-2c-44100hz.m4a", + "gs-16b-2c-44100hz.mp3", + "gs-16b-2c-44100hz.mp4", + "gs-16b-2c-44100hz.ogg", + "gs-16b-2c-44100hz.ogx", + "gs-16b-2c-44100hz.wav", + "gs-16b-2c-44100hz.wma", + ], +) +def test_audio_from_bytes(filename: str) -> None: + """ + Test reading audio from bytes. + """ + filepath = AUDIO_FOLDER / "stereo" / filename + assert filepath.is_file() + with filepath.open("rb") as file: + blob = file.read() + _ = Audio.from_bytes(blob) diff --git a/tests/unit/media/test_embedding.py b/tests/unit/media/test_embedding.py new file mode 100644 index 00000000..2f32be34 --- /dev/null +++ b/tests/unit/media/test_embedding.py @@ -0,0 +1,70 @@ +""" +Test `renumics.spotlight.media.Embedding` class. +""" + +import numpy as np +import pytest + +from renumics.spotlight.media import Embedding + +from .data import SEED +from ...integration.helpers import approx + + +@pytest.mark.parametrize("length", [1, 10, 100]) +@pytest.mark.parametrize("dtype", ["float32", "float64", "uint16", "int32"]) +@pytest.mark.parametrize("input_type", ["array", "list", "tuple"]) +def test_embedding(length: int, dtype: str, input_type: str) -> None: + """ + Test embedding initialization, encoding/decoding. + """ + np.random.seed(SEED) + np_dtype = np.dtype(dtype) + if np_dtype.str[1] == "f": + array = np.random.uniform(0, 1, length).astype(dtype) + else: + array = np.random.randint( # type: ignore + np.iinfo(np_dtype).min, np.iinfo(np_dtype).max, length, dtype + ) + if input_type == "list": + array = array.tolist() + elif input_type == "tuple": + array = tuple(array.tolist()) # type: ignore + + # Test instantiation + embedding = Embedding(array) + assert approx(array, embedding.data, "Embedding") + + # Test encode + encoded_embedding = embedding.encode() + assert isinstance(encoded_embedding, np.ndarray) + + # Test decode + decoded_embedding = Embedding.decode(encoded_embedding) + assert approx(embedding, decoded_embedding.data, "Embedding") + + +@pytest.mark.parametrize("input_type", ["array", "list", "tuple"]) +def test_zero_length_embedding_fails(input_type: str) -> None: + """ + Test if `Embedding` fails with zero-length data. + """ + array = np.empty(0) + if input_type == "list": + array = array.tolist() + elif input_type == "tuple": + array = tuple(array.tolist()) # type: ignore + with pytest.raises(ValueError): + _ = Embedding(array) + + +@pytest.mark.parametrize("num_dims", [0, 2, 3, 4]) +def test_multidimensional_embedding_fails(num_dims: int) -> None: + """ + Test if `Embedding` fails with not one-dimensional data. + """ + np.random.seed(SEED) + dims = np.random.randint(1, 20, num_dims) + array = np.random.uniform(0, 1, dims) + with pytest.raises(ValueError): + _ = Embedding(array) diff --git a/tests/unit/media/test_image.py b/tests/unit/media/test_image.py new file mode 100644 index 00000000..47b7550a --- /dev/null +++ b/tests/unit/media/test_image.py @@ -0,0 +1,153 @@ +""" +Test `renumics.spotlight.media.Image` class. +""" +import io +from typing import Optional + +import numpy as np +import pytest + +from renumics.spotlight.media import Image + +from .data import IMAGES_FOLDER, SEED +from ...integration.helpers import approx + + +@pytest.mark.parametrize("size", [1, 10, 100]) +@pytest.mark.parametrize("num_channels", [None, 1, 3, 4]) +@pytest.mark.parametrize("dtype", ["float32", "float64", "uint8", "int32"]) +@pytest.mark.parametrize("input_type", ["array", "list", "tuple"]) +def test_image_from_array( + size: int, num_channels: Optional[int], dtype: str, input_type: str +) -> None: + """ + Test image creation. + """ + np.random.seed(SEED) + shape = (size, size) if num_channels is None else (size, size, num_channels) + np_dtype = np.dtype(dtype) + if np_dtype.str[1] == "f": + array = np.random.uniform(0, 1, shape).astype(dtype) + target = (255 * array).round().astype("uint8") + else: + array = np.random.randint(0, 256, shape, dtype) # type: ignore + target = array.astype("uint8") + if num_channels == 1: + target = target.squeeze(axis=2) + if input_type == "list": + array = array.tolist() + elif input_type == "tuple": + array = tuple(array.tolist()) # type: ignore + image = Image(array) + + assert approx(target, image, "Image") + encoded_image = image.encode() + decoded_image = Image.decode(encoded_image) + assert approx(image, decoded_image, "Image") + + +@pytest.mark.parametrize( + "filename", + [ + "nature-256p.ico", + "nature-360p.bmp", + "nature-360p.gif", + "nature-360p.jpg", + "nature-360p.png", + "nature-360p.tif", + "nature-360p.webp", + "nature-720p.jpg", + "nature-1080p.jpg", + "sea-360p.gif", + "sea-360p.apng", + ], +) +def test_image_from_filepath(filename: str) -> None: + """ + Test reading image from an existing file. + """ + filepath = IMAGES_FOLDER / filename + assert filepath.is_file() + _ = Image.from_file(str(filepath)) + _ = Image.from_file(filepath) + + +@pytest.mark.parametrize( + "filename", + [ + "nature-256p.ico", + "nature-360p.bmp", + "nature-360p.gif", + "nature-360p.jpg", + "nature-360p.png", + "nature-360p.tif", + "nature-360p.webp", + "nature-720p.jpg", + "nature-1080p.jpg", + "sea-360p.gif", + "sea-360p.apng", + ], +) +def test_image_from_file(filename: str) -> None: + """ + Test reading image from a file descriptor. + """ + filepath = IMAGES_FOLDER / filename + assert filepath.is_file() + with filepath.open("rb") as file: + _ = Image.from_file(file) + + +@pytest.mark.parametrize( + "filename", + [ + "nature-256p.ico", + "nature-360p.bmp", + "nature-360p.gif", + "nature-360p.jpg", + "nature-360p.png", + "nature-360p.tif", + "nature-360p.webp", + "nature-720p.jpg", + "nature-1080p.jpg", + "sea-360p.gif", + "sea-360p.apng", + ], +) +def test_image_from_io(filename: str) -> None: + """ + Test reading image from an IO object. + """ + filepath = IMAGES_FOLDER / filename + assert filepath.is_file() + with filepath.open("rb") as file: + blob = file.read() + buffer = io.BytesIO(blob) + _ = Image.from_file(buffer) + + +@pytest.mark.parametrize( + "filename", + [ + "nature-256p.ico", + "nature-360p.bmp", + "nature-360p.gif", + "nature-360p.jpg", + "nature-360p.png", + "nature-360p.tif", + "nature-360p.webp", + "nature-720p.jpg", + "nature-1080p.jpg", + "sea-360p.gif", + "sea-360p.apng", + ], +) +def test_image_from_bytes(filename: str) -> None: + """ + Test reading image from bytes. + """ + filepath = IMAGES_FOLDER / filename + assert filepath.is_file() + with filepath.open("rb") as file: + blob = file.read() + _ = Image.from_bytes(blob) diff --git a/tests/unit/media/test_mesh.py b/tests/unit/media/test_mesh.py new file mode 100644 index 00000000..2053b951 --- /dev/null +++ b/tests/unit/media/test_mesh.py @@ -0,0 +1,68 @@ +""" +Test `renumics.spotlight.media.Mesh` class. +""" +import numpy as np +import pytest + +from renumics.spotlight.media import Mesh + +from .data import MESHES_FOLDER +from ...integration.helpers import approx + + +@pytest.mark.parametrize("num_points", [10, 100, 10000]) +@pytest.mark.parametrize("num_triangles", [10, 100, 10000]) +def test_mesh_from_array(num_points: int, num_triangles: int) -> None: + """ + Test mesh creation. + """ + points = np.random.random((num_points, 3)) + triangles = np.random.randint(0, num_points, (num_triangles, 3)) + mesh = Mesh( + points, + triangles, + { + "float32": np.random.rand(num_points).astype("float32"), + "float64_3": np.random.rand(num_points, 3).astype("float64"), + "float16_4": np.random.rand(num_points, 4).astype("float16"), + "uint64_1": np.random.randint(0, num_points, (num_points, 1), "uint64"), + "int16": np.random.randint(-10000, 10000, num_points, "int16"), + }, + { + "float32_1": np.random.rand(num_triangles, 1).astype("float32"), + "float64_3": np.random.rand(num_triangles, 3).astype("float64"), + "float32_4": np.random.rand(num_triangles, 4).astype("float32"), + "uint32": np.random.randint( + 0, np.iinfo("uint32").max, num_triangles, "uint32" + ), + "int64_1": np.random.randint(-10000, 10000, (num_triangles, 1), "int64"), + }, + [np.random.random((num_points, 3)) for _ in range(10)], + ) + assert len(mesh.points) <= len(triangles) * 3 + assert len(mesh.triangles) <= len(triangles) + encoded_mesh = mesh.encode() + decoded_mesh = Mesh.decode(encoded_mesh) + assert approx(mesh, decoded_mesh, "Mesh") + + +@pytest.mark.parametrize( + "filename", + [ + "tree.ascii.stl", + "tree.glb", + "tree.gltf", + "tree.obj", + "tree.off", + "tree.ply", + "tree.stl", + ], +) +def test_mesh_from_filepath(filename: str) -> None: + """ + Test reading mesh from an existing file. + """ + filepath = MESHES_FOLDER / filename + assert filepath.is_file() + _ = Mesh.from_file(str(filepath)) + _ = Mesh.from_file(filepath) diff --git a/tests/unit/media/test_sequence_1d.py b/tests/unit/media/test_sequence_1d.py new file mode 100644 index 00000000..6e852635 --- /dev/null +++ b/tests/unit/media/test_sequence_1d.py @@ -0,0 +1,40 @@ +""" +Test `renumics.spotlight.media.Sequence1D` class. +""" +import numpy as np +import pytest + +from renumics.spotlight.media import Sequence1D + +from ...integration.helpers import approx + + +@pytest.mark.parametrize("length", [1, 10, 100]) +@pytest.mark.parametrize("input_type", ["array", "list", "tuple"]) +def test_sequence_1d(length: int, input_type: str) -> None: + """ + Test sequence initialization, encoding/decoding. + """ + + index = np.random.rand(length) + value = np.random.rand(length) + if input_type == "list": + index = index.tolist() + value = value.tolist() + elif input_type == "tuple": + index = tuple(index.tolist()) # type: ignore + value = tuple(value.tolist()) # type: ignore + # Initialization with index and value. + sequence_1d = Sequence1D(index, value) + assert approx(index, sequence_1d.index, "array") + assert approx(value, sequence_1d.value, "array") + encoded_sequence_1d = sequence_1d.encode() + decoded_sequence_1d = Sequence1D.decode(encoded_sequence_1d) + assert approx(sequence_1d, decoded_sequence_1d, "Sequence1D") + # Initialization with value only. + sequence_1d = Sequence1D(value) + assert approx(np.arange(len(value)), sequence_1d.index, "array") + assert approx(value, sequence_1d.value, "array") + encoded_sequence_1d = sequence_1d.encode() + decoded_sequence_1d = Sequence1D.decode(encoded_sequence_1d) + assert approx(sequence_1d, decoded_sequence_1d, "Sequence1D") diff --git a/tests/unit/media/test_video.py b/tests/unit/media/test_video.py new file mode 100644 index 00000000..224df6c8 --- /dev/null +++ b/tests/unit/media/test_video.py @@ -0,0 +1,57 @@ +""" +Test `renumics.spotlight.media.Mesh` class. +""" +import pytest + +from renumics.spotlight.media import Video + +from .data import VIDEOS_FOLDER + + +@pytest.mark.parametrize( + "filename", + [ + "sea-360p.avi", + "sea-360p.mkv", + "sea-360p.mov", + "sea-360p.mp4", + "sea-360p.mpg", + "sea-360p.ogg", + "sea-360p.webm", + "sea-360p.wmv", + "sea-360p-10s.mp4", + ], +) +def test_video_from_filepath(filename: str) -> None: + """ + Test reading video from an existing file. + """ + filepath = VIDEOS_FOLDER / filename + assert filepath.is_file() + _ = Video.from_file(str(filepath)) + _ = Video.from_file(filepath) + + +@pytest.mark.parametrize( + "filename", + [ + "sea-360p.avi", + "sea-360p.mkv", + "sea-360p.mov", + "sea-360p.mp4", + "sea-360p.mpg", + "sea-360p.ogg", + "sea-360p.webm", + "sea-360p.wmv", + "sea-360p-10s.mp4", + ], +) +def test_video_from_bytes(filename: str) -> None: + """ + Test reading video from bytes. + """ + filepath = VIDEOS_FOLDER / filename + assert filepath.is_file() + with filepath.open("rb") as file: + blob = file.read() + _ = Video.from_bytes(blob)