diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d6628c4f..e71d164e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -11,10 +11,10 @@ Technical details on how to contribute can be found in our [documentation](https There are several ways you can contribute to Spotlight: -* Fix outstanding issues. -* Implement new features. -* Submit issues related to bugs or desired new features. -* Share your use case +- Fix outstanding issues. +- Implement new features. +- Submit issues related to bugs or desired new features. +- Share your use case If you don't know where to start, you might want to have a look at [hacktoberfest issues](https://github.com/Renumics/spotlight/issues?q=is%3Aissue+is%3Aopen+label%3Ahacktoberfest) and our guide on how to create a [new Lens](https://renumics.com/docs/development/lenses). diff --git a/README.md b/README.md index 6c68d669..320f9e13 100644 --- a/README.md +++ b/README.md @@ -17,9 +17,10 @@

-Spotlight helps you to **understand unstructured datasets** fast. You can quickly create **interactive visualizations** and leverage data enrichments (e.g. embeddings, prediction, uncertainties) to **identify critical clusters** in your data. +Spotlight helps you to **understand unstructured datasets** fast. You can quickly create **interactive visualizations** and leverage data enrichments (e.g. embeddings, prediction, uncertainties) to **identify critical clusters** in your data. Spotlight supports most unstructured data types including **images, audio, text, videos, time-series and geometric data**. You can start from your existing dataframe: +

And start Spotlight with just a few lines of code: @@ -49,7 +50,7 @@ Machine learning and engineering teams use Spotlight to understand and communica [Classification] Find Issues in Any Image Classification Dataset 👨‍đŸ’ģ 📝 🕹ī¸ - + Find data issues in the CIFAR-100 image dataset 🕹ī¸ @@ -91,7 +92,6 @@ Machine learning and engineering teams use Spotlight to understand and communica - ## ⏱ī¸ Quickstart Get started by installing Spotlight and loading your first dataset. @@ -132,12 +132,11 @@ ds = datasets.load_dataset('renumics/emodb-enriched', split='all') layout= spotlight.layouts.debug_classification(label='gender', prediction='m1_gender_prediction', embedding='m1_embedding', features=['age', 'emotion']) spotlight.show(ds, layout=layout) ``` + Here, the data types are discovered automatically from the dataset and we use a pre-defined layout for model debugging. Custom layouts can be built programmatically or via the UI. > The `datasets[audio]` package can be installed via pip. - - #### Usage Tracking We have added crash report and performance collection. We do NOT collect user data other than an anonymized Machine Id obtained by py-machineid, and only log our own actions. We do NOT collect folder names, dataset names, or row data of any kind only aggregate performance statistics like total time of a table_load, crash data, etc. Collecting Spotlight crashes will help us improve stability. To opt out of the crash report collection define an environment variable called `SPOTLIGHT_OPT_OUT` and set it to true. e.G.`export SPOTLIGHT_OPT_OUT=true` @@ -150,9 +149,9 @@ We have added crash report and performance collection. We do NOT collect user da ## Learn more about unstructured data workflows -- 🤗 [Huggingface](https://huggingface.co/renumics) example spaces and datasets -- 🏀 [Playbook](https://renumics.com/docs/playbook/) for data-centric AI workflows -- 🍰 [Sliceguard](https://github.com/Renumics/sliceguard) library for automatic slice detection +- 🤗 [Huggingface](https://huggingface.co/renumics) example spaces and datasets +- 🏀 [Playbook](https://renumics.com/docs/playbook/) for data-centric AI workflows +- 🍰 [Sliceguard](https://github.com/Renumics/sliceguard) library for automatic slice detection ## Contribute diff --git a/renumics/spotlight/app.py b/renumics/spotlight/app.py index 62220d2f..5baa724d 100644 --- a/renumics/spotlight/app.py +++ b/renumics/spotlight/app.py @@ -15,6 +15,7 @@ from typing_extensions import Annotated from fastapi import Cookie, FastAPI, Request, status +from fastapi.datastructures import Headers from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response from fastapi.staticfiles import StaticFiles @@ -56,10 +57,23 @@ from renumics.spotlight.dtypes import DTypeMap - CURRENT_LAYOUT_KEY = "layout.current" +class UncachedStaticFiles(StaticFiles): + """ + FastAPI StaticFiles but without caching + """ + + def is_not_modified( + self, response_headers: Headers, request_headers: Headers + ) -> bool: + """ + Never Cache + """ + return False + + class IssuesUpdatedMessage(Message): """ Notify about updated issues. @@ -75,7 +89,6 @@ class SpotlightApp(FastAPI): """ # lifecycle - _startup_complete: bool _loop: asyncio.AbstractEventLoop # connection @@ -106,7 +119,6 @@ class SpotlightApp(FastAPI): def __init__(self) -> None: super().__init__() - self._startup_complete = False self.task_manager = TaskManager() self.websocket_manager = None self.config = Config() @@ -207,9 +219,13 @@ async def _(_: Request, problem: Problem) -> JSONResponse: plugin.activate(self) try: + # Mount frontend files as uncached, + # so that we always get the new frontend after updating spotlight. + # NOTE: we might not need this if we added a version hash + # to our built js files self.mount( "/static", - StaticFiles(packages=["renumics.spotlight.backend"]), + UncachedStaticFiles(packages=["renumics.spotlight.backend"]), name="assets", ) except AssertionError: @@ -295,44 +311,45 @@ def update(self, config: AppConfig) -> None: """ Update application config. """ - if config.project_root is not None: - self.project_root = config.project_root - if config.dtypes is not None: - self._user_dtypes = config.dtypes - if config.analyze is not None: - self.analyze_columns = config.analyze - if config.custom_issues is not None: - self.custom_issues = config.custom_issues - if config.dataset is not None: - self._dataset = config.dataset - self._data_source = create_datasource(self._dataset) - if config.layout is not None: - self._layout = config.layout or layouts.default() - if config.filebrowsing_allowed is not None: - self.filebrowsing_allowed = config.filebrowsing_allowed - - if config.dtypes is not None or config.dataset is not None: - data_source = self._data_source - assert data_source is not None - self._data_store = DataStore(data_source, self._user_dtypes) - self._broadcast(RefreshMessage()) - self._update_issues() - if config.layout is not None: - if self._data_store is not None: - dataset_uid = self._data_store.uid - future = asyncio.run_coroutine_threadsafe( - self.config.remove_all(CURRENT_LAYOUT_KEY, dataset=dataset_uid), - self._loop, - ) - future.result() - self._broadcast(ResetLayoutMessage()) + try: + if config.project_root is not None: + self.project_root = config.project_root + if config.dtypes is not None: + self._user_dtypes = config.dtypes + if config.analyze is not None: + self.analyze_columns = config.analyze + if config.custom_issues is not None: + self.custom_issues = config.custom_issues + if config.dataset is not None: + self._dataset = config.dataset + self._data_source = create_datasource(self._dataset) + if config.layout is not None: + self._layout = config.layout or layouts.default() + if config.filebrowsing_allowed is not None: + self.filebrowsing_allowed = config.filebrowsing_allowed + + if config.dtypes is not None or config.dataset is not None: + data_source = self._data_source + assert data_source is not None + self._data_store = DataStore(data_source, self._user_dtypes) + self._broadcast(RefreshMessage()) + self._update_issues() + if config.layout is not None: + if self._data_store is not None: + dataset_uid = self._data_store.uid + future = asyncio.run_coroutine_threadsafe( + self.config.remove_all(CURRENT_LAYOUT_KEY, dataset=dataset_uid), + self._loop, + ) + future.result() + self._broadcast(ResetLayoutMessage()) - for plugin in load_plugins(): - plugin.update(self, config) + for plugin in load_plugins(): + plugin.update(self, config) + except Exception as e: + self._connection.send({"kind": "update_complete", "error": e}) - if not self._startup_complete: - self._startup_complete = True - self._connection.send({"kind": "startup_complete"}) + self._connection.send({"kind": "update_complete"}) def _handle_message(self, message: Any) -> None: kind = message.get("kind") diff --git a/renumics/spotlight/backend/tasks/reduction.py b/renumics/spotlight/backend/tasks/reduction.py index 00207c38..9dae9021 100644 --- a/renumics/spotlight/backend/tasks/reduction.py +++ b/renumics/spotlight/backend/tasks/reduction.py @@ -6,11 +6,9 @@ import numpy as np import pandas as pd -from sklearn import preprocessing -from renumics.spotlight.dataset.exceptions import ColumnNotExistsError from renumics.spotlight.data_store import DataStore -from renumics.spotlight.dtypes import is_category_dtype, is_embedding_dtype +from renumics.spotlight import dtypes SEED = 42 @@ -27,6 +25,7 @@ def align_data( """ Align data from table's columns, remove `NaN`'s. """ + from sklearn import preprocessing if not column_names or not indices: return np.empty(0, np.float64), [] @@ -35,7 +34,7 @@ def align_data( for column_name in column_names: dtype = data_store.dtypes[column_name] column_values = data_store.get_converted_values(column_name, indices) - if is_embedding_dtype(dtype): + if dtypes.is_embedding_dtype(dtype): embedding_length = max( 0 if x is None else len(cast(np.ndarray, x)) for x in column_values ) @@ -49,17 +48,19 @@ def align_data( ] ) ) - elif is_category_dtype(dtype): + elif dtypes.is_category_dtype(dtype): na_mask = np.array(column_values) == -1 one_hot_values = preprocessing.label_binarize( column_values, classes=sorted(set(column_values).difference({-1})) # type: ignore ).astype(float) one_hot_values[na_mask] = np.nan aligned_values.append(one_hot_values) - elif dtype in (int, bool, float): + elif dtypes.is_scalar_dtype(dtype): aligned_values.append(np.array(column_values, dtype=float)) else: - raise ColumnNotEmbeddable + raise ColumnNotEmbeddable( + f"Column '{column_name}' of type {dtype} is not embeddable." + ) data = np.hstack([col.reshape((len(indices), -1)) for col in aligned_values]) mask = ~pd.isna(data).any(axis=1) @@ -78,10 +79,8 @@ def compute_umap( Prepare data from table and compute U-Map on them. """ - try: - data, indices = align_data(data_store, column_names, indices) - except (ColumnNotExistsError, ColumnNotEmbeddable): - return np.empty(0, np.float64), [] + data, indices = align_data(data_store, column_names, indices) + if data.size == 0: return np.empty(0, np.float64), [] @@ -114,14 +113,13 @@ def compute_pca( Prepare data from table and compute PCA on them. """ - from sklearn import preprocessing, decomposition + data, indices = align_data(data_store, column_names, indices) - try: - data, indices = align_data(data_store, column_names, indices) - except (ColumnNotExistsError, ValueError): - return np.empty(0, np.float64), [] if data.size == 0: return np.empty(0, np.float64), [] + + from sklearn import preprocessing, decomposition + if data.shape[1] == 1: return np.hstack((data, np.zeros_like(data))), indices if normalization == "standardize": @@ -129,5 +127,6 @@ def compute_pca( elif normalization == "robust standardize": data = preprocessing.RobustScaler(copy=False).fit_transform(data) reducer = decomposition.PCA(n_components=2, copy=False, random_state=SEED) - embeddings = reducer.fit_transform(data) + # `fit_transform` returns Fortran-ordered array. + embeddings = np.ascontiguousarray(reducer.fit_transform(data)) return embeddings, indices diff --git a/renumics/spotlight/backend/tasks/task_manager.py b/renumics/spotlight/backend/tasks/task_manager.py index e120fa88..8c392a55 100644 --- a/renumics/spotlight/backend/tasks/task_manager.py +++ b/renumics/spotlight/backend/tasks/task_manager.py @@ -6,7 +6,7 @@ import multiprocessing from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool -from typing import Any, Callable, List, Optional, Sequence, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union from .exceptions import TaskCancelled from .task import Task @@ -30,16 +30,20 @@ def create_task( self, func: Callable, args: Sequence[Any], + kwargs: Optional[Dict[str, Any]] = None, name: Optional[str] = None, tag: Optional[Union[str, int]] = None, ) -> Task: """ create and launch a new task """ + if kwargs is None: + kwargs = {} + # cancel running task with same name self.cancel(name=name) - future = self.pool.submit(func, *args) + future = self.pool.submit(func, *args, **kwargs) task = Task(name, tag, future) self.tasks.append(task) @@ -59,6 +63,7 @@ async def run_async( self, func: Callable[..., T], args: Sequence[Any], + kwargs: Optional[Dict[str, Any]] = None, name: Optional[str] = None, tag: Optional[Union[str, int]] = None, ) -> T: @@ -66,7 +71,7 @@ async def run_async( Launch a new task. Await and return result. """ - task = self.create_task(func, args, name, tag) + task = self.create_task(func, args=args, kwargs=kwargs, name=name, tag=tag) try: return await asyncio.wrap_future(task.future) except BrokenProcessPool as e: diff --git a/renumics/spotlight/backend/websockets.py b/renumics/spotlight/backend/websockets.py index 8be62f4a..aed9a1ef 100644 --- a/renumics/spotlight/backend/websockets.py +++ b/renumics/spotlight/backend/websockets.py @@ -3,11 +3,20 @@ """ import asyncio -import functools -import orjson -from typing import Any, List, Optional, Set, Callable +from typing import ( + Any, + Coroutine, + Dict, + Optional, + Set, + Callable, + Tuple, + Type, + cast, +) import numpy as np +import orjson from fastapi import WebSocket, WebSocketDisconnect from loguru import logger from pydantic import BaseModel @@ -17,7 +26,7 @@ from .tasks import TaskManager, TaskCancelled from .tasks.reduction import compute_umap, compute_pca -from .exceptions import GenerationIDMismatch +from .exceptions import GenerationIDMismatch, Problem class Message(BaseModel): @@ -47,115 +56,46 @@ class ResetLayoutMessage(Message): data: Any = None -class ReductionMessage(Message): - """ - Common data reduction message model. - """ - +class TaskData(BaseModel): + task: str widget_id: str - uid: str - generation_id: int - - -class ReductionRequestData(BaseModel): - """ - Base data reduction request payload. - """ - - indices: List[int] - columns: List[str] - - -class UMapRequestData(ReductionRequestData): - """ - U-Map request payload. - """ - - n_neighbors: int - metric: str - min_dist: float - - -class PCARequestData(ReductionRequestData): - """ - PCA request payload. - """ - - normalization: str + task_id: str + generation_id: Optional[int] + args: Any -class UMapRequest(ReductionMessage): - """ - U-Map request model. - """ - - data: UMapRequestData - - -class PCARequest(ReductionMessage): +class UnknownMessageType(Exception): """ - PCA request model. + Websocket message type is unknown. """ - data: PCARequestData - -class ReductionResponseData(BaseModel): +class SerializationError(Exception): """ - Data reduction response payload. + Failed to serialize the WS message """ - indices: List[int] - points: np.ndarray - class Config: - arbitrary_types_allowed = True +PayloadType = Type[BaseModel] +MessageHandler = Callable[[Any, "WebsocketConnection"], Coroutine[Any, Any, Any]] +MessageHandlerSpec = Tuple[PayloadType, MessageHandler] +MESSAGE_HANDLERS: Dict[str, MessageHandlerSpec] = {} -class ReductionResponse(ReductionMessage): - """ - Data reduction response model. - """ +def register_message_handler( + message_type: str, handler_spec: MessageHandlerSpec +) -> None: + MESSAGE_HANDLERS[message_type] = handler_spec - data: ReductionResponseData +def message_handler( + message_type: str, payload_type: Type[BaseModel] +) -> Callable[[MessageHandler], MessageHandler]: + def decorator(handler: MessageHandler) -> MessageHandler: + register_message_handler(message_type, (payload_type, handler)) + return handler -MESSAGE_BY_TYPE = { - "umap": UMapRequest, - "umap_result": ReductionResponse, - "pca": PCARequest, - "pca_result": ReductionResponse, - "refresh": RefreshMessage, -} - - -class UnknownMessageType(Exception): - """ - Websocket message type is unknown. - """ - - -def parse_message(raw_message: str) -> Message: - """ - Parse a websocket message from a raw text. - """ - json_message = orjson.loads(raw_message) - message_type = json_message["type"] - message_class = MESSAGE_BY_TYPE.get(message_type) - if message_class is None: - raise UnknownMessageType(f"Message type {message_type} is unknown.") - return message_class(**json_message) - - -@functools.singledispatch -async def handle_message(request: Message, connection: "WebsocketConnection") -> None: - """ - Handle incoming messages. - - New message types should be registered by decorating with `@handle_message.register`. - """ - - raise NotImplementedError + return decorator class WebsocketConnection: @@ -174,11 +114,15 @@ async def send_async(self, message: Message) -> None: """ Send a message async. """ + try: - message_data = message.dict() - await self.websocket.send_text( - orjson.dumps(message_data, option=orjson.OPT_SERIALIZE_NUMPY).decode() - ) + json_text = orjson.dumps( + message.dict(), option=orjson.OPT_SERIALIZE_NUMPY + ).decode() + except TypeError as e: + raise SerializationError(str(e)) + try: + await self.websocket.send_text(json_text) except WebSocketDisconnect: self._on_disconnect() except RuntimeError: @@ -204,12 +148,17 @@ async def listen(self) -> None: try: while True: try: - message = parse_message(await self.websocket.receive_text()) + raw_message = await self.websocket.receive_text() + message = Message(**orjson.loads(raw_message)) except UnknownMessageType as e: logger.warning(str(e)) else: logger.info(f"WS message with type {message.type} received.") - asyncio.create_task(handle_message(message, self)) + if handler_spec := MESSAGE_HANDLERS.get(message.type): + payload_type, handler = handler_spec + asyncio.create_task(handler(payload_type(**message.data), self)) + else: + logger.error(f"Unknown message received: {message.type}") except WebSocketDisconnect: self._on_disconnect() @@ -300,73 +249,80 @@ def on_disconnect(self, connection: WebsocketConnection) -> None: callback(len(self.connections)) -@handle_message.register -async def _(request: UMapRequest, connection: "WebsocketConnection") -> None: - data_store: Optional[DataStore] = connection.websocket.app.data_store - if data_store is None: - return None - try: - data_store.check_generation_id(request.generation_id) - except GenerationIDMismatch: - return - - try: - points, valid_indices = await connection.task_manager.run_async( - compute_umap, - ( - data_store, - request.data.columns, - request.data.indices, - request.data.n_neighbors, - request.data.metric, - request.data.min_dist, - ), - name=request.widget_id, - tag=id(connection), - ) - except TaskCancelled: - ... - else: - response = ReductionResponse( - type="umap_result", - widget_id=request.widget_id, - uid=request.uid, - generation_id=request.generation_id, - data=ReductionResponseData(indices=valid_indices, points=points), - ) - await connection.send_async(response) +TASK_FUNCS = {"umap": compute_umap, "pca": compute_pca} -@handle_message.register -async def _(request: PCARequest, connection: "WebsocketConnection") -> None: +@message_handler("task", TaskData) +async def _(data: TaskData, connection: WebsocketConnection) -> None: data_store: Optional[DataStore] = connection.websocket.app.data_store if data_store is None: return None - try: - data_store.check_generation_id(request.generation_id) - except GenerationIDMismatch: - return + + if data.generation_id: + try: + data_store.check_generation_id(data.generation_id) + except GenerationIDMismatch: + return try: - points, valid_indices = await connection.task_manager.run_async( - compute_pca, - ( - data_store, - request.data.columns, - request.data.indices, - request.data.normalization, - ), - name=request.widget_id, + task_func = TASK_FUNCS[data.task] + result = await connection.task_manager.run_async( + task_func, # type: ignore + args=(data_store,), + kwargs=data.args, + name=data.widget_id, tag=id(connection), ) + points = cast(np.ndarray, result[0]) + valid_indices = cast(np.ndarray, result[1]) except TaskCancelled: - ... + pass + except Problem as e: + msg = Message( + type="task.error", + data={ + "task_id": data.task_id, + "error": { + "type": type(e).__name__, + "title": e.title, + "detail": e.detail, + }, + }, + ) + await connection.send_async(msg) + except Exception as e: + msg = Message( + type="task.error", + data={ + "task_id": data.task_id, + "error": { + "type": type(e).__name__, + "title": type(e).__name__, + "detail": type(e).__doc__, + }, + }, + ) + await connection.send_async(msg) else: - response = ReductionResponse( - type="pca_result", - widget_id=request.widget_id, - uid=request.uid, - generation_id=request.generation_id, - data=ReductionResponseData(indices=valid_indices, points=points), + msg = Message( + type="task.result", + data={ + "task_id": data.task_id, + "result": {"points": points, "indices": valid_indices}, + }, ) - await connection.send_async(response) + try: + await connection.send_async(msg) + except SerializationError as e: + error_msg = Message( + type="task.error", + data={ + "task_id": data.task_id, + "error": { + "type": type(e).__name__, + "title": "Serialization Error", + "detail": str(e), + }, + }, + ) + await connection.send_async(error_msg) diff --git a/renumics/spotlight/data_source/data_source.py b/renumics/spotlight/data_source/data_source.py index d06133f3..695acb5d 100644 --- a/renumics/spotlight/data_source/data_source.py +++ b/renumics/spotlight/data_source/data_source.py @@ -6,7 +6,6 @@ import pandas as pd import numpy as np -from pydantic.dataclasses import dataclass from renumics.spotlight.dataset.exceptions import ( ColumnExistsError, @@ -30,17 +29,6 @@ class ColumnMetadata: tags: List[str] = dataclasses.field(default_factory=list) -@dataclass -class CellsUpdate: - """ - A dataset's cell update. - """ - - value: Any - author: str - edited_at: str - - class DataSource(ABC): """abstract base class for different data sources""" @@ -61,7 +49,7 @@ def column_names(self) -> List[str]: @abstractmethod def intermediate_dtypes(self) -> DTypeMap: """ - The dtypes of intermediate values + The dtypes of intermediate values. Values for all columns must be filled. """ @property @@ -94,7 +82,7 @@ def check_generation_id(self, generation_id: int) -> None: @abstractmethod def semantic_dtypes(self) -> DTypeMap: """ - Semantic dtypes for viewer. + Semantic dtypes for viewer. Some values may be not present. """ @abstractmethod diff --git a/renumics/spotlight/data_store.py b/renumics/spotlight/data_store.py index 2eb0108b..ba940254 100644 --- a/renumics/spotlight/data_store.py +++ b/renumics/spotlight/data_store.py @@ -21,13 +21,16 @@ DType, DTypeMap, EmbeddingDType, + array_dtype, is_array_dtype, is_audio_dtype, is_category_dtype, + is_embedding_dtype, is_file_dtype, is_str_dtype, is_mixed_dtype, is_bytes_dtype, + is_window_dtype, str_dtype, audio_dtype, image_dtype, @@ -173,33 +176,32 @@ def _guess_dtype(self, col: str) -> DType: return semantic_dtype sample_values = self._data_source.get_column_values(col, slice(10)) - sample_dtypes = [_guess_value_dtype(value) for value in sample_values] - - try: - mode_dtype = statistics.mode(sample_dtypes) - except statistics.StatisticsError: + sample_dtypes: List[DType] = [] + for value in sample_values: + guessed_dtype = _guess_value_dtype(value) + if guessed_dtype is not None: + sample_dtypes.append(guessed_dtype) + if not sample_dtypes: return semantic_dtype - return mode_dtype or semantic_dtype + mode_dtype = statistics.mode(sample_dtypes) + # For windows and embeddings, at least sample values must be aligned. + if is_window_dtype(mode_dtype) and any( + not is_window_dtype(dtype) for dtype in sample_dtypes + ): + return array_dtype + if is_embedding_dtype(mode_dtype) and any( + (not is_embedding_dtype(dtype)) or dtype.length != mode_dtype.length + for dtype in sample_dtypes + ): + return array_dtype + + return mode_dtype def _intermediate_to_semantic_dtype(intermediate_dtype: DType) -> DType: if is_array_dtype(intermediate_dtype): - if intermediate_dtype.shape is None: - return intermediate_dtype - if intermediate_dtype.shape == (2,): - return window_dtype - if intermediate_dtype.ndim == 1 and intermediate_dtype.shape[0] is not None: - return EmbeddingDType(intermediate_dtype.shape[0]) - if intermediate_dtype.ndim == 1 and intermediate_dtype.shape[0] is None: - return sequence_1d_dtype - if intermediate_dtype.ndim == 2 and ( - intermediate_dtype.shape[0] == 2 or intermediate_dtype.shape[1] == 2 - ): - return sequence_1d_dtype - if intermediate_dtype.ndim == 3 and intermediate_dtype.shape[-1] in (1, 3, 4): - return image_dtype - return intermediate_dtype + return _guess_array_dtype(intermediate_dtype) if is_file_dtype(intermediate_dtype): return str_dtype if is_mixed_dtype(intermediate_dtype): @@ -262,5 +264,21 @@ def _guess_value_dtype(value: Any) -> Optional[DType]: except (TypeError, ValueError): pass else: - return ArrayDType(value.shape) + return _guess_array_dtype(ArrayDType(value.shape)) return None + + +def _guess_array_dtype(dtype: ArrayDType) -> DType: + if dtype.shape is None: + return dtype + if dtype.shape == (2,): + return window_dtype + if dtype.ndim == 1 and dtype.shape[0] is not None: + return EmbeddingDType(dtype.shape[0]) + if dtype.ndim == 1 and dtype.shape[0] is None: + return sequence_1d_dtype + if dtype.ndim == 2 and (dtype.shape[0] == 2 or dtype.shape[1] == 2): + return sequence_1d_dtype + if dtype.ndim == 3 and dtype.shape[-1] in (1, 3, 4): + return image_dtype + return dtype diff --git a/renumics/spotlight/dataset/__init__.py b/renumics/spotlight/dataset/__init__.py index 99828ddf..e586499d 100644 --- a/renumics/spotlight/dataset/__init__.py +++ b/renumics/spotlight/dataset/__init__.py @@ -32,12 +32,7 @@ from typing_extensions import TypeGuard from renumics.spotlight.__version__ import __version__ -from renumics.spotlight.io.pandas import ( - infer_dtypes, - prepare_column, - is_string_mask, - stringify_columns, -) +from .pandas import create_typed_series, infer_dtypes, is_string_mask, prepare_column from renumics.spotlight.typing import ( BoolType, IndexType, @@ -47,7 +42,6 @@ is_integer, is_iterable, ) -from renumics.spotlight.io.pandas import create_typed_series from renumics.spotlight.dtypes.conversion import prepare_path_or_url from renumics.spotlight import dtypes as spotlight_dtypes @@ -738,7 +732,7 @@ def from_pandas( df = df.reset_index(level=df.index.names) # type: ignore else: df = df.copy() - df.columns = pd.Index(stringify_columns(df)) + df.columns = pd.Index([str(column) for column in df.columns]) if dtypes is None: dtypes = {} diff --git a/renumics/spotlight/dataset/descriptors/__init__.py b/renumics/spotlight/dataset/descriptors/__init__.py index bb5cf837..3f52372d 100644 --- a/renumics/spotlight/dataset/descriptors/__init__.py +++ b/renumics/spotlight/dataset/descriptors/__init__.py @@ -1,5 +1,6 @@ """make descriptor methods more available """ +import warnings from typing import Optional, Tuple import numpy as np @@ -11,6 +12,13 @@ from renumics.spotlight.dataset.exceptions import ColumnExistsError, InvalidDTypeError from .data_alignment import align_column_data +warnings.warn( + "`renumics.spotlight.dataset.descriptors` module is deprecated and will " + "be removed in future versions.", + DeprecationWarning, + stacklevel=2, +) + def pca( dataset: Dataset, diff --git a/renumics/spotlight/io/pandas.py b/renumics/spotlight/dataset/pandas.py similarity index 86% rename from renumics/spotlight/io/pandas.py rename to renumics/spotlight/dataset/pandas.py index 4cf84f9e..75ccf00f 100644 --- a/renumics/spotlight/io/pandas.py +++ b/renumics/spotlight/dataset/pandas.py @@ -1,30 +1,22 @@ """ -This module contains helpers for importing `pandas.DataFrame`s. +Helper for conversion between H5 dataset and `pandas.DataFrame`. """ -import ast import os.path import statistics -from contextlib import suppress -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Union import PIL.Image import filetype -import trimesh import numpy as np import pandas as pd +import trimesh -from renumics.spotlight.dtypes import ( - Audio, - Embedding, - Image, - Mesh, - Sequence1D, - Video, -) -from renumics.spotlight.media.exceptions import UnsupportedDType -from renumics.spotlight.typing import is_iterable, is_pathtype from renumics.spotlight import dtypes +from renumics.spotlight.io import prepare_hugging_face_dict, try_literal_eval +from renumics.spotlight.media import Audio, Embedding, Image, Mesh, Sequence1D, Video +from renumics.spotlight.typing import is_iterable, is_pathtype +from .exceptions import InvalidDTypeError def create_typed_series( @@ -58,32 +50,62 @@ def create_typed_series( return pd.Series([] if values is None else values, dtype=pandas_dtype) -def is_empty(value: Any) -> bool: - """ - Check if value is `NA` or an empty string. +def prepare_column(column: pd.Series, dtype: dtypes.DType) -> pd.Series: """ - if is_iterable(value): - # `pd.isna` with an iterable argument returns an iterable result. But - # an iterable cannot be NA or empty string by default. - return False - return pd.isna(value) or value == "" + Convert a `pandas` column to the desired `dtype` and prepare some values, + but still as `pandas` column. + + Args: + column: A `pandas` column to prepare. + dtype: Target data type. + Returns: + Prepared `pandas` column. -def try_literal_eval(x: str) -> Any: - """ - Try to evaluate a literal expression, otherwise return value as is. + Raises: + TypeError: If `dtype` is not a Spotlight data type. """ - with suppress(Exception): - return ast.literal_eval(x) - return x + column = column.copy() + if dtypes.is_category_dtype(dtype): + # We only support string/`NA` categories, but `pandas` can more, so + # force categories to be strings (does not affect `NA`s). + return to_categorical(column, str_categories=True) -def stringify_columns(df: pd.DataFrame) -> List[str]: - """ - Convert `pandas.DataFrame`'s column names to strings, no matter which index - is used. - """ - return [str(column_name) for column_name in df.columns] + if dtypes.is_datetime_dtype(dtype): + # `errors="coerce"` will produce `NaT`s instead of fail. + return pd.to_datetime(column, errors="coerce") + + if dtypes.is_str_dtype(dtype): + # Allow `NA`s, convert all other elements to strings. + return column.astype(str).mask(column.isna(), None) # type: ignore + + if dtypes.is_bool_dtype(dtype): + return column.astype(bool) + + if dtypes.is_int_dtype(dtype): + return column.astype(int) + + if dtypes.is_float_dtype(dtype): + return column.astype(float) + + # We explicitely don't want to change the original `DataFrame`. + with pd.option_context("mode.chained_assignment", None): + # We consider empty strings as `NA`s. + str_mask = is_string_mask(column) + column[str_mask] = column[str_mask].replace("", None) + na_mask = column.isna() + + # When `pandas` reads a csv, arrays and lists are read as literal strings, + # try to interpret them. + str_mask = is_string_mask(column) + column[str_mask] = column[str_mask].apply(try_literal_eval) + + if dtypes.is_filebased_dtype(dtype): + dict_mask = column.map(type) == dict + column[dict_mask] = column[dict_mask].apply(prepare_hugging_face_dict) + + return column.mask(na_mask, None) # type: ignore def infer_dtype(column: pd.Series) -> dtypes.DType: @@ -225,7 +247,7 @@ def infer_dtypes(df: pd.DataFrame, dtype: Optional[dtypes.DTypeMap]) -> dtypes.D if column_index not in inferred_dtype: try: column_type = infer_dtype(df[column_index]) - except UnsupportedDType: + except InvalidDTypeError: column_type = dtypes.str_dtype inferred_dtype[str(column_index)] = column_type return inferred_dtype @@ -255,73 +277,3 @@ def to_categorical(column: pd.Series, str_categories: bool = False) -> pd.Series if str_categories: return column.cat.rename_categories(column.cat.categories.astype(str)) return column - - -def prepare_hugging_face_dict(x: Dict) -> Any: - """ - Prepare HuggingFace format for files to be used in Spotlight. - """ - if x.keys() != {"bytes", "path"}: - return x - blob = x["bytes"] - if blob is not None: - return blob - return x["path"] - - -def prepare_column(column: pd.Series, dtype: dtypes.DType) -> pd.Series: - """ - Convert a `pandas` column to the desired `dtype` and prepare some values, - but still as `pandas` column. - - Args: - column: A `pandas` column to prepare. - dtype: Target data type. - - Returns: - Prepared `pandas` column. - - Raises: - TypeError: If `dtype` is not a Spotlight data type. - """ - column = column.copy() - - if dtypes.is_category_dtype(dtype): - # We only support string/`NA` categories, but `pandas` can more, so - # force categories to be strings (does not affect `NA`s). - return to_categorical(column, str_categories=True) - - if dtypes.is_datetime_dtype(dtype): - # `errors="coerce"` will produce `NaT`s instead of fail. - return pd.to_datetime(column, errors="coerce") - - if dtypes.is_str_dtype(dtype): - # Allow `NA`s, convert all other elements to strings. - return column.astype(str).mask(column.isna(), None) # type: ignore - - if dtypes.is_bool_dtype(dtype): - return column.astype(bool) - - if dtypes.is_int_dtype(dtype): - return column.astype(int) - - if dtypes.is_float_dtype(dtype): - return column.astype(float) - - # We explicitely don't want to change the original `DataFrame`. - with pd.option_context("mode.chained_assignment", None): - # We consider empty strings as `NA`s. - str_mask = is_string_mask(column) - column[str_mask] = column[str_mask].replace("", None) - na_mask = column.isna() - - # When `pandas` reads a csv, arrays and lists are read as literal strings, - # try to interpret them. - str_mask = is_string_mask(column) - column[str_mask] = column[str_mask].apply(try_literal_eval) - - if dtypes.is_filebased_dtype(dtype): - dict_mask = column.map(type) == dict - column[dict_mask] = column[dict_mask].apply(prepare_hugging_face_dict) - - return column.mask(na_mask, None) # type: ignore diff --git a/renumics/spotlight/dtypes/__init__.py b/renumics/spotlight/dtypes/__init__.py index 0e24ea10..63910215 100644 --- a/renumics/spotlight/dtypes/__init__.py +++ b/renumics/spotlight/dtypes/__init__.py @@ -9,6 +9,8 @@ __all__ = [ "CategoryDType", + "ArrayDType", + "EmbeddingDType", "Sequence1DDType", "bool_dtype", "int_dtype", @@ -36,6 +38,14 @@ def __init__(self, name: str): def __str__(self) -> str: return self.name + def __eq__(self, other: Any) -> bool: + if isinstance(other, DType): + return other._name == self._name + return False + + def __hash__(self) -> int: + return hash(self._name) + @property def name(self) -> str: return self._name @@ -53,8 +63,10 @@ def __init__( self, categories: Optional[Union[Iterable[str], Dict[str, int]]] = None ): super().__init__("Category") - if isinstance(categories, dict) or categories is None: - self._categories = categories + if isinstance(categories, dict): + self._categories = dict(sorted(categories.items(), key=lambda x: x[1])) + elif categories is None: + self._categories = None else: self._categories = { category: code for code, category in enumerate(categories) @@ -71,6 +83,20 @@ def __init__( category: code for code, category in self._inverted_categories.items() } + def __eq__(self, other: Any) -> bool: + if isinstance(other, CategoryDType): + return other._categories == self._categories + return False + + def __hash__(self) -> int: + if self._categories is None: + return hash(self._name) ^ hash(None) + return ( + hash(self._name) + ^ hash(tuple(self._categories.keys())) + ^ hash(tuple(self._categories.values())) + ) + @property def categories(self) -> Optional[Dict[str, int]]: return self._categories @@ -91,6 +117,14 @@ def __init__(self, shape: Optional[Tuple[Optional[int], ...]] = None): super().__init__("array") self.shape = shape + def __eq__(self, other: Any) -> bool: + if isinstance(other, ArrayDType): + return other.shape == self.shape + return False + + def __hash__(self) -> int: + return hash(self._name) ^ hash(self.shape) + @property def ndim(self) -> int: if self.shape is None: @@ -111,6 +145,14 @@ def __init__(self, length: Optional[int] = None): raise ValueError(f"Length must be non-negative, but {length} received.") self.length = length + def __eq__(self, other: Any) -> bool: + if isinstance(other, EmbeddingDType): + return other.length == self.length + return False + + def __hash__(self) -> int: + return hash(self._name) ^ hash(self.length) + class Sequence1DDType(DType): """ @@ -125,6 +167,14 @@ def __init__(self, x_label: str = "x", y_label: str = "y"): self.x_label = x_label self.y_label = y_label + def __eq__(self, other: Any) -> bool: + if isinstance(other, Sequence1DDType): + return other.x_label == self.x_label and other.y_label == self.y_label + return False + + def __hash__(self) -> int: + return hash(self._name) ^ hash(self.x_label) ^ hash(self.y_label) + ALIASES: Dict[Any, DType] = {} diff --git a/renumics/spotlight/dtypes/conversion.py b/renumics/spotlight/dtypes/conversion.py index 48c90dde..c5b7cc12 100644 --- a/renumics/spotlight/dtypes/conversion.py +++ b/renumics/spotlight/dtypes/conversion.py @@ -37,23 +37,14 @@ import trimesh import PIL.Image import validators -from renumics.spotlight import dtypes +from renumics.spotlight import dtypes, media from renumics.spotlight.typing import PathOrUrlType, PathType from renumics.spotlight.cache import external_data_cache from renumics.spotlight.io import audio from renumics.spotlight.io.file import as_file from renumics.spotlight.media.exceptions import InvalidFile from renumics.spotlight.backend.exceptions import Problem -from renumics.spotlight.dtypes import ( - CategoryDType, - DType, - audio_dtype, - image_dtype, - mesh_dtype, - video_dtype, -) -from renumics.spotlight.media import Sequence1D, Image, Audio, Video, Mesh NormalizedValue = Union[ @@ -96,7 +87,7 @@ class ConversionFailed(Problem): def __init__( self, value: NormalizedValue, - dtype: DType, + dtype: dtypes.DType, reason: Optional[str] = None, ) -> None: super().__init__( @@ -116,14 +107,14 @@ class NoConverterAvailable(Problem): No matching converter could be applied """ - def __init__(self, value: NormalizedValue, dtype: DType) -> None: + def __init__(self, value: NormalizedValue, dtype: dtypes.DType) -> None: msg = f"No Converter for {type(value)} -> {dtype}" super().__init__(title="No matching converter", detail=msg, status_code=422) N = TypeVar("N", bound=NormalizedValue) -Converter = Callable[[N, DType], ConvertedValue] +Converter = Callable[[N, dtypes.DType], ConvertedValue] _converters_table: Dict[ Type[NormalizedValue], Dict[str, List[Converter]] ] = defaultdict(lambda: defaultdict(list)) @@ -176,7 +167,10 @@ def _decorate(func: Converter[N]) -> Converter[N]: def convert_to_dtype( - value: NormalizedValue, dtype: DType, simple: bool = False, check: bool = True + value: NormalizedValue, + dtype: dtypes.DType, + simple: bool = False, + check: bool = True, ) -> ConvertedValue: """ Convert normalized type from data source to internal Spotlight DType @@ -221,37 +215,41 @@ def convert_to_dtype( except (TypeError, ValueError) as e: if check: raise ConversionFailed(value, dtype) from e - else: - return None - - if check: - if last_conversion_error: - raise ConversionFailed(value, dtype, last_conversion_error.reason) - else: + else: + if check: + if last_conversion_error: + raise ConversionFailed(value, dtype, last_conversion_error.reason) raise NoConverterAvailable(value, dtype) + if simple and ( + dtypes.is_array_dtype(dtype) + or dtypes.is_embedding_dtype(dtype) + or dtypes.is_sequence_1d_dtype(dtype) + or dtypes.is_filebased_dtype(dtype) + ): + return "" return None @convert("datetime") -def _(value: datetime.datetime, _: DType) -> datetime.datetime: +def _(value: datetime.datetime, _: dtypes.DType) -> datetime.datetime: return value @convert("datetime") -def _(value: Union[str, np.str_], _: DType) -> Optional[datetime.datetime]: +def _(value: Union[str, np.str_], _: dtypes.DType) -> Optional[datetime.datetime]: if value == "": return None return datetime.datetime.fromisoformat(value) @convert("datetime") -def _(value: np.datetime64, _: DType) -> Optional[datetime.datetime]: +def _(value: np.datetime64, _: dtypes.DType) -> Optional[datetime.datetime]: return value.tolist() @convert("Category") -def _(value: Union[str, np.str_], dtype: CategoryDType) -> int: +def _(value: Union[str, np.str_], dtype: dtypes.CategoryDType) -> int: categories = dtype.categories if not categories: return -1 @@ -259,7 +257,7 @@ def _(value: Union[str, np.str_], dtype: CategoryDType) -> int: @convert("Category") -def _(_: None, _dtype: CategoryDType) -> int: +def _(_: None, _dtype: dtypes.CategoryDType) -> int: return -1 @@ -276,23 +274,23 @@ def _( np.uint32, np.uint64, ], - _: CategoryDType, + _: dtypes.CategoryDType, ) -> int: return int(value) @convert("Window") -def _(value: list, _: DType) -> np.ndarray: +def _(value: list, _: dtypes.DType) -> np.ndarray: return np.array(value, dtype=np.float64) @convert("Window") -def _(value: np.ndarray, _: DType) -> np.ndarray: +def _(value: np.ndarray, _: dtypes.DType) -> np.ndarray: return value.astype(np.float64) @convert("Window") -def _(value: Union[str, np.str_], _: DType) -> np.ndarray: +def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray: try: obj = ast.literal_eval(value) array = np.array(obj, dtype=np.float64) @@ -304,17 +302,17 @@ def _(value: Union[str, np.str_], _: DType) -> np.ndarray: @convert("Embedding", simple=False) -def _(value: list, _: DType) -> np.ndarray: +def _(value: list, _: dtypes.DType) -> np.ndarray: return np.array(value, dtype=np.float64) @convert("Embedding", simple=False) -def _(value: np.ndarray, _: DType) -> np.ndarray: +def _(value: np.ndarray, _: dtypes.DType) -> np.ndarray: return value.astype(np.float64) @convert("Embedding", simple=False) -def _(value: Union[str, np.str_], _: DType) -> np.ndarray: +def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray: try: obj = ast.literal_eval(value) array = np.array(obj, dtype=np.float64) @@ -326,50 +324,54 @@ def _(value: Union[str, np.str_], _: DType) -> np.ndarray: @convert("Sequence1D", simple=False) -def _(value: Union[np.ndarray, list], _: DType) -> np.ndarray: - return Sequence1D(value).encode() +def _(value: Union[np.ndarray, list], _: dtypes.DType) -> np.ndarray: + return media.Sequence1D(value).encode() @convert("Sequence1D", simple=False) -def _(value: Union[str, np.str_], _: DType) -> np.ndarray: +def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray: try: obj = ast.literal_eval(value) - return Sequence1D(obj).encode() + return media.Sequence1D(obj).encode() except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): raise ConversionError("Cannot interpret string as a 1D sequence") @convert("Image", simple=False) -def _(value: Union[str, np.str_], _: DType) -> bytes: +def _(value: Union[np.ndarray, list], _: dtypes.DType) -> bytes: + return media.Image(value).encode().tolist() + + +@convert("Image", simple=False) +def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes: try: - if data := read_external_value(value, image_dtype): + if data := read_external_value(value, dtypes.image_dtype): return data.tolist() except InvalidFile: - raise ConversionError() + try: + obj = ast.literal_eval(value) + return media.Image(obj).encode().tolist() + except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): + raise ConversionError() raise ConversionError() @convert("Image", simple=False) -def _(value: Union[bytes, np.bytes_], _: DType) -> bytes: - return Image.from_bytes(value).encode().tolist() - - -@convert("Image", simple=False) -def _(value: np.ndarray, _: DType) -> bytes: - return Image(value).encode().tolist() +def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes: + return media.Image.from_bytes(value).encode().tolist() @convert("Image", simple=False) -def _(value: PIL.Image.Image, _: DType) -> bytes: +def _(value: PIL.Image.Image, _: dtypes.DType) -> bytes: buffer = io.BytesIO() value.save(buffer, format="PNG") return buffer.getvalue() @convert("Audio", simple=False) -def _(value: Union[str, np.str_], _: DType) -> bytes: +def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes: try: - if data := read_external_value(value, audio_dtype): + if data := read_external_value(value, dtypes.audio_dtype): return data.tolist() except (InvalidFile, IndexError, ValueError): raise ConversionError() @@ -377,14 +379,14 @@ def _(value: Union[str, np.str_], _: DType) -> bytes: @convert("Audio", simple=False) -def _(value: Union[bytes, np.bytes_], _: DType) -> bytes: - return Audio.from_bytes(value).encode().tolist() +def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes: + return media.Audio.from_bytes(value).encode().tolist() @convert("Video", simple=False) -def _(value: Union[str, np.str_], _: DType) -> bytes: +def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes: try: - if data := read_external_value(value, video_dtype): + if data := read_external_value(value, dtypes.video_dtype): return data.tolist() except InvalidFile: raise ConversionError() @@ -392,14 +394,14 @@ def _(value: Union[str, np.str_], _: DType) -> bytes: @convert("Video", simple=False) -def _(value: Union[bytes, np.bytes_], _: DType) -> bytes: - return Video.from_bytes(value).encode().tolist() +def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes: + return media.Video.from_bytes(value).encode().tolist() @convert("Mesh", simple=False) -def _(value: Union[str, np.str_], _: DType) -> bytes: +def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes: try: - if data := read_external_value(value, mesh_dtype): + if data := read_external_value(value, dtypes.mesh_dtype): return data.tolist() except InvalidFile: raise ConversionError() @@ -407,29 +409,29 @@ def _(value: Union[str, np.str_], _: DType) -> bytes: @convert("Mesh", simple=False) -def _(value: Union[bytes, np.bytes_], _: DType) -> bytes: +def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes: return value # this should not be necessary @convert("Mesh", simple=False) # type: ignore -def _(value: trimesh.Trimesh, _: DType) -> bytes: - return Mesh.from_trimesh(value).encode().tolist() +def _(value: trimesh.Trimesh, _: dtypes.DType) -> bytes: + return media.Mesh.from_trimesh(value).encode().tolist() @convert("Embedding", simple=True) @convert("Sequence1D", simple=True) -def _(_: Union[np.ndarray, list, str, np.str_], _dtype: DType) -> str: +def _(_: Union[np.ndarray, list, str, np.str_], _dtype: dtypes.DType) -> str: return "[...]" @convert("Image", simple=True) -def _(_: np.ndarray, _dtype: DType) -> str: +def _(_: Union[np.ndarray, list], _dtype: dtypes.DType) -> str: return "[...]" @convert("Image", simple=True) -def _(_: PIL.Image.Image, _dtype: DType) -> str: +def _(_: PIL.Image.Image, _dtype: dtypes.DType) -> str: return "" @@ -437,7 +439,7 @@ def _(_: PIL.Image.Image, _dtype: DType) -> str: @convert("Audio", simple=True) @convert("Video", simple=True) @convert("Mesh", simple=True) -def _(value: Union[str, np.str_], _: DType) -> str: +def _(value: Union[str, np.str_], _: dtypes.DType) -> str: return str(value) @@ -445,19 +447,19 @@ def _(value: Union[str, np.str_], _: DType) -> str: @convert("Audio", simple=True) @convert("Video", simple=True) @convert("Mesh", simple=True) -def _(_: Union[bytes, np.bytes_], _dtype: DType) -> str: +def _(_: Union[bytes, np.bytes_], _dtype: dtypes.DType) -> str: return "" # this should not be necessary @convert("Mesh", simple=True) # type: ignore -def _(_: trimesh.Trimesh, _dtype: DType) -> str: +def _(_: trimesh.Trimesh, _dtype: dtypes.DType) -> str: return "" def read_external_value( path_or_url: Optional[str], - dtype: DType, + dtype: dtypes.DType, target_format: Optional[str] = None, workdir: PathType = ".", ) -> Optional[np.void]: @@ -494,7 +496,7 @@ def prepare_path_or_url(path_or_url: PathOrUrlType, workdir: PathType) -> str: def _decode_external_value( path_or_url: PathOrUrlType, - dtype: DType, + dtype: dtypes.DType, target_format: Optional[str] = None, workdir: PathType = ".", ) -> np.void: @@ -524,7 +526,7 @@ def _decode_external_value( # Convert all other formats/codecs to flac. output_format, output_codec = "flac", "flac" else: - output_format, output_codec = Audio.get_format_codec(target_format) + output_format, output_codec = media.Audio.get_format_codec(target_format) if output_format == input_format and output_codec == input_codec: # Nothing to transcode if isinstance(file, str): @@ -550,10 +552,10 @@ def _decode_external_value( ): return np.void(file.read()) # `image/tiff`s become blank in frontend, so convert them too. - return Image.from_file(file).encode(target_format) + return media.Image.from_file(file).encode(target_format) if dtypes.is_mesh_dtype(dtype): - return Mesh.from_file(path_or_url).encode(target_format) + return media.Mesh.from_file(path_or_url).encode(target_format) if dtypes.is_video_dtype(dtype): - return Video.from_file(path_or_url).encode(target_format) + return media.Video.from_file(path_or_url).encode(target_format) assert False diff --git a/renumics/spotlight/io/__init__.py b/renumics/spotlight/io/__init__.py index 2d2a6d26..8162a843 100644 --- a/renumics/spotlight/io/__init__.py +++ b/renumics/spotlight/io/__init__.py @@ -1,6 +1,9 @@ """ Reading and writing of different data formats. """ +import ast +from contextlib import suppress +from typing import Any from .audio import ( get_format_codec, @@ -19,6 +22,8 @@ decode_gltf_arrays, encode_gltf_array, ) +from .huggingface import prepare_hugging_face_dict + __all__ = [ "get_format_codec", @@ -34,4 +39,15 @@ "check_gltf", "decode_gltf_arrays", "encode_gltf_array", + "prepare_hugging_face_dict", + "try_literal_eval", ] + + +def try_literal_eval(x: str) -> Any: + """ + Try to evaluate a literal expression, otherwise return value as is. + """ + with suppress(Exception): + return ast.literal_eval(x) + return x diff --git a/renumics/spotlight/io/huggingface.py b/renumics/spotlight/io/huggingface.py new file mode 100644 index 00000000..06c0d441 --- /dev/null +++ b/renumics/spotlight/io/huggingface.py @@ -0,0 +1,16 @@ +""" +Helpers for HuggingFace formats. +""" +from typing import Any, Dict + + +def prepare_hugging_face_dict(x: Dict) -> Any: + """ + Prepare HuggingFace format for files to be used in Spotlight. + """ + if x.keys() != {"bytes", "path"}: + return x + blob = x["bytes"] + if blob is not None: + return blob + return x["path"] diff --git a/renumics/spotlight/server.py b/renumics/spotlight/server.py index 83d688db..bf56918d 100644 --- a/renumics/spotlight/server.py +++ b/renumics/spotlight/server.py @@ -41,6 +41,8 @@ class Server: process: Optional[subprocess.Popen] _startup_event: threading.Event + _update_complete_event: threading.Event + _update_error: Optional[Exception] connection: Optional[multiprocessing.connection.Connection] _connection_message_queue: Queue @@ -77,7 +79,7 @@ def __init__(self, host: str = "127.0.0.1", port: int = 8000) -> None: ) self._startup_event = threading.Event() - self._startup_complete_event = threading.Event() + self._update_complete_event = threading.Event() self._connection_thread_online = threading.Event() self._connection_thread = threading.Thread( @@ -151,7 +153,6 @@ def start(self, config: AppConfig) -> None: command.append("--reload") # start uvicorn - self.process = subprocess.Popen( command, env=env, @@ -164,7 +165,7 @@ def start(self, config: AppConfig) -> None: ) if platform.system() != "Windows": sock.close() - self._startup_complete_event.wait(timeout=120) + self._wait_for_update() def stop(self) -> None: """ @@ -197,7 +198,6 @@ def stop(self) -> None: self._port = self._requested_port self._startup_event.clear() - self._startup_complete_event.clear() @property def running(self) -> bool: @@ -217,9 +217,21 @@ def update(self, config: AppConfig) -> None: """ Update app config """ + self._update(config) + self._wait_for_update() + + def _update(self, config: AppConfig) -> None: self._app_config = config self.send({"kind": "update", "data": config}) + def _wait_for_update(self) -> None: + self._update_complete_event.wait(timeout=120) + self._update_complete_event.clear() + err = self._update_error + self._update_error = None + if err: + raise err + def get_df(self) -> Optional[pd.DataFrame]: """ Request and return the current DafaFrame from the server process (if possible) @@ -236,9 +248,10 @@ def _handle_message(self, message: Any) -> None: if kind == "startup": self._startup_event.set() - self.update(self._app_config) - elif kind == "startup_complete": - self._startup_complete_event.set() + self._update(self._app_config) + elif kind == "update_complete": + self._update_error = message.get("error") + self._update_complete_event.set() elif kind == "frontend_connected": self.connected_frontends = message["data"] self._all_frontends_disconnected.clear() diff --git a/renumics/spotlight/viewer.py b/renumics/spotlight/viewer.py index b7630b17..817f31b2 100644 --- a/renumics/spotlight/viewer.py +++ b/renumics/spotlight/viewer.py @@ -176,7 +176,11 @@ def show( if self not in _VIEWERS: _VIEWERS.append(self) else: - self._server.update(config) + try: + self._server.update(config) + except Exception as e: + self.close() + raise e if not no_browser and self._server.connected_frontends == 0: self.open_browser() @@ -220,7 +224,7 @@ def open_browser(self) -> None: """ if not self.port: return - launch_browser_in_thread(self.host, self.port) + launch_browser_in_thread("localhost", self.port) def refresh(self) -> None: """ diff --git a/renumics/spotlight_plugins/core/api/table.py b/renumics/spotlight_plugins/core/api/table.py index 1446bb74..807fdc40 100644 --- a/renumics/spotlight_plugins/core/api/table.py +++ b/renumics/spotlight_plugins/core/api/table.py @@ -73,7 +73,7 @@ def get_table(request: Request) -> ORJSONResponse: columns = [] for column_name in data_store.column_names: dtype = data_store.dtypes[column_name] - values = data_store.get_converted_values(column_name, simple=True) + values = data_store.get_converted_values(column_name, simple=True, check=False) meta = data_store.get_column_metadata(column_name) column = Column( name=column_name, diff --git a/renumics/spotlight_plugins/core/pandas_data_source.py b/renumics/spotlight_plugins/core/pandas_data_source.py index 430a69ca..4404a14d 100644 --- a/renumics/spotlight_plugins/core/pandas_data_source.py +++ b/renumics/spotlight_plugins/core/pandas_data_source.py @@ -7,14 +7,9 @@ import numpy as np import pandas as pd import datasets -from renumics.spotlight import dtypes -from renumics.spotlight.io.pandas import ( - infer_dtype, - prepare_hugging_face_dict, - stringify_columns, - try_literal_eval, -) +from renumics.spotlight import dtypes +from renumics.spotlight.io import prepare_hugging_face_dict, try_literal_eval from renumics.spotlight.data_source import ( datasource, ColumnMetadata, @@ -23,7 +18,6 @@ from renumics.spotlight.backend.exceptions import DatasetColumnsNotUnique from renumics.spotlight.dataset.exceptions import ColumnNotExistsError from renumics.spotlight.data_source.exceptions import InvalidDataSource -from renumics.spotlight.dtypes import DTypeMap @datasource(pd.DataFrame) @@ -41,6 +35,7 @@ class PandasDataSource(DataSource): _uid: str _df: pd.DataFrame _name: str + _intermediate_dtypes: dtypes.DTypeMap def __init__(self, source: Union[Path, pd.DataFrame]): if isinstance(source, Path): @@ -108,7 +103,7 @@ def __init__(self, source: Union[Path, pd.DataFrame]): @property def column_names(self) -> List[str]: - return stringify_columns(self._df) + return [str(column) for column in self._df.columns] @property def df(self) -> pd.DataFrame: @@ -118,18 +113,15 @@ def df(self) -> pd.DataFrame: return self._df.copy() @property - def intermediate_dtypes(self) -> DTypeMap: + def intermediate_dtypes(self) -> dtypes.DTypeMap: return self._intermediate_dtypes def __len__(self) -> int: return len(self._df) @property - def semantic_dtypes(self) -> DTypeMap: - return { - str(column_name): infer_dtype(self.df[column_name]) - for column_name in self.df - } + def semantic_dtypes(self) -> dtypes.DTypeMap: + return {} def get_generation_id(self) -> int: return self._generation_id @@ -167,12 +159,14 @@ def get_column_values( if pd.api.types.is_categorical_dtype(column): return column.cat.codes if pd.api.types.is_string_dtype(column): - values = column.to_numpy() - na_mask = column.isna() - values[na_mask] = None - return values + column = column.astype(object).mask(column.isna(), None) + str_mask = column.map(type) == str + column[str_mask] = column[str_mask].apply(try_literal_eval) + dict_mask = column.map(type) == dict + column[dict_mask] = column[dict_mask].apply(prepare_hugging_face_dict) + return column.to_numpy() if pd.api.types.is_object_dtype(column): - column = column.mask(column.isna(), None) + column = column.astype(object).mask(column.isna(), None) str_mask = column.map(type) == str column[str_mask] = column[str_mask].apply(try_literal_eval) dict_mask = column.map(type) == dict @@ -222,5 +216,4 @@ def _determine_intermediate_dtype(column: pd.Series) -> dtypes.DType: return dtypes.datetime_dtype if pd.api.types.is_string_dtype(column): return dtypes.str_dtype - else: - return dtypes.mixed_dtype + return dtypes.mixed_dtype diff --git a/src/lenses/SpectrogramLens/SpectrogramWorker.ts b/src/lenses/SpectrogramLens/SpectrogramWorker.ts index 55cad783..ff47fcf3 100644 --- a/src/lenses/SpectrogramLens/SpectrogramWorker.ts +++ b/src/lenses/SpectrogramLens/SpectrogramWorker.ts @@ -233,12 +233,7 @@ const calculateFrequencies = ( tables.cosTable ); - const array = new Uint8Array(fftSamples / 2); - let j; - for (j = 0; j < fftSamples / 2; j++) { - array[j] = Math.max(-255, Math.log10(spectrum[j]) * 45); - } - channelFreq.push(array); + channelFreq.push(spectrum); currentOffset += fftSamples - noverlap; } frequencies.push(channelFreq); diff --git a/src/lenses/SpectrogramLens/index.tsx b/src/lenses/SpectrogramLens/index.tsx index 60522df1..70e48ba1 100644 --- a/src/lenses/SpectrogramLens/index.tsx +++ b/src/lenses/SpectrogramLens/index.tsx @@ -10,7 +10,6 @@ import { ColorsState, useColors } from '../../stores/colors'; import { Lens } from '../../types'; import useSetting from '../useSetting'; import MenuBar from './MenuBar'; -import chroma from 'chroma-js'; import { fixWindow, freqType, unitType, amplitudeToDb } from './Spectrogram'; const Container = tw.div`flex flex-col w-full h-full items-stretch justify-center`; @@ -209,7 +208,8 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => { const widthScale = d3.scaleLinear([0, width], [0, frequenciesData.length]); let drawData = []; - let colorScale: chroma.Scale; + let min = 0; + let max = 0; if (ampScale === 'decibel') { let ref = 0; @@ -223,9 +223,6 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => { //const top_db = 80; const amin = 1e-5; - let log_spec_max = 0; - let log_spec_min = 0; - // Convert amplitudes to decibels for (let i = 0; i < frequenciesData.length; i++) { const col = []; @@ -234,23 +231,33 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => { const amplitude = frequenciesData[i][j]; col[j] = amplitudeToDb(amplitude, ref, amin); - if (col[j] > log_spec_max) { - log_spec_max = col[j]; + if (col[j] > max) { + max = col[j]; } - if (col[j] < log_spec_min) { - log_spec_min = col[j]; + if (col[j] < min) { + min = col[j]; } } drawData[i] = col; } - colorScale = colorPalette.scale().domain([log_spec_min, log_spec_max]); } else { // ampScale === 'linear' - colorScale = colorPalette.scale().domain([0, 256]); + for (let i = 0; i < frequenciesData.length; i++) { + const maxI = Math.max(...frequenciesData[i]); + const minI = Math.min(...frequenciesData[i]); + + if (maxI > max) { + max = maxI; + } + if (minI < min) { + min = minI; + } + } drawData = frequenciesData; } + const colorScale = colorPalette.scale().domain([min, max]); for (let y = 0; y < height; y++) { let value = 0; diff --git a/src/services/data.ts b/src/services/data.ts index 16aa82f5..069e01dc 100644 --- a/src/services/data.ts +++ b/src/services/data.ts @@ -1,7 +1,5 @@ -import { useDataset } from '../stores/dataset'; -import { useLayout } from '../stores/layout'; import { IndexArray } from '../types'; -import { v4 as uuidv4 } from 'uuid'; +import taskService from './task'; export const umapMetricNames = [ 'euclidean', @@ -10,11 +8,8 @@ export const umapMetricNames = [ 'cosine', 'mahalanobis', ] as const; - export type UmapMetric = typeof umapMetricNames[number]; - export const pcaNormalizations = ['none', 'standardize', 'robust standardize'] as const; - export type PCANormalization = typeof pcaNormalizations[number]; interface ReductionResult { @@ -22,84 +17,7 @@ interface ReductionResult { indices: IndexArray; } -const MAX_QUEUED_MESSAGES = 16; - -class Connection { - url: string; - socket?: WebSocket; - messageQueue: string[]; - onmessage?: (data: unknown) => void; - - constructor(host: string, port: string) { - this.messageQueue = []; - if (globalThis.location.protocol === 'https:') { - this.url = `wss://${host}:${port}/api/ws`; - } else { - this.url = `ws://${host}:${port}/api/ws`; - } - this.#connect(); - } - - send(message: unknown): void { - const data = JSON.stringify(message); - if (this.socket) { - this.socket.send(data); - } else { - if (this.messageQueue.length > MAX_QUEUED_MESSAGES) { - this.messageQueue.shift(); - } - this.messageQueue.push(data); - } - } - - #connect() { - const webSocket = new WebSocket(this.url); - - webSocket.onopen = () => { - this.socket = webSocket; - this.messageQueue.forEach((message) => this.socket?.send(message)); - this.messageQueue.length = 0; - }; - webSocket.onmessage = (event) => { - const message = JSON.parse(event.data); - this.onmessage?.(message); - }; - webSocket.onerror = () => { - webSocket.close(); - }; - webSocket.onclose = () => { - this.socket = undefined; - setTimeout(() => { - this.#connect(); - }, 500); - }; - } -} - export class DataService { - connection: Connection; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - dispatchTable: Map void>; - - constructor(host: string, port: string) { - this.connection = new Connection(host, port); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - this.connection.onmessage = (message: any) => { - if (message.uid) { - this.dispatchTable.get(message.uid)?.(message); - return; - } - if (message.type === 'refresh') { - useDataset.getState().refresh(); - } else if (message.type === 'resetLayout') { - useLayout.getState().reset(); - } else if (message.type === 'issuesUpdated') { - useDataset.getState().fetchIssues(); - } - }; - this.dispatchTable = new Map(); - } - async computeUmap( widgetId: string, columnNames: string[], @@ -108,34 +26,13 @@ export class DataService { metric: UmapMetric, min_dist: number ): Promise { - const messageId = uuidv4(); - - const promise = new Promise((resolve) => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - this.dispatchTable.set(messageId, (message: any) => { - const result = { - points: message.data.points, - indices: message.data.indices, - }; - resolve(result); - }); - }); - - this.connection.send({ - type: 'umap', - widget_id: widgetId, - uid: messageId, - generation_id: useDataset.getState().generationID, - data: { - indices: Array.from(indices), - columns: columnNames, - n_neighbors: n_neighbors, - metric: metric, - min_dist: min_dist, - }, + return await taskService.run('umap', widgetId, { + column_names: columnNames, + indices: Array.from(indices), + n_neighbors, + metric, + min_dist, }); - - return promise; } async computePCA( @@ -144,38 +41,13 @@ export class DataService { indices: IndexArray, pcaNormalization: PCANormalization ): Promise { - const messageId = uuidv4(); - - const promise = new Promise((resolve) => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - this.dispatchTable.set(messageId, (message: any) => { - const result = { - points: message.data.points, - indices: message.data.indices, - }; - resolve(result); - }); - }); - - this.connection.send({ - type: 'pca', - widget_id: widgetId, - uid: messageId, - generation_id: useDataset.getState().generationID, - data: { - indices: Array.from(indices), - columns: columnNames, - normalization: pcaNormalization, - }, + return await taskService.run('pca', widgetId, { + indices: Array.from(indices), + column_names: columnNames, + normalization: pcaNormalization, }); - - return promise; } } -const dataService = new DataService( - globalThis.location.hostname, - globalThis.location.port -); - +const dataService = new DataService(); export default dataService; diff --git a/src/services/task.ts b/src/services/task.ts new file mode 100644 index 00000000..f95a5d5e --- /dev/null +++ b/src/services/task.ts @@ -0,0 +1,47 @@ +import { useDataset } from '../lib'; +import websocketService, { WebsocketService } from './websocket'; + +interface ResponseHandler { + resolve: (result: unknown) => void; + reject: (error: unknown) => void; +} + +// service handling the execution of remote tasks +class TaskService { + // dispatch table for task responses/errors + // eslint-disable-next-line @typescript-eslint/no-explicit-any + dispatchTable: Map; + websocketService: WebsocketService; + + constructor(websocketService: WebsocketService) { + this.dispatchTable = new Map(); + this.websocketService = websocketService; + this.websocketService.registerMessageHandler('task.result', (data) => { + this.dispatchTable.get(data.task_id)?.resolve(data.result); + }); + this.websocketService.registerMessageHandler('task.error', (data) => { + this.dispatchTable.get(data.task_id)?.reject(data.error); + }); + } + + async run(task: string, name: string, args: unknown) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return new Promise((resolve, reject) => { + const task_id = crypto.randomUUID(); + this.dispatchTable.set(task_id, { resolve, reject }); + websocketService.send({ + type: 'task', + data: { + task: task, + widget_id: name, + task_id, + generation_id: useDataset.getState().generationID, + args, + }, + }); + }); + } +} + +const taskService = new TaskService(websocketService); +export default taskService; diff --git a/src/services/websocket/connection.ts b/src/services/websocket/connection.ts new file mode 100644 index 00000000..84452ca7 --- /dev/null +++ b/src/services/websocket/connection.ts @@ -0,0 +1,61 @@ +// the maximum number of outgoing messages that are queued, +import { Message } from './types'; + +// while the connection is down +const MAX_QUEUED_MESSAGES = 16; + +// a websocket connection to the spotlight backend +// automatically reconnects when necessary +class Connection { + url: string; + socket?: WebSocket; + messageQueue: string[]; + onmessage?: (data: unknown) => void; + + constructor(host: string, port: string) { + this.messageQueue = []; + if (globalThis.location.protocol === 'https:') { + this.url = `wss://${host}:${port}/api/ws`; + } else { + this.url = `ws://${host}:${port}/api/ws`; + } + this.#connect(); + } + + send(message: Message): void { + const data = JSON.stringify(message); + if (this.socket) { + this.socket.send(data); + } else { + if (this.messageQueue.length > MAX_QUEUED_MESSAGES) { + this.messageQueue.shift(); + } + this.messageQueue.push(data); + } + } + + #connect() { + const webSocket = new WebSocket(this.url); + + webSocket.onopen = () => { + this.socket = webSocket; + this.messageQueue.forEach((message) => this.socket?.send(message)); + this.messageQueue.length = 0; + }; + webSocket.onmessage = (event) => { + const message = JSON.parse(event.data); + this.onmessage?.(message); + }; + webSocket.onerror = () => { + webSocket.close(); + }; + webSocket.onclose = () => { + this.socket = undefined; + setTimeout(() => { + this.#connect(); + }, 500); + }; + } +} + +export default Connection; diff --git a/src/services/websocket/index.ts b/src/services/websocket/index.ts new file mode 100644 index 00000000..135ea37c --- /dev/null +++ b/src/services/websocket/index.ts @@ -0,0 +1,45 @@ +import { notifyProblem } from '../../notify'; +import { Problem } from '../../types'; +import Connection from './connection'; +import { Message, MessageHandler } from './types'; + +// this service provides a websocket connection to the service +// and handles incoming and outgoing messages +export class WebsocketService { + connection: Connection; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + messageHandlers: Map; + + constructor(host: string, port: string) { + this.messageHandlers = new Map(); + this.connection = new Connection(host, port); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + this.connection.onmessage = (message: any) => { + const messageHandler = this.messageHandlers.get(message.type); + if (messageHandler) { + messageHandler(message.data); + } else { + console.error(`Unknown websocket message: ${message.type}`); + } + }; + } + + registerMessageHandler(messageType: string, handler: MessageHandler): void { + this.messageHandlers.set(messageType, handler); + } + + send(message: Message): void { + this.connection.send(message); + } +} + +const websocketService = new WebsocketService( + globalThis.location.hostname, + globalThis.location.port +); + +websocketService.registerMessageHandler('error', (problem: Problem) => { + notifyProblem(problem); +}); + +export default websocketService; diff --git a/src/services/websocket/types.ts b/src/services/websocket/types.ts new file mode 100644 index 00000000..18840150 --- /dev/null +++ b/src/services/websocket/types.ts @@ -0,0 +1,7 @@ +export interface Message { + type: string; + data: unknown; +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type MessageHandler = (data: any) => void; diff --git a/src/stores/dataset/dataset.ts b/src/stores/dataset/dataset.ts index 0cd3e3c0..4916c015 100644 --- a/src/stores/dataset/dataset.ts +++ b/src/stores/dataset/dataset.ts @@ -22,6 +22,7 @@ import { notifyAPIError, notifyError } from '../../notify'; import { makeColumnsColorTransferFunctions } from './colorTransferFunctionFactory'; import { makeColumn } from './columnFactory'; import { makeColumnsStats } from './statisticsFactory'; +import websocketService from '../../services/websocket'; export type CallbackOrData = ((data: T) => T) | T; @@ -540,3 +541,11 @@ useDataset.subscribe( ); useColors.subscribe(() => useDataset.getState().recomputeColorTransferFunctions()); + +websocketService.registerMessageHandler('refresh', () => { + useDataset.getState().refresh(); +}); + +websocketService.registerMessageHandler('issuesUpdated', () => { + useDataset.getState().fetchIssues(); +}); diff --git a/src/stores/layout.ts b/src/stores/layout.ts index 8eb5f517..53d4a993 100644 --- a/src/stores/layout.ts +++ b/src/stores/layout.ts @@ -2,6 +2,7 @@ import { create } from 'zustand'; import { AppLayout } from '../types'; import api from '../api'; import { saveAs } from 'file-saver'; +import websocketService from '../services/websocket'; export interface State { layout: AppLayout; @@ -41,3 +42,7 @@ export const useLayout = create((set) => ({ reader.readAsText(file); }, })); + +websocketService.registerMessageHandler('resetLayout', () => { + useLayout.getState().reset(); +}); diff --git a/src/widgets/SimilarityMap/SimilarityMap.tsx b/src/widgets/SimilarityMap/SimilarityMap.tsx index 193a46f8..023c4dd1 100644 --- a/src/widgets/SimilarityMap/SimilarityMap.tsx +++ b/src/widgets/SimilarityMap/SimilarityMap.tsx @@ -23,6 +23,7 @@ import { IndexArray, isNumberColumn, NumberColumn, + Problem, TableData, } from '../../types'; import { createSizeTransferFunction } from '../../dataformat'; @@ -220,8 +221,8 @@ const SimilarityMap: Widget = () => { const placeByColumns = useDataset(placeByColumnsSelector, shallow); const [positions, setPositions] = useState<[number, number][]>([]); - const [visibleIndices, setVisibleIndices] = useState([]); + const [problem, setProblem] = useState(null); const selected = useMemo(() => { const selected: boolean[] = []; @@ -275,13 +276,14 @@ const SimilarityMap: Widget = () => { const widgetId = useMemo(() => uuidv4(), []); useEffect(() => { + setVisibleIndices([]); + setPositions([]); + setProblem(null); + if (!indices.length || !placeByColumnKeys.length) { - setVisibleIndices([]); - setPositions([]); setIsComputing(false); return; } - setIsComputing(true); const reductionPromise = @@ -302,12 +304,18 @@ const SimilarityMap: Widget = () => { ); let cancelled = false; - reductionPromise.then(({ points, indices }) => { - if (cancelled) return; - setVisibleIndices(indices); - setPositions(points); - setIsComputing(false); - }); + reductionPromise + .then(({ points, indices }) => { + if (cancelled) return; + setVisibleIndices(indices); + setPositions(points); + setIsComputing(false); + }) + .catch((problem: Problem) => { + if (cancelled) return; + setProblem(problem); + setIsComputing(false); + }); return () => { cancelled = true; @@ -443,23 +451,36 @@ const SimilarityMap: Widget = () => { const areColumnsSelected = !!placeByColumnKeys.length; const hasVisibleRows = !!visibleIndices.length; - return ( - - {isComputing && } - {!areColumnsSelected && !isComputing && ( - - - Select columns and show at least two rows to display similarity - map - - - )} - {areColumnsSelected && !hasVisibleRows && !isComputing && ( - - Not enough rows - - )} - {areColumnsSelected && hasVisibleRows && !isComputing && ( + let content: JSX.Element; + if (problem) { + content = ( +
+
+
+
{problem.title}
+
{problem.detail}
+
+
+ ); + } else if (isComputing) { + content = ; + } else if (!areColumnsSelected) { + content = ( + + + Select columns and show at least two rows to display similarity map + + + ); + } else if (!hasVisibleRows) { + content = ( + + Not enough rows + + ); + } else { + content = ( + <> { /> )} - )} - {!isComputing && (
{visibleIndices.length} of {indices.length} rows
- )} + + ); + } + return ( + + {content}