Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add intermediate dtypes to data sources #224

Merged
merged 4 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions renumics/spotlight/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ def update(self, config: AppConfig) -> None:
self._broadcast(RefreshMessage())
self._update_issues()

for plugin in load_plugins():
plugin.update(self, config)

if not self._startup_complete:
self._startup_complete = True
self._connection.send({"kind": "startup_complete"})
Expand Down
7 changes: 7 additions & 0 deletions renumics/spotlight/data_source/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def column_names(self) -> List[str]:
Dataset's available column names.
"""

@property
@abstractmethod
def intermediate_dtypes(self) -> DTypeMap:
"""
The dtypes of intermediate values
"""

@property
def df(self) -> Optional[pd.DataFrame]:
"""
Expand Down
57 changes: 33 additions & 24 deletions renumics/spotlight/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,14 @@


class DataStore:
_dtypes: DTypeMap
_data_source: DataSource
_user_dtypes: DTypeMap
_dtypes: DTypeMap

def __init__(self, data_source: DataSource, user_dtypes: DTypeMap) -> None:
self._data_source = data_source
guessed_dtypes = self._data_source.guess_dtypes()
dtypes = {
**guessed_dtypes,
**{
column_name: dtype
for column_name, dtype in user_dtypes.items()
if column_name in guessed_dtypes
},
}
for column_name, dtype in dtypes.items():
if (
is_category_dtype(dtype)
and dtype.categories is None
and is_str_dtype(guessed_dtypes[column_name])
):
normalized_values = self._data_source.get_column_values(column_name)
converted_values = [
convert_to_dtype(value, str_dtype, simple=True, check=True)
for value in normalized_values
]
category_names = sorted(cast(Set[str], set(converted_values)))
dtypes[column_name] = CategoryDType(category_names)
self._dtypes = dtypes
self._user_dtypes = user_dtypes
self._update_dtypes()

def __len__(self) -> int:
return len(self._data_source)
Expand All @@ -67,6 +47,10 @@ def generation_id(self) -> int:
def column_names(self) -> List[str]:
return self._data_source.column_names

@property
def data_source(self) -> DataSource:
return self._data_source

@property
def dtypes(self) -> DTypeMap:
return self._dtypes
Expand Down Expand Up @@ -116,3 +100,28 @@ def get_waveform(self, column_name: str, index: int) -> Optional[np.ndarray]:
waveform = audio.get_waveform(io.BytesIO(blob)) # type: ignore
external_data_cache[cache_key] = waveform
return waveform

def _update_dtypes(self) -> None:
guessed_dtypes = self._data_source.guess_dtypes()
dtypes = {
**guessed_dtypes,
**{
column_name: dtype
for column_name, dtype in self._user_dtypes.items()
if column_name in guessed_dtypes
},
}
for column_name, dtype in dtypes.items():
if (
is_category_dtype(dtype)
and dtype.categories is None
and is_str_dtype(guessed_dtypes[column_name])
):
normalized_values = self._data_source.get_column_values(column_name)
converted_values = [
convert_to_dtype(value, str_dtype, simple=True, check=True)
for value in normalized_values
]
category_names = sorted(cast(Set[str], set(converted_values)))
dtypes[column_name] = CategoryDType(category_names)
self._dtypes = dtypes
2 changes: 1 addition & 1 deletion renumics/spotlight/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def open(self, mode: Optional[str] = None) -> None:
self.close()
self._mode = mode
if self._closed:
self._h5_file = h5py.File(self._filepath, self._mode)
self._h5_file = h5py.File(self._filepath, self._mode, locking=False)
self._closed = False
self._column_names, self._length = self._get_column_names_and_length()
if self._is_writable():
Expand Down
15 changes: 14 additions & 1 deletion renumics/spotlight/dtypes/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import numpy as np
import trimesh
import PIL.Image
import validators

from renumics.spotlight.typing import PathOrUrlType, PathType
Expand Down Expand Up @@ -357,6 +358,13 @@ def _(value: np.ndarray, _: DType) -> bytes:
return Image(value).encode().tolist()


@convert("Image", simple=False)
def _(value: PIL.Image.Image, _: DType) -> bytes:
buffer = io.BytesIO()
value.save(buffer, format="PNG")
return buffer.getvalue()


@convert("Audio", simple=False)
def _(value: Union[str, np.str_], _: DType) -> bytes:
try:
Expand Down Expand Up @@ -419,6 +427,11 @@ def _(_: np.ndarray, _dtype: DType) -> str:
return "[...]"


@convert("Image", simple=True)
def _(_: PIL.Image.Image, _dtype: DType) -> str:
return "<PIL.Image>"


@convert("Image", simple=True)
@convert("Audio", simple=True)
@convert("Video", simple=True)
Expand All @@ -438,7 +451,7 @@ def _(_: Union[bytes, np.bytes_], _dtype: DType) -> str:
# this should not be necessary
@convert("Mesh", simple=True) # type: ignore
def _(_: trimesh.Trimesh, _dtype: DType) -> str:
return "<object>"
return "<Trimesh>"


def read_external_value(
Expand Down
18 changes: 8 additions & 10 deletions renumics/spotlight/plugin_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from dataclasses import dataclass
from types import ModuleType
from pathlib import Path
from typing import Callable, List, Optional
from typing import Any, Callable, List, Optional
from fastapi import FastAPI

from renumics.spotlight.settings import settings
from renumics.spotlight.develop.project import get_project_info
from renumics.spotlight.io.path import is_path_relative_to
from renumics.spotlight.app_config import AppConfig

import renumics.spotlight_plugins as plugins_namespace

Expand All @@ -28,6 +29,7 @@ class Plugin:
module: ModuleType
init: Callable[[], None]
activate: Callable[[FastAPI], None]
update: Callable[[FastAPI, AppConfig], None]
dev: bool
frontend_entrypoint: Optional[Path]

Expand All @@ -46,14 +48,9 @@ def load_plugins() -> List[Plugin]:
if _plugins is not None:
return _plugins

def noinit() -> None:
def noop(*_args: Any, **_kwargs: Any) -> None:
"""
noop impl for __init__
"""

def noactivate(_: FastAPI) -> None:
"""
noop impl for __activate__
noop impl for plugin hooks
"""

plugins = {}
Expand All @@ -73,8 +70,9 @@ def noactivate(_: FastAPI) -> None:
plugins[name] = Plugin(
name=name,
priority=getattr(module, "__priority__", 1000),
init=getattr(module, "__register__", noinit),
activate=getattr(module, "__activate__", noactivate),
init=getattr(module, "__register__", noop),
activate=getattr(module, "__activate__", noop),
update=getattr(module, "__update__", noop),
module=module,
dev=dev,
frontend_entrypoint=main_js if main_js.exists() else None,
Expand Down
5 changes: 2 additions & 3 deletions renumics/spotlight_plugins/core/api/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from renumics.spotlight.app_config import AppConfig
from renumics.spotlight.io.path import is_path_relative_to
from renumics.spotlight.reporting import emit_timed_event

from renumics.spotlight.dtypes import is_category_dtype
from renumics.spotlight import dtypes


class Column(BaseModel):
Expand Down Expand Up @@ -81,7 +80,7 @@ def get_table(request: Request) -> ORJSONResponse:
editable=meta.editable,
optional=meta.nullable,
role=dtype.name,
categories=dtype.categories if is_category_dtype(dtype) else None,
categories=dtype.categories if dtypes.is_category_dtype(dtype) else None,
description=meta.description,
tags=meta.tags,
)
Expand Down
33 changes: 5 additions & 28 deletions renumics/spotlight_plugins/core/hdf5_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import h5py
import numpy as np

from renumics.spotlight.typing import IndexType
from renumics.spotlight.dataset import Dataset, INTERNAL_COLUMN_NAMES
from renumics.spotlight.dataset import Dataset

from renumics.spotlight.data_source import DataSource, datasource
from renumics.spotlight.backend.exceptions import (
Expand Down Expand Up @@ -63,32 +62,6 @@ def read_column(
return normalized_values
return raw_values

def duplicate_row(self, from_index: IndexType, to_index: IndexType) -> None:
"""
Duplicate a dataset's row. Increases the dataset's length by 1.
"""
self._assert_is_writable()
self._assert_index_exists(from_index)
length = self._length
if from_index < 0:
from_index += length
if to_index < 0:
to_index += length
if to_index != length:
self._assert_index_exists(to_index)
for column_name in self.keys() + INTERNAL_COLUMN_NAMES:
column = cast(h5py.Dataset, self._h5_file[column_name])
column.resize(length + 1, axis=0)
if to_index != length:
# Shift all values after the insertion position by one.
raw_values = column[int(to_index) : -1]
if is_embedding_dtype(self._get_dtype(column)):
raw_values = list(raw_values)
column[int(to_index) + 1 :] = raw_values
column[int(to_index)] = column[from_index]
self._length += 1
self._update_generation_id()

def _resolve_refs(self, refs: np.ndarray, column_name: str) -> np.ndarray:
raw_values = np.empty(len(refs), dtype=object)
raw_values[:] = [
Expand Down Expand Up @@ -132,6 +105,10 @@ def __del__(self) -> None:
def column_names(self) -> List[str]:
return self._table.keys()

@property
def intermediate_dtypes(self) -> DTypeMap:
return self.guess_dtypes()

def __len__(self) -> int:
return len(self._table)

Expand Down
32 changes: 32 additions & 0 deletions renumics/spotlight_plugins/core/pandas_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pandas as pd
import datasets
from renumics.spotlight import dtypes

from renumics.spotlight.io.pandas import (
infer_dtype,
Expand Down Expand Up @@ -102,6 +103,11 @@ def __init__(self, source: Union[Path, pd.DataFrame]):
self._generation_id = 0
self._uid = str(id(df))
self._df = df.convert_dtypes()
self._intermediate_dtypes = {
# TODO: convert column name
col: _determine_intermediate_dtype(self._df[col])
for col in self._df.columns
}

@property
def column_names(self) -> List[str]:
Expand All @@ -114,6 +120,10 @@ def df(self) -> pd.DataFrame:
"""
return self._df.copy()

@property
def intermediate_dtypes(self) -> DTypeMap:
return self._intermediate_dtypes

def __len__(self) -> int:
return len(self._df)

Expand Down Expand Up @@ -197,3 +207,25 @@ def _parse_column_index(self, column_name: str) -> Any:
f"Column '{column_name}' doesn't exist in the dataset."
) from e
return self._df.columns[index]


def _determine_intermediate_dtype(column: pd.Series) -> dtypes.DType:
if pd.api.types.is_bool_dtype(column):
return dtypes.bool_dtype
if pd.api.types.is_categorical_dtype(column):
return dtypes.CategoryDType(
{
category: code
for code, category in zip(column.cat.codes, column.cat.categories)
}
)
if pd.api.types.is_integer_dtype(column):
return dtypes.int_dtype
if pd.api.types.is_float_dtype(column):
return dtypes.float_dtype
if pd.api.types.is_datetime64_any_dtype(column):
return dtypes.datetime_dtype
if pd.api.types.is_string_dtype(column):
return dtypes.str_dtype
else:
return dtypes.mixed_dtype