diff --git a/renumics/spotlight/dataset/__init__.py b/renumics/spotlight/dataset/__init__.py index fea5415b..612bba74 100644 --- a/renumics/spotlight/dataset/__init__.py +++ b/renumics/spotlight/dataset/__init__.py @@ -47,52 +47,9 @@ is_integer, is_iterable, ) -from renumics.spotlight.dtypes import ( - Embedding, - Mesh, - Sequence1D, - Image, - Audio, - Category, - Video, - Window, -) from renumics.spotlight.io.pandas import create_typed_series from renumics.spotlight.dtypes.conversion import prepare_path_or_url -from renumics.spotlight.dtypes import ( - CategoryDType, - Sequence1DDType, - create_dtype, - is_bool_dtype, - is_int_dtype, - is_float_dtype, - is_str_dtype, - is_datetime_dtype, - is_category_dtype, - is_array_dtype, - is_window_dtype, - is_embedding_dtype, - is_sequence_1d_dtype, - is_audio_dtype, - is_image_dtype, - is_mesh_dtype, - is_video_dtype, - is_file_dtype, - is_scalar_dtype, - DType, - bool_dtype, - int_dtype, - float_dtype, - str_dtype, - datetime_dtype, - array_dtype, - window_dtype, - embedding_dtype, - audio_dtype, - image_dtype, - mesh_dtype, - video_dtype, -) +from renumics.spotlight import dtypes as spotlight_dtypes from . import exceptions from .typing import ( @@ -122,7 +79,7 @@ INTERNAL_COLUMN_NAMES = ["__last_edited_by__", "__last_edited_at__"] -INTERNAL_COLUMN_DTYPES = [str_dtype, datetime_dtype] +INTERNAL_COLUMN_DTYPES = [spotlight_dtypes.str_dtype, spotlight_dtypes.datetime_dtype] _EncodedColumnType = Optional[Union[bool, int, float, str, np.ndarray, h5py.Reference]] @@ -133,15 +90,15 @@ "float": float, "str": str, "datetime": datetime, - "Category": Category, + "Category": spotlight_dtypes.Category, "array": np.ndarray, - "Window": Window, - "Embedding": Embedding, - "Sequence1D": Sequence1D, - "Audio": Audio, - "Image": Image, - "Video": Video, - "Mesh": Mesh, + "Window": spotlight_dtypes.Window, + "Embedding": spotlight_dtypes.Embedding, + "Sequence1D": spotlight_dtypes.Sequence1D, + "Audio": spotlight_dtypes.Audio, + "Image": spotlight_dtypes.Image, + "Video": spotlight_dtypes.Video, + "Mesh": spotlight_dtypes.Mesh, } @@ -201,12 +158,12 @@ def unescape_dataset_name(escaped_name: str) -> str: np.floating, ), "Window": (np.ndarray, list, tuple), - "Embedding": (Embedding, np.ndarray, list, tuple), - "Sequence1D": (Sequence1D, np.ndarray, list, tuple), - "Audio": (Audio, bytes, str, os.PathLike), - "Image": (Image, bytes, str, os.PathLike, np.ndarray, list, tuple), - "Mesh": (Mesh, trimesh.Trimesh, str, os.PathLike), - "Video": (Video, bytes, str, os.PathLike), + "Embedding": (spotlight_dtypes.Embedding, np.ndarray, list, tuple), + "Sequence1D": (spotlight_dtypes.Sequence1D, np.ndarray, list, tuple), + "Audio": (spotlight_dtypes.Audio, bytes, str, os.PathLike), + "Image": (spotlight_dtypes.Image, bytes, str, os.PathLike, np.ndarray, list, tuple), + "Mesh": (spotlight_dtypes.Mesh, trimesh.Trimesh, str, os.PathLike), + "Video": (spotlight_dtypes.Video, bytes, str, os.PathLike), } _ALLOWED_COLUMN_DTYPES: Dict[str, Tuple[Type, ...]] = { "bool": (np.bool_,), @@ -218,7 +175,7 @@ def unescape_dataset_name(escaped_name: str) -> str: } -def _check_valid_value_type(value: Any, dtype: DType) -> bool: +def _check_valid_value_type(value: Any, dtype: spotlight_dtypes.DType) -> bool: """ Check if a value is suitable for the given column type. Instances of the given type are always suitable for its type, but extra types from @@ -228,7 +185,9 @@ def _check_valid_value_type(value: Any, dtype: DType) -> bool: return isinstance(value, allowed_types) -def _check_valid_value_dtype(value_dtype: np.dtype, dtype: DType) -> bool: +def _check_valid_value_dtype( + value_dtype: np.dtype, dtype: spotlight_dtypes.DType +) -> bool: """ Check if an array with the given dtype is suitable for the given column type. Only types from `_ALLOWED_COLUMN_DTYPES` are checked. All other column types @@ -240,7 +199,9 @@ def _check_valid_value_dtype(value_dtype: np.dtype, dtype: DType) -> bool: ) -def _check_valid_array(value: Any, dtype: DType) -> TypeGuard[np.ndarray]: +def _check_valid_array( + value: Any, dtype: spotlight_dtypes.DType +) -> TypeGuard[np.ndarray]: """ Check if a value is an array and its type is suitable for the given column type. """ @@ -263,7 +224,7 @@ class Dataset: _length: int @staticmethod - def _user_column_attributes(dtype: DType) -> Dict[str, Type]: + def _user_column_attributes(dtype: spotlight_dtypes.DType) -> Dict[str, Type]: attribute_names = { "order": int, "hidden": bool, @@ -273,37 +234,37 @@ def _user_column_attributes(dtype: DType) -> Dict[str, Type]: "tags": list, } if ( - is_scalar_dtype(dtype) - or is_str_dtype(dtype) - or is_category_dtype(dtype) - or is_window_dtype(dtype) + spotlight_dtypes.is_scalar_dtype(dtype) + or spotlight_dtypes.is_str_dtype(dtype) + or spotlight_dtypes.is_category_dtype(dtype) + or spotlight_dtypes.is_window_dtype(dtype) ): attribute_names["editable"] = bool - if is_category_dtype(dtype): + if spotlight_dtypes.is_category_dtype(dtype): attribute_names["categories"] = dict - if is_sequence_1d_dtype(dtype): + if spotlight_dtypes.is_sequence_1d_dtype(dtype): attribute_names["x_label"] = str attribute_names["y_label"] = str - if is_file_dtype(dtype): + if spotlight_dtypes.is_file_dtype(dtype): attribute_names["lookup"] = dict attribute_names["external"] = bool - if is_audio_dtype(dtype): + if spotlight_dtypes.is_audio_dtype(dtype): attribute_names["lossy"] = bool return attribute_names @classmethod - def _default_default(cls, dtype: DType) -> Any: - if is_bool_dtype(dtype): + def _default_default(cls, dtype: spotlight_dtypes.DType) -> Any: + if spotlight_dtypes.is_bool_dtype(dtype): return False - if is_int_dtype(dtype): + if spotlight_dtypes.is_int_dtype(dtype): return 0 - if is_float_dtype(dtype): + if spotlight_dtypes.is_float_dtype(dtype): return float("nan") - if is_str_dtype(dtype): + if spotlight_dtypes.is_str_dtype(dtype): return "" - if is_datetime_dtype(dtype): + if spotlight_dtypes.is_datetime_dtype(dtype): return np.datetime64("NaT") - if is_window_dtype(dtype): + if spotlight_dtypes.is_window_dtype(dtype): return [np.nan, np.nan] return None @@ -413,7 +374,7 @@ def __delitem__(self, item: Union[str, IndexType, Indices1dType]) -> None: for column_name in self.keys() + INTERNAL_COLUMN_NAMES: column = self._h5_file[column_name] raw_values = column[item + 1 :] - if is_embedding_dtype(self._get_dtype(column)): + if spotlight_dtypes.is_embedding_dtype(self._get_dtype(column)): raw_values = list(raw_values) column[item:-1] = raw_values column.resize(self._length - 1, axis=0) @@ -783,7 +744,11 @@ def from_pandas( dtypes = {} inferred_dtypes = infer_dtypes( - df, {col: create_dtype(dtype) for col, dtype in dtypes.items()} + df, + { + col: spotlight_dtypes.create_dtype(dtype) + for col, dtype in dtypes.items() + }, ) for column_name in df.columns: @@ -793,7 +758,7 @@ def from_pandas( column = prepare_column(column, dtype) - if workdir is not None and is_file_dtype(dtype): + if workdir is not None and spotlight_dtypes.is_file_dtype(dtype): # For file-based data types, relative paths should be resolved. str_mask = is_string_mask(column) column[str_mask] = column[str_mask].apply( @@ -802,17 +767,17 @@ def from_pandas( attrs = {} - if is_category_dtype(dtype): + if spotlight_dtypes.is_category_dtype(dtype): attrs["categories"] = column.cat.categories.to_list() values = column.to_numpy() # `pandas` uses `NaN`s for unknown values, we use `None`. values = np.where(pd.isna(values), np.array(None), values) - elif is_datetime_dtype(dtype): + elif spotlight_dtypes.is_datetime_dtype(dtype): values = column.to_numpy("datetime64[us]") else: values = column.to_numpy() - if is_file_dtype(dtype): + if spotlight_dtypes.is_file_dtype(dtype): attrs["external"] = False # type: ignore attrs["lookup"] = False # type: ignore @@ -886,12 +851,12 @@ def to_pandas(self) -> pd.DataFrame: df = pd.DataFrame() for column_name in self._column_names: dtype = self.get_dtype(column_name) - if is_datetime_dtype(dtype): + if spotlight_dtypes.is_datetime_dtype(dtype): df[column_name] = create_typed_series(dtype, self[column_name]) elif ( - is_scalar_dtype(dtype) - or is_str_dtype(dtype) - or is_category_dtype(dtype) + spotlight_dtypes.is_scalar_dtype(dtype) + or spotlight_dtypes.is_str_dtype(dtype) + or spotlight_dtypes.is_category_dtype(dtype) ): df[column_name] = create_typed_series( dtype, self._h5_file[column_name][:] @@ -946,7 +911,7 @@ def append_bool_column( """ self._append_column( name, - bool_dtype, + spotlight_dtypes.bool_dtype, values, np.dtype(bool), order, @@ -993,7 +958,7 @@ def append_int_column( """ self._append_column( name, - int_dtype, + spotlight_dtypes.int_dtype, values, np.dtype(int), order, @@ -1041,7 +1006,7 @@ def append_float_column( """ self._append_column( name, - float_dtype, + spotlight_dtypes.float_dtype, values, np.dtype(float), order, @@ -1089,7 +1054,7 @@ def append_string_column( """ self._append_column( name, - str_dtype, + spotlight_dtypes.str_dtype, values, h5py.string_dtype(), order, @@ -1144,7 +1109,7 @@ def append_datetime_column( """ self._append_column( name, - datetime_dtype, + spotlight_dtypes.datetime_dtype, values, h5py.string_dtype(), order, @@ -1195,7 +1160,7 @@ def append_array_column( """ self._append_column( name, - array_dtype, + spotlight_dtypes.array_dtype, values, h5py.string_dtype(), order, @@ -1245,7 +1210,7 @@ def append_categorical_column( """ self._append_column( name, - CategoryDType(categories), + spotlight_dtypes.CategoryDType(categories), values, np.dtype("int32"), order, @@ -1299,7 +1264,7 @@ def append_embedding_column( ) self._append_column( name, - embedding_dtype, + spotlight_dtypes.embedding_dtype, values, h5py.vlen_dtype(np_dtype), order, @@ -1353,7 +1318,7 @@ def append_sequence_1d_column( y_label = name self._append_column( name, - Sequence1DDType(x_label, y_label), + spotlight_dtypes.Sequence1DDType(x_label, y_label), values, h5py.string_dtype(), order, @@ -1414,7 +1379,7 @@ def append_mesh_column( """ self._append_column( name, - mesh_dtype, + spotlight_dtypes.mesh_dtype, values, h5py.string_dtype(), order, @@ -1475,7 +1440,7 @@ def append_image_column( """ self._append_column( name, - image_dtype, + spotlight_dtypes.image_dtype, values, h5py.string_dtype(), order, @@ -1548,7 +1513,7 @@ def append_audio_column( attrs["lossy"] = lossy self._append_column( name, - audio_dtype, + spotlight_dtypes.audio_dtype, values, h5py.string_dtype(), order, @@ -1609,7 +1574,7 @@ def append_video_column( """ self._append_column( name, - video_dtype, + spotlight_dtypes.video_dtype, values, h5py.string_dtype(), order, @@ -1659,7 +1624,7 @@ def append_window_column( """ self._append_column( name, - window_dtype, + spotlight_dtypes.window_dtype, values, np.dtype("float32"), order, @@ -1682,7 +1647,7 @@ def append_column( default: ColumnInputType = None, description: Optional[str] = None, tags: Optional[List[str]] = None, - **attrs: Optional[Union[str, bool]], + **attrs: Any, ) -> None: """ Create and optionally fill a dataset column of the given type. @@ -1722,38 +1687,46 @@ def append_column( [ True True True True True] [1. 1. 1. 1. 1.] """ - dtype = create_dtype(dtype) + dtype = spotlight_dtypes.create_dtype(dtype) - if dtype.name == "bool": + if spotlight_dtypes.is_bool_dtype(dtype): append_column_fn: Callable = self.append_bool_column - elif dtype.name == "int": + elif spotlight_dtypes.is_int_dtype(dtype): append_column_fn = self.append_int_column - elif dtype.name == "float": + elif spotlight_dtypes.is_float_dtype(dtype): append_column_fn = self.append_float_column - elif dtype.name == "str": + elif spotlight_dtypes.is_str_dtype(dtype): append_column_fn = self.append_string_column - elif dtype.name == "datetime": + elif spotlight_dtypes.is_datetime_dtype(dtype): append_column_fn = self.append_datetime_column - elif dtype.name == "array": + elif spotlight_dtypes.is_category_dtype(dtype): + append_column_fn = self.append_categorical_column + if dtype.categories: + if "categories" in attrs and attrs["categories"] != dtype.categories: + raise exceptions.InvalidAttributeError( + f"Categories differ between `dtype` ({dtype.categories}) " + f"and `categories` ({attrs['categories']}) keyword argument." + ) + attrs["categories"] = dtype.categories + elif spotlight_dtypes.is_array_dtype(dtype): append_column_fn = self.append_array_column - elif dtype.name == "Embedding": + elif spotlight_dtypes.is_window_dtype(dtype): + append_column_fn = self.append_window_column + elif spotlight_dtypes.is_embedding_dtype(dtype): append_column_fn = self.append_embedding_column - elif dtype.name == "Image": - append_column_fn = self.append_image_column - elif dtype.name == "Mesh": - append_column_fn = self.append_mesh_column - elif dtype.name == "Sequence1D": + elif spotlight_dtypes.is_sequence_1d_dtype(dtype): append_column_fn = self.append_sequence_1d_column - elif dtype.name == "Audio": + elif spotlight_dtypes.is_audio_dtype(dtype): append_column_fn = self.append_audio_column - elif dtype.name == "Category": - append_column_fn = self.append_categorical_column - elif dtype.name == "Video": + elif spotlight_dtypes.is_image_dtype(dtype): + append_column_fn = self.append_image_column + elif spotlight_dtypes.is_mesh_dtype(dtype): + append_column_fn = self.append_mesh_column + elif spotlight_dtypes.is_video_dtype(dtype): append_column_fn = self.append_video_column - elif dtype.name == "Window": - append_column_fn = self.append_window_column else: raise exceptions.InvalidDTypeError(f"Unknown column type: {dtype}.") + append_column_fn( name=name, values=values, # type: ignore @@ -1857,7 +1830,7 @@ def insert_row(self, index: IndexType, values: Dict[str, ColumnInputType]) -> No column = self._h5_file[column_name] column.resize(length + 1, axis=0) raw_values = column[index:-1] - if self._get_dtype(column).name == "Embedding": + if spotlight_dtypes.is_embedding_dtype(self._get_dtype(column)): raw_values = list(raw_values) column[index + 1 :] = raw_values self._length += 1 @@ -1903,15 +1876,15 @@ def isnull(self, column_name: str) -> np.ndarray: dtype = self._get_dtype(column) if self._is_ref_column(column): return ~raw_values.astype(bool) - if dtype.name == "datetime": + if spotlight_dtypes.is_datetime_dtype(dtype): return np.array([raw_value in ["", b""] for raw_value in raw_values]) - if dtype.name == "float": + if spotlight_dtypes.is_float_dtype(dtype): return np.isnan(raw_values) - if dtype.name == "Category": + if spotlight_dtypes.is_category_dtype(dtype): return raw_values == -1 - if dtype.name == "Window": + if spotlight_dtypes.is_window_dtype(dtype): return np.isnan(raw_values).all(axis=1) - if dtype.name == "Embedding": + if spotlight_dtypes.is_embedding_dtype(dtype): return np.array([len(x) == 0 for x in raw_values]) return np.full(len(self), False) @@ -2000,14 +1973,14 @@ def prune(self) -> None: else: refs.append(None) raw_values = refs - if self._get_dtype(column).name == "Embedding": + if spotlight_dtypes.is_embedding_dtype(self._get_dtype(column)): raw_values = list(raw_values) new_column[:] = raw_values self.close() shutil.move(new_dataset, os.path.realpath(self._filepath)) self.open() - def get_dtype(self, column_name: str) -> DType: + def get_dtype(self, column_name: str) -> spotlight_dtypes.DType: """ Get type of dataset column. @@ -2339,12 +2312,12 @@ def set_column_attributes( if default is None and old_default is None: default = self._default_default(dtype) if default is None: - if is_embedding_dtype(dtype) and not self._is_ref_column( - column - ): + if spotlight_dtypes.is_embedding_dtype( + dtype + ) and not self._is_ref_column(column): # For a non-ref `Embedding` column, replace `None` with an empty array. default = np.empty(0, column.dtype.metadata["vlen"]) - if is_category_dtype(dtype) and default is not None: + if spotlight_dtypes.is_category_dtype(dtype) and default is not None: categories: List[str] = column.attrs["category_keys"].tolist() if default not in categories: column.attrs["category_values"] = np.append( @@ -2353,10 +2326,13 @@ def set_column_attributes( column.attrs["category_keys"] = categories + [default] if default is not None: encoded_value = self._encode_value(default, column) - if dtype.name == "datetime" and encoded_value is None: + if ( + spotlight_dtypes.is_datetime_dtype(dtype) + and encoded_value is None + ): encoded_value = "" column.attrs["default"] = encoded_value - elif is_category_dtype(dtype): + elif spotlight_dtypes.is_category_dtype(dtype): column.attrs["default"] = -1 except Exception as e: @@ -2401,7 +2377,7 @@ def _format(value: _EncodedColumnType, ref_type_name: str) -> str: for column in columns: attrs = column.attrs type_name = attrs["type"] - dtype = create_dtype(type_name) + dtype = spotlight_dtypes.create_dtype(type_name) optional_keys = set(self._user_column_attributes(dtype).keys()).difference( required_keys ) @@ -2425,7 +2401,7 @@ def _format(value: _EncodedColumnType, ref_type_name: str) -> str: def _append_column( self, name: str, - dtype: DType, + dtype: spotlight_dtypes.DType, values: Union[ColumnInputType, Iterable[ColumnInputType]], np_dtype: np.dtype, order: Optional[int] = None, @@ -2443,7 +2419,7 @@ def _append_column( # `set_column_attributes` method. shape: Tuple[int, ...] = (0,) maxshape: Tuple[Optional[int], ...] = (None,) - if is_category_dtype(dtype): + if spotlight_dtypes.is_category_dtype(dtype): if dtype.categories is None: if is_iterable(values): categories: List[str] = sorted(set(values)) @@ -2451,15 +2427,15 @@ def _append_column( categories = [] else: categories = cast(List[str], [values]) - dtype = CategoryDType(categories) + dtype = spotlight_dtypes.CategoryDType(categories) attrs["categories"] = dtype.categories - elif is_window_dtype(dtype): + elif spotlight_dtypes.is_window_dtype(dtype): shape = (0, 2) maxshape = (None, 2) - elif is_sequence_1d_dtype(dtype): + elif spotlight_dtypes.is_sequence_1d_dtype(dtype): attrs["x_label"] = dtype.x_label attrs["y_label"] = dtype.y_label - elif is_file_dtype(dtype): + elif spotlight_dtypes.is_file_dtype(dtype): lookup = attrs.get("lookup", None) if is_iterable(lookup) and not isinstance(lookup, dict): # Assume that we can keep all the lookup values in memory. @@ -2569,7 +2545,7 @@ def _set_column( else: # Reorder values according to the given indices. encoded_values = encoded_values[values_indices] - if self._get_dtype(column).name == "Embedding": + if spotlight_dtypes.is_embedding_dtype(self._get_dtype(column)): encoded_values = list(encoded_values) elif values is not None: # A single value is given. `Window` and `Embedding` values should @@ -2681,30 +2657,30 @@ def _decode_values(self, values: np.ndarray, column: h5py.Dataset) -> np.ndarray @staticmethod def _decode_simple_values( - values: np.ndarray, column: h5py.Dataset, dtype: DType + values: np.ndarray, column: h5py.Dataset, dtype: spotlight_dtypes.DType ) -> np.ndarray: - if is_category_dtype(dtype): + if spotlight_dtypes.is_category_dtype(dtype): if dtype.inverted_categories is None: return np.full(len(values), None) return np.array([dtype.inverted_categories.get(x) for x in values]) if h5py.check_string_dtype(column.dtype): # `dtype` is `str` or `datetime`. values = np.array([x.decode("utf-8") for x in values]) - if is_str_dtype(dtype): + if spotlight_dtypes.is_str_dtype(dtype): return values # Decode datetimes. return np.array( [None if x == "" else datetime.fromisoformat(x) for x in values], dtype=object, ) - if is_embedding_dtype(dtype): + if spotlight_dtypes.is_embedding_dtype(dtype): null_mask = [len(x) == 0 for x in values] values[null_mask] = None # For column types `bool`, `int`, `float` or `Window`, return the array as-is. return values def _decode_ref_values( - self, values: np.ndarray, column: h5py.Dataset, dtype: DType + self, values: np.ndarray, column: h5py.Dataset, dtype: spotlight_dtypes.DType ) -> np.ndarray: column_name = self._get_column_name(column) if dtype.name in ("array", "Embedding"): @@ -2720,7 +2696,9 @@ def _decode_ref_values( dtype=object, ) - def _decode_external_values(self, values: np.ndarray, dtype: DType) -> np.ndarray: + def _decode_external_values( + self, values: np.ndarray, dtype: spotlight_dtypes.DType + ) -> np.ndarray: return np.array( [self._decode_external_value(value, dtype) for value in values], dtype=object, @@ -2797,7 +2775,7 @@ def _encode_simple_values( self, values: Iterable[SimpleColumnInputType], column: h5py.Dataset ) -> np.ndarray: dtype = self._get_dtype(column) - if is_category_dtype(dtype): + if spotlight_dtypes.is_category_dtype(dtype): categories = cast(Dict[Optional[str], int], (dtype.categories or {}).copy()) if column.attrs.get("optional", False): default = column.attrs.get("default", -1) @@ -2822,7 +2800,7 @@ def _encode_simple_values( f"'{column_name}'. Valid values for this column are: " f"{categories_str}." ) from e - if is_datetime_dtype(dtype): + if spotlight_dtypes.is_datetime_dtype(dtype): if _check_valid_array(values, dtype): encoded_values = np.array( [None if x is None else x.isoformat() for x in values.tolist()] @@ -2837,7 +2815,7 @@ def _encode_simple_values( # That means, we have all strings in array, no `None`s. return encoded_values return self._replace_none(encoded_values, column) - if is_window_dtype(dtype): + if spotlight_dtypes.is_window_dtype(dtype): encoded_values = self._asarray(values, column, dtype) if encoded_values.ndim == 1: if len(encoded_values) == 2: @@ -2855,7 +2833,7 @@ def _encode_simple_values( f"one of shapes (2,) (a single window) or (n, 2) (multiple " f"windows), but values with shape {encoded_values.shape} received." ) - if is_embedding_dtype(dtype): + if spotlight_dtypes.is_embedding_dtype(dtype): if _check_valid_array(values, dtype): # This is the only case we can handle fast and easily, otherwise # embedding should go through `_encode_value` element-wise. @@ -2904,7 +2882,7 @@ def _asarray( self, values: Iterable[SimpleColumnInputType], column: h5py.Dataset, - dtype: DType, + dtype: spotlight_dtypes.DType, ) -> np.ndarray: if isinstance(values, np.ndarray): if _check_valid_value_dtype(values.dtype, dtype): @@ -3024,7 +3002,7 @@ def _encode_value( if self._is_ref_column(column): value = cast(RefColumnInputType, value) self._assert_valid_value_type(value, dtype, column_name) - if is_file_dtype(dtype) and isinstance(value, str): + if spotlight_dtypes.is_file_dtype(dtype) and isinstance(value, str): try: return self._find_lookup_ref(value, column) except KeyError: @@ -3050,7 +3028,7 @@ def _encode_simple_value( self, value: SimpleColumnInputType, column: h5py.Dataset, - dtype: DType, + dtype: spotlight_dtypes.DType, column_name: str, ) -> _EncodedColumnType: """ @@ -3061,7 +3039,7 @@ def _encode_simple_value( """ self._assert_valid_value_type(value, dtype, column_name) attrs = column.attrs - if is_category_dtype(dtype): + if spotlight_dtypes.is_category_dtype(dtype): value = cast(str, value) if dtype.categories: try: @@ -3075,7 +3053,7 @@ def _encode_simple_value( f"categorical column '{column_name}'. Valid values for this " f"column are: {categories_str}." ) - if is_window_dtype(dtype): + if spotlight_dtypes.is_window_dtype(dtype): value = np.asarray(value, dtype=column.dtype) if value.shape == (2,): return value @@ -3083,9 +3061,9 @@ def _encode_simple_value( f"Windows should consist of 2 values, but window of shape " f"{value.shape} received for column {column_name}." ) - if is_embedding_dtype(dtype): + if spotlight_dtypes.is_embedding_dtype(dtype): # `Embedding` column is not a ref column. - if isinstance(value, Embedding): + if isinstance(value, spotlight_dtypes.Embedding): value = value.encode(attrs.get("format", None)) value = np.asarray(value, dtype=column.dtype.metadata["vlen"]) self._assert_valid_or_set_embedding_shape(value.shape, column) @@ -3119,7 +3097,10 @@ def _write_ref_value( return ref def _encode_ref_value( - self, value: RefColumnInputType, column: h5py.Dataset, dtype: DType + self, + value: RefColumnInputType, + column: h5py.Dataset, + dtype: spotlight_dtypes.DType, ) -> Optional[Union[np.ndarray, np.void]]: """ Encode a ref value, e.g. np.ndarray, Sequence1D, Image, Mesh, Audio, @@ -3127,63 +3108,63 @@ def _encode_ref_value( Value *cannot* be `None` already. """ - if is_array_dtype(dtype): + if spotlight_dtypes.is_array_dtype(dtype): value = np.asarray(value) self._assert_valid_or_set_value_dtype(value.dtype, column) return value - if is_embedding_dtype(dtype): - if not isinstance(value, Embedding): - value = Embedding(value) # type: ignore + if spotlight_dtypes.is_embedding_dtype(dtype): + if not isinstance(value, spotlight_dtypes.Embedding): + value = spotlight_dtypes.Embedding(value) # type: ignore value = value.encode() self._assert_valid_or_set_value_dtype(value.dtype, column) self._assert_valid_or_set_embedding_shape(value.shape, column) return value - if is_sequence_1d_dtype(dtype): - if not isinstance(value, Sequence1D): - value = Sequence1D(value) # type: ignore + if spotlight_dtypes.is_sequence_1d_dtype(dtype): + if not isinstance(value, spotlight_dtypes.Sequence1D): + value = spotlight_dtypes.Sequence1D(value) # type: ignore value = value.encode() self._assert_valid_or_set_value_dtype(value.dtype, column) return value - if is_audio_dtype(dtype): + if spotlight_dtypes.is_audio_dtype(dtype): if isinstance(value, (str, os.PathLike)): try: - value = Audio.from_file(value) + value = spotlight_dtypes.Audio.from_file(value) except Exception: return None if isinstance(value, bytes): - value = Audio.from_bytes(value) - assert isinstance(value, Audio) + value = spotlight_dtypes.Audio.from_bytes(value) + assert isinstance(value, spotlight_dtypes.Audio) return value.encode(column.attrs.get("format", None)) - if is_image_dtype(dtype): + if spotlight_dtypes.is_image_dtype(dtype): if isinstance(value, (str, os.PathLike)): try: - value = Image.from_file(value) + value = spotlight_dtypes.Image.from_file(value) except Exception: return None if isinstance(value, bytes): - value = Image.from_bytes(value) - if not isinstance(value, Image): - value = Image(value) # type: ignore + value = spotlight_dtypes.Image.from_bytes(value) + if not isinstance(value, spotlight_dtypes.Image): + value = spotlight_dtypes.Image(value) # type: ignore return value.encode() - if is_mesh_dtype(dtype): + if spotlight_dtypes.is_mesh_dtype(dtype): if isinstance(value, (str, os.PathLike)): try: - value = Mesh.from_file(value) + value = spotlight_dtypes.Mesh.from_file(value) except Exception: return None if isinstance(value, trimesh.Trimesh): - value = Mesh.from_trimesh(value) - assert isinstance(value, Mesh) + value = spotlight_dtypes.Mesh.from_trimesh(value) + assert isinstance(value, spotlight_dtypes.Mesh) return value.encode() - if is_video_dtype(dtype): + if spotlight_dtypes.is_video_dtype(dtype): if isinstance(value, (str, os.PathLike)): try: - value = Video.from_file(value) + value = spotlight_dtypes.Video.from_file(value) except Exception: return None if isinstance(value, bytes): - value = Video.from_bytes(value) - assert isinstance(value, Video) + value = spotlight_dtypes.Video.from_bytes(value) + assert isinstance(value, spotlight_dtypes.Video) return value.encode(column.attrs.get("format", None)) assert False @@ -3239,7 +3220,7 @@ def _encode_external_value(self, value: PathOrUrlType, column: h5py.Dataset) -> @staticmethod def _assert_valid_value_type( - value: ColumnInputType, dtype: DType, column_name: str + value: ColumnInputType, dtype: spotlight_dtypes.DType, column_name: str ) -> None: if not _check_valid_value_type(value, dtype): allowed_types = _ALLOWED_COLUMN_TYPES.get(dtype.name, ()) @@ -3270,22 +3251,22 @@ def _decode_value( @staticmethod def _decode_simple_value( value: Union[np.bool_, np.integer, np.floating, bytes, str, np.ndarray], - dtype: DType, + dtype: spotlight_dtypes.DType, ) -> Optional[Union[bool, int, float, str, datetime, np.ndarray]]: - if is_window_dtype(dtype): + if spotlight_dtypes.is_window_dtype(dtype): return value # type: ignore - if is_embedding_dtype(dtype): + if spotlight_dtypes.is_embedding_dtype(dtype): value = cast(np.ndarray, value) if len(value) == 0: return None return value - if is_category_dtype(dtype): + if spotlight_dtypes.is_category_dtype(dtype): if dtype.inverted_categories is None: return None return dtype.inverted_categories.get(cast(int, value), None) if isinstance(value, bytes): value = value.decode("utf-8") - if is_datetime_dtype(dtype): + if spotlight_dtypes.is_datetime_dtype(dtype): value = cast(str, value) if value == "": return None @@ -3295,21 +3276,32 @@ def _decode_simple_value( def _decode_ref_value( self, ref: Union[bytes, str, h5py.Reference], - dtype: DType, + dtype: spotlight_dtypes.DType, column_name: str, - ) -> Optional[Union[np.ndarray, Audio, Image, Mesh, Sequence1D, Video]]: + ) -> Optional[ + Union[ + np.ndarray, + spotlight_dtypes.Audio, + spotlight_dtypes.Image, + spotlight_dtypes.Mesh, + spotlight_dtypes.Sequence1D, + spotlight_dtypes.Video, + ] + ]: # Value can be a H5 reference or a string reference. if not ref: return None value = self._resolve_ref(ref, column_name)[()] - if is_array_dtype(dtype) or is_embedding_dtype(dtype): + if spotlight_dtypes.is_array_dtype( + dtype + ) or spotlight_dtypes.is_embedding_dtype(dtype): return np.asarray(value) return VALUE_TYPE_BY_DTYPE_NAME[dtype.name].decode(value) # type: ignore def _decode_external_value( self, value: Union[str, bytes], - dtype: DType, + dtype: spotlight_dtypes.DType, ) -> Optional[ExternalOutputType]: if not value: return None @@ -3354,7 +3346,7 @@ def _append_internal_columns(self) -> None: f"has no type stored in attributes. Remove or rename " f"the respective h5 dataset." ) from e - if create_dtype(type_name).name != dtype.name: + if spotlight_dtypes.create_dtype(type_name).name != dtype.name: raise exceptions.InconsistentDatasetError( f'Internal column "{column_name}" already exists, ' f"but has invalid type `{type_name}` " @@ -3423,7 +3415,9 @@ def _resolve_ref( def _get_username() -> str: return "" - def _get_dtype(self, x: Union[h5py.Dataset, h5py.AttributeManager]) -> DType: + def _get_dtype( + self, x: Union[h5py.Dataset, h5py.AttributeManager] + ) -> spotlight_dtypes.DType: """ Get column type by its name, or extract it from `h5py` entities. """ @@ -3432,12 +3426,14 @@ def _get_dtype(self, x: Union[h5py.Dataset, h5py.AttributeManager]) -> DType: type_name = x["type"] if type_name == "Category": - return CategoryDType( + return spotlight_dtypes.CategoryDType( dict(zip(x.get("category_keys", []), x.get("category_values", []))) ) if type_name == "Sequence1D": - return Sequence1DDType(x.get("x_label", "x"), x.get("y_label", "y")) - return create_dtype(type_name) + return spotlight_dtypes.Sequence1DDType( + x.get("x_label", "x"), x.get("y_label", "y") + ) + return spotlight_dtypes.create_dtype(type_name) @staticmethod def _get_column_name(column: h5py.Dataset) -> str: diff --git a/renumics/spotlight/dataset/descriptors/__init__.py b/renumics/spotlight/dataset/descriptors/__init__.py index e4dab9a2..9925c590 100644 --- a/renumics/spotlight/dataset/descriptors/__init__.py +++ b/renumics/spotlight/dataset/descriptors/__init__.py @@ -6,6 +6,7 @@ import pycatch22 from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler +from renumics.spotlight import dtypes from renumics.spotlight.dataset import Dataset from renumics.spotlight.dataset.exceptions import ColumnExistsError, InvalidDTypeError @@ -77,7 +78,7 @@ def catch22( if suffix is None: suffix = "catch24" if catch24 else "catch22" dtype = dataset.get_dtype(column) - if dtype.name == ("Audio", "Sequence1D"): + if not dtypes.is_audio_dtype(dtype) and not dtypes.is_sequence_1d_dtype(dtype): raise InvalidDTypeError( f"catch22 is only applicable to columns of type `Audio` and " f'`Sequence1D`, but column "{column}" of type {dtype} received.' diff --git a/renumics/spotlight/dataset/descriptors/data_alignment.py b/renumics/spotlight/dataset/descriptors/data_alignment.py index 0c54fae9..b409eda4 100644 --- a/renumics/spotlight/dataset/descriptors/data_alignment.py +++ b/renumics/spotlight/dataset/descriptors/data_alignment.py @@ -11,6 +11,7 @@ from renumics.spotlight import ( Dataset, + dtypes, ) from renumics.spotlight.dataset import exceptions @@ -21,7 +22,7 @@ def align_audio_data(dataset: Dataset, column: str) -> Tuple[np.ndarray, np.ndar """ dtype = dataset.get_dtype(column) - if dtype.name != "Audio": + if not dtypes.is_audio_dtype(dtype): raise exceptions.InvalidDTypeError( f'An audio column expected, but column "{column}" of type {dtype} received.' ) @@ -66,7 +67,7 @@ def align_embedding_data( Align data from an embedding column. """ dtype = dataset.get_dtype(column) - if dtype.name != "Embedding": + if not dtypes.is_embedding_dtype(dtype): raise exceptions.InvalidDTypeError( f'An embedding column expected, but column "{column}" of type {dtype} received.' ) @@ -82,7 +83,7 @@ def align_image_data(dataset: Dataset, column: str) -> Tuple[np.ndarray, np.ndar Align data from an image column. """ dtype = dataset.get_dtype(column) - if dtype.name != "Image": + if not dtypes.is_image_dtype(dtype): raise exceptions.InvalidDTypeError( f'An image column expected, but column "{column}" of type {dtype} received.' ) @@ -124,7 +125,7 @@ def align_sequence_1d_data( Align data from an sequence 1D column. """ dtype = dataset.get_dtype(column) - if dtype.name != "Sequence1D": + if not dtypes.is_sequence_1d_dtype(dtype): raise exceptions.InvalidDTypeError( f'A sequence 1D column expected, but column "{column}" of type {dtype} received.' ) @@ -161,15 +162,15 @@ def align_column_data( Align data from an Spotlight dataset column if possible. """ dtype = dataset.get_dtype(column) - if dtype.name == "Audio": + if dtypes.is_audio_dtype(dtype): data, mask = align_audio_data(dataset, column) - elif dtype.name == "Embedding": + elif dtypes.is_embedding_dtype(dtype): data, mask = align_embedding_data(dataset, column) - elif dtype.name == "Image": + elif dtypes.is_image_dtype(dtype): data, mask = align_image_data(dataset, column) - elif dtype.name == "Sequence1D": + elif dtypes.is_sequence_1d_dtype(dtype): data, mask = align_sequence_1d_data(dataset, column) - elif dtype.name in ("bool", "int", "float", "Window"): + elif dtypes.is_scalar_dtype(dtype) or dtypes.is_window_dtype(dtype): data = dataset[column].astype(np.float64).reshape((len(dataset), -1)) mask = np.full(len(dataset), True) else: diff --git a/renumics/spotlight/dtypes/conversion.py b/renumics/spotlight/dtypes/conversion.py index 0f4a97cc..48c90dde 100644 --- a/renumics/spotlight/dtypes/conversion.py +++ b/renumics/spotlight/dtypes/conversion.py @@ -37,6 +37,7 @@ import trimesh import PIL.Image import validators +from renumics.spotlight import dtypes from renumics.spotlight.typing import PathOrUrlType, PathType from renumics.spotlight.cache import external_data_cache @@ -198,18 +199,18 @@ def convert_to_dtype( try: if value is None: return None - if dtype.name == "bool": + if dtypes.is_bool_dtype(dtype): return bool(value) # type: ignore - if dtype.name == "int": + if dtypes.is_int_dtype(dtype): return int(value) # type: ignore - if dtype.name == "float": + if dtypes.is_float_dtype(dtype): return float(value) # type: ignore - if dtype.name == "str": + if dtypes.is_str_dtype(dtype): str_value = str(value) if simple and len(str_value) > 100: return str_value[:97] + "..." return str_value - if dtype.name == "array": + if dtypes.is_array_dtype(dtype): if simple: return "[...]" if isinstance(value, list): @@ -502,7 +503,7 @@ def _decode_external_value( """ path_or_url = prepare_path_or_url(path_or_url, workdir) - if dtype.name == "Audio": + if dtypes.is_audio_dtype(dtype): file = audio.prepare_input_file(path_or_url, reusable=True) # `file` is a filepath of type `str` or an URL downloaded as `io.BytesIO`. input_format, input_codec = audio.get_format_codec(file) @@ -534,7 +535,7 @@ def _decode_external_value( audio.transcode_audio(file, buffer, output_format, output_codec) return np.void(buffer.getvalue()) - if dtype.name == "Image": + if dtypes.is_image_dtype(dtype): with as_file(path_or_url) as file: kind = filetype.guess(file) if kind is not None and kind.mime.split("/")[1] in ( @@ -551,8 +552,8 @@ def _decode_external_value( # `image/tiff`s become blank in frontend, so convert them too. return Image.from_file(file).encode(target_format) - if dtype.name == "Mesh": + if dtypes.is_mesh_dtype(dtype): return Mesh.from_file(path_or_url).encode(target_format) - if dtype.name == "Video": + if dtypes.is_video_dtype(dtype): return Video.from_file(path_or_url).encode(target_format) assert False diff --git a/renumics/spotlight/io/pandas.py b/renumics/spotlight/io/pandas.py index 19f62708..3cba1736 100644 --- a/renumics/spotlight/io/pandas.py +++ b/renumics/spotlight/io/pandas.py @@ -235,6 +235,8 @@ def is_string_mask(column: pd.Series) -> pd.Series: """ Return mask of column's elements of type string. """ + if len(column) == 0: + return pd.Series([], dtype=bool) return column.map(type) == str diff --git a/tests/integration/dataset/test_dataset.py b/tests/integration/dataset/test_dataset.py index 7c2aa186..10f7ed75 100644 --- a/tests/integration/dataset/test_dataset.py +++ b/tests/integration/dataset/test_dataset.py @@ -1030,4 +1030,8 @@ def test_to_pandas() -> None: column_name = dtype.name assert column_name in df inferred_dtype = infer_dtype(df[column_name]) - assert inferred_dtype.name == dtype.name + if dtypes.is_category_dtype(dtype): + assert dtypes.is_category_dtype(inferred_dtype) + assert inferred_dtype.categories == dtype.categories + else: + assert inferred_dtype.name == dtype.name