Skip to content

Commit

Permalink
Merge pull request #311 from Renumics/fix/232-fallback-for-simple-con…
Browse files Browse the repository at this point in the history
…verters

Fix/232 fallback for simple converters
  • Loading branch information
druzsan authored Oct 26, 2023
2 parents d16dbb3 + 122015f commit 8b363d8
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 74 deletions.
148 changes: 75 additions & 73 deletions renumics/spotlight/dtypes/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,14 @@
import trimesh
import PIL.Image
import validators
from renumics.spotlight import dtypes

from renumics.spotlight import dtypes, media
from renumics.spotlight.typing import PathOrUrlType, PathType
from renumics.spotlight.cache import external_data_cache
from renumics.spotlight.io import audio
from renumics.spotlight.io.file import as_file
from renumics.spotlight.media.exceptions import InvalidFile
from renumics.spotlight.backend.exceptions import Problem
from renumics.spotlight.dtypes import (
CategoryDType,
DType,
audio_dtype,
image_dtype,
mesh_dtype,
video_dtype,
)
from renumics.spotlight.media import Sequence1D, Image, Audio, Video, Mesh


NormalizedValue = Union[
Expand Down Expand Up @@ -96,7 +87,7 @@ class ConversionFailed(Problem):
def __init__(
self,
value: NormalizedValue,
dtype: DType,
dtype: dtypes.DType,
reason: Optional[str] = None,
) -> None:
super().__init__(
Expand All @@ -116,14 +107,14 @@ class NoConverterAvailable(Problem):
No matching converter could be applied
"""

def __init__(self, value: NormalizedValue, dtype: DType) -> None:
def __init__(self, value: NormalizedValue, dtype: dtypes.DType) -> None:
msg = f"No Converter for {type(value)} -> {dtype}"
super().__init__(title="No matching converter", detail=msg, status_code=422)


N = TypeVar("N", bound=NormalizedValue)

Converter = Callable[[N, DType], ConvertedValue]
Converter = Callable[[N, dtypes.DType], ConvertedValue]
_converters_table: Dict[
Type[NormalizedValue], Dict[str, List[Converter]]
] = defaultdict(lambda: defaultdict(list))
Expand Down Expand Up @@ -176,7 +167,10 @@ def _decorate(func: Converter[N]) -> Converter[N]:


def convert_to_dtype(
value: NormalizedValue, dtype: DType, simple: bool = False, check: bool = True
value: NormalizedValue,
dtype: dtypes.DType,
simple: bool = False,
check: bool = True,
) -> ConvertedValue:
"""
Convert normalized type from data source to internal Spotlight DType
Expand Down Expand Up @@ -221,45 +215,49 @@ def convert_to_dtype(
except (TypeError, ValueError) as e:
if check:
raise ConversionFailed(value, dtype) from e
else:
return None

if check:
if last_conversion_error:
raise ConversionFailed(value, dtype, last_conversion_error.reason)
else:
else:
if check:
if last_conversion_error:
raise ConversionFailed(value, dtype, last_conversion_error.reason)
raise NoConverterAvailable(value, dtype)

if simple and (
dtypes.is_array_dtype(dtype)
or dtypes.is_embedding_dtype(dtype)
or dtypes.is_sequence_1d_dtype(dtype)
or dtypes.is_filebased_dtype(dtype)
):
return "<invalid>"
return None


@convert("datetime")
def _(value: datetime.datetime, _: DType) -> datetime.datetime:
def _(value: datetime.datetime, _: dtypes.DType) -> datetime.datetime:
return value


@convert("datetime")
def _(value: Union[str, np.str_], _: DType) -> Optional[datetime.datetime]:
def _(value: Union[str, np.str_], _: dtypes.DType) -> Optional[datetime.datetime]:
if value == "":
return None
return datetime.datetime.fromisoformat(value)


@convert("datetime")
def _(value: np.datetime64, _: DType) -> Optional[datetime.datetime]:
def _(value: np.datetime64, _: dtypes.DType) -> Optional[datetime.datetime]:
return value.tolist()


@convert("Category")
def _(value: Union[str, np.str_], dtype: CategoryDType) -> int:
def _(value: Union[str, np.str_], dtype: dtypes.CategoryDType) -> int:
categories = dtype.categories
if not categories:
return -1
return categories[value]


@convert("Category")
def _(_: None, _dtype: CategoryDType) -> int:
def _(_: None, _dtype: dtypes.CategoryDType) -> int:
return -1


Expand All @@ -276,23 +274,23 @@ def _(
np.uint32,
np.uint64,
],
_: CategoryDType,
_: dtypes.CategoryDType,
) -> int:
return int(value)


@convert("Window")
def _(value: list, _: DType) -> np.ndarray:
def _(value: list, _: dtypes.DType) -> np.ndarray:
return np.array(value, dtype=np.float64)


@convert("Window")
def _(value: np.ndarray, _: DType) -> np.ndarray:
def _(value: np.ndarray, _: dtypes.DType) -> np.ndarray:
return value.astype(np.float64)


@convert("Window")
def _(value: Union[str, np.str_], _: DType) -> np.ndarray:
def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray:
try:
obj = ast.literal_eval(value)
array = np.array(obj, dtype=np.float64)
Expand All @@ -304,17 +302,17 @@ def _(value: Union[str, np.str_], _: DType) -> np.ndarray:


@convert("Embedding", simple=False)
def _(value: list, _: DType) -> np.ndarray:
def _(value: list, _: dtypes.DType) -> np.ndarray:
return np.array(value, dtype=np.float64)


@convert("Embedding", simple=False)
def _(value: np.ndarray, _: DType) -> np.ndarray:
def _(value: np.ndarray, _: dtypes.DType) -> np.ndarray:
return value.astype(np.float64)


@convert("Embedding", simple=False)
def _(value: Union[str, np.str_], _: DType) -> np.ndarray:
def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray:
try:
obj = ast.literal_eval(value)
array = np.array(obj, dtype=np.float64)
Expand All @@ -326,138 +324,142 @@ def _(value: Union[str, np.str_], _: DType) -> np.ndarray:


@convert("Sequence1D", simple=False)
def _(value: Union[np.ndarray, list], _: DType) -> np.ndarray:
return Sequence1D(value).encode()
def _(value: Union[np.ndarray, list], _: dtypes.DType) -> np.ndarray:
return media.Sequence1D(value).encode()


@convert("Sequence1D", simple=False)
def _(value: Union[str, np.str_], _: DType) -> np.ndarray:
def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray:
try:
obj = ast.literal_eval(value)
return Sequence1D(obj).encode()
return media.Sequence1D(obj).encode()
except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
raise ConversionError("Cannot interpret string as a 1D sequence")


@convert("Image", simple=False)
def _(value: Union[str, np.str_], _: DType) -> bytes:
def _(value: Union[np.ndarray, list], _: dtypes.DType) -> bytes:
return media.Image(value).encode().tolist()


@convert("Image", simple=False)
def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes:
try:
if data := read_external_value(value, image_dtype):
if data := read_external_value(value, dtypes.image_dtype):
return data.tolist()
except InvalidFile:
raise ConversionError()
try:
obj = ast.literal_eval(value)
return media.Image(obj).encode().tolist()
except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
raise ConversionError()
raise ConversionError()


@convert("Image", simple=False)
def _(value: Union[bytes, np.bytes_], _: DType) -> bytes:
return Image.from_bytes(value).encode().tolist()


@convert("Image", simple=False)
def _(value: np.ndarray, _: DType) -> bytes:
return Image(value).encode().tolist()
def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes:
return media.Image.from_bytes(value).encode().tolist()


@convert("Image", simple=False)
def _(value: PIL.Image.Image, _: DType) -> bytes:
def _(value: PIL.Image.Image, _: dtypes.DType) -> bytes:
buffer = io.BytesIO()
value.save(buffer, format="PNG")
return buffer.getvalue()


@convert("Audio", simple=False)
def _(value: Union[str, np.str_], _: DType) -> bytes:
def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes:
try:
if data := read_external_value(value, audio_dtype):
if data := read_external_value(value, dtypes.audio_dtype):
return data.tolist()
except (InvalidFile, IndexError, ValueError):
raise ConversionError()
raise ConversionError()


@convert("Audio", simple=False)
def _(value: Union[bytes, np.bytes_], _: DType) -> bytes:
return Audio.from_bytes(value).encode().tolist()
def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes:
return media.Audio.from_bytes(value).encode().tolist()


@convert("Video", simple=False)
def _(value: Union[str, np.str_], _: DType) -> bytes:
def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes:
try:
if data := read_external_value(value, video_dtype):
if data := read_external_value(value, dtypes.video_dtype):
return data.tolist()
except InvalidFile:
raise ConversionError()
raise ConversionError()


@convert("Video", simple=False)
def _(value: Union[bytes, np.bytes_], _: DType) -> bytes:
return Video.from_bytes(value).encode().tolist()
def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes:
return media.Video.from_bytes(value).encode().tolist()


@convert("Mesh", simple=False)
def _(value: Union[str, np.str_], _: DType) -> bytes:
def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes:
try:
if data := read_external_value(value, mesh_dtype):
if data := read_external_value(value, dtypes.mesh_dtype):
return data.tolist()
except InvalidFile:
raise ConversionError()
raise ConversionError()


@convert("Mesh", simple=False)
def _(value: Union[bytes, np.bytes_], _: DType) -> bytes:
def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes:
return value


# this should not be necessary
@convert("Mesh", simple=False) # type: ignore
def _(value: trimesh.Trimesh, _: DType) -> bytes:
return Mesh.from_trimesh(value).encode().tolist()
def _(value: trimesh.Trimesh, _: dtypes.DType) -> bytes:
return media.Mesh.from_trimesh(value).encode().tolist()


@convert("Embedding", simple=True)
@convert("Sequence1D", simple=True)
def _(_: Union[np.ndarray, list, str, np.str_], _dtype: DType) -> str:
def _(_: Union[np.ndarray, list, str, np.str_], _dtype: dtypes.DType) -> str:
return "[...]"


@convert("Image", simple=True)
def _(_: np.ndarray, _dtype: DType) -> str:
def _(_: Union[np.ndarray, list], _dtype: dtypes.DType) -> str:
return "[...]"


@convert("Image", simple=True)
def _(_: PIL.Image.Image, _dtype: DType) -> str:
def _(_: PIL.Image.Image, _dtype: dtypes.DType) -> str:
return "<PIL.Image>"


@convert("Image", simple=True)
@convert("Audio", simple=True)
@convert("Video", simple=True)
@convert("Mesh", simple=True)
def _(value: Union[str, np.str_], _: DType) -> str:
def _(value: Union[str, np.str_], _: dtypes.DType) -> str:
return str(value)


@convert("Image", simple=True)
@convert("Audio", simple=True)
@convert("Video", simple=True)
@convert("Mesh", simple=True)
def _(_: Union[bytes, np.bytes_], _dtype: DType) -> str:
def _(_: Union[bytes, np.bytes_], _dtype: dtypes.DType) -> str:
return "<bytes>"


# this should not be necessary
@convert("Mesh", simple=True) # type: ignore
def _(_: trimesh.Trimesh, _dtype: DType) -> str:
def _(_: trimesh.Trimesh, _dtype: dtypes.DType) -> str:
return "<Trimesh>"


def read_external_value(
path_or_url: Optional[str],
dtype: DType,
dtype: dtypes.DType,
target_format: Optional[str] = None,
workdir: PathType = ".",
) -> Optional[np.void]:
Expand Down Expand Up @@ -494,7 +496,7 @@ def prepare_path_or_url(path_or_url: PathOrUrlType, workdir: PathType) -> str:

def _decode_external_value(
path_or_url: PathOrUrlType,
dtype: DType,
dtype: dtypes.DType,
target_format: Optional[str] = None,
workdir: PathType = ".",
) -> np.void:
Expand Down Expand Up @@ -524,7 +526,7 @@ def _decode_external_value(
# Convert all other formats/codecs to flac.
output_format, output_codec = "flac", "flac"
else:
output_format, output_codec = Audio.get_format_codec(target_format)
output_format, output_codec = media.Audio.get_format_codec(target_format)
if output_format == input_format and output_codec == input_codec:
# Nothing to transcode
if isinstance(file, str):
Expand All @@ -550,10 +552,10 @@ def _decode_external_value(
):
return np.void(file.read())
# `image/tiff`s become blank in frontend, so convert them too.
return Image.from_file(file).encode(target_format)
return media.Image.from_file(file).encode(target_format)

if dtypes.is_mesh_dtype(dtype):
return Mesh.from_file(path_or_url).encode(target_format)
return media.Mesh.from_file(path_or_url).encode(target_format)
if dtypes.is_video_dtype(dtype):
return Video.from_file(path_or_url).encode(target_format)
return media.Video.from_file(path_or_url).encode(target_format)
assert False
2 changes: 1 addition & 1 deletion renumics/spotlight_plugins/core/api/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_table(request: Request) -> ORJSONResponse:
columns = []
for column_name in data_store.column_names:
dtype = data_store.dtypes[column_name]
values = data_store.get_converted_values(column_name, simple=True)
values = data_store.get_converted_values(column_name, simple=True, check=False)
meta = data_store.get_column_metadata(column_name)
column = Column(
name=column_name,
Expand Down

0 comments on commit 8b363d8

Please sign in to comment.