Skip to content

Commit

Permalink
Merge pull request #376 from Renumics/feature/aabb-dtype-in-dataset
Browse files Browse the repository at this point in the history
Feature/aabb dtype in dataset
  • Loading branch information
druzsan authored Nov 21, 2023
2 parents ab6c0d9 + 23c3d89 commit 9f6b824
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 29 deletions.
126 changes: 110 additions & 16 deletions renumics/spotlight/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
cast,
)

import PIL.Image
import h5py
import numpy as np
import pandas as pd
Expand All @@ -46,6 +47,7 @@
from . import exceptions
from .pandas import create_typed_series, infer_dtypes, is_string_mask, prepare_column
from .typing import (
BoundingBoxColumnInputType,
OutputType,
ExternalOutputType,
BoolColumnInputType,
Expand Down Expand Up @@ -82,11 +84,6 @@
"int": int,
"float": float,
"str": str,
"datetime": datetime,
"Category": spotlight_dtypes.Category,
"array": np.ndarray,
"Window": spotlight_dtypes.Window,
"Embedding": spotlight_dtypes.Embedding,
"Sequence1D": spotlight_dtypes.Sequence1D,
"Audio": spotlight_dtypes.Audio,
"Image": spotlight_dtypes.Image,
Expand Down Expand Up @@ -152,10 +149,20 @@ def unescape_dataset_name(escaped_name: str) -> str:
np.floating,
),
"Window": (np.ndarray, list, tuple, range),
"BoundingBox": (np.ndarray, list, tuple, range),
"Embedding": (spotlight_dtypes.Embedding, np.ndarray, list, tuple, range),
"Sequence1D": (spotlight_dtypes.Sequence1D, np.ndarray, list, tuple, range),
"Audio": (spotlight_dtypes.Audio, bytes, str, os.PathLike),
"Image": (spotlight_dtypes.Image, bytes, str, os.PathLike, np.ndarray, list, tuple),
"Image": (
spotlight_dtypes.Image,
bytes,
str,
os.PathLike,
np.ndarray,
list,
tuple,
PIL.Image.Image,
),
"Mesh": (spotlight_dtypes.Mesh, trimesh.Trimesh, str, os.PathLike),
"Video": (spotlight_dtypes.Video, bytes, str, os.PathLike),
}
Expand All @@ -165,6 +172,7 @@ def unescape_dataset_name(escaped_name: str) -> str:
"float": (np.floating,),
"datetime": (np.datetime64,),
"Window": (np.floating,),
"BoundingBox": (np.floating,),
"Embedding": (np.floating,),
}

Expand Down Expand Up @@ -232,6 +240,7 @@ def _user_column_attributes(dtype: spotlight_dtypes.DType) -> Dict[str, Type]:
or spotlight_dtypes.is_str_dtype(dtype)
or spotlight_dtypes.is_category_dtype(dtype)
or spotlight_dtypes.is_window_dtype(dtype)
or spotlight_dtypes.is_bounding_box_dtype(dtype)
):
attribute_names["editable"] = bool
if spotlight_dtypes.is_category_dtype(dtype):
Expand Down Expand Up @@ -259,7 +268,9 @@ def _default_default(cls, dtype: spotlight_dtypes.DType) -> Any:
if spotlight_dtypes.is_datetime_dtype(dtype):
return np.datetime64("NaT")
if spotlight_dtypes.is_window_dtype(dtype):
return [np.nan, np.nan]
return np.full(2, np.nan)
if spotlight_dtypes.is_bounding_box_dtype(dtype):
return np.full(4, np.nan)
return None

def __init__(self, filepath: PathType, mode: str):
Expand Down Expand Up @@ -1631,6 +1642,52 @@ def append_window_column(
editable=editable,
)

def append_bounding_box_column(
self,
name: str,
values: Optional[
Union[BoundingBoxColumnInputType, Iterable[BoundingBoxColumnInputType]]
] = None,
order: Optional[int] = None,
hidden: bool = False,
optional: bool = False,
default: BoundingBoxColumnInputType = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
editable: bool = True,
) -> None:
"""
Create and optionally fill axis-aligned bounding box column.
Args:
name: Column name.
values: Optional column values. If a single value, the whole column
filled with this value.
order: Optional Spotlight priority order value. `None` means the
lowest priority.
hidden: Whether column is hidden in Spotlight.
optional: Whether column is optional. If `default` other than `None`
is specified, `optional` is automatically set to `True`.
default: Value to use by default if column is optional and no value
or `None` is given.
description: Optional column description.
tags: Optional tags for the column.
editable: Whether column is editable in Spotlight.
"""
self._append_column(
name,
spotlight_dtypes.bounding_box_dtype,
values,
np.dtype("float32"),
order,
hidden,
optional,
default,
description,
tags,
editable=editable,
)

def append_column(
self,
name: str,
Expand Down Expand Up @@ -1707,6 +1764,8 @@ def append_column(
append_column_fn = self.append_array_column
elif spotlight_dtypes.is_window_dtype(dtype):
append_column_fn = self.append_window_column
elif spotlight_dtypes.is_bounding_box_dtype(dtype):
append_column_fn = self.append_bounding_box_column
elif spotlight_dtypes.is_embedding_dtype(dtype):
append_column_fn = self.append_embedding_column
if dtype.length is not None:
Expand Down Expand Up @@ -1884,7 +1943,9 @@ def isnull(self, column_name: str) -> np.ndarray:
return np.isnan(raw_values)
if spotlight_dtypes.is_category_dtype(dtype):
return raw_values == -1
if spotlight_dtypes.is_window_dtype(dtype):
if spotlight_dtypes.is_window_dtype(
dtype
) or spotlight_dtypes.is_bounding_box_dtype(dtype):
return np.isnan(raw_values).all(axis=1)
if spotlight_dtypes.is_embedding_dtype(dtype):
return np.array([len(x) == 0 for x in raw_values])
Expand Down Expand Up @@ -2477,6 +2538,9 @@ def _append_column(
elif spotlight_dtypes.is_window_dtype(dtype):
shape = (0, 2)
maxshape = (None, 2)
elif spotlight_dtypes.is_bounding_box_dtype(dtype):
shape = (0, 4)
maxshape = (None, 4)
elif spotlight_dtypes.is_sequence_1d_dtype(dtype):
attrs["x_label"] = dtype.x_label
attrs["y_label"] = dtype.y_label
Expand Down Expand Up @@ -2566,7 +2630,7 @@ def _set_column(

encoded_values: Union[np.ndarray, List[_EncodedColumnType]]
if is_iterable(values):
# Single windows and embeddings also come here.
# Single windows, bounding boxes and embeddings also come here.
encoded_values = self._encode_values(values, column)
if len(encoded_values) != indices_length:
if indices_length == 0:
Expand Down Expand Up @@ -2729,14 +2793,16 @@ def _decode_simple_values(
if spotlight_dtypes.is_embedding_dtype(dtype):
null_mask = [len(x) == 0 for x in values]
values[null_mask] = None
# For column types `bool`, `int`, `float` or `Window`, return the array as-is.
# For bool, int, float, window or bounding box columns, return the array as-is.
return values

def _decode_ref_values(
self, values: np.ndarray, column: h5py.Dataset, dtype: spotlight_dtypes.DType
) -> np.ndarray:
column_name = self._get_column_name(column)
if dtype.name in ("array", "Embedding"):
if spotlight_dtypes.is_array_dtype(
dtype
) or spotlight_dtypes.is_embedding_dtype(dtype):
# `np.array([<...>], dtype=object)` creation does not work for
# some cases and erases dtypes of sub-arrays, so we use assignment.
decoded_values = np.empty(len(values), dtype=object)
Expand Down Expand Up @@ -2882,10 +2948,28 @@ def _encode_simple_values(
return encoded_values
column_name = self._get_column_name(column)
raise exceptions.InvalidShapeError(
f'Input values to `Window` column "{column_name}" should have '
f'Input values to window column "{column_name}" should have '
f"one of shapes (2,) (a single window) or (n, 2) (multiple "
f"windows), but values with shape {encoded_values.shape} received."
)
if spotlight_dtypes.is_bounding_box_dtype(dtype):
encoded_values = self._asarray(values, column, dtype)
if encoded_values.ndim == 1:
if len(encoded_values) == 4:
# A single bounding box, reshape it to an array.
return np.broadcast_to(values, (1, 4)) # type: ignore
if len(encoded_values) == 0:
# An empty array, reshape for compatibility.
return np.broadcast_to(values, (0, 4)) # type: ignore
elif encoded_values.ndim == 2 and encoded_values.shape[1] == 4:
# An array with valid bounding boxes.
return encoded_values
column_name = self._get_column_name(column)
raise exceptions.InvalidShapeError(
f'Input values to bounding box column "{column_name}" should have '
f"one of shapes (4,) (a single bounding box) or (n, 4) (multiple "
f"bounding boxes), but values with shape {encoded_values.shape} received."
)
if spotlight_dtypes.is_embedding_dtype(dtype):
if _check_valid_array(values, dtype):
# This is the only case we can handle fast and easily, otherwise
Expand Down Expand Up @@ -3110,9 +3194,17 @@ def _encode_simple_value(
value = np.asarray(value, dtype=column.dtype)
if value.shape == (2,):
return value
raise exceptions.InvalidDTypeError(
raise exceptions.InvalidShapeError(
f"Windows should consist of 2 values, but window of shape "
f"{value.shape} received for column {column_name}."
f"{value.shape} received for column '{column_name}'."
)
if spotlight_dtypes.is_bounding_box_dtype(dtype):
value = np.asarray(value, dtype=column.dtype)
if value.shape == (4,):
return value
raise exceptions.InvalidShapeError(
f"Bounding boxes should consist of 4 values, but bounding box "
f"of shape {value.shape} received for column '{column_name}'."
)
if spotlight_dtypes.is_embedding_dtype(dtype):
# `Embedding` column is not a ref column.
Expand Down Expand Up @@ -3306,8 +3398,10 @@ def _decode_simple_value(
value: Union[np.bool_, np.integer, np.floating, bytes, str, np.ndarray],
dtype: spotlight_dtypes.DType,
) -> Optional[Union[bool, int, float, str, datetime, np.ndarray]]:
if spotlight_dtypes.is_window_dtype(dtype):
return value # type: ignore
if spotlight_dtypes.is_window_dtype(
dtype
) or spotlight_dtypes.is_bounding_box_dtype(dtype):
return cast(np.ndarray, value)
if spotlight_dtypes.is_embedding_dtype(dtype):
value = cast(np.ndarray, value)
if len(value) == 0:
Expand Down
8 changes: 8 additions & 0 deletions renumics/spotlight/dataset/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@
WindowColumnInputType = Optional[
Union[List[NumberType], Tuple[NumberType, NumberType], np.ndarray]
]
BoundingBoxColumnInputType = Optional[
Union[
List[NumberType],
Tuple[NumberType, NumberType, NumberType, NumberType],
np.ndarray,
]
]
ArrayColumnInputType = Optional[Union[np.ndarray, Sequence]]
EmbeddingColumnInputType = Optional[Union[Embedding, Array1dLike]]
AudioColumnInputType = Optional[Union[Audio, PathOrUrlType, bytes]]
Expand All @@ -56,6 +63,7 @@
DatetimeColumnInputType,
CategoricalColumnInputType,
WindowColumnInputType,
BoundingBoxColumnInputType,
EmbeddingColumnInputType,
]
RefColumnInputType = Union[
Expand Down
1 change: 1 addition & 0 deletions renumics/spotlight/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def register_dtype(dtype: DType, aliases: Optional[list] = None) -> None:
coordinates [x_min, y_min, x_max, y_max] (float values scaled onto 0 to 1).
Top-left image corner is assumed to be (0, 0).
"""
register_dtype(bounding_box_dtype, [])

bounding_boxes_dtype = SequenceDType(bounding_box_dtype)
"""
Expand Down
22 changes: 9 additions & 13 deletions renumics/spotlight_plugins/core/hdf5_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@
CouldNotOpenTableFile,
)
from renumics.spotlight.data_source.data_source import ColumnMetadata
from renumics.spotlight.dtypes import (
DTypeMap,
create_dtype,
is_embedding_dtype,
is_window_dtype,
)
from renumics.spotlight import dtypes


@datasource(".h5")
Expand Down Expand Up @@ -66,16 +61,16 @@ def column_names(self) -> List[str]:
return column_names

@property
def intermediate_dtypes(self) -> DTypeMap:
def intermediate_dtypes(self) -> dtypes.DTypeMap:
return self.semantic_dtypes

def __len__(self) -> int:
return len(self._table)

@property
def semantic_dtypes(self) -> DTypeMap:
def semantic_dtypes(self) -> dtypes.DTypeMap:
return {
column_name: create_dtype(self._table.get_dtype(column_name))
column_name: dtypes.create_dtype(self._table.get_dtype(column_name))
for column_name in self.column_names
}

Expand Down Expand Up @@ -106,6 +101,7 @@ def get_column_values(
self._table._assert_column_exists(column_name, internal=True)

column = cast(h5py.Dataset, self._table._h5_file[column_name])
dtype = self._table._get_dtype(column)
is_string_dtype = h5py.check_string_dtype(column.dtype)

raw_values = column[indices]
Expand All @@ -122,11 +118,11 @@ def get_column_values(
else:
value = self._table._resolve_ref(ref, column_name)[()]
yield value.tolist() if isinstance(value, np.void) else value
elif is_embedding_dtype(self._table._get_dtype(column)):
for x in raw_values:
yield None if len(x) == 0 else x
elif is_window_dtype(self._table._get_dtype(column)):
elif dtypes.is_window_dtype(dtype) or dtypes.is_bounding_box_dtype(dtype):
for x in raw_values:
yield None if np.isnan(x).all() else x
elif dtypes.is_embedding_dtype(dtype):
for x in raw_values:
yield None if len(x) == 0 else x
else:
yield from raw_values
Loading

0 comments on commit 9f6b824

Please sign in to comment.