diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index d6628c4f..e71d164e 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -11,10 +11,10 @@ Technical details on how to contribute can be found in our [documentation](https
There are several ways you can contribute to Spotlight:
-* Fix outstanding issues.
-* Implement new features.
-* Submit issues related to bugs or desired new features.
-* Share your use case
+- Fix outstanding issues.
+- Implement new features.
+- Submit issues related to bugs or desired new features.
+- Share your use case
If you don't know where to start, you might want to have a look at [hacktoberfest issues](https://github.com/Renumics/spotlight/issues?q=is%3Aissue+is%3Aopen+label%3Ahacktoberfest)
and our guide on how to create a [new Lens](https://renumics.com/docs/development/lenses).
diff --git a/README.md b/README.md
index 6c68d669..320f9e13 100644
--- a/README.md
+++ b/README.md
@@ -17,9 +17,10 @@
-Spotlight helps you to **understand unstructured datasets** fast. You can quickly create **interactive visualizations** and leverage data enrichments (e.g. embeddings, prediction, uncertainties) to **identify critical clusters** in your data.
+Spotlight helps you to **understand unstructured datasets** fast. You can quickly create **interactive visualizations** and leverage data enrichments (e.g. embeddings, prediction, uncertainties) to **identify critical clusters** in your data.
Spotlight supports most unstructured data types including **images, audio, text, videos, time-series and geometric data**. You can start from your existing dataframe:
+
And start Spotlight with just a few lines of code:
@@ -49,7 +50,7 @@ Machine learning and engineering teams use Spotlight to understand and communica
[Classification] |
Find Issues in Any Image Classification Dataset |
đ¨âđģ đ đšī¸ |
-
+
Find data issues in the CIFAR-100 image dataset |
đšī¸ |
@@ -91,7 +92,6 @@ Machine learning and engineering teams use Spotlight to understand and communica
-
## âąī¸ Quickstart
Get started by installing Spotlight and loading your first dataset.
@@ -132,12 +132,11 @@ ds = datasets.load_dataset('renumics/emodb-enriched', split='all')
layout= spotlight.layouts.debug_classification(label='gender', prediction='m1_gender_prediction', embedding='m1_embedding', features=['age', 'emotion'])
spotlight.show(ds, layout=layout)
```
+
Here, the data types are discovered automatically from the dataset and we use a pre-defined layout for model debugging. Custom layouts can be built programmatically or via the UI.
> The `datasets[audio]` package can be installed via pip.
-
-
#### Usage Tracking
We have added crash report and performance collection. We do NOT collect user data other than an anonymized Machine Id obtained by py-machineid, and only log our own actions. We do NOT collect folder names, dataset names, or row data of any kind only aggregate performance statistics like total time of a table_load, crash data, etc. Collecting Spotlight crashes will help us improve stability. To opt out of the crash report collection define an environment variable called `SPOTLIGHT_OPT_OUT` and set it to true. e.G.`export SPOTLIGHT_OPT_OUT=true`
@@ -150,9 +149,9 @@ We have added crash report and performance collection. We do NOT collect user da
## Learn more about unstructured data workflows
-- đ¤ [Huggingface](https://huggingface.co/renumics) example spaces and datasets
-- đ [Playbook](https://renumics.com/docs/playbook/) for data-centric AI workflows
-- đ° [Sliceguard](https://github.com/Renumics/sliceguard) library for automatic slice detection
+- đ¤ [Huggingface](https://huggingface.co/renumics) example spaces and datasets
+- đ [Playbook](https://renumics.com/docs/playbook/) for data-centric AI workflows
+- đ° [Sliceguard](https://github.com/Renumics/sliceguard) library for automatic slice detection
## Contribute
diff --git a/renumics/spotlight/app.py b/renumics/spotlight/app.py
index 62220d2f..5baa724d 100644
--- a/renumics/spotlight/app.py
+++ b/renumics/spotlight/app.py
@@ -15,6 +15,7 @@
from typing_extensions import Annotated
from fastapi import Cookie, FastAPI, Request, status
+from fastapi.datastructures import Headers
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response
from fastapi.staticfiles import StaticFiles
@@ -56,10 +57,23 @@
from renumics.spotlight.dtypes import DTypeMap
-
CURRENT_LAYOUT_KEY = "layout.current"
+class UncachedStaticFiles(StaticFiles):
+ """
+ FastAPI StaticFiles but without caching
+ """
+
+ def is_not_modified(
+ self, response_headers: Headers, request_headers: Headers
+ ) -> bool:
+ """
+ Never Cache
+ """
+ return False
+
+
class IssuesUpdatedMessage(Message):
"""
Notify about updated issues.
@@ -75,7 +89,6 @@ class SpotlightApp(FastAPI):
"""
# lifecycle
- _startup_complete: bool
_loop: asyncio.AbstractEventLoop
# connection
@@ -106,7 +119,6 @@ class SpotlightApp(FastAPI):
def __init__(self) -> None:
super().__init__()
- self._startup_complete = False
self.task_manager = TaskManager()
self.websocket_manager = None
self.config = Config()
@@ -207,9 +219,13 @@ async def _(_: Request, problem: Problem) -> JSONResponse:
plugin.activate(self)
try:
+ # Mount frontend files as uncached,
+ # so that we always get the new frontend after updating spotlight.
+ # NOTE: we might not need this if we added a version hash
+ # to our built js files
self.mount(
"/static",
- StaticFiles(packages=["renumics.spotlight.backend"]),
+ UncachedStaticFiles(packages=["renumics.spotlight.backend"]),
name="assets",
)
except AssertionError:
@@ -295,44 +311,45 @@ def update(self, config: AppConfig) -> None:
"""
Update application config.
"""
- if config.project_root is not None:
- self.project_root = config.project_root
- if config.dtypes is not None:
- self._user_dtypes = config.dtypes
- if config.analyze is not None:
- self.analyze_columns = config.analyze
- if config.custom_issues is not None:
- self.custom_issues = config.custom_issues
- if config.dataset is not None:
- self._dataset = config.dataset
- self._data_source = create_datasource(self._dataset)
- if config.layout is not None:
- self._layout = config.layout or layouts.default()
- if config.filebrowsing_allowed is not None:
- self.filebrowsing_allowed = config.filebrowsing_allowed
-
- if config.dtypes is not None or config.dataset is not None:
- data_source = self._data_source
- assert data_source is not None
- self._data_store = DataStore(data_source, self._user_dtypes)
- self._broadcast(RefreshMessage())
- self._update_issues()
- if config.layout is not None:
- if self._data_store is not None:
- dataset_uid = self._data_store.uid
- future = asyncio.run_coroutine_threadsafe(
- self.config.remove_all(CURRENT_LAYOUT_KEY, dataset=dataset_uid),
- self._loop,
- )
- future.result()
- self._broadcast(ResetLayoutMessage())
+ try:
+ if config.project_root is not None:
+ self.project_root = config.project_root
+ if config.dtypes is not None:
+ self._user_dtypes = config.dtypes
+ if config.analyze is not None:
+ self.analyze_columns = config.analyze
+ if config.custom_issues is not None:
+ self.custom_issues = config.custom_issues
+ if config.dataset is not None:
+ self._dataset = config.dataset
+ self._data_source = create_datasource(self._dataset)
+ if config.layout is not None:
+ self._layout = config.layout or layouts.default()
+ if config.filebrowsing_allowed is not None:
+ self.filebrowsing_allowed = config.filebrowsing_allowed
+
+ if config.dtypes is not None or config.dataset is not None:
+ data_source = self._data_source
+ assert data_source is not None
+ self._data_store = DataStore(data_source, self._user_dtypes)
+ self._broadcast(RefreshMessage())
+ self._update_issues()
+ if config.layout is not None:
+ if self._data_store is not None:
+ dataset_uid = self._data_store.uid
+ future = asyncio.run_coroutine_threadsafe(
+ self.config.remove_all(CURRENT_LAYOUT_KEY, dataset=dataset_uid),
+ self._loop,
+ )
+ future.result()
+ self._broadcast(ResetLayoutMessage())
- for plugin in load_plugins():
- plugin.update(self, config)
+ for plugin in load_plugins():
+ plugin.update(self, config)
+ except Exception as e:
+ self._connection.send({"kind": "update_complete", "error": e})
- if not self._startup_complete:
- self._startup_complete = True
- self._connection.send({"kind": "startup_complete"})
+ self._connection.send({"kind": "update_complete"})
def _handle_message(self, message: Any) -> None:
kind = message.get("kind")
diff --git a/renumics/spotlight/backend/tasks/reduction.py b/renumics/spotlight/backend/tasks/reduction.py
index 00207c38..9dae9021 100644
--- a/renumics/spotlight/backend/tasks/reduction.py
+++ b/renumics/spotlight/backend/tasks/reduction.py
@@ -6,11 +6,9 @@
import numpy as np
import pandas as pd
-from sklearn import preprocessing
-from renumics.spotlight.dataset.exceptions import ColumnNotExistsError
from renumics.spotlight.data_store import DataStore
-from renumics.spotlight.dtypes import is_category_dtype, is_embedding_dtype
+from renumics.spotlight import dtypes
SEED = 42
@@ -27,6 +25,7 @@ def align_data(
"""
Align data from table's columns, remove `NaN`'s.
"""
+ from sklearn import preprocessing
if not column_names or not indices:
return np.empty(0, np.float64), []
@@ -35,7 +34,7 @@ def align_data(
for column_name in column_names:
dtype = data_store.dtypes[column_name]
column_values = data_store.get_converted_values(column_name, indices)
- if is_embedding_dtype(dtype):
+ if dtypes.is_embedding_dtype(dtype):
embedding_length = max(
0 if x is None else len(cast(np.ndarray, x)) for x in column_values
)
@@ -49,17 +48,19 @@ def align_data(
]
)
)
- elif is_category_dtype(dtype):
+ elif dtypes.is_category_dtype(dtype):
na_mask = np.array(column_values) == -1
one_hot_values = preprocessing.label_binarize(
column_values, classes=sorted(set(column_values).difference({-1})) # type: ignore
).astype(float)
one_hot_values[na_mask] = np.nan
aligned_values.append(one_hot_values)
- elif dtype in (int, bool, float):
+ elif dtypes.is_scalar_dtype(dtype):
aligned_values.append(np.array(column_values, dtype=float))
else:
- raise ColumnNotEmbeddable
+ raise ColumnNotEmbeddable(
+ f"Column '{column_name}' of type {dtype} is not embeddable."
+ )
data = np.hstack([col.reshape((len(indices), -1)) for col in aligned_values])
mask = ~pd.isna(data).any(axis=1)
@@ -78,10 +79,8 @@ def compute_umap(
Prepare data from table and compute U-Map on them.
"""
- try:
- data, indices = align_data(data_store, column_names, indices)
- except (ColumnNotExistsError, ColumnNotEmbeddable):
- return np.empty(0, np.float64), []
+ data, indices = align_data(data_store, column_names, indices)
+
if data.size == 0:
return np.empty(0, np.float64), []
@@ -114,14 +113,13 @@ def compute_pca(
Prepare data from table and compute PCA on them.
"""
- from sklearn import preprocessing, decomposition
+ data, indices = align_data(data_store, column_names, indices)
- try:
- data, indices = align_data(data_store, column_names, indices)
- except (ColumnNotExistsError, ValueError):
- return np.empty(0, np.float64), []
if data.size == 0:
return np.empty(0, np.float64), []
+
+ from sklearn import preprocessing, decomposition
+
if data.shape[1] == 1:
return np.hstack((data, np.zeros_like(data))), indices
if normalization == "standardize":
@@ -129,5 +127,6 @@ def compute_pca(
elif normalization == "robust standardize":
data = preprocessing.RobustScaler(copy=False).fit_transform(data)
reducer = decomposition.PCA(n_components=2, copy=False, random_state=SEED)
- embeddings = reducer.fit_transform(data)
+ # `fit_transform` returns Fortran-ordered array.
+ embeddings = np.ascontiguousarray(reducer.fit_transform(data))
return embeddings, indices
diff --git a/renumics/spotlight/backend/tasks/task_manager.py b/renumics/spotlight/backend/tasks/task_manager.py
index e120fa88..8c392a55 100644
--- a/renumics/spotlight/backend/tasks/task_manager.py
+++ b/renumics/spotlight/backend/tasks/task_manager.py
@@ -6,7 +6,7 @@
import multiprocessing
from concurrent.futures import Future, ProcessPoolExecutor
from concurrent.futures.process import BrokenProcessPool
-from typing import Any, Callable, List, Optional, Sequence, TypeVar, Union
+from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union
from .exceptions import TaskCancelled
from .task import Task
@@ -30,16 +30,20 @@ def create_task(
self,
func: Callable,
args: Sequence[Any],
+ kwargs: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
tag: Optional[Union[str, int]] = None,
) -> Task:
"""
create and launch a new task
"""
+ if kwargs is None:
+ kwargs = {}
+
# cancel running task with same name
self.cancel(name=name)
- future = self.pool.submit(func, *args)
+ future = self.pool.submit(func, *args, **kwargs)
task = Task(name, tag, future)
self.tasks.append(task)
@@ -59,6 +63,7 @@ async def run_async(
self,
func: Callable[..., T],
args: Sequence[Any],
+ kwargs: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
tag: Optional[Union[str, int]] = None,
) -> T:
@@ -66,7 +71,7 @@ async def run_async(
Launch a new task. Await and return result.
"""
- task = self.create_task(func, args, name, tag)
+ task = self.create_task(func, args=args, kwargs=kwargs, name=name, tag=tag)
try:
return await asyncio.wrap_future(task.future)
except BrokenProcessPool as e:
diff --git a/renumics/spotlight/backend/websockets.py b/renumics/spotlight/backend/websockets.py
index 8be62f4a..aed9a1ef 100644
--- a/renumics/spotlight/backend/websockets.py
+++ b/renumics/spotlight/backend/websockets.py
@@ -3,11 +3,20 @@
"""
import asyncio
-import functools
-import orjson
-from typing import Any, List, Optional, Set, Callable
+from typing import (
+ Any,
+ Coroutine,
+ Dict,
+ Optional,
+ Set,
+ Callable,
+ Tuple,
+ Type,
+ cast,
+)
import numpy as np
+import orjson
from fastapi import WebSocket, WebSocketDisconnect
from loguru import logger
from pydantic import BaseModel
@@ -17,7 +26,7 @@
from .tasks import TaskManager, TaskCancelled
from .tasks.reduction import compute_umap, compute_pca
-from .exceptions import GenerationIDMismatch
+from .exceptions import GenerationIDMismatch, Problem
class Message(BaseModel):
@@ -47,115 +56,46 @@ class ResetLayoutMessage(Message):
data: Any = None
-class ReductionMessage(Message):
- """
- Common data reduction message model.
- """
-
+class TaskData(BaseModel):
+ task: str
widget_id: str
- uid: str
- generation_id: int
-
-
-class ReductionRequestData(BaseModel):
- """
- Base data reduction request payload.
- """
-
- indices: List[int]
- columns: List[str]
-
-
-class UMapRequestData(ReductionRequestData):
- """
- U-Map request payload.
- """
-
- n_neighbors: int
- metric: str
- min_dist: float
-
-
-class PCARequestData(ReductionRequestData):
- """
- PCA request payload.
- """
-
- normalization: str
+ task_id: str
+ generation_id: Optional[int]
+ args: Any
-class UMapRequest(ReductionMessage):
- """
- U-Map request model.
- """
-
- data: UMapRequestData
-
-
-class PCARequest(ReductionMessage):
+class UnknownMessageType(Exception):
"""
- PCA request model.
+ Websocket message type is unknown.
"""
- data: PCARequestData
-
-class ReductionResponseData(BaseModel):
+class SerializationError(Exception):
"""
- Data reduction response payload.
+ Failed to serialize the WS message
"""
- indices: List[int]
- points: np.ndarray
- class Config:
- arbitrary_types_allowed = True
+PayloadType = Type[BaseModel]
+MessageHandler = Callable[[Any, "WebsocketConnection"], Coroutine[Any, Any, Any]]
+MessageHandlerSpec = Tuple[PayloadType, MessageHandler]
+MESSAGE_HANDLERS: Dict[str, MessageHandlerSpec] = {}
-class ReductionResponse(ReductionMessage):
- """
- Data reduction response model.
- """
+def register_message_handler(
+ message_type: str, handler_spec: MessageHandlerSpec
+) -> None:
+ MESSAGE_HANDLERS[message_type] = handler_spec
- data: ReductionResponseData
+def message_handler(
+ message_type: str, payload_type: Type[BaseModel]
+) -> Callable[[MessageHandler], MessageHandler]:
+ def decorator(handler: MessageHandler) -> MessageHandler:
+ register_message_handler(message_type, (payload_type, handler))
+ return handler
-MESSAGE_BY_TYPE = {
- "umap": UMapRequest,
- "umap_result": ReductionResponse,
- "pca": PCARequest,
- "pca_result": ReductionResponse,
- "refresh": RefreshMessage,
-}
-
-
-class UnknownMessageType(Exception):
- """
- Websocket message type is unknown.
- """
-
-
-def parse_message(raw_message: str) -> Message:
- """
- Parse a websocket message from a raw text.
- """
- json_message = orjson.loads(raw_message)
- message_type = json_message["type"]
- message_class = MESSAGE_BY_TYPE.get(message_type)
- if message_class is None:
- raise UnknownMessageType(f"Message type {message_type} is unknown.")
- return message_class(**json_message)
-
-
-@functools.singledispatch
-async def handle_message(request: Message, connection: "WebsocketConnection") -> None:
- """
- Handle incoming messages.
-
- New message types should be registered by decorating with `@handle_message.register`.
- """
-
- raise NotImplementedError
+ return decorator
class WebsocketConnection:
@@ -174,11 +114,15 @@ async def send_async(self, message: Message) -> None:
"""
Send a message async.
"""
+
try:
- message_data = message.dict()
- await self.websocket.send_text(
- orjson.dumps(message_data, option=orjson.OPT_SERIALIZE_NUMPY).decode()
- )
+ json_text = orjson.dumps(
+ message.dict(), option=orjson.OPT_SERIALIZE_NUMPY
+ ).decode()
+ except TypeError as e:
+ raise SerializationError(str(e))
+ try:
+ await self.websocket.send_text(json_text)
except WebSocketDisconnect:
self._on_disconnect()
except RuntimeError:
@@ -204,12 +148,17 @@ async def listen(self) -> None:
try:
while True:
try:
- message = parse_message(await self.websocket.receive_text())
+ raw_message = await self.websocket.receive_text()
+ message = Message(**orjson.loads(raw_message))
except UnknownMessageType as e:
logger.warning(str(e))
else:
logger.info(f"WS message with type {message.type} received.")
- asyncio.create_task(handle_message(message, self))
+ if handler_spec := MESSAGE_HANDLERS.get(message.type):
+ payload_type, handler = handler_spec
+ asyncio.create_task(handler(payload_type(**message.data), self))
+ else:
+ logger.error(f"Unknown message received: {message.type}")
except WebSocketDisconnect:
self._on_disconnect()
@@ -300,73 +249,80 @@ def on_disconnect(self, connection: WebsocketConnection) -> None:
callback(len(self.connections))
-@handle_message.register
-async def _(request: UMapRequest, connection: "WebsocketConnection") -> None:
- data_store: Optional[DataStore] = connection.websocket.app.data_store
- if data_store is None:
- return None
- try:
- data_store.check_generation_id(request.generation_id)
- except GenerationIDMismatch:
- return
-
- try:
- points, valid_indices = await connection.task_manager.run_async(
- compute_umap,
- (
- data_store,
- request.data.columns,
- request.data.indices,
- request.data.n_neighbors,
- request.data.metric,
- request.data.min_dist,
- ),
- name=request.widget_id,
- tag=id(connection),
- )
- except TaskCancelled:
- ...
- else:
- response = ReductionResponse(
- type="umap_result",
- widget_id=request.widget_id,
- uid=request.uid,
- generation_id=request.generation_id,
- data=ReductionResponseData(indices=valid_indices, points=points),
- )
- await connection.send_async(response)
+TASK_FUNCS = {"umap": compute_umap, "pca": compute_pca}
-@handle_message.register
-async def _(request: PCARequest, connection: "WebsocketConnection") -> None:
+@message_handler("task", TaskData)
+async def _(data: TaskData, connection: WebsocketConnection) -> None:
data_store: Optional[DataStore] = connection.websocket.app.data_store
if data_store is None:
return None
- try:
- data_store.check_generation_id(request.generation_id)
- except GenerationIDMismatch:
- return
+
+ if data.generation_id:
+ try:
+ data_store.check_generation_id(data.generation_id)
+ except GenerationIDMismatch:
+ return
try:
- points, valid_indices = await connection.task_manager.run_async(
- compute_pca,
- (
- data_store,
- request.data.columns,
- request.data.indices,
- request.data.normalization,
- ),
- name=request.widget_id,
+ task_func = TASK_FUNCS[data.task]
+ result = await connection.task_manager.run_async(
+ task_func, # type: ignore
+ args=(data_store,),
+ kwargs=data.args,
+ name=data.widget_id,
tag=id(connection),
)
+ points = cast(np.ndarray, result[0])
+ valid_indices = cast(np.ndarray, result[1])
except TaskCancelled:
- ...
+ pass
+ except Problem as e:
+ msg = Message(
+ type="task.error",
+ data={
+ "task_id": data.task_id,
+ "error": {
+ "type": type(e).__name__,
+ "title": e.title,
+ "detail": e.detail,
+ },
+ },
+ )
+ await connection.send_async(msg)
+ except Exception as e:
+ msg = Message(
+ type="task.error",
+ data={
+ "task_id": data.task_id,
+ "error": {
+ "type": type(e).__name__,
+ "title": type(e).__name__,
+ "detail": type(e).__doc__,
+ },
+ },
+ )
+ await connection.send_async(msg)
else:
- response = ReductionResponse(
- type="pca_result",
- widget_id=request.widget_id,
- uid=request.uid,
- generation_id=request.generation_id,
- data=ReductionResponseData(indices=valid_indices, points=points),
+ msg = Message(
+ type="task.result",
+ data={
+ "task_id": data.task_id,
+ "result": {"points": points, "indices": valid_indices},
+ },
)
- await connection.send_async(response)
+ try:
+ await connection.send_async(msg)
+ except SerializationError as e:
+ error_msg = Message(
+ type="task.error",
+ data={
+ "task_id": data.task_id,
+ "error": {
+ "type": type(e).__name__,
+ "title": "Serialization Error",
+ "detail": str(e),
+ },
+ },
+ )
+ await connection.send_async(error_msg)
diff --git a/renumics/spotlight/data_source/data_source.py b/renumics/spotlight/data_source/data_source.py
index d06133f3..695acb5d 100644
--- a/renumics/spotlight/data_source/data_source.py
+++ b/renumics/spotlight/data_source/data_source.py
@@ -6,7 +6,6 @@
import pandas as pd
import numpy as np
-from pydantic.dataclasses import dataclass
from renumics.spotlight.dataset.exceptions import (
ColumnExistsError,
@@ -30,17 +29,6 @@ class ColumnMetadata:
tags: List[str] = dataclasses.field(default_factory=list)
-@dataclass
-class CellsUpdate:
- """
- A dataset's cell update.
- """
-
- value: Any
- author: str
- edited_at: str
-
-
class DataSource(ABC):
"""abstract base class for different data sources"""
@@ -61,7 +49,7 @@ def column_names(self) -> List[str]:
@abstractmethod
def intermediate_dtypes(self) -> DTypeMap:
"""
- The dtypes of intermediate values
+ The dtypes of intermediate values. Values for all columns must be filled.
"""
@property
@@ -94,7 +82,7 @@ def check_generation_id(self, generation_id: int) -> None:
@abstractmethod
def semantic_dtypes(self) -> DTypeMap:
"""
- Semantic dtypes for viewer.
+ Semantic dtypes for viewer. Some values may be not present.
"""
@abstractmethod
diff --git a/renumics/spotlight/data_store.py b/renumics/spotlight/data_store.py
index 2eb0108b..ba940254 100644
--- a/renumics/spotlight/data_store.py
+++ b/renumics/spotlight/data_store.py
@@ -21,13 +21,16 @@
DType,
DTypeMap,
EmbeddingDType,
+ array_dtype,
is_array_dtype,
is_audio_dtype,
is_category_dtype,
+ is_embedding_dtype,
is_file_dtype,
is_str_dtype,
is_mixed_dtype,
is_bytes_dtype,
+ is_window_dtype,
str_dtype,
audio_dtype,
image_dtype,
@@ -173,33 +176,32 @@ def _guess_dtype(self, col: str) -> DType:
return semantic_dtype
sample_values = self._data_source.get_column_values(col, slice(10))
- sample_dtypes = [_guess_value_dtype(value) for value in sample_values]
-
- try:
- mode_dtype = statistics.mode(sample_dtypes)
- except statistics.StatisticsError:
+ sample_dtypes: List[DType] = []
+ for value in sample_values:
+ guessed_dtype = _guess_value_dtype(value)
+ if guessed_dtype is not None:
+ sample_dtypes.append(guessed_dtype)
+ if not sample_dtypes:
return semantic_dtype
- return mode_dtype or semantic_dtype
+ mode_dtype = statistics.mode(sample_dtypes)
+ # For windows and embeddings, at least sample values must be aligned.
+ if is_window_dtype(mode_dtype) and any(
+ not is_window_dtype(dtype) for dtype in sample_dtypes
+ ):
+ return array_dtype
+ if is_embedding_dtype(mode_dtype) and any(
+ (not is_embedding_dtype(dtype)) or dtype.length != mode_dtype.length
+ for dtype in sample_dtypes
+ ):
+ return array_dtype
+
+ return mode_dtype
def _intermediate_to_semantic_dtype(intermediate_dtype: DType) -> DType:
if is_array_dtype(intermediate_dtype):
- if intermediate_dtype.shape is None:
- return intermediate_dtype
- if intermediate_dtype.shape == (2,):
- return window_dtype
- if intermediate_dtype.ndim == 1 and intermediate_dtype.shape[0] is not None:
- return EmbeddingDType(intermediate_dtype.shape[0])
- if intermediate_dtype.ndim == 1 and intermediate_dtype.shape[0] is None:
- return sequence_1d_dtype
- if intermediate_dtype.ndim == 2 and (
- intermediate_dtype.shape[0] == 2 or intermediate_dtype.shape[1] == 2
- ):
- return sequence_1d_dtype
- if intermediate_dtype.ndim == 3 and intermediate_dtype.shape[-1] in (1, 3, 4):
- return image_dtype
- return intermediate_dtype
+ return _guess_array_dtype(intermediate_dtype)
if is_file_dtype(intermediate_dtype):
return str_dtype
if is_mixed_dtype(intermediate_dtype):
@@ -262,5 +264,21 @@ def _guess_value_dtype(value: Any) -> Optional[DType]:
except (TypeError, ValueError):
pass
else:
- return ArrayDType(value.shape)
+ return _guess_array_dtype(ArrayDType(value.shape))
return None
+
+
+def _guess_array_dtype(dtype: ArrayDType) -> DType:
+ if dtype.shape is None:
+ return dtype
+ if dtype.shape == (2,):
+ return window_dtype
+ if dtype.ndim == 1 and dtype.shape[0] is not None:
+ return EmbeddingDType(dtype.shape[0])
+ if dtype.ndim == 1 and dtype.shape[0] is None:
+ return sequence_1d_dtype
+ if dtype.ndim == 2 and (dtype.shape[0] == 2 or dtype.shape[1] == 2):
+ return sequence_1d_dtype
+ if dtype.ndim == 3 and dtype.shape[-1] in (1, 3, 4):
+ return image_dtype
+ return dtype
diff --git a/renumics/spotlight/dataset/__init__.py b/renumics/spotlight/dataset/__init__.py
index 99828ddf..e586499d 100644
--- a/renumics/spotlight/dataset/__init__.py
+++ b/renumics/spotlight/dataset/__init__.py
@@ -32,12 +32,7 @@
from typing_extensions import TypeGuard
from renumics.spotlight.__version__ import __version__
-from renumics.spotlight.io.pandas import (
- infer_dtypes,
- prepare_column,
- is_string_mask,
- stringify_columns,
-)
+from .pandas import create_typed_series, infer_dtypes, is_string_mask, prepare_column
from renumics.spotlight.typing import (
BoolType,
IndexType,
@@ -47,7 +42,6 @@
is_integer,
is_iterable,
)
-from renumics.spotlight.io.pandas import create_typed_series
from renumics.spotlight.dtypes.conversion import prepare_path_or_url
from renumics.spotlight import dtypes as spotlight_dtypes
@@ -738,7 +732,7 @@ def from_pandas(
df = df.reset_index(level=df.index.names) # type: ignore
else:
df = df.copy()
- df.columns = pd.Index(stringify_columns(df))
+ df.columns = pd.Index([str(column) for column in df.columns])
if dtypes is None:
dtypes = {}
diff --git a/renumics/spotlight/dataset/descriptors/__init__.py b/renumics/spotlight/dataset/descriptors/__init__.py
index bb5cf837..3f52372d 100644
--- a/renumics/spotlight/dataset/descriptors/__init__.py
+++ b/renumics/spotlight/dataset/descriptors/__init__.py
@@ -1,5 +1,6 @@
"""make descriptor methods more available
"""
+import warnings
from typing import Optional, Tuple
import numpy as np
@@ -11,6 +12,13 @@
from renumics.spotlight.dataset.exceptions import ColumnExistsError, InvalidDTypeError
from .data_alignment import align_column_data
+warnings.warn(
+ "`renumics.spotlight.dataset.descriptors` module is deprecated and will "
+ "be removed in future versions.",
+ DeprecationWarning,
+ stacklevel=2,
+)
+
def pca(
dataset: Dataset,
diff --git a/renumics/spotlight/io/pandas.py b/renumics/spotlight/dataset/pandas.py
similarity index 86%
rename from renumics/spotlight/io/pandas.py
rename to renumics/spotlight/dataset/pandas.py
index 4cf84f9e..75ccf00f 100644
--- a/renumics/spotlight/io/pandas.py
+++ b/renumics/spotlight/dataset/pandas.py
@@ -1,30 +1,22 @@
"""
-This module contains helpers for importing `pandas.DataFrame`s.
+Helper for conversion between H5 dataset and `pandas.DataFrame`.
"""
-import ast
import os.path
import statistics
-from contextlib import suppress
-from typing import Any, Dict, List, Optional, Sequence, Union
+from typing import Any, Optional, Sequence, Union
import PIL.Image
import filetype
-import trimesh
import numpy as np
import pandas as pd
+import trimesh
-from renumics.spotlight.dtypes import (
- Audio,
- Embedding,
- Image,
- Mesh,
- Sequence1D,
- Video,
-)
-from renumics.spotlight.media.exceptions import UnsupportedDType
-from renumics.spotlight.typing import is_iterable, is_pathtype
from renumics.spotlight import dtypes
+from renumics.spotlight.io import prepare_hugging_face_dict, try_literal_eval
+from renumics.spotlight.media import Audio, Embedding, Image, Mesh, Sequence1D, Video
+from renumics.spotlight.typing import is_iterable, is_pathtype
+from .exceptions import InvalidDTypeError
def create_typed_series(
@@ -58,32 +50,62 @@ def create_typed_series(
return pd.Series([] if values is None else values, dtype=pandas_dtype)
-def is_empty(value: Any) -> bool:
- """
- Check if value is `NA` or an empty string.
+def prepare_column(column: pd.Series, dtype: dtypes.DType) -> pd.Series:
"""
- if is_iterable(value):
- # `pd.isna` with an iterable argument returns an iterable result. But
- # an iterable cannot be NA or empty string by default.
- return False
- return pd.isna(value) or value == ""
+ Convert a `pandas` column to the desired `dtype` and prepare some values,
+ but still as `pandas` column.
+
+ Args:
+ column: A `pandas` column to prepare.
+ dtype: Target data type.
+ Returns:
+ Prepared `pandas` column.
-def try_literal_eval(x: str) -> Any:
- """
- Try to evaluate a literal expression, otherwise return value as is.
+ Raises:
+ TypeError: If `dtype` is not a Spotlight data type.
"""
- with suppress(Exception):
- return ast.literal_eval(x)
- return x
+ column = column.copy()
+ if dtypes.is_category_dtype(dtype):
+ # We only support string/`NA` categories, but `pandas` can more, so
+ # force categories to be strings (does not affect `NA`s).
+ return to_categorical(column, str_categories=True)
-def stringify_columns(df: pd.DataFrame) -> List[str]:
- """
- Convert `pandas.DataFrame`'s column names to strings, no matter which index
- is used.
- """
- return [str(column_name) for column_name in df.columns]
+ if dtypes.is_datetime_dtype(dtype):
+ # `errors="coerce"` will produce `NaT`s instead of fail.
+ return pd.to_datetime(column, errors="coerce")
+
+ if dtypes.is_str_dtype(dtype):
+ # Allow `NA`s, convert all other elements to strings.
+ return column.astype(str).mask(column.isna(), None) # type: ignore
+
+ if dtypes.is_bool_dtype(dtype):
+ return column.astype(bool)
+
+ if dtypes.is_int_dtype(dtype):
+ return column.astype(int)
+
+ if dtypes.is_float_dtype(dtype):
+ return column.astype(float)
+
+ # We explicitely don't want to change the original `DataFrame`.
+ with pd.option_context("mode.chained_assignment", None):
+ # We consider empty strings as `NA`s.
+ str_mask = is_string_mask(column)
+ column[str_mask] = column[str_mask].replace("", None)
+ na_mask = column.isna()
+
+ # When `pandas` reads a csv, arrays and lists are read as literal strings,
+ # try to interpret them.
+ str_mask = is_string_mask(column)
+ column[str_mask] = column[str_mask].apply(try_literal_eval)
+
+ if dtypes.is_filebased_dtype(dtype):
+ dict_mask = column.map(type) == dict
+ column[dict_mask] = column[dict_mask].apply(prepare_hugging_face_dict)
+
+ return column.mask(na_mask, None) # type: ignore
def infer_dtype(column: pd.Series) -> dtypes.DType:
@@ -225,7 +247,7 @@ def infer_dtypes(df: pd.DataFrame, dtype: Optional[dtypes.DTypeMap]) -> dtypes.D
if column_index not in inferred_dtype:
try:
column_type = infer_dtype(df[column_index])
- except UnsupportedDType:
+ except InvalidDTypeError:
column_type = dtypes.str_dtype
inferred_dtype[str(column_index)] = column_type
return inferred_dtype
@@ -255,73 +277,3 @@ def to_categorical(column: pd.Series, str_categories: bool = False) -> pd.Series
if str_categories:
return column.cat.rename_categories(column.cat.categories.astype(str))
return column
-
-
-def prepare_hugging_face_dict(x: Dict) -> Any:
- """
- Prepare HuggingFace format for files to be used in Spotlight.
- """
- if x.keys() != {"bytes", "path"}:
- return x
- blob = x["bytes"]
- if blob is not None:
- return blob
- return x["path"]
-
-
-def prepare_column(column: pd.Series, dtype: dtypes.DType) -> pd.Series:
- """
- Convert a `pandas` column to the desired `dtype` and prepare some values,
- but still as `pandas` column.
-
- Args:
- column: A `pandas` column to prepare.
- dtype: Target data type.
-
- Returns:
- Prepared `pandas` column.
-
- Raises:
- TypeError: If `dtype` is not a Spotlight data type.
- """
- column = column.copy()
-
- if dtypes.is_category_dtype(dtype):
- # We only support string/`NA` categories, but `pandas` can more, so
- # force categories to be strings (does not affect `NA`s).
- return to_categorical(column, str_categories=True)
-
- if dtypes.is_datetime_dtype(dtype):
- # `errors="coerce"` will produce `NaT`s instead of fail.
- return pd.to_datetime(column, errors="coerce")
-
- if dtypes.is_str_dtype(dtype):
- # Allow `NA`s, convert all other elements to strings.
- return column.astype(str).mask(column.isna(), None) # type: ignore
-
- if dtypes.is_bool_dtype(dtype):
- return column.astype(bool)
-
- if dtypes.is_int_dtype(dtype):
- return column.astype(int)
-
- if dtypes.is_float_dtype(dtype):
- return column.astype(float)
-
- # We explicitely don't want to change the original `DataFrame`.
- with pd.option_context("mode.chained_assignment", None):
- # We consider empty strings as `NA`s.
- str_mask = is_string_mask(column)
- column[str_mask] = column[str_mask].replace("", None)
- na_mask = column.isna()
-
- # When `pandas` reads a csv, arrays and lists are read as literal strings,
- # try to interpret them.
- str_mask = is_string_mask(column)
- column[str_mask] = column[str_mask].apply(try_literal_eval)
-
- if dtypes.is_filebased_dtype(dtype):
- dict_mask = column.map(type) == dict
- column[dict_mask] = column[dict_mask].apply(prepare_hugging_face_dict)
-
- return column.mask(na_mask, None) # type: ignore
diff --git a/renumics/spotlight/dtypes/__init__.py b/renumics/spotlight/dtypes/__init__.py
index 0e24ea10..63910215 100644
--- a/renumics/spotlight/dtypes/__init__.py
+++ b/renumics/spotlight/dtypes/__init__.py
@@ -9,6 +9,8 @@
__all__ = [
"CategoryDType",
+ "ArrayDType",
+ "EmbeddingDType",
"Sequence1DDType",
"bool_dtype",
"int_dtype",
@@ -36,6 +38,14 @@ def __init__(self, name: str):
def __str__(self) -> str:
return self.name
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, DType):
+ return other._name == self._name
+ return False
+
+ def __hash__(self) -> int:
+ return hash(self._name)
+
@property
def name(self) -> str:
return self._name
@@ -53,8 +63,10 @@ def __init__(
self, categories: Optional[Union[Iterable[str], Dict[str, int]]] = None
):
super().__init__("Category")
- if isinstance(categories, dict) or categories is None:
- self._categories = categories
+ if isinstance(categories, dict):
+ self._categories = dict(sorted(categories.items(), key=lambda x: x[1]))
+ elif categories is None:
+ self._categories = None
else:
self._categories = {
category: code for code, category in enumerate(categories)
@@ -71,6 +83,20 @@ def __init__(
category: code for code, category in self._inverted_categories.items()
}
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, CategoryDType):
+ return other._categories == self._categories
+ return False
+
+ def __hash__(self) -> int:
+ if self._categories is None:
+ return hash(self._name) ^ hash(None)
+ return (
+ hash(self._name)
+ ^ hash(tuple(self._categories.keys()))
+ ^ hash(tuple(self._categories.values()))
+ )
+
@property
def categories(self) -> Optional[Dict[str, int]]:
return self._categories
@@ -91,6 +117,14 @@ def __init__(self, shape: Optional[Tuple[Optional[int], ...]] = None):
super().__init__("array")
self.shape = shape
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, ArrayDType):
+ return other.shape == self.shape
+ return False
+
+ def __hash__(self) -> int:
+ return hash(self._name) ^ hash(self.shape)
+
@property
def ndim(self) -> int:
if self.shape is None:
@@ -111,6 +145,14 @@ def __init__(self, length: Optional[int] = None):
raise ValueError(f"Length must be non-negative, but {length} received.")
self.length = length
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, EmbeddingDType):
+ return other.length == self.length
+ return False
+
+ def __hash__(self) -> int:
+ return hash(self._name) ^ hash(self.length)
+
class Sequence1DDType(DType):
"""
@@ -125,6 +167,14 @@ def __init__(self, x_label: str = "x", y_label: str = "y"):
self.x_label = x_label
self.y_label = y_label
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, Sequence1DDType):
+ return other.x_label == self.x_label and other.y_label == self.y_label
+ return False
+
+ def __hash__(self) -> int:
+ return hash(self._name) ^ hash(self.x_label) ^ hash(self.y_label)
+
ALIASES: Dict[Any, DType] = {}
diff --git a/renumics/spotlight/dtypes/conversion.py b/renumics/spotlight/dtypes/conversion.py
index 48c90dde..c5b7cc12 100644
--- a/renumics/spotlight/dtypes/conversion.py
+++ b/renumics/spotlight/dtypes/conversion.py
@@ -37,23 +37,14 @@
import trimesh
import PIL.Image
import validators
-from renumics.spotlight import dtypes
+from renumics.spotlight import dtypes, media
from renumics.spotlight.typing import PathOrUrlType, PathType
from renumics.spotlight.cache import external_data_cache
from renumics.spotlight.io import audio
from renumics.spotlight.io.file import as_file
from renumics.spotlight.media.exceptions import InvalidFile
from renumics.spotlight.backend.exceptions import Problem
-from renumics.spotlight.dtypes import (
- CategoryDType,
- DType,
- audio_dtype,
- image_dtype,
- mesh_dtype,
- video_dtype,
-)
-from renumics.spotlight.media import Sequence1D, Image, Audio, Video, Mesh
NormalizedValue = Union[
@@ -96,7 +87,7 @@ class ConversionFailed(Problem):
def __init__(
self,
value: NormalizedValue,
- dtype: DType,
+ dtype: dtypes.DType,
reason: Optional[str] = None,
) -> None:
super().__init__(
@@ -116,14 +107,14 @@ class NoConverterAvailable(Problem):
No matching converter could be applied
"""
- def __init__(self, value: NormalizedValue, dtype: DType) -> None:
+ def __init__(self, value: NormalizedValue, dtype: dtypes.DType) -> None:
msg = f"No Converter for {type(value)} -> {dtype}"
super().__init__(title="No matching converter", detail=msg, status_code=422)
N = TypeVar("N", bound=NormalizedValue)
-Converter = Callable[[N, DType], ConvertedValue]
+Converter = Callable[[N, dtypes.DType], ConvertedValue]
_converters_table: Dict[
Type[NormalizedValue], Dict[str, List[Converter]]
] = defaultdict(lambda: defaultdict(list))
@@ -176,7 +167,10 @@ def _decorate(func: Converter[N]) -> Converter[N]:
def convert_to_dtype(
- value: NormalizedValue, dtype: DType, simple: bool = False, check: bool = True
+ value: NormalizedValue,
+ dtype: dtypes.DType,
+ simple: bool = False,
+ check: bool = True,
) -> ConvertedValue:
"""
Convert normalized type from data source to internal Spotlight DType
@@ -221,37 +215,41 @@ def convert_to_dtype(
except (TypeError, ValueError) as e:
if check:
raise ConversionFailed(value, dtype) from e
- else:
- return None
-
- if check:
- if last_conversion_error:
- raise ConversionFailed(value, dtype, last_conversion_error.reason)
- else:
+ else:
+ if check:
+ if last_conversion_error:
+ raise ConversionFailed(value, dtype, last_conversion_error.reason)
raise NoConverterAvailable(value, dtype)
+ if simple and (
+ dtypes.is_array_dtype(dtype)
+ or dtypes.is_embedding_dtype(dtype)
+ or dtypes.is_sequence_1d_dtype(dtype)
+ or dtypes.is_filebased_dtype(dtype)
+ ):
+ return ""
return None
@convert("datetime")
-def _(value: datetime.datetime, _: DType) -> datetime.datetime:
+def _(value: datetime.datetime, _: dtypes.DType) -> datetime.datetime:
return value
@convert("datetime")
-def _(value: Union[str, np.str_], _: DType) -> Optional[datetime.datetime]:
+def _(value: Union[str, np.str_], _: dtypes.DType) -> Optional[datetime.datetime]:
if value == "":
return None
return datetime.datetime.fromisoformat(value)
@convert("datetime")
-def _(value: np.datetime64, _: DType) -> Optional[datetime.datetime]:
+def _(value: np.datetime64, _: dtypes.DType) -> Optional[datetime.datetime]:
return value.tolist()
@convert("Category")
-def _(value: Union[str, np.str_], dtype: CategoryDType) -> int:
+def _(value: Union[str, np.str_], dtype: dtypes.CategoryDType) -> int:
categories = dtype.categories
if not categories:
return -1
@@ -259,7 +257,7 @@ def _(value: Union[str, np.str_], dtype: CategoryDType) -> int:
@convert("Category")
-def _(_: None, _dtype: CategoryDType) -> int:
+def _(_: None, _dtype: dtypes.CategoryDType) -> int:
return -1
@@ -276,23 +274,23 @@ def _(
np.uint32,
np.uint64,
],
- _: CategoryDType,
+ _: dtypes.CategoryDType,
) -> int:
return int(value)
@convert("Window")
-def _(value: list, _: DType) -> np.ndarray:
+def _(value: list, _: dtypes.DType) -> np.ndarray:
return np.array(value, dtype=np.float64)
@convert("Window")
-def _(value: np.ndarray, _: DType) -> np.ndarray:
+def _(value: np.ndarray, _: dtypes.DType) -> np.ndarray:
return value.astype(np.float64)
@convert("Window")
-def _(value: Union[str, np.str_], _: DType) -> np.ndarray:
+def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray:
try:
obj = ast.literal_eval(value)
array = np.array(obj, dtype=np.float64)
@@ -304,17 +302,17 @@ def _(value: Union[str, np.str_], _: DType) -> np.ndarray:
@convert("Embedding", simple=False)
-def _(value: list, _: DType) -> np.ndarray:
+def _(value: list, _: dtypes.DType) -> np.ndarray:
return np.array(value, dtype=np.float64)
@convert("Embedding", simple=False)
-def _(value: np.ndarray, _: DType) -> np.ndarray:
+def _(value: np.ndarray, _: dtypes.DType) -> np.ndarray:
return value.astype(np.float64)
@convert("Embedding", simple=False)
-def _(value: Union[str, np.str_], _: DType) -> np.ndarray:
+def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray:
try:
obj = ast.literal_eval(value)
array = np.array(obj, dtype=np.float64)
@@ -326,50 +324,54 @@ def _(value: Union[str, np.str_], _: DType) -> np.ndarray:
@convert("Sequence1D", simple=False)
-def _(value: Union[np.ndarray, list], _: DType) -> np.ndarray:
- return Sequence1D(value).encode()
+def _(value: Union[np.ndarray, list], _: dtypes.DType) -> np.ndarray:
+ return media.Sequence1D(value).encode()
@convert("Sequence1D", simple=False)
-def _(value: Union[str, np.str_], _: DType) -> np.ndarray:
+def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray:
try:
obj = ast.literal_eval(value)
- return Sequence1D(obj).encode()
+ return media.Sequence1D(obj).encode()
except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
raise ConversionError("Cannot interpret string as a 1D sequence")
@convert("Image", simple=False)
-def _(value: Union[str, np.str_], _: DType) -> bytes:
+def _(value: Union[np.ndarray, list], _: dtypes.DType) -> bytes:
+ return media.Image(value).encode().tolist()
+
+
+@convert("Image", simple=False)
+def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes:
try:
- if data := read_external_value(value, image_dtype):
+ if data := read_external_value(value, dtypes.image_dtype):
return data.tolist()
except InvalidFile:
- raise ConversionError()
+ try:
+ obj = ast.literal_eval(value)
+ return media.Image(obj).encode().tolist()
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
+ raise ConversionError()
raise ConversionError()
@convert("Image", simple=False)
-def _(value: Union[bytes, np.bytes_], _: DType) -> bytes:
- return Image.from_bytes(value).encode().tolist()
-
-
-@convert("Image", simple=False)
-def _(value: np.ndarray, _: DType) -> bytes:
- return Image(value).encode().tolist()
+def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes:
+ return media.Image.from_bytes(value).encode().tolist()
@convert("Image", simple=False)
-def _(value: PIL.Image.Image, _: DType) -> bytes:
+def _(value: PIL.Image.Image, _: dtypes.DType) -> bytes:
buffer = io.BytesIO()
value.save(buffer, format="PNG")
return buffer.getvalue()
@convert("Audio", simple=False)
-def _(value: Union[str, np.str_], _: DType) -> bytes:
+def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes:
try:
- if data := read_external_value(value, audio_dtype):
+ if data := read_external_value(value, dtypes.audio_dtype):
return data.tolist()
except (InvalidFile, IndexError, ValueError):
raise ConversionError()
@@ -377,14 +379,14 @@ def _(value: Union[str, np.str_], _: DType) -> bytes:
@convert("Audio", simple=False)
-def _(value: Union[bytes, np.bytes_], _: DType) -> bytes:
- return Audio.from_bytes(value).encode().tolist()
+def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes:
+ return media.Audio.from_bytes(value).encode().tolist()
@convert("Video", simple=False)
-def _(value: Union[str, np.str_], _: DType) -> bytes:
+def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes:
try:
- if data := read_external_value(value, video_dtype):
+ if data := read_external_value(value, dtypes.video_dtype):
return data.tolist()
except InvalidFile:
raise ConversionError()
@@ -392,14 +394,14 @@ def _(value: Union[str, np.str_], _: DType) -> bytes:
@convert("Video", simple=False)
-def _(value: Union[bytes, np.bytes_], _: DType) -> bytes:
- return Video.from_bytes(value).encode().tolist()
+def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes:
+ return media.Video.from_bytes(value).encode().tolist()
@convert("Mesh", simple=False)
-def _(value: Union[str, np.str_], _: DType) -> bytes:
+def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes:
try:
- if data := read_external_value(value, mesh_dtype):
+ if data := read_external_value(value, dtypes.mesh_dtype):
return data.tolist()
except InvalidFile:
raise ConversionError()
@@ -407,29 +409,29 @@ def _(value: Union[str, np.str_], _: DType) -> bytes:
@convert("Mesh", simple=False)
-def _(value: Union[bytes, np.bytes_], _: DType) -> bytes:
+def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes:
return value
# this should not be necessary
@convert("Mesh", simple=False) # type: ignore
-def _(value: trimesh.Trimesh, _: DType) -> bytes:
- return Mesh.from_trimesh(value).encode().tolist()
+def _(value: trimesh.Trimesh, _: dtypes.DType) -> bytes:
+ return media.Mesh.from_trimesh(value).encode().tolist()
@convert("Embedding", simple=True)
@convert("Sequence1D", simple=True)
-def _(_: Union[np.ndarray, list, str, np.str_], _dtype: DType) -> str:
+def _(_: Union[np.ndarray, list, str, np.str_], _dtype: dtypes.DType) -> str:
return "[...]"
@convert("Image", simple=True)
-def _(_: np.ndarray, _dtype: DType) -> str:
+def _(_: Union[np.ndarray, list], _dtype: dtypes.DType) -> str:
return "[...]"
@convert("Image", simple=True)
-def _(_: PIL.Image.Image, _dtype: DType) -> str:
+def _(_: PIL.Image.Image, _dtype: dtypes.DType) -> str:
return ""
@@ -437,7 +439,7 @@ def _(_: PIL.Image.Image, _dtype: DType) -> str:
@convert("Audio", simple=True)
@convert("Video", simple=True)
@convert("Mesh", simple=True)
-def _(value: Union[str, np.str_], _: DType) -> str:
+def _(value: Union[str, np.str_], _: dtypes.DType) -> str:
return str(value)
@@ -445,19 +447,19 @@ def _(value: Union[str, np.str_], _: DType) -> str:
@convert("Audio", simple=True)
@convert("Video", simple=True)
@convert("Mesh", simple=True)
-def _(_: Union[bytes, np.bytes_], _dtype: DType) -> str:
+def _(_: Union[bytes, np.bytes_], _dtype: dtypes.DType) -> str:
return ""
# this should not be necessary
@convert("Mesh", simple=True) # type: ignore
-def _(_: trimesh.Trimesh, _dtype: DType) -> str:
+def _(_: trimesh.Trimesh, _dtype: dtypes.DType) -> str:
return ""
def read_external_value(
path_or_url: Optional[str],
- dtype: DType,
+ dtype: dtypes.DType,
target_format: Optional[str] = None,
workdir: PathType = ".",
) -> Optional[np.void]:
@@ -494,7 +496,7 @@ def prepare_path_or_url(path_or_url: PathOrUrlType, workdir: PathType) -> str:
def _decode_external_value(
path_or_url: PathOrUrlType,
- dtype: DType,
+ dtype: dtypes.DType,
target_format: Optional[str] = None,
workdir: PathType = ".",
) -> np.void:
@@ -524,7 +526,7 @@ def _decode_external_value(
# Convert all other formats/codecs to flac.
output_format, output_codec = "flac", "flac"
else:
- output_format, output_codec = Audio.get_format_codec(target_format)
+ output_format, output_codec = media.Audio.get_format_codec(target_format)
if output_format == input_format and output_codec == input_codec:
# Nothing to transcode
if isinstance(file, str):
@@ -550,10 +552,10 @@ def _decode_external_value(
):
return np.void(file.read())
# `image/tiff`s become blank in frontend, so convert them too.
- return Image.from_file(file).encode(target_format)
+ return media.Image.from_file(file).encode(target_format)
if dtypes.is_mesh_dtype(dtype):
- return Mesh.from_file(path_or_url).encode(target_format)
+ return media.Mesh.from_file(path_or_url).encode(target_format)
if dtypes.is_video_dtype(dtype):
- return Video.from_file(path_or_url).encode(target_format)
+ return media.Video.from_file(path_or_url).encode(target_format)
assert False
diff --git a/renumics/spotlight/io/__init__.py b/renumics/spotlight/io/__init__.py
index 2d2a6d26..8162a843 100644
--- a/renumics/spotlight/io/__init__.py
+++ b/renumics/spotlight/io/__init__.py
@@ -1,6 +1,9 @@
"""
Reading and writing of different data formats.
"""
+import ast
+from contextlib import suppress
+from typing import Any
from .audio import (
get_format_codec,
@@ -19,6 +22,8 @@
decode_gltf_arrays,
encode_gltf_array,
)
+from .huggingface import prepare_hugging_face_dict
+
__all__ = [
"get_format_codec",
@@ -34,4 +39,15 @@
"check_gltf",
"decode_gltf_arrays",
"encode_gltf_array",
+ "prepare_hugging_face_dict",
+ "try_literal_eval",
]
+
+
+def try_literal_eval(x: str) -> Any:
+ """
+ Try to evaluate a literal expression, otherwise return value as is.
+ """
+ with suppress(Exception):
+ return ast.literal_eval(x)
+ return x
diff --git a/renumics/spotlight/io/huggingface.py b/renumics/spotlight/io/huggingface.py
new file mode 100644
index 00000000..06c0d441
--- /dev/null
+++ b/renumics/spotlight/io/huggingface.py
@@ -0,0 +1,16 @@
+"""
+Helpers for HuggingFace formats.
+"""
+from typing import Any, Dict
+
+
+def prepare_hugging_face_dict(x: Dict) -> Any:
+ """
+ Prepare HuggingFace format for files to be used in Spotlight.
+ """
+ if x.keys() != {"bytes", "path"}:
+ return x
+ blob = x["bytes"]
+ if blob is not None:
+ return blob
+ return x["path"]
diff --git a/renumics/spotlight/server.py b/renumics/spotlight/server.py
index 83d688db..bf56918d 100644
--- a/renumics/spotlight/server.py
+++ b/renumics/spotlight/server.py
@@ -41,6 +41,8 @@ class Server:
process: Optional[subprocess.Popen]
_startup_event: threading.Event
+ _update_complete_event: threading.Event
+ _update_error: Optional[Exception]
connection: Optional[multiprocessing.connection.Connection]
_connection_message_queue: Queue
@@ -77,7 +79,7 @@ def __init__(self, host: str = "127.0.0.1", port: int = 8000) -> None:
)
self._startup_event = threading.Event()
- self._startup_complete_event = threading.Event()
+ self._update_complete_event = threading.Event()
self._connection_thread_online = threading.Event()
self._connection_thread = threading.Thread(
@@ -151,7 +153,6 @@ def start(self, config: AppConfig) -> None:
command.append("--reload")
# start uvicorn
-
self.process = subprocess.Popen(
command,
env=env,
@@ -164,7 +165,7 @@ def start(self, config: AppConfig) -> None:
)
if platform.system() != "Windows":
sock.close()
- self._startup_complete_event.wait(timeout=120)
+ self._wait_for_update()
def stop(self) -> None:
"""
@@ -197,7 +198,6 @@ def stop(self) -> None:
self._port = self._requested_port
self._startup_event.clear()
- self._startup_complete_event.clear()
@property
def running(self) -> bool:
@@ -217,9 +217,21 @@ def update(self, config: AppConfig) -> None:
"""
Update app config
"""
+ self._update(config)
+ self._wait_for_update()
+
+ def _update(self, config: AppConfig) -> None:
self._app_config = config
self.send({"kind": "update", "data": config})
+ def _wait_for_update(self) -> None:
+ self._update_complete_event.wait(timeout=120)
+ self._update_complete_event.clear()
+ err = self._update_error
+ self._update_error = None
+ if err:
+ raise err
+
def get_df(self) -> Optional[pd.DataFrame]:
"""
Request and return the current DafaFrame from the server process (if possible)
@@ -236,9 +248,10 @@ def _handle_message(self, message: Any) -> None:
if kind == "startup":
self._startup_event.set()
- self.update(self._app_config)
- elif kind == "startup_complete":
- self._startup_complete_event.set()
+ self._update(self._app_config)
+ elif kind == "update_complete":
+ self._update_error = message.get("error")
+ self._update_complete_event.set()
elif kind == "frontend_connected":
self.connected_frontends = message["data"]
self._all_frontends_disconnected.clear()
diff --git a/renumics/spotlight/viewer.py b/renumics/spotlight/viewer.py
index b7630b17..817f31b2 100644
--- a/renumics/spotlight/viewer.py
+++ b/renumics/spotlight/viewer.py
@@ -176,7 +176,11 @@ def show(
if self not in _VIEWERS:
_VIEWERS.append(self)
else:
- self._server.update(config)
+ try:
+ self._server.update(config)
+ except Exception as e:
+ self.close()
+ raise e
if not no_browser and self._server.connected_frontends == 0:
self.open_browser()
@@ -220,7 +224,7 @@ def open_browser(self) -> None:
"""
if not self.port:
return
- launch_browser_in_thread(self.host, self.port)
+ launch_browser_in_thread("localhost", self.port)
def refresh(self) -> None:
"""
diff --git a/renumics/spotlight_plugins/core/api/table.py b/renumics/spotlight_plugins/core/api/table.py
index 1446bb74..807fdc40 100644
--- a/renumics/spotlight_plugins/core/api/table.py
+++ b/renumics/spotlight_plugins/core/api/table.py
@@ -73,7 +73,7 @@ def get_table(request: Request) -> ORJSONResponse:
columns = []
for column_name in data_store.column_names:
dtype = data_store.dtypes[column_name]
- values = data_store.get_converted_values(column_name, simple=True)
+ values = data_store.get_converted_values(column_name, simple=True, check=False)
meta = data_store.get_column_metadata(column_name)
column = Column(
name=column_name,
diff --git a/renumics/spotlight_plugins/core/pandas_data_source.py b/renumics/spotlight_plugins/core/pandas_data_source.py
index 430a69ca..4404a14d 100644
--- a/renumics/spotlight_plugins/core/pandas_data_source.py
+++ b/renumics/spotlight_plugins/core/pandas_data_source.py
@@ -7,14 +7,9 @@
import numpy as np
import pandas as pd
import datasets
-from renumics.spotlight import dtypes
-from renumics.spotlight.io.pandas import (
- infer_dtype,
- prepare_hugging_face_dict,
- stringify_columns,
- try_literal_eval,
-)
+from renumics.spotlight import dtypes
+from renumics.spotlight.io import prepare_hugging_face_dict, try_literal_eval
from renumics.spotlight.data_source import (
datasource,
ColumnMetadata,
@@ -23,7 +18,6 @@
from renumics.spotlight.backend.exceptions import DatasetColumnsNotUnique
from renumics.spotlight.dataset.exceptions import ColumnNotExistsError
from renumics.spotlight.data_source.exceptions import InvalidDataSource
-from renumics.spotlight.dtypes import DTypeMap
@datasource(pd.DataFrame)
@@ -41,6 +35,7 @@ class PandasDataSource(DataSource):
_uid: str
_df: pd.DataFrame
_name: str
+ _intermediate_dtypes: dtypes.DTypeMap
def __init__(self, source: Union[Path, pd.DataFrame]):
if isinstance(source, Path):
@@ -108,7 +103,7 @@ def __init__(self, source: Union[Path, pd.DataFrame]):
@property
def column_names(self) -> List[str]:
- return stringify_columns(self._df)
+ return [str(column) for column in self._df.columns]
@property
def df(self) -> pd.DataFrame:
@@ -118,18 +113,15 @@ def df(self) -> pd.DataFrame:
return self._df.copy()
@property
- def intermediate_dtypes(self) -> DTypeMap:
+ def intermediate_dtypes(self) -> dtypes.DTypeMap:
return self._intermediate_dtypes
def __len__(self) -> int:
return len(self._df)
@property
- def semantic_dtypes(self) -> DTypeMap:
- return {
- str(column_name): infer_dtype(self.df[column_name])
- for column_name in self.df
- }
+ def semantic_dtypes(self) -> dtypes.DTypeMap:
+ return {}
def get_generation_id(self) -> int:
return self._generation_id
@@ -167,12 +159,14 @@ def get_column_values(
if pd.api.types.is_categorical_dtype(column):
return column.cat.codes
if pd.api.types.is_string_dtype(column):
- values = column.to_numpy()
- na_mask = column.isna()
- values[na_mask] = None
- return values
+ column = column.astype(object).mask(column.isna(), None)
+ str_mask = column.map(type) == str
+ column[str_mask] = column[str_mask].apply(try_literal_eval)
+ dict_mask = column.map(type) == dict
+ column[dict_mask] = column[dict_mask].apply(prepare_hugging_face_dict)
+ return column.to_numpy()
if pd.api.types.is_object_dtype(column):
- column = column.mask(column.isna(), None)
+ column = column.astype(object).mask(column.isna(), None)
str_mask = column.map(type) == str
column[str_mask] = column[str_mask].apply(try_literal_eval)
dict_mask = column.map(type) == dict
@@ -222,5 +216,4 @@ def _determine_intermediate_dtype(column: pd.Series) -> dtypes.DType:
return dtypes.datetime_dtype
if pd.api.types.is_string_dtype(column):
return dtypes.str_dtype
- else:
- return dtypes.mixed_dtype
+ return dtypes.mixed_dtype
diff --git a/src/lenses/SpectrogramLens/SpectrogramWorker.ts b/src/lenses/SpectrogramLens/SpectrogramWorker.ts
index 55cad783..ff47fcf3 100644
--- a/src/lenses/SpectrogramLens/SpectrogramWorker.ts
+++ b/src/lenses/SpectrogramLens/SpectrogramWorker.ts
@@ -233,12 +233,7 @@ const calculateFrequencies = (
tables.cosTable
);
- const array = new Uint8Array(fftSamples / 2);
- let j;
- for (j = 0; j < fftSamples / 2; j++) {
- array[j] = Math.max(-255, Math.log10(spectrum[j]) * 45);
- }
- channelFreq.push(array);
+ channelFreq.push(spectrum);
currentOffset += fftSamples - noverlap;
}
frequencies.push(channelFreq);
diff --git a/src/lenses/SpectrogramLens/index.tsx b/src/lenses/SpectrogramLens/index.tsx
index 60522df1..70e48ba1 100644
--- a/src/lenses/SpectrogramLens/index.tsx
+++ b/src/lenses/SpectrogramLens/index.tsx
@@ -10,7 +10,6 @@ import { ColorsState, useColors } from '../../stores/colors';
import { Lens } from '../../types';
import useSetting from '../useSetting';
import MenuBar from './MenuBar';
-import chroma from 'chroma-js';
import { fixWindow, freqType, unitType, amplitudeToDb } from './Spectrogram';
const Container = tw.div`flex flex-col w-full h-full items-stretch justify-center`;
@@ -209,7 +208,8 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => {
const widthScale = d3.scaleLinear([0, width], [0, frequenciesData.length]);
let drawData = [];
- let colorScale: chroma.Scale;
+ let min = 0;
+ let max = 0;
if (ampScale === 'decibel') {
let ref = 0;
@@ -223,9 +223,6 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => {
//const top_db = 80;
const amin = 1e-5;
- let log_spec_max = 0;
- let log_spec_min = 0;
-
// Convert amplitudes to decibels
for (let i = 0; i < frequenciesData.length; i++) {
const col = [];
@@ -234,23 +231,33 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => {
const amplitude = frequenciesData[i][j];
col[j] = amplitudeToDb(amplitude, ref, amin);
- if (col[j] > log_spec_max) {
- log_spec_max = col[j];
+ if (col[j] > max) {
+ max = col[j];
}
- if (col[j] < log_spec_min) {
- log_spec_min = col[j];
+ if (col[j] < min) {
+ min = col[j];
}
}
drawData[i] = col;
}
- colorScale = colorPalette.scale().domain([log_spec_min, log_spec_max]);
} else {
// ampScale === 'linear'
- colorScale = colorPalette.scale().domain([0, 256]);
+ for (let i = 0; i < frequenciesData.length; i++) {
+ const maxI = Math.max(...frequenciesData[i]);
+ const minI = Math.min(...frequenciesData[i]);
+
+ if (maxI > max) {
+ max = maxI;
+ }
+ if (minI < min) {
+ min = minI;
+ }
+ }
drawData = frequenciesData;
}
+ const colorScale = colorPalette.scale().domain([min, max]);
for (let y = 0; y < height; y++) {
let value = 0;
diff --git a/src/services/data.ts b/src/services/data.ts
index 16aa82f5..069e01dc 100644
--- a/src/services/data.ts
+++ b/src/services/data.ts
@@ -1,7 +1,5 @@
-import { useDataset } from '../stores/dataset';
-import { useLayout } from '../stores/layout';
import { IndexArray } from '../types';
-import { v4 as uuidv4 } from 'uuid';
+import taskService from './task';
export const umapMetricNames = [
'euclidean',
@@ -10,11 +8,8 @@ export const umapMetricNames = [
'cosine',
'mahalanobis',
] as const;
-
export type UmapMetric = typeof umapMetricNames[number];
-
export const pcaNormalizations = ['none', 'standardize', 'robust standardize'] as const;
-
export type PCANormalization = typeof pcaNormalizations[number];
interface ReductionResult {
@@ -22,84 +17,7 @@ interface ReductionResult {
indices: IndexArray;
}
-const MAX_QUEUED_MESSAGES = 16;
-
-class Connection {
- url: string;
- socket?: WebSocket;
- messageQueue: string[];
- onmessage?: (data: unknown) => void;
-
- constructor(host: string, port: string) {
- this.messageQueue = [];
- if (globalThis.location.protocol === 'https:') {
- this.url = `wss://${host}:${port}/api/ws`;
- } else {
- this.url = `ws://${host}:${port}/api/ws`;
- }
- this.#connect();
- }
-
- send(message: unknown): void {
- const data = JSON.stringify(message);
- if (this.socket) {
- this.socket.send(data);
- } else {
- if (this.messageQueue.length > MAX_QUEUED_MESSAGES) {
- this.messageQueue.shift();
- }
- this.messageQueue.push(data);
- }
- }
-
- #connect() {
- const webSocket = new WebSocket(this.url);
-
- webSocket.onopen = () => {
- this.socket = webSocket;
- this.messageQueue.forEach((message) => this.socket?.send(message));
- this.messageQueue.length = 0;
- };
- webSocket.onmessage = (event) => {
- const message = JSON.parse(event.data);
- this.onmessage?.(message);
- };
- webSocket.onerror = () => {
- webSocket.close();
- };
- webSocket.onclose = () => {
- this.socket = undefined;
- setTimeout(() => {
- this.#connect();
- }, 500);
- };
- }
-}
-
export class DataService {
- connection: Connection;
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- dispatchTable: Map void>;
-
- constructor(host: string, port: string) {
- this.connection = new Connection(host, port);
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- this.connection.onmessage = (message: any) => {
- if (message.uid) {
- this.dispatchTable.get(message.uid)?.(message);
- return;
- }
- if (message.type === 'refresh') {
- useDataset.getState().refresh();
- } else if (message.type === 'resetLayout') {
- useLayout.getState().reset();
- } else if (message.type === 'issuesUpdated') {
- useDataset.getState().fetchIssues();
- }
- };
- this.dispatchTable = new Map();
- }
-
async computeUmap(
widgetId: string,
columnNames: string[],
@@ -108,34 +26,13 @@ export class DataService {
metric: UmapMetric,
min_dist: number
): Promise {
- const messageId = uuidv4();
-
- const promise = new Promise((resolve) => {
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- this.dispatchTable.set(messageId, (message: any) => {
- const result = {
- points: message.data.points,
- indices: message.data.indices,
- };
- resolve(result);
- });
- });
-
- this.connection.send({
- type: 'umap',
- widget_id: widgetId,
- uid: messageId,
- generation_id: useDataset.getState().generationID,
- data: {
- indices: Array.from(indices),
- columns: columnNames,
- n_neighbors: n_neighbors,
- metric: metric,
- min_dist: min_dist,
- },
+ return await taskService.run('umap', widgetId, {
+ column_names: columnNames,
+ indices: Array.from(indices),
+ n_neighbors,
+ metric,
+ min_dist,
});
-
- return promise;
}
async computePCA(
@@ -144,38 +41,13 @@ export class DataService {
indices: IndexArray,
pcaNormalization: PCANormalization
): Promise {
- const messageId = uuidv4();
-
- const promise = new Promise((resolve) => {
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- this.dispatchTable.set(messageId, (message: any) => {
- const result = {
- points: message.data.points,
- indices: message.data.indices,
- };
- resolve(result);
- });
- });
-
- this.connection.send({
- type: 'pca',
- widget_id: widgetId,
- uid: messageId,
- generation_id: useDataset.getState().generationID,
- data: {
- indices: Array.from(indices),
- columns: columnNames,
- normalization: pcaNormalization,
- },
+ return await taskService.run('pca', widgetId, {
+ indices: Array.from(indices),
+ column_names: columnNames,
+ normalization: pcaNormalization,
});
-
- return promise;
}
}
-const dataService = new DataService(
- globalThis.location.hostname,
- globalThis.location.port
-);
-
+const dataService = new DataService();
export default dataService;
diff --git a/src/services/task.ts b/src/services/task.ts
new file mode 100644
index 00000000..f95a5d5e
--- /dev/null
+++ b/src/services/task.ts
@@ -0,0 +1,47 @@
+import { useDataset } from '../lib';
+import websocketService, { WebsocketService } from './websocket';
+
+interface ResponseHandler {
+ resolve: (result: unknown) => void;
+ reject: (error: unknown) => void;
+}
+
+// service handling the execution of remote tasks
+class TaskService {
+ // dispatch table for task responses/errors
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ dispatchTable: Map;
+ websocketService: WebsocketService;
+
+ constructor(websocketService: WebsocketService) {
+ this.dispatchTable = new Map();
+ this.websocketService = websocketService;
+ this.websocketService.registerMessageHandler('task.result', (data) => {
+ this.dispatchTable.get(data.task_id)?.resolve(data.result);
+ });
+ this.websocketService.registerMessageHandler('task.error', (data) => {
+ this.dispatchTable.get(data.task_id)?.reject(data.error);
+ });
+ }
+
+ async run(task: string, name: string, args: unknown) {
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ return new Promise((resolve, reject) => {
+ const task_id = crypto.randomUUID();
+ this.dispatchTable.set(task_id, { resolve, reject });
+ websocketService.send({
+ type: 'task',
+ data: {
+ task: task,
+ widget_id: name,
+ task_id,
+ generation_id: useDataset.getState().generationID,
+ args,
+ },
+ });
+ });
+ }
+}
+
+const taskService = new TaskService(websocketService);
+export default taskService;
diff --git a/src/services/websocket/connection.ts b/src/services/websocket/connection.ts
new file mode 100644
index 00000000..84452ca7
--- /dev/null
+++ b/src/services/websocket/connection.ts
@@ -0,0 +1,61 @@
+// the maximum number of outgoing messages that are queued,
+import { Message } from './types';
+
+// while the connection is down
+const MAX_QUEUED_MESSAGES = 16;
+
+// a websocket connection to the spotlight backend
+// automatically reconnects when necessary
+class Connection {
+ url: string;
+ socket?: WebSocket;
+ messageQueue: string[];
+ onmessage?: (data: unknown) => void;
+
+ constructor(host: string, port: string) {
+ this.messageQueue = [];
+ if (globalThis.location.protocol === 'https:') {
+ this.url = `wss://${host}:${port}/api/ws`;
+ } else {
+ this.url = `ws://${host}:${port}/api/ws`;
+ }
+ this.#connect();
+ }
+
+ send(message: Message): void {
+ const data = JSON.stringify(message);
+ if (this.socket) {
+ this.socket.send(data);
+ } else {
+ if (this.messageQueue.length > MAX_QUEUED_MESSAGES) {
+ this.messageQueue.shift();
+ }
+ this.messageQueue.push(data);
+ }
+ }
+
+ #connect() {
+ const webSocket = new WebSocket(this.url);
+
+ webSocket.onopen = () => {
+ this.socket = webSocket;
+ this.messageQueue.forEach((message) => this.socket?.send(message));
+ this.messageQueue.length = 0;
+ };
+ webSocket.onmessage = (event) => {
+ const message = JSON.parse(event.data);
+ this.onmessage?.(message);
+ };
+ webSocket.onerror = () => {
+ webSocket.close();
+ };
+ webSocket.onclose = () => {
+ this.socket = undefined;
+ setTimeout(() => {
+ this.#connect();
+ }, 500);
+ };
+ }
+}
+
+export default Connection;
diff --git a/src/services/websocket/index.ts b/src/services/websocket/index.ts
new file mode 100644
index 00000000..135ea37c
--- /dev/null
+++ b/src/services/websocket/index.ts
@@ -0,0 +1,45 @@
+import { notifyProblem } from '../../notify';
+import { Problem } from '../../types';
+import Connection from './connection';
+import { Message, MessageHandler } from './types';
+
+// this service provides a websocket connection to the service
+// and handles incoming and outgoing messages
+export class WebsocketService {
+ connection: Connection;
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ messageHandlers: Map;
+
+ constructor(host: string, port: string) {
+ this.messageHandlers = new Map();
+ this.connection = new Connection(host, port);
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ this.connection.onmessage = (message: any) => {
+ const messageHandler = this.messageHandlers.get(message.type);
+ if (messageHandler) {
+ messageHandler(message.data);
+ } else {
+ console.error(`Unknown websocket message: ${message.type}`);
+ }
+ };
+ }
+
+ registerMessageHandler(messageType: string, handler: MessageHandler): void {
+ this.messageHandlers.set(messageType, handler);
+ }
+
+ send(message: Message): void {
+ this.connection.send(message);
+ }
+}
+
+const websocketService = new WebsocketService(
+ globalThis.location.hostname,
+ globalThis.location.port
+);
+
+websocketService.registerMessageHandler('error', (problem: Problem) => {
+ notifyProblem(problem);
+});
+
+export default websocketService;
diff --git a/src/services/websocket/types.ts b/src/services/websocket/types.ts
new file mode 100644
index 00000000..18840150
--- /dev/null
+++ b/src/services/websocket/types.ts
@@ -0,0 +1,7 @@
+export interface Message {
+ type: string;
+ data: unknown;
+}
+
+// eslint-disable-next-line @typescript-eslint/no-explicit-any
+export type MessageHandler = (data: any) => void;
diff --git a/src/stores/dataset/dataset.ts b/src/stores/dataset/dataset.ts
index 0cd3e3c0..4916c015 100644
--- a/src/stores/dataset/dataset.ts
+++ b/src/stores/dataset/dataset.ts
@@ -22,6 +22,7 @@ import { notifyAPIError, notifyError } from '../../notify';
import { makeColumnsColorTransferFunctions } from './colorTransferFunctionFactory';
import { makeColumn } from './columnFactory';
import { makeColumnsStats } from './statisticsFactory';
+import websocketService from '../../services/websocket';
export type CallbackOrData = ((data: T) => T) | T;
@@ -540,3 +541,11 @@ useDataset.subscribe(
);
useColors.subscribe(() => useDataset.getState().recomputeColorTransferFunctions());
+
+websocketService.registerMessageHandler('refresh', () => {
+ useDataset.getState().refresh();
+});
+
+websocketService.registerMessageHandler('issuesUpdated', () => {
+ useDataset.getState().fetchIssues();
+});
diff --git a/src/stores/layout.ts b/src/stores/layout.ts
index 8eb5f517..53d4a993 100644
--- a/src/stores/layout.ts
+++ b/src/stores/layout.ts
@@ -2,6 +2,7 @@ import { create } from 'zustand';
import { AppLayout } from '../types';
import api from '../api';
import { saveAs } from 'file-saver';
+import websocketService from '../services/websocket';
export interface State {
layout: AppLayout;
@@ -41,3 +42,7 @@ export const useLayout = create((set) => ({
reader.readAsText(file);
},
}));
+
+websocketService.registerMessageHandler('resetLayout', () => {
+ useLayout.getState().reset();
+});
diff --git a/src/widgets/SimilarityMap/SimilarityMap.tsx b/src/widgets/SimilarityMap/SimilarityMap.tsx
index 193a46f8..023c4dd1 100644
--- a/src/widgets/SimilarityMap/SimilarityMap.tsx
+++ b/src/widgets/SimilarityMap/SimilarityMap.tsx
@@ -23,6 +23,7 @@ import {
IndexArray,
isNumberColumn,
NumberColumn,
+ Problem,
TableData,
} from '../../types';
import { createSizeTransferFunction } from '../../dataformat';
@@ -220,8 +221,8 @@ const SimilarityMap: Widget = () => {
const placeByColumns = useDataset(placeByColumnsSelector, shallow);
const [positions, setPositions] = useState<[number, number][]>([]);
-
const [visibleIndices, setVisibleIndices] = useState([]);
+ const [problem, setProblem] = useState(null);
const selected = useMemo(() => {
const selected: boolean[] = [];
@@ -275,13 +276,14 @@ const SimilarityMap: Widget = () => {
const widgetId = useMemo(() => uuidv4(), []);
useEffect(() => {
+ setVisibleIndices([]);
+ setPositions([]);
+ setProblem(null);
+
if (!indices.length || !placeByColumnKeys.length) {
- setVisibleIndices([]);
- setPositions([]);
setIsComputing(false);
return;
}
-
setIsComputing(true);
const reductionPromise =
@@ -302,12 +304,18 @@ const SimilarityMap: Widget = () => {
);
let cancelled = false;
- reductionPromise.then(({ points, indices }) => {
- if (cancelled) return;
- setVisibleIndices(indices);
- setPositions(points);
- setIsComputing(false);
- });
+ reductionPromise
+ .then(({ points, indices }) => {
+ if (cancelled) return;
+ setVisibleIndices(indices);
+ setPositions(points);
+ setIsComputing(false);
+ })
+ .catch((problem: Problem) => {
+ if (cancelled) return;
+ setProblem(problem);
+ setIsComputing(false);
+ });
return () => {
cancelled = true;
@@ -443,23 +451,36 @@ const SimilarityMap: Widget = () => {
const areColumnsSelected = !!placeByColumnKeys.length;
const hasVisibleRows = !!visibleIndices.length;
- return (
-
- {isComputing && }
- {!areColumnsSelected && !isComputing && (
-
-
- Select columns and show at least two rows to display similarity
- map
-
-
- )}
- {areColumnsSelected && !hasVisibleRows && !isComputing && (
-
- Not enough rows
-
- )}
- {areColumnsSelected && hasVisibleRows && !isComputing && (
+ let content: JSX.Element;
+ if (problem) {
+ content = (
+
+
+
+
{problem.title}
+
{problem.detail}
+
+
+ );
+ } else if (isComputing) {
+ content = ;
+ } else if (!areColumnsSelected) {
+ content = (
+
+
+ Select columns and show at least two rows to display similarity map
+
+
+ );
+ } else if (!hasVisibleRows) {
+ content = (
+
+ Not enough rows
+
+ );
+ } else {
+ content = (
+ <>
{
/>
)}
- )}
- {!isComputing && (
{visibleIndices.length} of {indices.length} rows
- )}
+ >
+ );
+ }
+ return (
+
+ {content}