From f114895485b4d98289517fbfa839816371c736c5 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Mon, 23 Oct 2023 17:15:38 +0200 Subject: [PATCH 1/8] feat: transport Exceptions from tasks over our websocket connection --- renumics/spotlight/backend/tasks/reduction.py | 2 + renumics/spotlight/backend/websockets.py | 24 ++++++ src/services/data.ts | 12 ++- src/widgets/SimilarityMap/SimilarityMap.tsx | 84 ++++++++++++------- 4 files changed, 91 insertions(+), 31 deletions(-) diff --git a/renumics/spotlight/backend/tasks/reduction.py b/renumics/spotlight/backend/tasks/reduction.py index 00207c38..bb14096a 100644 --- a/renumics/spotlight/backend/tasks/reduction.py +++ b/renumics/spotlight/backend/tasks/reduction.py @@ -78,6 +78,8 @@ def compute_umap( Prepare data from table and compute U-Map on them. """ + raise Exception("Ooops") + try: data, indices = align_data(data_store, column_names, indices) except (ColumnNotExistsError, ColumnNotEmbeddable): diff --git a/renumics/spotlight/backend/websockets.py b/renumics/spotlight/backend/websockets.py index 8be62f4a..9fa6e9d6 100644 --- a/renumics/spotlight/backend/websockets.py +++ b/renumics/spotlight/backend/websockets.py @@ -57,6 +57,19 @@ class ReductionMessage(Message): generation_id: int +class TaskErrorMessage(Message): + """ + Common message model for task errors + """ + + widget_id: str + uid: str + error: str + title: str + detail: Optional[str] = None + data: None = None + + class ReductionRequestData(BaseModel): """ Base data reduction request payload. @@ -155,6 +168,7 @@ async def handle_message(request: Message, connection: "WebsocketConnection") -> New message types should be registered by decorating with `@handle_message.register`. """ + # TODO: add generic error handler raise NotImplementedError @@ -326,6 +340,16 @@ async def _(request: UMapRequest, connection: "WebsocketConnection") -> None: ) except TaskCancelled: ... + except Exception as e: + msg = TaskErrorMessage( + type="tasks.error", + widget_id=request.widget_id, + uid=request.uid, + error=type(e).__name__, + title=type(e).__name__, + detail=type(e).__doc__, + ) + await connection.send_async(msg) else: response = ReductionResponse( type="umap_result", diff --git a/src/services/data.ts b/src/services/data.ts index 16aa82f5..4884e365 100644 --- a/src/services/data.ts +++ b/src/services/data.ts @@ -2,6 +2,7 @@ import { useDataset } from '../stores/dataset'; import { useLayout } from '../stores/layout'; import { IndexArray } from '../types'; import { v4 as uuidv4 } from 'uuid'; +import { Problem } from '../types'; export const umapMetricNames = [ 'euclidean', @@ -110,9 +111,18 @@ export class DataService { ): Promise { const messageId = uuidv4(); - const promise = new Promise((resolve) => { + const promise = new Promise((resolve, reject) => { // eslint-disable-next-line @typescript-eslint/no-explicit-any this.dispatchTable.set(messageId, (message: any) => { + if (message.type === 'tasks.error') { + const error: Problem = { + type: message.error, + title: message.title, + detail: message.detail, + }; + reject(error); + return; + } const result = { points: message.data.points, indices: message.data.indices, diff --git a/src/widgets/SimilarityMap/SimilarityMap.tsx b/src/widgets/SimilarityMap/SimilarityMap.tsx index 193a46f8..df9e602b 100644 --- a/src/widgets/SimilarityMap/SimilarityMap.tsx +++ b/src/widgets/SimilarityMap/SimilarityMap.tsx @@ -1,4 +1,5 @@ import SimilaritiesIcon from '../../icons/Bubbles'; +import WarningIcon from '../../icons/Warning'; import LoadingIndicator from '../../components/LoadingIndicator'; import Plot, { MergeStrategy, @@ -23,6 +24,7 @@ import { IndexArray, isNumberColumn, NumberColumn, + Problem, TableData, } from '../../types'; import { createSizeTransferFunction } from '../../dataformat'; @@ -220,8 +222,8 @@ const SimilarityMap: Widget = () => { const placeByColumns = useDataset(placeByColumnsSelector, shallow); const [positions, setPositions] = useState<[number, number][]>([]); - const [visibleIndices, setVisibleIndices] = useState([]); + const [problem, setProblem] = useState(null); const selected = useMemo(() => { const selected: boolean[] = []; @@ -275,13 +277,14 @@ const SimilarityMap: Widget = () => { const widgetId = useMemo(() => uuidv4(), []); useEffect(() => { + setVisibleIndices([]); + setPositions([]); + setProblem(null); + if (!indices.length || !placeByColumnKeys.length) { - setVisibleIndices([]); - setPositions([]); setIsComputing(false); return; } - setIsComputing(true); const reductionPromise = @@ -302,12 +305,18 @@ const SimilarityMap: Widget = () => { ); let cancelled = false; - reductionPromise.then(({ points, indices }) => { - if (cancelled) return; - setVisibleIndices(indices); - setPositions(points); - setIsComputing(false); - }); + reductionPromise + .then(({ points, indices }) => { + if (cancelled) return; + setVisibleIndices(indices); + setPositions(points); + setIsComputing(false); + }) + .catch((problem: Problem) => { + if (cancelled) return; + setProblem(problem); + setIsComputing(false); + }); return () => { cancelled = true; @@ -443,23 +452,35 @@ const SimilarityMap: Widget = () => { const areColumnsSelected = !!placeByColumnKeys.length; const hasVisibleRows = !!visibleIndices.length; - return ( - - {isComputing && } - {!areColumnsSelected && !isComputing && ( - - - Select columns and show at least two rows to display similarity - map - - - )} - {areColumnsSelected && !hasVisibleRows && !isComputing && ( - - Not enough rows - - )} - {areColumnsSelected && hasVisibleRows && !isComputing && ( + let content: JSX.Element; + if (problem) { + content = ( +
+
+ {problem.title} +
+
{problem.detail}
+
+ ); + } else if (isComputing) { + content = ; + } else if (!areColumnsSelected) { + content = ( + + + Select columns and show at least two rows to display similarity map + + + ); + } else if (!hasVisibleRows) { + content = ( + + Not enough rows + + ); + } else { + content = ( + <> { /> )} - )} - {!isComputing && (
{visibleIndices.length} of {indices.length} rows
- )} + + ); + } + return ( + + {content} Date: Tue, 24 Oct 2023 10:08:29 +0200 Subject: [PATCH 2/8] feat: style error in simmap --- src/widgets/SimilarityMap/SimilarityMap.tsx | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/widgets/SimilarityMap/SimilarityMap.tsx b/src/widgets/SimilarityMap/SimilarityMap.tsx index df9e602b..5c42ebc9 100644 --- a/src/widgets/SimilarityMap/SimilarityMap.tsx +++ b/src/widgets/SimilarityMap/SimilarityMap.tsx @@ -456,10 +456,11 @@ const SimilarityMap: Widget = () => { if (problem) { content = (
-
- {problem.title} +
+
+
{problem.title}
+
{problem.detail}
-
{problem.detail}
); } else if (isComputing) { From 5c41864d3cf72945ee5e31b07f8d1c6d1028b2ca Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 12:49:39 +0200 Subject: [PATCH 3/8] refactor: setup explicit ws message handlers in backend --- renumics/spotlight/backend/tasks/reduction.py | 2 - .../spotlight/backend/tasks/task_manager.py | 11 +- renumics/spotlight/backend/websockets.py | 271 ++++++------------ src/services/data.ts | 162 +---------- src/services/task.ts | 46 +++ src/services/websocket/connection.ts | 61 ++++ src/services/websocket/index.ts | 49 ++++ src/services/websocket/types.ts | 6 + 8 files changed, 270 insertions(+), 338 deletions(-) create mode 100644 src/services/task.ts create mode 100644 src/services/websocket/connection.ts create mode 100644 src/services/websocket/index.ts create mode 100644 src/services/websocket/types.ts diff --git a/renumics/spotlight/backend/tasks/reduction.py b/renumics/spotlight/backend/tasks/reduction.py index bb14096a..00207c38 100644 --- a/renumics/spotlight/backend/tasks/reduction.py +++ b/renumics/spotlight/backend/tasks/reduction.py @@ -78,8 +78,6 @@ def compute_umap( Prepare data from table and compute U-Map on them. """ - raise Exception("Ooops") - try: data, indices = align_data(data_store, column_names, indices) except (ColumnNotExistsError, ColumnNotEmbeddable): diff --git a/renumics/spotlight/backend/tasks/task_manager.py b/renumics/spotlight/backend/tasks/task_manager.py index e120fa88..8c392a55 100644 --- a/renumics/spotlight/backend/tasks/task_manager.py +++ b/renumics/spotlight/backend/tasks/task_manager.py @@ -6,7 +6,7 @@ import multiprocessing from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool -from typing import Any, Callable, List, Optional, Sequence, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union from .exceptions import TaskCancelled from .task import Task @@ -30,16 +30,20 @@ def create_task( self, func: Callable, args: Sequence[Any], + kwargs: Optional[Dict[str, Any]] = None, name: Optional[str] = None, tag: Optional[Union[str, int]] = None, ) -> Task: """ create and launch a new task """ + if kwargs is None: + kwargs = {} + # cancel running task with same name self.cancel(name=name) - future = self.pool.submit(func, *args) + future = self.pool.submit(func, *args, **kwargs) task = Task(name, tag, future) self.tasks.append(task) @@ -59,6 +63,7 @@ async def run_async( self, func: Callable[..., T], args: Sequence[Any], + kwargs: Optional[Dict[str, Any]] = None, name: Optional[str] = None, tag: Optional[Union[str, int]] = None, ) -> T: @@ -66,7 +71,7 @@ async def run_async( Launch a new task. Await and return result. """ - task = self.create_task(func, args, name, tag) + task = self.create_task(func, args=args, kwargs=kwargs, name=name, tag=tag) try: return await asyncio.wrap_future(task.future) except BrokenProcessPool as e: diff --git a/renumics/spotlight/backend/websockets.py b/renumics/spotlight/backend/websockets.py index 9fa6e9d6..30d15652 100644 --- a/renumics/spotlight/backend/websockets.py +++ b/renumics/spotlight/backend/websockets.py @@ -3,11 +3,20 @@ """ import asyncio -import functools -import orjson -from typing import Any, List, Optional, Set, Callable +from typing import ( + Any, + Coroutine, + Dict, + Optional, + Set, + Callable, + Tuple, + Type, + cast, +) import numpy as np +import orjson from fastapi import WebSocket, WebSocketDisconnect from loguru import logger from pydantic import BaseModel @@ -17,7 +26,7 @@ from .tasks import TaskManager, TaskCancelled from .tasks.reduction import compute_umap, compute_pca -from .exceptions import GenerationIDMismatch +from .exceptions import GenerationIDMismatch, Problem class Message(BaseModel): @@ -47,99 +56,12 @@ class ResetLayoutMessage(Message): data: Any = None -class ReductionMessage(Message): - """ - Common data reduction message model. - """ - +class TaskData(BaseModel): + task: str widget_id: str - uid: str - generation_id: int - - -class TaskErrorMessage(Message): - """ - Common message model for task errors - """ - - widget_id: str - uid: str - error: str - title: str - detail: Optional[str] = None - data: None = None - - -class ReductionRequestData(BaseModel): - """ - Base data reduction request payload. - """ - - indices: List[int] - columns: List[str] - - -class UMapRequestData(ReductionRequestData): - """ - U-Map request payload. - """ - - n_neighbors: int - metric: str - min_dist: float - - -class PCARequestData(ReductionRequestData): - """ - PCA request payload. - """ - - normalization: str - - -class UMapRequest(ReductionMessage): - """ - U-Map request model. - """ - - data: UMapRequestData - - -class PCARequest(ReductionMessage): - """ - PCA request model. - """ - - data: PCARequestData - - -class ReductionResponseData(BaseModel): - """ - Data reduction response payload. - """ - - indices: List[int] - points: np.ndarray - - class Config: - arbitrary_types_allowed = True - - -class ReductionResponse(ReductionMessage): - """ - Data reduction response model. - """ - - data: ReductionResponseData - - -MESSAGE_BY_TYPE = { - "umap": UMapRequest, - "umap_result": ReductionResponse, - "pca": PCARequest, - "pca_result": ReductionResponse, - "refresh": RefreshMessage, -} + task_id: str + generation_id: Optional[int] + args: Any class UnknownMessageType(Exception): @@ -148,28 +70,26 @@ class UnknownMessageType(Exception): """ -def parse_message(raw_message: str) -> Message: - """ - Parse a websocket message from a raw text. - """ - json_message = orjson.loads(raw_message) - message_type = json_message["type"] - message_class = MESSAGE_BY_TYPE.get(message_type) - if message_class is None: - raise UnknownMessageType(f"Message type {message_type} is unknown.") - return message_class(**json_message) +PayloadType = Type[BaseModel] +MessageHandler = Callable[[Any, "WebsocketConnection"], Coroutine[Any, Any, Any]] +MessageHandlerSpec = Tuple[PayloadType, MessageHandler] +MESSAGE_HANDLERS: Dict[str, MessageHandlerSpec] = {} -@functools.singledispatch -async def handle_message(request: Message, connection: "WebsocketConnection") -> None: - """ - Handle incoming messages. +def register_message_handler( + message_type: str, handler_spec: MessageHandlerSpec +) -> None: + MESSAGE_HANDLERS[message_type] = handler_spec - New message types should be registered by decorating with `@handle_message.register`. - """ - # TODO: add generic error handler - raise NotImplementedError +def message_handler( + message_type: str, payload_type: Type[BaseModel] +) -> Callable[[MessageHandler], MessageHandler]: + def decorator(handler: MessageHandler) -> MessageHandler: + register_message_handler(message_type, (payload_type, handler)) + return handler + + return decorator class WebsocketConnection: @@ -218,12 +138,17 @@ async def listen(self) -> None: try: while True: try: - message = parse_message(await self.websocket.receive_text()) + raw_message = await self.websocket.receive_text() + message = Message(**orjson.loads(raw_message)) except UnknownMessageType as e: logger.warning(str(e)) else: logger.info(f"WS message with type {message.type} received.") - asyncio.create_task(handle_message(message, self)) + if handler_spec := MESSAGE_HANDLERS.get(message.type): + payload_type, handler = handler_spec + asyncio.create_task(handler(payload_type(**message.data), self)) + else: + logger.error(f"Unknown message received: {message.type}") except WebSocketDisconnect: self._on_disconnect() @@ -314,83 +239,63 @@ def on_disconnect(self, connection: WebsocketConnection) -> None: callback(len(self.connections)) -@handle_message.register -async def _(request: UMapRequest, connection: "WebsocketConnection") -> None: +TASK_FUNCS = {"umap": compute_umap, "pca": compute_pca} + + +@message_handler("task", TaskData) +async def _(data: TaskData, connection: WebsocketConnection) -> None: data_store: Optional[DataStore] = connection.websocket.app.data_store if data_store is None: return None - try: - data_store.check_generation_id(request.generation_id) - except GenerationIDMismatch: - return + + if data.generation_id: + try: + data_store.check_generation_id(data.generation_id) + except GenerationIDMismatch: + return try: - points, valid_indices = await connection.task_manager.run_async( - compute_umap, - ( - data_store, - request.data.columns, - request.data.indices, - request.data.n_neighbors, - request.data.metric, - request.data.min_dist, - ), - name=request.widget_id, + task_func = TASK_FUNCS[data.task] + print(data.args) + result = await connection.task_manager.run_async( + task_func, # type: ignore + args=(data_store,), + kwargs=data.args, + name=data.widget_id, tag=id(connection), ) + points = cast(np.ndarray, result[0]) + valid_indices = cast(np.ndarray, result[1]) except TaskCancelled: - ... - except Exception as e: - msg = TaskErrorMessage( - type="tasks.error", - widget_id=request.widget_id, - uid=request.uid, - error=type(e).__name__, - title=type(e).__name__, - detail=type(e).__doc__, + pass + except Problem as e: + msg = Message( + type="task.error", + data={ + "task_id": data.task_id, + "error": { + "type": type(e).__name__, + "title": type(e).__name__, + "detail": type(e).__doc__, + }, + }, ) await connection.send_async(msg) - else: - response = ReductionResponse( - type="umap_result", - widget_id=request.widget_id, - uid=request.uid, - generation_id=request.generation_id, - data=ReductionResponseData(indices=valid_indices, points=points), - ) - await connection.send_async(response) - - -@handle_message.register -async def _(request: PCARequest, connection: "WebsocketConnection") -> None: - data_store: Optional[DataStore] = connection.websocket.app.data_store - if data_store is None: - return None - try: - data_store.check_generation_id(request.generation_id) - except GenerationIDMismatch: - return - - try: - points, valid_indices = await connection.task_manager.run_async( - compute_pca, - ( - data_store, - request.data.columns, - request.data.indices, - request.data.normalization, - ), - name=request.widget_id, - tag=id(connection), + except Exception as e: + msg = Message( + type="task.error", + data={ + "task_id": data.task_id, + "error": { + "type": type(e).__name__, + "title": type(e).__name__, + "detail": type(e).__doc__, + }, + }, ) - except TaskCancelled: - ... + await connection.send_async(msg) else: - response = ReductionResponse( - type="pca_result", - widget_id=request.widget_id, - uid=request.uid, - generation_id=request.generation_id, - data=ReductionResponseData(indices=valid_indices, points=points), + msg = Message( + type="task.result", data={"points": points, "indices": valid_indices} ) - await connection.send_async(response) + await connection.send_async(msg) diff --git a/src/services/data.ts b/src/services/data.ts index 4884e365..069e01dc 100644 --- a/src/services/data.ts +++ b/src/services/data.ts @@ -1,8 +1,5 @@ -import { useDataset } from '../stores/dataset'; -import { useLayout } from '../stores/layout'; import { IndexArray } from '../types'; -import { v4 as uuidv4 } from 'uuid'; -import { Problem } from '../types'; +import taskService from './task'; export const umapMetricNames = [ 'euclidean', @@ -11,11 +8,8 @@ export const umapMetricNames = [ 'cosine', 'mahalanobis', ] as const; - export type UmapMetric = typeof umapMetricNames[number]; - export const pcaNormalizations = ['none', 'standardize', 'robust standardize'] as const; - export type PCANormalization = typeof pcaNormalizations[number]; interface ReductionResult { @@ -23,84 +17,7 @@ interface ReductionResult { indices: IndexArray; } -const MAX_QUEUED_MESSAGES = 16; - -class Connection { - url: string; - socket?: WebSocket; - messageQueue: string[]; - onmessage?: (data: unknown) => void; - - constructor(host: string, port: string) { - this.messageQueue = []; - if (globalThis.location.protocol === 'https:') { - this.url = `wss://${host}:${port}/api/ws`; - } else { - this.url = `ws://${host}:${port}/api/ws`; - } - this.#connect(); - } - - send(message: unknown): void { - const data = JSON.stringify(message); - if (this.socket) { - this.socket.send(data); - } else { - if (this.messageQueue.length > MAX_QUEUED_MESSAGES) { - this.messageQueue.shift(); - } - this.messageQueue.push(data); - } - } - - #connect() { - const webSocket = new WebSocket(this.url); - - webSocket.onopen = () => { - this.socket = webSocket; - this.messageQueue.forEach((message) => this.socket?.send(message)); - this.messageQueue.length = 0; - }; - webSocket.onmessage = (event) => { - const message = JSON.parse(event.data); - this.onmessage?.(message); - }; - webSocket.onerror = () => { - webSocket.close(); - }; - webSocket.onclose = () => { - this.socket = undefined; - setTimeout(() => { - this.#connect(); - }, 500); - }; - } -} - export class DataService { - connection: Connection; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - dispatchTable: Map void>; - - constructor(host: string, port: string) { - this.connection = new Connection(host, port); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - this.connection.onmessage = (message: any) => { - if (message.uid) { - this.dispatchTable.get(message.uid)?.(message); - return; - } - if (message.type === 'refresh') { - useDataset.getState().refresh(); - } else if (message.type === 'resetLayout') { - useLayout.getState().reset(); - } else if (message.type === 'issuesUpdated') { - useDataset.getState().fetchIssues(); - } - }; - this.dispatchTable = new Map(); - } - async computeUmap( widgetId: string, columnNames: string[], @@ -109,43 +26,13 @@ export class DataService { metric: UmapMetric, min_dist: number ): Promise { - const messageId = uuidv4(); - - const promise = new Promise((resolve, reject) => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - this.dispatchTable.set(messageId, (message: any) => { - if (message.type === 'tasks.error') { - const error: Problem = { - type: message.error, - title: message.title, - detail: message.detail, - }; - reject(error); - return; - } - const result = { - points: message.data.points, - indices: message.data.indices, - }; - resolve(result); - }); - }); - - this.connection.send({ - type: 'umap', - widget_id: widgetId, - uid: messageId, - generation_id: useDataset.getState().generationID, - data: { - indices: Array.from(indices), - columns: columnNames, - n_neighbors: n_neighbors, - metric: metric, - min_dist: min_dist, - }, + return await taskService.run('umap', widgetId, { + column_names: columnNames, + indices: Array.from(indices), + n_neighbors, + metric, + min_dist, }); - - return promise; } async computePCA( @@ -154,38 +41,13 @@ export class DataService { indices: IndexArray, pcaNormalization: PCANormalization ): Promise { - const messageId = uuidv4(); - - const promise = new Promise((resolve) => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - this.dispatchTable.set(messageId, (message: any) => { - const result = { - points: message.data.points, - indices: message.data.indices, - }; - resolve(result); - }); - }); - - this.connection.send({ - type: 'pca', - widget_id: widgetId, - uid: messageId, - generation_id: useDataset.getState().generationID, - data: { - indices: Array.from(indices), - columns: columnNames, - normalization: pcaNormalization, - }, + return await taskService.run('pca', widgetId, { + indices: Array.from(indices), + column_names: columnNames, + normalization: pcaNormalization, }); - - return promise; } } -const dataService = new DataService( - globalThis.location.hostname, - globalThis.location.port -); - +const dataService = new DataService(); export default dataService; diff --git a/src/services/task.ts b/src/services/task.ts new file mode 100644 index 00000000..0a74d58e --- /dev/null +++ b/src/services/task.ts @@ -0,0 +1,46 @@ +import { useDataset } from '../lib'; +import websocketService, { WebsocketService } from './websocket'; + +interface ResponseHandler { + resolve: (result: unknown) => void; + reject: (error: unknown) => void; +} + +// service handling the execution of remote tasks +class TaskService { + // dispatch table for task responses/errors + // eslint-disable-next-line @typescript-eslint/no-explicit-any + dispatchTable: Map; + websocketService: WebsocketService; + + constructor(websocketService: WebsocketService) { + this.dispatchTable = new Map(); + this.websocketService = websocketService; + this.websocketService.registerMessageHandler('task.result', (message: any) => { + this.dispatchTable.get(message.data.task_id)?.resolve(message.data.result); + }); + this.websocketService.registerMessageHandler('task.error', (message: any) => { + this.dispatchTable.get(message.data.task_id)?.reject(message.data.error); + }); + } + + async run(task: string, name: string, args: unknown) { + return new Promise((resolve, reject) => { + const task_id = crypto.randomUUID(); + this.dispatchTable.set(task_id, { resolve, reject }); + websocketService.send({ + type: 'task', + data: { + task: task, + widget_id: name, + task_id, + generation_id: useDataset.getState().generationID, + args, + }, + }); + }); + } +} + +const taskService = new TaskService(websocketService); +export default taskService; diff --git a/src/services/websocket/connection.ts b/src/services/websocket/connection.ts new file mode 100644 index 00000000..84452ca7 --- /dev/null +++ b/src/services/websocket/connection.ts @@ -0,0 +1,61 @@ +// the maximum number of outgoing messages that are queued, +import { Message } from './types'; + +// while the connection is down +const MAX_QUEUED_MESSAGES = 16; + +// a websocket connection to the spotlight backend +// automatically reconnects when necessary +class Connection { + url: string; + socket?: WebSocket; + messageQueue: string[]; + onmessage?: (data: unknown) => void; + + constructor(host: string, port: string) { + this.messageQueue = []; + if (globalThis.location.protocol === 'https:') { + this.url = `wss://${host}:${port}/api/ws`; + } else { + this.url = `ws://${host}:${port}/api/ws`; + } + this.#connect(); + } + + send(message: Message): void { + const data = JSON.stringify(message); + if (this.socket) { + this.socket.send(data); + } else { + if (this.messageQueue.length > MAX_QUEUED_MESSAGES) { + this.messageQueue.shift(); + } + this.messageQueue.push(data); + } + } + + #connect() { + const webSocket = new WebSocket(this.url); + + webSocket.onopen = () => { + this.socket = webSocket; + this.messageQueue.forEach((message) => this.socket?.send(message)); + this.messageQueue.length = 0; + }; + webSocket.onmessage = (event) => { + const message = JSON.parse(event.data); + this.onmessage?.(message); + }; + webSocket.onerror = () => { + webSocket.close(); + }; + webSocket.onclose = () => { + this.socket = undefined; + setTimeout(() => { + this.#connect(); + }, 500); + }; + } +} + +export default Connection; diff --git a/src/services/websocket/index.ts b/src/services/websocket/index.ts new file mode 100644 index 00000000..9183ded2 --- /dev/null +++ b/src/services/websocket/index.ts @@ -0,0 +1,49 @@ +import Connection from './connection'; +import { Message, MessageHandler } from './types'; + +// this service provides a websocket connection to the service +// and handles incoming and outgoing messages +export class WebsocketService { + connection: Connection; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + messageHandlers: Map; + + constructor(host: string, port: string) { + this.messageHandlers = new Map(); + this.connection = new Connection(host, port); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + this.connection.onmessage = (message: any) => { + const messageHandler = this.messageHandlers.get(message.type); + if (messageHandler) { + messageHandler(message); + } else { + console.error(`Unknown websocket message: ${message.type}`); + } + }; + } + + registerMessageHandler(messageType: string, handler: MessageHandler): void { + this.messageHandlers.set(messageType, handler); + } + + send(message: Message): void { + this.connection.send(message); + } +} + +const websocketService = new WebsocketService( + globalThis.location.hostname, + globalThis.location.port +); +export default websocketService; + +// TODO: reimplement/move +/* + if (message.type === 'refresh') { + useDataset.getState().refresh(); + } else if (message.type === 'resetLayout') { + useLayout.getState().reset(); + } else if (message.type === 'issuesUpdated') { + useDataset.getState().fetchIssues(); + } +*/ diff --git a/src/services/websocket/types.ts b/src/services/websocket/types.ts new file mode 100644 index 00000000..f33f9eee --- /dev/null +++ b/src/services/websocket/types.ts @@ -0,0 +1,6 @@ +export interface Message { + type: string; + data: any; +} + +export type MessageHandler = (message: Message) => void; From 6fd6f8cac0ee08c0deefc2cbf8b5e73d05387053 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 12:58:11 +0200 Subject: [PATCH 4/8] chore: reimplement missing message handlers in fe --- src/services/websocket/index.ts | 6 +----- src/stores/dataset/dataset.ts | 9 +++++++++ src/stores/layout.ts | 5 +++++ 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/services/websocket/index.ts b/src/services/websocket/index.ts index 9183ded2..efc344e1 100644 --- a/src/services/websocket/index.ts +++ b/src/services/websocket/index.ts @@ -37,13 +37,9 @@ const websocketService = new WebsocketService( ); export default websocketService; -// TODO: reimplement/move +// TODO: these functions should probably be somewhere else /* - if (message.type === 'refresh') { - useDataset.getState().refresh(); } else if (message.type === 'resetLayout') { useLayout.getState().reset(); - } else if (message.type === 'issuesUpdated') { - useDataset.getState().fetchIssues(); } */ diff --git a/src/stores/dataset/dataset.ts b/src/stores/dataset/dataset.ts index 0cd3e3c0..4916c015 100644 --- a/src/stores/dataset/dataset.ts +++ b/src/stores/dataset/dataset.ts @@ -22,6 +22,7 @@ import { notifyAPIError, notifyError } from '../../notify'; import { makeColumnsColorTransferFunctions } from './colorTransferFunctionFactory'; import { makeColumn } from './columnFactory'; import { makeColumnsStats } from './statisticsFactory'; +import websocketService from '../../services/websocket'; export type CallbackOrData = ((data: T) => T) | T; @@ -540,3 +541,11 @@ useDataset.subscribe( ); useColors.subscribe(() => useDataset.getState().recomputeColorTransferFunctions()); + +websocketService.registerMessageHandler('refresh', () => { + useDataset.getState().refresh(); +}); + +websocketService.registerMessageHandler('issuesUpdated', () => { + useDataset.getState().fetchIssues(); +}); diff --git a/src/stores/layout.ts b/src/stores/layout.ts index 8eb5f517..53d4a993 100644 --- a/src/stores/layout.ts +++ b/src/stores/layout.ts @@ -2,6 +2,7 @@ import { create } from 'zustand'; import { AppLayout } from '../types'; import api from '../api'; import { saveAs } from 'file-saver'; +import websocketService from '../services/websocket'; export interface State { layout: AppLayout; @@ -41,3 +42,7 @@ export const useLayout = create((set) => ({ reader.readAsText(file); }, })); + +websocketService.registerMessageHandler('resetLayout', () => { + useLayout.getState().reset(); +}); From a1d7cb919723be515a6c8429935510f92de3c7dd Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 13:23:12 +0200 Subject: [PATCH 5/8] feat: handle errors during message serialization --- renumics/spotlight/backend/tasks/reduction.py | 5 +--- renumics/spotlight/backend/websockets.py | 27 ++++++++++++++----- src/services/task.ts | 23 +++++++++++----- src/services/websocket/index.ts | 13 +++++---- 4 files changed, 45 insertions(+), 23 deletions(-) diff --git a/renumics/spotlight/backend/tasks/reduction.py b/renumics/spotlight/backend/tasks/reduction.py index 00207c38..2b6f3180 100644 --- a/renumics/spotlight/backend/tasks/reduction.py +++ b/renumics/spotlight/backend/tasks/reduction.py @@ -116,10 +116,7 @@ def compute_pca( from sklearn import preprocessing, decomposition - try: - data, indices = align_data(data_store, column_names, indices) - except (ColumnNotExistsError, ValueError): - return np.empty(0, np.float64), [] + data, indices = align_data(data_store, column_names, indices) if data.size == 0: return np.empty(0, np.float64), [] if data.shape[1] == 1: diff --git a/renumics/spotlight/backend/websockets.py b/renumics/spotlight/backend/websockets.py index 30d15652..99187dd8 100644 --- a/renumics/spotlight/backend/websockets.py +++ b/renumics/spotlight/backend/websockets.py @@ -108,11 +108,25 @@ async def send_async(self, message: Message) -> None: """ Send a message async. """ + try: - message_data = message.dict() - await self.websocket.send_text( - orjson.dumps(message_data, option=orjson.OPT_SERIALIZE_NUMPY).decode() + json_text = orjson.dumps( + message.dict(), option=orjson.OPT_SERIALIZE_NUMPY + ).decode() + except TypeError as e: + error_message = Message( + type="error", + data={ + "type": type(e).__name__, + "title": "Failed to serialize message", + "detail": str(e), + }, ) + json_text = orjson.dumps( + error_message.dict(), option=orjson.OPT_SERIALIZE_NUMPY + ).decode() + try: + await self.websocket.send_text(json_text) except WebSocketDisconnect: self._on_disconnect() except RuntimeError: @@ -256,7 +270,6 @@ async def _(data: TaskData, connection: WebsocketConnection) -> None: try: task_func = TASK_FUNCS[data.task] - print(data.args) result = await connection.task_manager.run_async( task_func, # type: ignore args=(data_store,), @@ -267,6 +280,7 @@ async def _(data: TaskData, connection: WebsocketConnection) -> None: points = cast(np.ndarray, result[0]) valid_indices = cast(np.ndarray, result[1]) except TaskCancelled: + print("task cancelled") pass except Problem as e: msg = Message( @@ -275,13 +289,14 @@ async def _(data: TaskData, connection: WebsocketConnection) -> None: "task_id": data.task_id, "error": { "type": type(e).__name__, - "title": type(e).__name__, - "detail": type(e).__doc__, + "title": e.title, + "detail": e.detail, }, }, ) await connection.send_async(msg) except Exception as e: + print("task failed") msg = Message( type="task.error", data={ diff --git a/src/services/task.ts b/src/services/task.ts index 0a74d58e..2d16c45f 100644 --- a/src/services/task.ts +++ b/src/services/task.ts @@ -1,5 +1,6 @@ import { useDataset } from '../lib'; import websocketService, { WebsocketService } from './websocket'; +import { Message } from './websocket/types'; interface ResponseHandler { resolve: (result: unknown) => void; @@ -16,12 +17,22 @@ class TaskService { constructor(websocketService: WebsocketService) { this.dispatchTable = new Map(); this.websocketService = websocketService; - this.websocketService.registerMessageHandler('task.result', (message: any) => { - this.dispatchTable.get(message.data.task_id)?.resolve(message.data.result); - }); - this.websocketService.registerMessageHandler('task.error', (message: any) => { - this.dispatchTable.get(message.data.task_id)?.reject(message.data.error); - }); + this.websocketService.registerMessageHandler( + 'task.result', + (message: Message) => { + this.dispatchTable + .get(message.data.task_id) + ?.resolve(message.data.result); + } + ); + this.websocketService.registerMessageHandler( + 'task.error', + (message: Message) => { + this.dispatchTable + .get(message.data.task_id) + ?.reject(message.data.error); + } + ); } async run(task: string, name: string, args: unknown) { diff --git a/src/services/websocket/index.ts b/src/services/websocket/index.ts index efc344e1..2d5e241a 100644 --- a/src/services/websocket/index.ts +++ b/src/services/websocket/index.ts @@ -1,3 +1,4 @@ +import { notifyProblem } from '../../notify'; import Connection from './connection'; import { Message, MessageHandler } from './types'; @@ -35,11 +36,9 @@ const websocketService = new WebsocketService( globalThis.location.hostname, globalThis.location.port ); -export default websocketService; -// TODO: these functions should probably be somewhere else -/* - } else if (message.type === 'resetLayout') { - useLayout.getState().reset(); - } -*/ +websocketService.registerMessageHandler('error', (message: Message) => { + notifyProblem(message.data); +}); + +export default websocketService; From 480c9009479a2d180a4e09c2346edb8424bc20c1 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 13:31:46 +0200 Subject: [PATCH 6/8] feat: raise serialization errors and handle at call site --- renumics/spotlight/backend/websockets.py | 36 +++++++++++++++--------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/renumics/spotlight/backend/websockets.py b/renumics/spotlight/backend/websockets.py index 99187dd8..34a39d77 100644 --- a/renumics/spotlight/backend/websockets.py +++ b/renumics/spotlight/backend/websockets.py @@ -70,6 +70,12 @@ class UnknownMessageType(Exception): """ +class SerializationError(Exception): + """ + Failed to serialize the WS message + """ + + PayloadType = Type[BaseModel] MessageHandler = Callable[[Any, "WebsocketConnection"], Coroutine[Any, Any, Any]] MessageHandlerSpec = Tuple[PayloadType, MessageHandler] @@ -114,17 +120,7 @@ async def send_async(self, message: Message) -> None: message.dict(), option=orjson.OPT_SERIALIZE_NUMPY ).decode() except TypeError as e: - error_message = Message( - type="error", - data={ - "type": type(e).__name__, - "title": "Failed to serialize message", - "detail": str(e), - }, - ) - json_text = orjson.dumps( - error_message.dict(), option=orjson.OPT_SERIALIZE_NUMPY - ).decode() + raise SerializationError(str(e)) try: await self.websocket.send_text(json_text) except WebSocketDisconnect: @@ -280,7 +276,6 @@ async def _(data: TaskData, connection: WebsocketConnection) -> None: points = cast(np.ndarray, result[0]) valid_indices = cast(np.ndarray, result[1]) except TaskCancelled: - print("task cancelled") pass except Problem as e: msg = Message( @@ -296,7 +291,6 @@ async def _(data: TaskData, connection: WebsocketConnection) -> None: ) await connection.send_async(msg) except Exception as e: - print("task failed") msg = Message( type="task.error", data={ @@ -313,4 +307,18 @@ async def _(data: TaskData, connection: WebsocketConnection) -> None: msg = Message( type="task.result", data={"points": points, "indices": valid_indices} ) - await connection.send_async(msg) + try: + await connection.send_async(msg) + except SerializationError as e: + error_msg = Message( + type="task.error", + data={ + "task_id": data.task_id, + "error": { + "type": type(e).__name__, + "title": "Serialization Error", + "detail": str(e), + }, + }, + ) + await connection.send_async(error_msg) From fd6c71e232aa59351f05392abb6d6cc308c2ff54 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 13:53:42 +0200 Subject: [PATCH 7/8] refactor: change format of task result --- renumics/spotlight/backend/websockets.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/renumics/spotlight/backend/websockets.py b/renumics/spotlight/backend/websockets.py index 34a39d77..aed9a1ef 100644 --- a/renumics/spotlight/backend/websockets.py +++ b/renumics/spotlight/backend/websockets.py @@ -305,7 +305,11 @@ async def _(data: TaskData, connection: WebsocketConnection) -> None: await connection.send_async(msg) else: msg = Message( - type="task.result", data={"points": points, "indices": valid_indices} + type="task.result", + data={ + "task_id": data.task_id, + "result": {"points": points, "indices": valid_indices}, + }, ) try: await connection.send_async(msg) From 2c0ef37abe2fda813f809c1ec9e1774df1c77061 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 15:03:00 +0200 Subject: [PATCH 8/8] chore: fix linter errors --- src/services/task.ts | 24 ++++++--------------- src/services/websocket/index.ts | 5 +++-- src/services/websocket/types.ts | 5 +++-- src/widgets/SimilarityMap/SimilarityMap.tsx | 1 - 4 files changed, 13 insertions(+), 22 deletions(-) diff --git a/src/services/task.ts b/src/services/task.ts index 2d16c45f..f95a5d5e 100644 --- a/src/services/task.ts +++ b/src/services/task.ts @@ -1,6 +1,5 @@ import { useDataset } from '../lib'; import websocketService, { WebsocketService } from './websocket'; -import { Message } from './websocket/types'; interface ResponseHandler { resolve: (result: unknown) => void; @@ -17,25 +16,16 @@ class TaskService { constructor(websocketService: WebsocketService) { this.dispatchTable = new Map(); this.websocketService = websocketService; - this.websocketService.registerMessageHandler( - 'task.result', - (message: Message) => { - this.dispatchTable - .get(message.data.task_id) - ?.resolve(message.data.result); - } - ); - this.websocketService.registerMessageHandler( - 'task.error', - (message: Message) => { - this.dispatchTable - .get(message.data.task_id) - ?.reject(message.data.error); - } - ); + this.websocketService.registerMessageHandler('task.result', (data) => { + this.dispatchTable.get(data.task_id)?.resolve(data.result); + }); + this.websocketService.registerMessageHandler('task.error', (data) => { + this.dispatchTable.get(data.task_id)?.reject(data.error); + }); } async run(task: string, name: string, args: unknown) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any return new Promise((resolve, reject) => { const task_id = crypto.randomUUID(); this.dispatchTable.set(task_id, { resolve, reject }); diff --git a/src/services/websocket/index.ts b/src/services/websocket/index.ts index 2d5e241a..e722e1de 100644 --- a/src/services/websocket/index.ts +++ b/src/services/websocket/index.ts @@ -1,4 +1,5 @@ import { notifyProblem } from '../../notify'; +import { Problem } from '../../types'; import Connection from './connection'; import { Message, MessageHandler } from './types'; @@ -37,8 +38,8 @@ const websocketService = new WebsocketService( globalThis.location.port ); -websocketService.registerMessageHandler('error', (message: Message) => { - notifyProblem(message.data); +websocketService.registerMessageHandler('error', (problem: Problem) => { + notifyProblem(problem); }); export default websocketService; diff --git a/src/services/websocket/types.ts b/src/services/websocket/types.ts index f33f9eee..18840150 100644 --- a/src/services/websocket/types.ts +++ b/src/services/websocket/types.ts @@ -1,6 +1,7 @@ export interface Message { type: string; - data: any; + data: unknown; } -export type MessageHandler = (message: Message) => void; +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type MessageHandler = (data: any) => void; diff --git a/src/widgets/SimilarityMap/SimilarityMap.tsx b/src/widgets/SimilarityMap/SimilarityMap.tsx index 5c42ebc9..023c4dd1 100644 --- a/src/widgets/SimilarityMap/SimilarityMap.tsx +++ b/src/widgets/SimilarityMap/SimilarityMap.tsx @@ -1,5 +1,4 @@ import SimilaritiesIcon from '../../icons/Bubbles'; -import WarningIcon from '../../icons/Warning'; import LoadingIndicator from '../../components/LoadingIndicator'; import Plot, { MergeStrategy,