Skip to content

Commit

Permalink
Standardize imports
Browse files Browse the repository at this point in the history
  • Loading branch information
druzsan committed Oct 25, 2023
1 parent 6d53da2 commit 723092a
Showing 1 changed file with 56 additions and 62 deletions.
118 changes: 56 additions & 62 deletions renumics/spotlight/dtypes/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,13 @@
import PIL.Image
import validators

from renumics.spotlight import dtypes
from renumics.spotlight import dtypes, media
from renumics.spotlight.typing import PathOrUrlType, PathType
from renumics.spotlight.cache import external_data_cache
from renumics.spotlight.io import audio
from renumics.spotlight.io.file import as_file
from renumics.spotlight.media.exceptions import InvalidFile
from renumics.spotlight.backend.exceptions import Problem
from renumics.spotlight.dtypes import (
CategoryDType,
DType,
audio_dtype,
image_dtype,
mesh_dtype,
video_dtype,
)
from renumics.spotlight.media import Sequence1D, Image, Audio, Video, Mesh


NormalizedValue = Union[
Expand Down Expand Up @@ -96,7 +87,7 @@ class ConversionFailed(Problem):
def __init__(
self,
value: NormalizedValue,
dtype: DType,
dtype: dtypes.DType,
reason: Optional[str] = None,
) -> None:
super().__init__(
Expand All @@ -116,14 +107,14 @@ class NoConverterAvailable(Problem):
No matching converter could be applied
"""

def __init__(self, value: NormalizedValue, dtype: DType) -> None:
def __init__(self, value: NormalizedValue, dtype: dtypes.DType) -> None:
msg = f"No Converter for {type(value)} -> {dtype}"
super().__init__(title="No matching converter", detail=msg, status_code=422)


N = TypeVar("N", bound=NormalizedValue)

Converter = Callable[[N, DType], ConvertedValue]
Converter = Callable[[N, dtypes.DType], ConvertedValue]
_converters_table: Dict[
Type[NormalizedValue], Dict[str, List[Converter]]
] = defaultdict(lambda: defaultdict(list))
Expand Down Expand Up @@ -176,7 +167,10 @@ def _decorate(func: Converter[N]) -> Converter[N]:


def convert_to_dtype(
value: NormalizedValue, dtype: DType, simple: bool = False, check: bool = True
value: NormalizedValue,
dtype: dtypes.DType,
simple: bool = False,
check: bool = True,
) -> ConvertedValue:
"""
Convert normalized type from data source to internal Spotlight DType
Expand Down Expand Up @@ -238,32 +232,32 @@ def convert_to_dtype(


@convert("datetime")
def _(value: datetime.datetime, _: DType) -> datetime.datetime:
def _(value: datetime.datetime, _: dtypes.DType) -> datetime.datetime:
return value


@convert("datetime")
def _(value: Union[str, np.str_], _: DType) -> Optional[datetime.datetime]:
def _(value: Union[str, np.str_], _: dtypes.DType) -> Optional[datetime.datetime]:
if value == "":
return None
return datetime.datetime.fromisoformat(value)


@convert("datetime")
def _(value: np.datetime64, _: DType) -> Optional[datetime.datetime]:
def _(value: np.datetime64, _: dtypes.DType) -> Optional[datetime.datetime]:
return value.tolist()


@convert("Category")
def _(value: Union[str, np.str_], dtype: CategoryDType) -> int:
def _(value: Union[str, np.str_], dtype: dtypes.CategoryDType) -> int:
categories = dtype.categories
if not categories:
return -1
return categories[value]


@convert("Category")
def _(_: None, _dtype: CategoryDType) -> int:
def _(_: None, _dtype: dtypes.CategoryDType) -> int:
return -1


Expand All @@ -280,23 +274,23 @@ def _(
np.uint32,
np.uint64,
],
_: CategoryDType,
_: dtypes.CategoryDType,
) -> int:
return int(value)


@convert("Window")
def _(value: list, _: DType) -> np.ndarray:
def _(value: list, _: dtypes.DType) -> np.ndarray:
return np.array(value, dtype=np.float64)


@convert("Window")
def _(value: np.ndarray, _: DType) -> np.ndarray:
def _(value: np.ndarray, _: dtypes.DType) -> np.ndarray:
return value.astype(np.float64)


@convert("Window")
def _(value: Union[str, np.str_], _: DType) -> np.ndarray:
def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray:
try:
obj = ast.literal_eval(value)
array = np.array(obj, dtype=np.float64)
Expand All @@ -308,17 +302,17 @@ def _(value: Union[str, np.str_], _: DType) -> np.ndarray:


@convert("Embedding", simple=False)
def _(value: list, _: DType) -> np.ndarray:
def _(value: list, _: dtypes.DType) -> np.ndarray:
return np.array(value, dtype=np.float64)


@convert("Embedding", simple=False)
def _(value: np.ndarray, _: DType) -> np.ndarray:
def _(value: np.ndarray, _: dtypes.DType) -> np.ndarray:
return value.astype(np.float64)


@convert("Embedding", simple=False)
def _(value: Union[str, np.str_], _: DType) -> np.ndarray:
def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray:
try:
obj = ast.literal_eval(value)
array = np.array(obj, dtype=np.float64)
Expand All @@ -330,138 +324,138 @@ def _(value: Union[str, np.str_], _: DType) -> np.ndarray:


@convert("Sequence1D", simple=False)
def _(value: Union[np.ndarray, list], _: DType) -> np.ndarray:
return Sequence1D(value).encode()
def _(value: Union[np.ndarray, list], _: dtypes.DType) -> np.ndarray:
return media.Sequence1D(value).encode()


@convert("Sequence1D", simple=False)
def _(value: Union[str, np.str_], _: DType) -> np.ndarray:
def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray:
try:
obj = ast.literal_eval(value)
return Sequence1D(obj).encode()
return media.Sequence1D(obj).encode()
except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
raise ConversionError("Cannot interpret string as a 1D sequence")


@convert("Image", simple=False)
def _(value: Union[str, np.str_], _: DType) -> bytes:
def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes:
try:
if data := read_external_value(value, image_dtype):
if data := read_external_value(value, dtypes.image_dtype):
return data.tolist()
except InvalidFile:
raise ConversionError()
raise ConversionError()


@convert("Image", simple=False)
def _(value: Union[bytes, np.bytes_], _: DType) -> bytes:
return Image.from_bytes(value).encode().tolist()
def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes:
return media.Image.from_bytes(value).encode().tolist()


@convert("Image", simple=False)
def _(value: np.ndarray, _: DType) -> bytes:
return Image(value).encode().tolist()
def _(value: np.ndarray, _: dtypes.DType) -> bytes:
return media.Image(value).encode().tolist()


@convert("Image", simple=False)
def _(value: PIL.Image.Image, _: DType) -> bytes:
def _(value: PIL.Image.Image, _: dtypes.DType) -> bytes:
buffer = io.BytesIO()
value.save(buffer, format="PNG")
return buffer.getvalue()


@convert("Audio", simple=False)
def _(value: Union[str, np.str_], _: DType) -> bytes:
def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes:
try:
if data := read_external_value(value, audio_dtype):
if data := read_external_value(value, dtypes.audio_dtype):
return data.tolist()
except (InvalidFile, IndexError, ValueError):
raise ConversionError()
raise ConversionError()


@convert("Audio", simple=False)
def _(value: Union[bytes, np.bytes_], _: DType) -> bytes:
return Audio.from_bytes(value).encode().tolist()
def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes:
return media.Audio.from_bytes(value).encode().tolist()


@convert("Video", simple=False)
def _(value: Union[str, np.str_], _: DType) -> bytes:
def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes:
try:
if data := read_external_value(value, video_dtype):
if data := read_external_value(value, dtypes.video_dtype):
return data.tolist()
except InvalidFile:
raise ConversionError()
raise ConversionError()


@convert("Video", simple=False)
def _(value: Union[bytes, np.bytes_], _: DType) -> bytes:
return Video.from_bytes(value).encode().tolist()
def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes:
return media.Video.from_bytes(value).encode().tolist()


@convert("Mesh", simple=False)
def _(value: Union[str, np.str_], _: DType) -> bytes:
def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes:
try:
if data := read_external_value(value, mesh_dtype):
if data := read_external_value(value, dtypes.mesh_dtype):
return data.tolist()
except InvalidFile:
raise ConversionError()
raise ConversionError()


@convert("Mesh", simple=False)
def _(value: Union[bytes, np.bytes_], _: DType) -> bytes:
def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes:
return value


# this should not be necessary
@convert("Mesh", simple=False) # type: ignore
def _(value: trimesh.Trimesh, _: DType) -> bytes:
return Mesh.from_trimesh(value).encode().tolist()
def _(value: trimesh.Trimesh, _: dtypes.DType) -> bytes:
return media.Mesh.from_trimesh(value).encode().tolist()


@convert("Embedding", simple=True)
@convert("Sequence1D", simple=True)
def _(_: Union[np.ndarray, list, str, np.str_], _dtype: DType) -> str:
def _(_: Union[np.ndarray, list, str, np.str_], _dtype: dtypes.DType) -> str:
return "[...]"


@convert("Image", simple=True)
def _(_: np.ndarray, _dtype: DType) -> str:
def _(_: np.ndarray, _dtype: dtypes.DType) -> str:
return "[...]"


@convert("Image", simple=True)
def _(_: PIL.Image.Image, _dtype: DType) -> str:
def _(_: PIL.Image.Image, _dtype: dtypes.DType) -> str:
return "<PIL.Image>"


@convert("Image", simple=True)
@convert("Audio", simple=True)
@convert("Video", simple=True)
@convert("Mesh", simple=True)
def _(value: Union[str, np.str_], _: DType) -> str:
def _(value: Union[str, np.str_], _: dtypes.DType) -> str:
return str(value)


@convert("Image", simple=True)
@convert("Audio", simple=True)
@convert("Video", simple=True)
@convert("Mesh", simple=True)
def _(_: Union[bytes, np.bytes_], _dtype: DType) -> str:
def _(_: Union[bytes, np.bytes_], _dtype: dtypes.DType) -> str:
return "<bytes>"


# this should not be necessary
@convert("Mesh", simple=True) # type: ignore
def _(_: trimesh.Trimesh, _dtype: DType) -> str:
def _(_: trimesh.Trimesh, _dtype: dtypes.DType) -> str:
return "<Trimesh>"


def read_external_value(
path_or_url: Optional[str],
dtype: DType,
dtype: dtypes.DType,
target_format: Optional[str] = None,
workdir: PathType = ".",
) -> Optional[np.void]:
Expand Down Expand Up @@ -498,7 +492,7 @@ def prepare_path_or_url(path_or_url: PathOrUrlType, workdir: PathType) -> str:

def _decode_external_value(
path_or_url: PathOrUrlType,
dtype: DType,
dtype: dtypes.DType,
target_format: Optional[str] = None,
workdir: PathType = ".",
) -> np.void:
Expand Down Expand Up @@ -528,7 +522,7 @@ def _decode_external_value(
# Convert all other formats/codecs to flac.
output_format, output_codec = "flac", "flac"
else:
output_format, output_codec = Audio.get_format_codec(target_format)
output_format, output_codec = media.Audio.get_format_codec(target_format)
if output_format == input_format and output_codec == input_codec:
# Nothing to transcode
if isinstance(file, str):
Expand All @@ -554,10 +548,10 @@ def _decode_external_value(
):
return np.void(file.read())
# `image/tiff`s become blank in frontend, so convert them too.
return Image.from_file(file).encode(target_format)
return media.Image.from_file(file).encode(target_format)

if dtypes.is_mesh_dtype(dtype):
return Mesh.from_file(path_or_url).encode(target_format)
return media.Mesh.from_file(path_or_url).encode(target_format)
if dtypes.is_video_dtype(dtype):
return Video.from_file(path_or_url).encode(target_format)
return media.Video.from_file(path_or_url).encode(target_format)
assert False

0 comments on commit 723092a

Please sign in to comment.