From c7a5293ddc811aff8e27b3c4bf86a7c945c39a6a Mon Sep 17 00:00:00 2001 From: Tarek Date: Thu, 10 Aug 2023 15:29:41 +0200 Subject: [PATCH 01/31] add mel scale for spectrogram --- src/lenses/SpectrogramLens/Spectrogram.tsx | 38 +++++++++++++++++++++- src/lenses/SpectrogramLens/index.tsx | 33 +++++++++++++++++-- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/src/lenses/SpectrogramLens/Spectrogram.tsx b/src/lenses/SpectrogramLens/Spectrogram.tsx index 8dd7f5e0..bfd77463 100644 --- a/src/lenses/SpectrogramLens/Spectrogram.tsx +++ b/src/lenses/SpectrogramLens/Spectrogram.tsx @@ -51,4 +51,40 @@ const amplitudeToDb = (amplitude: number, ref: number, amin: number) => { return log_spec; }; -export { unitType, freqType, fixWindow, amplitudeToDb }; +const hzToMel = (freq: number) => { + // Fill in the linear part + const f_min = 0.0; + const f_sp = 200.0 / 3; + + let mel = (freq - f_min) / f_sp; + + // Fill in the log-scale part + const min_log_hz = 1000.0; // beginning of log region (Hz) + const min_log_mel = (min_log_hz - f_min) / f_sp; // same (Mels) + const logstep = Math.log(6.4) / 27.0; // step size for log region + + if (freq >= min_log_hz) { + mel = min_log_mel + Math.log(freq / min_log_hz) / logstep; + } + + return mel; +}; + +const melToHz = (mel: number) => { + const f_min = 0.0; + const f_sp = 200.0 / 3; + let freq = f_min + f_sp * mel; + + // And now the nonlinear scale + const min_log_hz = 1000.0; // beginning of log region (Hz) + const min_log_mel = (min_log_hz - f_min) / f_sp; // same (Mels) + const logstep = Math.log(6.4) / 27.0; // step size for log region + + if (mel >= min_log_mel) { + // If we have scalar data, check directly + freq = min_log_hz * Math.exp(logstep * (mel - min_log_mel)); + } + return freq; +}; + +export { unitType, freqType, fixWindow, amplitudeToDb, hzToMel, melToHz }; diff --git a/src/lenses/SpectrogramLens/index.tsx b/src/lenses/SpectrogramLens/index.tsx index 60522df1..13ecb298 100644 --- a/src/lenses/SpectrogramLens/index.tsx +++ b/src/lenses/SpectrogramLens/index.tsx @@ -11,7 +11,14 @@ import { Lens } from '../../types'; import useSetting from '../useSetting'; import MenuBar from './MenuBar'; import chroma from 'chroma-js'; -import { fixWindow, freqType, unitType, amplitudeToDb } from './Spectrogram'; +import { + fixWindow, + freqType, + unitType, + amplitudeToDb, + hzToMel, + melToHz, +} from './Spectrogram'; const Container = tw.div`flex flex-col w-full h-full items-stretch justify-center`; const EmptyNote = styled.p` @@ -246,6 +253,28 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => { drawData[i] = col; } colorScale = colorPalette.scale().domain([log_spec_min, log_spec_max]); + } else if (ampScale === 'mel') { + let mel_spec_min = 0; + let mel_spec_max = 0; + + for (let i = 0; i < frequenciesData.length; i++) { + const col = []; + + for (let j = 0; j < frequenciesData[i].length; j++) { + const amplitude = frequenciesData[i][j]; + col[j] = hzToMel(amplitude); + + if (col[j] > mel_spec_max) { + mel_spec_max = col[j]; + } + + if (col[j] < mel_spec_min) { + mel_spec_min = col[j]; + } + } + drawData[i] = col; + } + colorScale = colorPalette.scale().domain([mel_spec_min, mel_spec_max]); } else { // ampScale === 'linear' colorScale = colorPalette.scale().domain([0, 256]); @@ -441,7 +470,7 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => { availableFreqScales={['linear', 'logarithmic']} freqScale={freqScale} onChangeFreqScale={handleFreqScaleChange} - availableAmpScales={['decibel', 'linear']} + availableAmpScales={['decibel', 'linear', 'mel']} ampScale={ampScale} onChangeAmpScale={handleAmpScaleChange} /> From 29ce91bbbde79b7585a6d41a1e2b9761829ae31a Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Mon, 16 Oct 2023 15:21:38 +0200 Subject: [PATCH 02/31] Remove unused API model --- renumics/spotlight/data_source/data_source.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/renumics/spotlight/data_source/data_source.py b/renumics/spotlight/data_source/data_source.py index d06133f3..695acb5d 100644 --- a/renumics/spotlight/data_source/data_source.py +++ b/renumics/spotlight/data_source/data_source.py @@ -6,7 +6,6 @@ import pandas as pd import numpy as np -from pydantic.dataclasses import dataclass from renumics.spotlight.dataset.exceptions import ( ColumnExistsError, @@ -30,17 +29,6 @@ class ColumnMetadata: tags: List[str] = dataclasses.field(default_factory=list) -@dataclass -class CellsUpdate: - """ - A dataset's cell update. - """ - - value: Any - author: str - edited_at: str - - class DataSource(ABC): """abstract base class for different data sources""" @@ -61,7 +49,7 @@ def column_names(self) -> List[str]: @abstractmethod def intermediate_dtypes(self) -> DTypeMap: """ - The dtypes of intermediate values + The dtypes of intermediate values. Values for all columns must be filled. """ @property @@ -94,7 +82,7 @@ def check_generation_id(self, generation_id: int) -> None: @abstractmethod def semantic_dtypes(self) -> DTypeMap: """ - Semantic dtypes for viewer. + Semantic dtypes for viewer. Some values may be not present. """ @abstractmethod From f056ab8da32839cae4c451cf5f3a8f5af4a8b852 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Tue, 17 Oct 2023 09:56:30 +0200 Subject: [PATCH 03/31] Infer pandas dtypes in data store --- renumics/spotlight/data_store.py | 64 ++++++++++++------- renumics/spotlight/io/pandas.py | 11 ---- .../core/pandas_data_source.py | 31 +++++---- 3 files changed, 57 insertions(+), 49 deletions(-) diff --git a/renumics/spotlight/data_store.py b/renumics/spotlight/data_store.py index 42867e5a..416d592c 100644 --- a/renumics/spotlight/data_store.py +++ b/renumics/spotlight/data_store.py @@ -19,13 +19,16 @@ DType, DTypeMap, EmbeddingDType, + array_dtype, is_array_dtype, is_audio_dtype, is_category_dtype, + is_embedding_dtype, is_file_dtype, is_str_dtype, is_mixed_dtype, is_bytes_dtype, + is_window_dtype, str_dtype, audio_dtype, image_dtype, @@ -130,6 +133,8 @@ def get_waveform(self, column_name: str, index: int) -> Optional[np.ndarray]: def _update_dtypes(self) -> None: guessed_dtypes = self._data_source.semantic_dtypes.copy() + print(self._data_source.intermediate_dtypes) + print(guessed_dtypes) # guess missing dtypes from intermediate dtypes for col, dtype in self._data_source.intermediate_dtypes.items(): @@ -171,33 +176,32 @@ def _guess_dtype(self, col: str) -> DType: return semantic_dtype sample_values = self._data_source.get_column_values(col, slice(10)) - sample_dtypes = [_guess_value_dtype(value) for value in sample_values] - - try: - mode_dtype = statistics.mode(sample_dtypes) - except statistics.StatisticsError: + sample_dtypes: List[DType] = [] + for value in sample_values: + guessed_dtype = _guess_value_dtype(value) + if guessed_dtype is not None: + sample_dtypes.append(guessed_dtype) + if not sample_dtypes: return semantic_dtype - return mode_dtype or semantic_dtype + mode_dtype = statistics.mode(sample_dtypes) + # For windows and embeddings, at least sample values must be aligned. + if is_window_dtype(mode_dtype) and any( + not is_window_dtype(dtype) for dtype in sample_dtypes + ): + return array_dtype + if is_embedding_dtype(mode_dtype) and any( + (not is_embedding_dtype(dtype)) or dtype.length != mode_dtype.length + for dtype in sample_dtypes + ): + return array_dtype + + return mode_dtype def _intermediate_to_semantic_dtype(intermediate_dtype: DType) -> DType: if is_array_dtype(intermediate_dtype): - if intermediate_dtype.shape is None: - return intermediate_dtype - if intermediate_dtype.shape == (2,): - return window_dtype - if intermediate_dtype.ndim == 1 and intermediate_dtype.shape[0] is not None: - return EmbeddingDType(intermediate_dtype.shape[0]) - if intermediate_dtype.ndim == 1 and intermediate_dtype.shape[0] is None: - return sequence_1d_dtype - if intermediate_dtype.ndim == 2 and ( - intermediate_dtype.shape[0] == 2 or intermediate_dtype.shape[1] == 2 - ): - return sequence_1d_dtype - if intermediate_dtype.ndim == 3 and intermediate_dtype.shape[-1] in (1, 3, 4): - return image_dtype - return intermediate_dtype + return _guess_array_dtype(intermediate_dtype) if is_file_dtype(intermediate_dtype): return str_dtype if is_mixed_dtype(intermediate_dtype): @@ -248,5 +252,21 @@ def _guess_value_dtype(value: Any) -> Optional[DType]: except (TypeError, ValueError): pass else: - return ArrayDType(value.shape) + return _guess_array_dtype(ArrayDType(value.shape)) return None + + +def _guess_array_dtype(dtype: ArrayDType) -> DType: + if dtype.shape is None: + return dtype + if dtype.shape == (2,): + return window_dtype + if dtype.ndim == 1 and dtype.shape[0] is not None: + return EmbeddingDType(dtype.shape[0]) + if dtype.ndim == 1 and dtype.shape[0] is None: + return sequence_1d_dtype + if dtype.ndim == 2 and (dtype.shape[0] == 2 or dtype.shape[1] == 2): + return sequence_1d_dtype + if dtype.ndim == 3 and dtype.shape[-1] in (1, 3, 4): + return image_dtype + return dtype diff --git a/renumics/spotlight/io/pandas.py b/renumics/spotlight/io/pandas.py index 4cf84f9e..5663fc9a 100644 --- a/renumics/spotlight/io/pandas.py +++ b/renumics/spotlight/io/pandas.py @@ -58,17 +58,6 @@ def create_typed_series( return pd.Series([] if values is None else values, dtype=pandas_dtype) -def is_empty(value: Any) -> bool: - """ - Check if value is `NA` or an empty string. - """ - if is_iterable(value): - # `pd.isna` with an iterable argument returns an iterable result. But - # an iterable cannot be NA or empty string by default. - return False - return pd.isna(value) or value == "" - - def try_literal_eval(x: str) -> Any: """ Try to evaluate a literal expression, otherwise return value as is. diff --git a/renumics/spotlight_plugins/core/pandas_data_source.py b/renumics/spotlight_plugins/core/pandas_data_source.py index 430a69ca..7a02c94d 100644 --- a/renumics/spotlight_plugins/core/pandas_data_source.py +++ b/renumics/spotlight_plugins/core/pandas_data_source.py @@ -7,10 +7,9 @@ import numpy as np import pandas as pd import datasets -from renumics.spotlight import dtypes +from renumics.spotlight import dtypes from renumics.spotlight.io.pandas import ( - infer_dtype, prepare_hugging_face_dict, stringify_columns, try_literal_eval, @@ -23,7 +22,6 @@ from renumics.spotlight.backend.exceptions import DatasetColumnsNotUnique from renumics.spotlight.dataset.exceptions import ColumnNotExistsError from renumics.spotlight.data_source.exceptions import InvalidDataSource -from renumics.spotlight.dtypes import DTypeMap @datasource(pd.DataFrame) @@ -41,6 +39,7 @@ class PandasDataSource(DataSource): _uid: str _df: pd.DataFrame _name: str + _intermediate_dtypes: dtypes.DTypeMap def __init__(self, source: Union[Path, pd.DataFrame]): if isinstance(source, Path): @@ -99,7 +98,9 @@ def __init__(self, source: Union[Path, pd.DataFrame]): raise DatasetColumnsNotUnique() self._generation_id = 0 self._uid = str(id(df)) + print(df.dtypes) self._df = df.convert_dtypes() + print(self._df.dtypes) self._intermediate_dtypes = { # TODO: convert column name col: _determine_intermediate_dtype(self._df[col]) @@ -118,18 +119,15 @@ def df(self) -> pd.DataFrame: return self._df.copy() @property - def intermediate_dtypes(self) -> DTypeMap: + def intermediate_dtypes(self) -> dtypes.DTypeMap: return self._intermediate_dtypes def __len__(self) -> int: return len(self._df) @property - def semantic_dtypes(self) -> DTypeMap: - return { - str(column_name): infer_dtype(self.df[column_name]) - for column_name in self.df - } + def semantic_dtypes(self) -> dtypes.DTypeMap: + return {} def get_generation_id(self) -> int: return self._generation_id @@ -167,12 +165,14 @@ def get_column_values( if pd.api.types.is_categorical_dtype(column): return column.cat.codes if pd.api.types.is_string_dtype(column): - values = column.to_numpy() - na_mask = column.isna() - values[na_mask] = None - return values + column = column.astype(object).mask(column.isna(), None) + str_mask = column.map(type) == str + column[str_mask] = column[str_mask].apply(try_literal_eval) + dict_mask = column.map(type) == dict + column[dict_mask] = column[dict_mask].apply(prepare_hugging_face_dict) + return column.to_numpy() if pd.api.types.is_object_dtype(column): - column = column.mask(column.isna(), None) + column = column.astype(object).mask(column.isna(), None) str_mask = column.map(type) == str column[str_mask] = column[str_mask].apply(try_literal_eval) dict_mask = column.map(type) == dict @@ -222,5 +222,4 @@ def _determine_intermediate_dtype(column: pd.Series) -> dtypes.DType: return dtypes.datetime_dtype if pd.api.types.is_string_dtype(column): return dtypes.str_dtype - else: - return dtypes.mixed_dtype + return dtypes.mixed_dtype From 44c86e566dd0a29882986f5c0b92e90d2fe98a3f Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Tue, 17 Oct 2023 10:00:57 +0200 Subject: [PATCH 04/31] Remove debug prints --- renumics/spotlight/data_store.py | 2 -- renumics/spotlight_plugins/core/pandas_data_source.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/renumics/spotlight/data_store.py b/renumics/spotlight/data_store.py index 416d592c..6f2bd862 100644 --- a/renumics/spotlight/data_store.py +++ b/renumics/spotlight/data_store.py @@ -133,8 +133,6 @@ def get_waveform(self, column_name: str, index: int) -> Optional[np.ndarray]: def _update_dtypes(self) -> None: guessed_dtypes = self._data_source.semantic_dtypes.copy() - print(self._data_source.intermediate_dtypes) - print(guessed_dtypes) # guess missing dtypes from intermediate dtypes for col, dtype in self._data_source.intermediate_dtypes.items(): diff --git a/renumics/spotlight_plugins/core/pandas_data_source.py b/renumics/spotlight_plugins/core/pandas_data_source.py index 7a02c94d..3eee6ebe 100644 --- a/renumics/spotlight_plugins/core/pandas_data_source.py +++ b/renumics/spotlight_plugins/core/pandas_data_source.py @@ -98,9 +98,7 @@ def __init__(self, source: Union[Path, pd.DataFrame]): raise DatasetColumnsNotUnique() self._generation_id = 0 self._uid = str(id(df)) - print(df.dtypes) self._df = df.convert_dtypes() - print(self._df.dtypes) self._intermediate_dtypes = { # TODO: convert column name col: _determine_intermediate_dtype(self._df[col]) From 33492b5bcd46ef1c39b4fc922013c39d95f8ea01 Mon Sep 17 00:00:00 2001 From: Tarek Date: Tue, 17 Oct 2023 14:35:39 +0200 Subject: [PATCH 05/31] remove duplicate decibel calculation --- .../SpectrogramLens/SpectrogramWorker.ts | 7 +- src/lenses/SpectrogramLens/index.tsx | 64 ++++++++++--------- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/lenses/SpectrogramLens/SpectrogramWorker.ts b/src/lenses/SpectrogramLens/SpectrogramWorker.ts index 55cad783..ff47fcf3 100644 --- a/src/lenses/SpectrogramLens/SpectrogramWorker.ts +++ b/src/lenses/SpectrogramLens/SpectrogramWorker.ts @@ -233,12 +233,7 @@ const calculateFrequencies = ( tables.cosTable ); - const array = new Uint8Array(fftSamples / 2); - let j; - for (j = 0; j < fftSamples / 2; j++) { - array[j] = Math.max(-255, Math.log10(spectrum[j]) * 45); - } - channelFreq.push(array); + channelFreq.push(spectrum); currentOffset += fftSamples - noverlap; } frequencies.push(channelFreq); diff --git a/src/lenses/SpectrogramLens/index.tsx b/src/lenses/SpectrogramLens/index.tsx index 13ecb298..9643ce99 100644 --- a/src/lenses/SpectrogramLens/index.tsx +++ b/src/lenses/SpectrogramLens/index.tsx @@ -10,15 +10,7 @@ import { ColorsState, useColors } from '../../stores/colors'; import { Lens } from '../../types'; import useSetting from '../useSetting'; import MenuBar from './MenuBar'; -import chroma from 'chroma-js'; -import { - fixWindow, - freqType, - unitType, - amplitudeToDb, - hzToMel, - melToHz, -} from './Spectrogram'; +import { fixWindow, freqType, unitType, amplitudeToDb, hzToMel } from './Spectrogram'; const Container = tw.div`flex flex-col w-full h-full items-stretch justify-center`; const EmptyNote = styled.p` @@ -174,15 +166,23 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => { return; } + //console.log(JSON.stringify((buffer.getChannelData(0).slice(start, end))) + //const slice = buffer.getChannelData(0).slice(start, end); + const arr = buffer.getChannelData(0).slice(start, end); + console.log(Math.max(...arr), Math.min(...arr)); const frequenciesData = await worker( FFT_SAMPLES, backend.windowFunc, backend.alpha, width, FFT_SAMPLES, - buffer.getChannelData(0).slice(start, end) + arr ); + console.log(frequenciesData); + console.log(frequenciesData.length); + console.log(frequenciesData[0].length); + setIsComputing(false); // Get the canvas render context 2D @@ -216,7 +216,10 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => { const widthScale = d3.scaleLinear([0, width], [0, frequenciesData.length]); let drawData = []; - let colorScale: chroma.Scale; + //let colorScale: chroma.Scale; + + let min = 0; + let max = 0; if (ampScale === 'decibel') { let ref = 0; @@ -230,9 +233,6 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => { //const top_db = 80; const amin = 1e-5; - let log_spec_max = 0; - let log_spec_min = 0; - // Convert amplitudes to decibels for (let i = 0; i < frequenciesData.length; i++) { const col = []; @@ -241,45 +241,51 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => { const amplitude = frequenciesData[i][j]; col[j] = amplitudeToDb(amplitude, ref, amin); - if (col[j] > log_spec_max) { - log_spec_max = col[j]; + if (col[j] > max) { + max = col[j]; } - if (col[j] < log_spec_min) { - log_spec_min = col[j]; + if (col[j] < min) { + min = col[j]; } } drawData[i] = col; } - colorScale = colorPalette.scale().domain([log_spec_min, log_spec_max]); } else if (ampScale === 'mel') { - let mel_spec_min = 0; - let mel_spec_max = 0; - for (let i = 0; i < frequenciesData.length; i++) { const col = []; for (let j = 0; j < frequenciesData[i].length; j++) { const amplitude = frequenciesData[i][j]; - col[j] = hzToMel(amplitude); + col[j] = hzToMel(amplitude ** 2); - if (col[j] > mel_spec_max) { - mel_spec_max = col[j]; + if (col[j] > max) { + max = col[j]; } - if (col[j] < mel_spec_min) { - mel_spec_min = col[j]; + if (col[j] < min) { + min = col[j]; } } drawData[i] = col; } - colorScale = colorPalette.scale().domain([mel_spec_min, mel_spec_max]); } else { // ampScale === 'linear' - colorScale = colorPalette.scale().domain([0, 256]); + for (let i = 0; i < frequenciesData.length; i++) { + const maxI = Math.max(...frequenciesData[i]); + const minI = Math.min(...frequenciesData[i]); + + if (maxI > max) { + max = maxI; + } + if (minI < min) { + min = minI; + } + } drawData = frequenciesData; } + const colorScale = colorPalette.scale().domain([min, max]); for (let y = 0; y < height; y++) { let value = 0; From e57108d6ec3fdaa6495e8b37b9f774144e41218e Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Mon, 23 Oct 2023 10:42:32 +0200 Subject: [PATCH 06/31] Implement __eq__ and __hash__ for dtypes --- renumics/spotlight/dtypes/__init__.py | 44 +++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/renumics/spotlight/dtypes/__init__.py b/renumics/spotlight/dtypes/__init__.py index 0e24ea10..ae96cd7e 100644 --- a/renumics/spotlight/dtypes/__init__.py +++ b/renumics/spotlight/dtypes/__init__.py @@ -36,6 +36,14 @@ def __init__(self, name: str): def __str__(self) -> str: return self.name + def __eq__(self, other: Any) -> bool: + if isinstance(other, DType): + return other._name == self._name + return False + + def __hash__(self) -> int: + return hash(self._name) + @property def name(self) -> str: return self._name @@ -53,8 +61,10 @@ def __init__( self, categories: Optional[Union[Iterable[str], Dict[str, int]]] = None ): super().__init__("Category") - if isinstance(categories, dict) or categories is None: - self._categories = categories + if isinstance(categories, dict): + self._categories = dict(sorted(categories.items(), key=lambda x: x[1])) + elif categories is None: + self._categories = None else: self._categories = { category: code for code, category in enumerate(categories) @@ -71,6 +81,20 @@ def __init__( category: code for code, category in self._inverted_categories.items() } + def __eq__(self, other: Any) -> bool: + if isinstance(other, CategoryDType): + return other._name == self._name and other._categories == self._categories + return False + + def __hash__(self) -> int: + if self._categories is None: + return hash(self._name) ^ hash(None) + return ( + hash(self._name) + ^ hash(tuple(self._categories.keys())) + ^ hash(tuple(self._categories.values())) + ) + @property def categories(self) -> Optional[Dict[str, int]]: return self._categories @@ -91,6 +115,14 @@ def __init__(self, shape: Optional[Tuple[Optional[int], ...]] = None): super().__init__("array") self.shape = shape + def __eq__(self, other: Any) -> bool: + if isinstance(other, ArrayDType): + return other._name == self._name and other.shape == self.shape + return False + + def __hash__(self) -> int: + return hash(self._name) ^ hash(self.shape) + @property def ndim(self) -> int: if self.shape is None: @@ -111,6 +143,14 @@ def __init__(self, length: Optional[int] = None): raise ValueError(f"Length must be non-negative, but {length} received.") self.length = length + def __eq__(self, other: Any) -> bool: + if isinstance(other, EmbeddingDType): + return other._name == self._name and other.length == self.length + return False + + def __hash__(self) -> int: + return hash(self._name) ^ hash(self.length) + class Sequence1DDType(DType): """ From 01eccccc9c1debfabddb61833de2b151df1f7436 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Mon, 23 Oct 2023 13:40:46 +0200 Subject: [PATCH 07/31] Move `spotlight.io.pandas` to `spotlight.dataset.pandas` --- renumics/spotlight/dataset/__init__.py | 10 +- renumics/spotlight/dtypes/__init__.py | 16 +- renumics/spotlight/io/__init__.py | 16 + renumics/spotlight/io/pandas.py | 316 ------------------ .../core/pandas_data_source.py | 8 +- tests/integration/dataset/test_dataset.py | 2 +- 6 files changed, 34 insertions(+), 334 deletions(-) delete mode 100644 renumics/spotlight/io/pandas.py diff --git a/renumics/spotlight/dataset/__init__.py b/renumics/spotlight/dataset/__init__.py index 99828ddf..e586499d 100644 --- a/renumics/spotlight/dataset/__init__.py +++ b/renumics/spotlight/dataset/__init__.py @@ -32,12 +32,7 @@ from typing_extensions import TypeGuard from renumics.spotlight.__version__ import __version__ -from renumics.spotlight.io.pandas import ( - infer_dtypes, - prepare_column, - is_string_mask, - stringify_columns, -) +from .pandas import create_typed_series, infer_dtypes, is_string_mask, prepare_column from renumics.spotlight.typing import ( BoolType, IndexType, @@ -47,7 +42,6 @@ is_integer, is_iterable, ) -from renumics.spotlight.io.pandas import create_typed_series from renumics.spotlight.dtypes.conversion import prepare_path_or_url from renumics.spotlight import dtypes as spotlight_dtypes @@ -738,7 +732,7 @@ def from_pandas( df = df.reset_index(level=df.index.names) # type: ignore else: df = df.copy() - df.columns = pd.Index(stringify_columns(df)) + df.columns = pd.Index([str(column) for column in df.columns]) if dtypes is None: dtypes = {} diff --git a/renumics/spotlight/dtypes/__init__.py b/renumics/spotlight/dtypes/__init__.py index ae96cd7e..63910215 100644 --- a/renumics/spotlight/dtypes/__init__.py +++ b/renumics/spotlight/dtypes/__init__.py @@ -9,6 +9,8 @@ __all__ = [ "CategoryDType", + "ArrayDType", + "EmbeddingDType", "Sequence1DDType", "bool_dtype", "int_dtype", @@ -83,7 +85,7 @@ def __init__( def __eq__(self, other: Any) -> bool: if isinstance(other, CategoryDType): - return other._name == self._name and other._categories == self._categories + return other._categories == self._categories return False def __hash__(self) -> int: @@ -117,7 +119,7 @@ def __init__(self, shape: Optional[Tuple[Optional[int], ...]] = None): def __eq__(self, other: Any) -> bool: if isinstance(other, ArrayDType): - return other._name == self._name and other.shape == self.shape + return other.shape == self.shape return False def __hash__(self) -> int: @@ -145,7 +147,7 @@ def __init__(self, length: Optional[int] = None): def __eq__(self, other: Any) -> bool: if isinstance(other, EmbeddingDType): - return other._name == self._name and other.length == self.length + return other.length == self.length return False def __hash__(self) -> int: @@ -165,6 +167,14 @@ def __init__(self, x_label: str = "x", y_label: str = "y"): self.x_label = x_label self.y_label = y_label + def __eq__(self, other: Any) -> bool: + if isinstance(other, Sequence1DDType): + return other.x_label == self.x_label and other.y_label == self.y_label + return False + + def __hash__(self) -> int: + return hash(self._name) ^ hash(self.x_label) ^ hash(self.y_label) + ALIASES: Dict[Any, DType] = {} diff --git a/renumics/spotlight/io/__init__.py b/renumics/spotlight/io/__init__.py index 2d2a6d26..8162a843 100644 --- a/renumics/spotlight/io/__init__.py +++ b/renumics/spotlight/io/__init__.py @@ -1,6 +1,9 @@ """ Reading and writing of different data formats. """ +import ast +from contextlib import suppress +from typing import Any from .audio import ( get_format_codec, @@ -19,6 +22,8 @@ decode_gltf_arrays, encode_gltf_array, ) +from .huggingface import prepare_hugging_face_dict + __all__ = [ "get_format_codec", @@ -34,4 +39,15 @@ "check_gltf", "decode_gltf_arrays", "encode_gltf_array", + "prepare_hugging_face_dict", + "try_literal_eval", ] + + +def try_literal_eval(x: str) -> Any: + """ + Try to evaluate a literal expression, otherwise return value as is. + """ + with suppress(Exception): + return ast.literal_eval(x) + return x diff --git a/renumics/spotlight/io/pandas.py b/renumics/spotlight/io/pandas.py deleted file mode 100644 index 5663fc9a..00000000 --- a/renumics/spotlight/io/pandas.py +++ /dev/null @@ -1,316 +0,0 @@ -""" -This module contains helpers for importing `pandas.DataFrame`s. -""" - -import ast -import os.path -import statistics -from contextlib import suppress -from typing import Any, Dict, List, Optional, Sequence, Union - -import PIL.Image -import filetype -import trimesh -import numpy as np -import pandas as pd - -from renumics.spotlight.dtypes import ( - Audio, - Embedding, - Image, - Mesh, - Sequence1D, - Video, -) -from renumics.spotlight.media.exceptions import UnsupportedDType -from renumics.spotlight.typing import is_iterable, is_pathtype -from renumics.spotlight import dtypes - - -def create_typed_series( - dtype: dtypes.DType, values: Optional[Union[Sequence, np.ndarray]] = None -) -> pd.Series: - if dtypes.is_category_dtype(dtype): - if values is None or len(values) == 0: - return pd.Series( - dtype=pd.CategoricalDtype( - [] if not dtype.categories else list(dtype.categories.keys()) - ) - ) - if dtype.inverted_categories is None: - return pd.Series([None] * len(values), dtype=pd.CategoricalDtype()) - return pd.Series( - [dtype.inverted_categories.get(code) for code in values], - dtype=pd.CategoricalDtype(), - ) - if dtypes.is_bool_dtype(dtype): - pandas_dtype = "boolean" - elif dtypes.is_int_dtype(dtype): - pandas_dtype = "Int64" - elif dtypes.is_float_dtype(dtype): - pandas_dtype = "float" - elif dtypes.is_str_dtype(dtype): - pandas_dtype = "string" - elif dtypes.is_datetime_dtype(dtype): - pandas_dtype = "datetime64[ns]" - else: - pandas_dtype = "object" - return pd.Series([] if values is None else values, dtype=pandas_dtype) - - -def try_literal_eval(x: str) -> Any: - """ - Try to evaluate a literal expression, otherwise return value as is. - """ - with suppress(Exception): - return ast.literal_eval(x) - return x - - -def stringify_columns(df: pd.DataFrame) -> List[str]: - """ - Convert `pandas.DataFrame`'s column names to strings, no matter which index - is used. - """ - return [str(column_name) for column_name in df.columns] - - -def infer_dtype(column: pd.Series) -> dtypes.DType: - """ - Get an equivalent Spotlight data type for a `pandas` column, if possible. - - At the moment, only scalar data types can be inferred. - - Nullable boolean and integer `pandas` dtypes have no equivalent Spotlight - data type and will be read as strings. - - Float, string, and category data types are allowed to have `NaN`s. - - Args: - column: A `pandas` column to infer dtype from. - - Returns: - Inferred dtype. - - Raises: - ValueError: If dtype cannot be inferred automatically. - """ - - if pd.api.types.is_bool_dtype(column): - return dtypes.bool_dtype - if pd.api.types.is_categorical_dtype(column): - return dtypes.CategoryDType( - {category: code for code, category in enumerate(column.cat.categories)} - ) - if pd.api.types.is_integer_dtype(column): - return dtypes.int_dtype - if pd.api.types.is_float_dtype(column): - return dtypes.float_dtype - if pd.api.types.is_datetime64_any_dtype(column): - return dtypes.datetime_dtype - - column = column.copy() - str_mask = is_string_mask(column) - column[str_mask] = column[str_mask].replace("", None) - - column = column[~column.isna()] - if len(column) == 0: - return dtypes.str_dtype - - column_head = column.iloc[:10] - head_dtypes = column_head.apply(infer_value_dtype).to_list() # type: ignore - dtype_mode = statistics.mode(head_dtypes) - - if dtype_mode is None: - return dtypes.str_dtype - if dtype_mode in [dtypes.window_dtype, dtypes.embedding_dtype]: - column = column.astype(object) - str_mask = is_string_mask(column) - x = column[str_mask].apply(try_literal_eval) - column[str_mask] = x - dict_mask = column.map(type) == dict - column[dict_mask] = column[dict_mask].apply(prepare_hugging_face_dict) - try: - np.asarray(column.to_list(), dtype=float) - except (TypeError, ValueError): - return dtypes.sequence_1d_dtype - return dtype_mode - return dtype_mode - - -def infer_value_dtype(value: Any) -> Optional[dtypes.DType]: - """ - Infer dtype for value - """ - if isinstance(value, Embedding): - return dtypes.embedding_dtype - if isinstance(value, Sequence1D): - return dtypes.sequence_1d_dtype - if isinstance(value, Image): - return dtypes.image_dtype - if isinstance(value, Audio): - return dtypes.audio_dtype - if isinstance(value, Video): - return dtypes.video_dtype - if isinstance(value, Mesh): - return dtypes.mesh_dtype - if isinstance(value, PIL.Image.Image): - return dtypes.image_dtype - if isinstance(value, trimesh.Trimesh): - return dtypes.mesh_dtype - if isinstance(value, np.ndarray): - return infer_array_dtype(value) - - # When `pandas` reads a csv, arrays and lists are read as literal strings, - # try to interpret them. - value = try_literal_eval(value) - if isinstance(value, dict): - value = prepare_hugging_face_dict(value) - if isinstance(value, bytes) or (is_pathtype(value) and os.path.isfile(value)): - kind = filetype.guess(value) - if kind is not None: - mime_group = kind.mime.split("/")[0] - if mime_group == "image": - return dtypes.image_dtype - if mime_group == "audio": - return dtypes.audio_dtype - if mime_group == "video": - return dtypes.video_dtype - return None - if is_iterable(value): - try: - value = np.asarray(value, dtype=float) - except (TypeError, ValueError): - pass - else: - return infer_array_dtype(value) - return None - - -def infer_array_dtype(value: np.ndarray) -> dtypes.DType: - """ - Infer dtype of a numpy array - """ - if value.ndim == 3: - if value.shape[-1] in (1, 3, 4): - return dtypes.image_dtype - elif value.ndim == 2: - if value.shape[0] == 2 or value.shape[1] == 2: - return dtypes.sequence_1d_dtype - elif value.ndim == 1: - if len(value) == 2: - return dtypes.window_dtype - return dtypes.embedding_dtype - return dtypes.array_dtype - - -def infer_dtypes(df: pd.DataFrame, dtype: Optional[dtypes.DTypeMap]) -> dtypes.DTypeMap: - """ - Check column types from the given `dtype` and complete it with auto inferred - column types for the given `pandas.DataFrame`. - """ - inferred_dtype = dtype or {} - for column_index in df: - if column_index not in inferred_dtype: - try: - column_type = infer_dtype(df[column_index]) - except UnsupportedDType: - column_type = dtypes.str_dtype - inferred_dtype[str(column_index)] = column_type - return inferred_dtype - - -def is_string_mask(column: pd.Series) -> pd.Series: - """ - Return mask of column's elements of type string. - """ - if len(column) == 0: - return pd.Series([], dtype=bool) - return column.map(type) == str - - -def to_categorical(column: pd.Series, str_categories: bool = False) -> pd.Series: - """ - Convert a `pandas` column to categorical dtype. - - Args: - column: A `pandas` column. - str_categories: Replace all categories with their string representations. - - Returns: - categorical `pandas` column. - """ - column = column.mask(column.isna(), None).astype("category") # type: ignore - if str_categories: - return column.cat.rename_categories(column.cat.categories.astype(str)) - return column - - -def prepare_hugging_face_dict(x: Dict) -> Any: - """ - Prepare HuggingFace format for files to be used in Spotlight. - """ - if x.keys() != {"bytes", "path"}: - return x - blob = x["bytes"] - if blob is not None: - return blob - return x["path"] - - -def prepare_column(column: pd.Series, dtype: dtypes.DType) -> pd.Series: - """ - Convert a `pandas` column to the desired `dtype` and prepare some values, - but still as `pandas` column. - - Args: - column: A `pandas` column to prepare. - dtype: Target data type. - - Returns: - Prepared `pandas` column. - - Raises: - TypeError: If `dtype` is not a Spotlight data type. - """ - column = column.copy() - - if dtypes.is_category_dtype(dtype): - # We only support string/`NA` categories, but `pandas` can more, so - # force categories to be strings (does not affect `NA`s). - return to_categorical(column, str_categories=True) - - if dtypes.is_datetime_dtype(dtype): - # `errors="coerce"` will produce `NaT`s instead of fail. - return pd.to_datetime(column, errors="coerce") - - if dtypes.is_str_dtype(dtype): - # Allow `NA`s, convert all other elements to strings. - return column.astype(str).mask(column.isna(), None) # type: ignore - - if dtypes.is_bool_dtype(dtype): - return column.astype(bool) - - if dtypes.is_int_dtype(dtype): - return column.astype(int) - - if dtypes.is_float_dtype(dtype): - return column.astype(float) - - # We explicitely don't want to change the original `DataFrame`. - with pd.option_context("mode.chained_assignment", None): - # We consider empty strings as `NA`s. - str_mask = is_string_mask(column) - column[str_mask] = column[str_mask].replace("", None) - na_mask = column.isna() - - # When `pandas` reads a csv, arrays and lists are read as literal strings, - # try to interpret them. - str_mask = is_string_mask(column) - column[str_mask] = column[str_mask].apply(try_literal_eval) - - if dtypes.is_filebased_dtype(dtype): - dict_mask = column.map(type) == dict - column[dict_mask] = column[dict_mask].apply(prepare_hugging_face_dict) - - return column.mask(na_mask, None) # type: ignore diff --git a/renumics/spotlight_plugins/core/pandas_data_source.py b/renumics/spotlight_plugins/core/pandas_data_source.py index 3eee6ebe..4404a14d 100644 --- a/renumics/spotlight_plugins/core/pandas_data_source.py +++ b/renumics/spotlight_plugins/core/pandas_data_source.py @@ -9,11 +9,7 @@ import datasets from renumics.spotlight import dtypes -from renumics.spotlight.io.pandas import ( - prepare_hugging_face_dict, - stringify_columns, - try_literal_eval, -) +from renumics.spotlight.io import prepare_hugging_face_dict, try_literal_eval from renumics.spotlight.data_source import ( datasource, ColumnMetadata, @@ -107,7 +103,7 @@ def __init__(self, source: Union[Path, pd.DataFrame]): @property def column_names(self) -> List[str]: - return stringify_columns(self._df) + return [str(column) for column in self._df.columns] @property def df(self) -> pd.DataFrame: diff --git a/tests/integration/dataset/test_dataset.py b/tests/integration/dataset/test_dataset.py index 10f7ed75..fe28c8fc 100644 --- a/tests/integration/dataset/test_dataset.py +++ b/tests/integration/dataset/test_dataset.py @@ -26,7 +26,7 @@ from renumics.spotlight.dataset import escape_dataset_name, unescape_dataset_name from renumics.spotlight import dtypes from renumics.spotlight.dataset.typing import OutputType -from renumics.spotlight.io.pandas import infer_dtype +from renumics.spotlight.dataset.pandas import infer_dtype from .conftest import ColumnData from .helpers import get_append_column_fn_name from ..helpers import approx From a4d95a2b7dcce013c93bb287f1807f7023c169ca Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Mon, 23 Oct 2023 13:41:16 +0200 Subject: [PATCH 08/31] Track forgotten files --- renumics/spotlight/dataset/pandas.py | 275 +++++++++++++++++++++++++++ renumics/spotlight/io/huggingface.py | 16 ++ 2 files changed, 291 insertions(+) create mode 100644 renumics/spotlight/dataset/pandas.py create mode 100644 renumics/spotlight/io/huggingface.py diff --git a/renumics/spotlight/dataset/pandas.py b/renumics/spotlight/dataset/pandas.py new file mode 100644 index 00000000..dfa133e3 --- /dev/null +++ b/renumics/spotlight/dataset/pandas.py @@ -0,0 +1,275 @@ +import os.path +import statistics +from typing import Any, Optional, Sequence, Union + +import PIL.Image +import filetype +import numpy as np +import pandas as pd +import trimesh + +from renumics.spotlight import dtypes +from renumics.spotlight.io import prepare_hugging_face_dict, try_literal_eval +from renumics.spotlight.media import Audio, Embedding, Image, Mesh, Sequence1D, Video +from renumics.spotlight.typing import is_iterable, is_pathtype +from .exceptions import InvalidDTypeError + + +def create_typed_series( + dtype: dtypes.DType, values: Optional[Union[Sequence, np.ndarray]] = None +) -> pd.Series: + if dtypes.is_category_dtype(dtype): + if values is None or len(values) == 0: + return pd.Series( + dtype=pd.CategoricalDtype( + [] if not dtype.categories else list(dtype.categories.keys()) + ) + ) + if dtype.inverted_categories is None: + return pd.Series([None] * len(values), dtype=pd.CategoricalDtype()) + return pd.Series( + [dtype.inverted_categories.get(code) for code in values], + dtype=pd.CategoricalDtype(), + ) + if dtypes.is_bool_dtype(dtype): + pandas_dtype = "boolean" + elif dtypes.is_int_dtype(dtype): + pandas_dtype = "Int64" + elif dtypes.is_float_dtype(dtype): + pandas_dtype = "float" + elif dtypes.is_str_dtype(dtype): + pandas_dtype = "string" + elif dtypes.is_datetime_dtype(dtype): + pandas_dtype = "datetime64[ns]" + else: + pandas_dtype = "object" + return pd.Series([] if values is None else values, dtype=pandas_dtype) + + +def prepare_column(column: pd.Series, dtype: dtypes.DType) -> pd.Series: + """ + Convert a `pandas` column to the desired `dtype` and prepare some values, + but still as `pandas` column. + + Args: + column: A `pandas` column to prepare. + dtype: Target data type. + + Returns: + Prepared `pandas` column. + + Raises: + TypeError: If `dtype` is not a Spotlight data type. + """ + column = column.copy() + + if dtypes.is_category_dtype(dtype): + # We only support string/`NA` categories, but `pandas` can more, so + # force categories to be strings (does not affect `NA`s). + return to_categorical(column, str_categories=True) + + if dtypes.is_datetime_dtype(dtype): + # `errors="coerce"` will produce `NaT`s instead of fail. + return pd.to_datetime(column, errors="coerce") + + if dtypes.is_str_dtype(dtype): + # Allow `NA`s, convert all other elements to strings. + return column.astype(str).mask(column.isna(), None) # type: ignore + + if dtypes.is_bool_dtype(dtype): + return column.astype(bool) + + if dtypes.is_int_dtype(dtype): + return column.astype(int) + + if dtypes.is_float_dtype(dtype): + return column.astype(float) + + # We explicitely don't want to change the original `DataFrame`. + with pd.option_context("mode.chained_assignment", None): + # We consider empty strings as `NA`s. + str_mask = is_string_mask(column) + column[str_mask] = column[str_mask].replace("", None) + na_mask = column.isna() + + # When `pandas` reads a csv, arrays and lists are read as literal strings, + # try to interpret them. + str_mask = is_string_mask(column) + column[str_mask] = column[str_mask].apply(try_literal_eval) + + if dtypes.is_filebased_dtype(dtype): + dict_mask = column.map(type) == dict + column[dict_mask] = column[dict_mask].apply(prepare_hugging_face_dict) + + return column.mask(na_mask, None) # type: ignore + + +def infer_dtype(column: pd.Series) -> dtypes.DType: + """ + Get an equivalent Spotlight data type for a `pandas` column, if possible. + + At the moment, only scalar data types can be inferred. + + Nullable boolean and integer `pandas` dtypes have no equivalent Spotlight + data type and will be read as strings. + + Float, string, and category data types are allowed to have `NaN`s. + + Args: + column: A `pandas` column to infer dtype from. + + Returns: + Inferred dtype. + + Raises: + ValueError: If dtype cannot be inferred automatically. + """ + + if pd.api.types.is_bool_dtype(column): + return dtypes.bool_dtype + if pd.api.types.is_categorical_dtype(column): + return dtypes.CategoryDType( + {category: code for code, category in enumerate(column.cat.categories)} + ) + if pd.api.types.is_integer_dtype(column): + return dtypes.int_dtype + if pd.api.types.is_float_dtype(column): + return dtypes.float_dtype + if pd.api.types.is_datetime64_any_dtype(column): + return dtypes.datetime_dtype + + column = column.copy() + str_mask = is_string_mask(column) + column[str_mask] = column[str_mask].replace("", None) + + column = column[~column.isna()] + if len(column) == 0: + return dtypes.str_dtype + + column_head = column.iloc[:10] + head_dtypes = column_head.apply(infer_value_dtype).to_list() # type: ignore + dtype_mode = statistics.mode(head_dtypes) + + if dtype_mode is None: + return dtypes.str_dtype + if dtype_mode in [dtypes.window_dtype, dtypes.embedding_dtype]: + column = column.astype(object) + str_mask = is_string_mask(column) + x = column[str_mask].apply(try_literal_eval) + column[str_mask] = x + dict_mask = column.map(type) == dict + column[dict_mask] = column[dict_mask].apply(prepare_hugging_face_dict) + try: + np.asarray(column.to_list(), dtype=float) + except (TypeError, ValueError): + return dtypes.sequence_1d_dtype + return dtype_mode + return dtype_mode + + +def infer_value_dtype(value: Any) -> Optional[dtypes.DType]: + """ + Infer dtype for value + """ + if isinstance(value, Embedding): + return dtypes.embedding_dtype + if isinstance(value, Sequence1D): + return dtypes.sequence_1d_dtype + if isinstance(value, Image): + return dtypes.image_dtype + if isinstance(value, Audio): + return dtypes.audio_dtype + if isinstance(value, Video): + return dtypes.video_dtype + if isinstance(value, Mesh): + return dtypes.mesh_dtype + if isinstance(value, PIL.Image.Image): + return dtypes.image_dtype + if isinstance(value, trimesh.Trimesh): + return dtypes.mesh_dtype + if isinstance(value, np.ndarray): + return infer_array_dtype(value) + + # When `pandas` reads a csv, arrays and lists are read as literal strings, + # try to interpret them. + value = try_literal_eval(value) + if isinstance(value, dict): + value = prepare_hugging_face_dict(value) + if isinstance(value, bytes) or (is_pathtype(value) and os.path.isfile(value)): + kind = filetype.guess(value) + if kind is not None: + mime_group = kind.mime.split("/")[0] + if mime_group == "image": + return dtypes.image_dtype + if mime_group == "audio": + return dtypes.audio_dtype + if mime_group == "video": + return dtypes.video_dtype + return None + if is_iterable(value): + try: + value = np.asarray(value, dtype=float) + except (TypeError, ValueError): + pass + else: + return infer_array_dtype(value) + return None + + +def infer_array_dtype(value: np.ndarray) -> dtypes.DType: + """ + Infer dtype of a numpy array + """ + if value.ndim == 3: + if value.shape[-1] in (1, 3, 4): + return dtypes.image_dtype + elif value.ndim == 2: + if value.shape[0] == 2 or value.shape[1] == 2: + return dtypes.sequence_1d_dtype + elif value.ndim == 1: + if len(value) == 2: + return dtypes.window_dtype + return dtypes.embedding_dtype + return dtypes.array_dtype + + +def infer_dtypes(df: pd.DataFrame, dtype: Optional[dtypes.DTypeMap]) -> dtypes.DTypeMap: + """ + Check column types from the given `dtype` and complete it with auto inferred + column types for the given `pandas.DataFrame`. + """ + inferred_dtype = dtype or {} + for column_index in df: + if column_index not in inferred_dtype: + try: + column_type = infer_dtype(df[column_index]) + except InvalidDTypeError: + column_type = dtypes.str_dtype + inferred_dtype[str(column_index)] = column_type + return inferred_dtype + + +def is_string_mask(column: pd.Series) -> pd.Series: + """ + Return mask of column's elements of type string. + """ + if len(column) == 0: + return pd.Series([], dtype=bool) + return column.map(type) == str + + +def to_categorical(column: pd.Series, str_categories: bool = False) -> pd.Series: + """ + Convert a `pandas` column to categorical dtype. + + Args: + column: A `pandas` column. + str_categories: Replace all categories with their string representations. + + Returns: + categorical `pandas` column. + """ + column = column.mask(column.isna(), None).astype("category") # type: ignore + if str_categories: + return column.cat.rename_categories(column.cat.categories.astype(str)) + return column diff --git a/renumics/spotlight/io/huggingface.py b/renumics/spotlight/io/huggingface.py new file mode 100644 index 00000000..06c0d441 --- /dev/null +++ b/renumics/spotlight/io/huggingface.py @@ -0,0 +1,16 @@ +""" +Helpers for HuggingFace formats. +""" +from typing import Any, Dict + + +def prepare_hugging_face_dict(x: Dict) -> Any: + """ + Prepare HuggingFace format for files to be used in Spotlight. + """ + if x.keys() != {"bytes", "path"}: + return x + blob = x["bytes"] + if blob is not None: + return blob + return x["path"] From d172873800a820e32092316198efa4d2da73a44c Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Mon, 23 Oct 2023 13:57:08 +0200 Subject: [PATCH 09/31] Add docstring --- renumics/spotlight/dataset/pandas.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/renumics/spotlight/dataset/pandas.py b/renumics/spotlight/dataset/pandas.py index dfa133e3..75ccf00f 100644 --- a/renumics/spotlight/dataset/pandas.py +++ b/renumics/spotlight/dataset/pandas.py @@ -1,3 +1,7 @@ +""" +Helper for conversion between H5 dataset and `pandas.DataFrame`. +""" + import os.path import statistics from typing import Any, Optional, Sequence, Union From 472937840913ef27a11a0a1a66544e8d52477b1b Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Mon, 23 Oct 2023 14:02:26 +0200 Subject: [PATCH 10/31] Fix checking scalar dtypes for data alignment before embedding --- renumics/spotlight/backend/tasks/reduction.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/renumics/spotlight/backend/tasks/reduction.py b/renumics/spotlight/backend/tasks/reduction.py index 00207c38..9a8d71c1 100644 --- a/renumics/spotlight/backend/tasks/reduction.py +++ b/renumics/spotlight/backend/tasks/reduction.py @@ -10,7 +10,7 @@ from renumics.spotlight.dataset.exceptions import ColumnNotExistsError from renumics.spotlight.data_store import DataStore -from renumics.spotlight.dtypes import is_category_dtype, is_embedding_dtype +from renumics.spotlight import dtypes SEED = 42 @@ -35,7 +35,7 @@ def align_data( for column_name in column_names: dtype = data_store.dtypes[column_name] column_values = data_store.get_converted_values(column_name, indices) - if is_embedding_dtype(dtype): + if dtypes.is_embedding_dtype(dtype): embedding_length = max( 0 if x is None else len(cast(np.ndarray, x)) for x in column_values ) @@ -49,17 +49,19 @@ def align_data( ] ) ) - elif is_category_dtype(dtype): + elif dtypes.is_category_dtype(dtype): na_mask = np.array(column_values) == -1 one_hot_values = preprocessing.label_binarize( column_values, classes=sorted(set(column_values).difference({-1})) # type: ignore ).astype(float) one_hot_values[na_mask] = np.nan aligned_values.append(one_hot_values) - elif dtype in (int, bool, float): + elif dtypes.is_scalar_dtype(dtype): aligned_values.append(np.array(column_values, dtype=float)) else: - raise ColumnNotEmbeddable + raise ColumnNotEmbeddable( + "Column '{column_name}' of type {dtype} is not embeddable." + ) data = np.hstack([col.reshape((len(indices), -1)) for col in aligned_values]) mask = ~pd.isna(data).any(axis=1) From f114895485b4d98289517fbfa839816371c736c5 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Mon, 23 Oct 2023 17:15:38 +0200 Subject: [PATCH 11/31] feat: transport Exceptions from tasks over our websocket connection --- renumics/spotlight/backend/tasks/reduction.py | 2 + renumics/spotlight/backend/websockets.py | 24 ++++++ src/services/data.ts | 12 ++- src/widgets/SimilarityMap/SimilarityMap.tsx | 84 ++++++++++++------- 4 files changed, 91 insertions(+), 31 deletions(-) diff --git a/renumics/spotlight/backend/tasks/reduction.py b/renumics/spotlight/backend/tasks/reduction.py index 00207c38..bb14096a 100644 --- a/renumics/spotlight/backend/tasks/reduction.py +++ b/renumics/spotlight/backend/tasks/reduction.py @@ -78,6 +78,8 @@ def compute_umap( Prepare data from table and compute U-Map on them. """ + raise Exception("Ooops") + try: data, indices = align_data(data_store, column_names, indices) except (ColumnNotExistsError, ColumnNotEmbeddable): diff --git a/renumics/spotlight/backend/websockets.py b/renumics/spotlight/backend/websockets.py index 8be62f4a..9fa6e9d6 100644 --- a/renumics/spotlight/backend/websockets.py +++ b/renumics/spotlight/backend/websockets.py @@ -57,6 +57,19 @@ class ReductionMessage(Message): generation_id: int +class TaskErrorMessage(Message): + """ + Common message model for task errors + """ + + widget_id: str + uid: str + error: str + title: str + detail: Optional[str] = None + data: None = None + + class ReductionRequestData(BaseModel): """ Base data reduction request payload. @@ -155,6 +168,7 @@ async def handle_message(request: Message, connection: "WebsocketConnection") -> New message types should be registered by decorating with `@handle_message.register`. """ + # TODO: add generic error handler raise NotImplementedError @@ -326,6 +340,16 @@ async def _(request: UMapRequest, connection: "WebsocketConnection") -> None: ) except TaskCancelled: ... + except Exception as e: + msg = TaskErrorMessage( + type="tasks.error", + widget_id=request.widget_id, + uid=request.uid, + error=type(e).__name__, + title=type(e).__name__, + detail=type(e).__doc__, + ) + await connection.send_async(msg) else: response = ReductionResponse( type="umap_result", diff --git a/src/services/data.ts b/src/services/data.ts index 16aa82f5..4884e365 100644 --- a/src/services/data.ts +++ b/src/services/data.ts @@ -2,6 +2,7 @@ import { useDataset } from '../stores/dataset'; import { useLayout } from '../stores/layout'; import { IndexArray } from '../types'; import { v4 as uuidv4 } from 'uuid'; +import { Problem } from '../types'; export const umapMetricNames = [ 'euclidean', @@ -110,9 +111,18 @@ export class DataService { ): Promise { const messageId = uuidv4(); - const promise = new Promise((resolve) => { + const promise = new Promise((resolve, reject) => { // eslint-disable-next-line @typescript-eslint/no-explicit-any this.dispatchTable.set(messageId, (message: any) => { + if (message.type === 'tasks.error') { + const error: Problem = { + type: message.error, + title: message.title, + detail: message.detail, + }; + reject(error); + return; + } const result = { points: message.data.points, indices: message.data.indices, diff --git a/src/widgets/SimilarityMap/SimilarityMap.tsx b/src/widgets/SimilarityMap/SimilarityMap.tsx index 193a46f8..df9e602b 100644 --- a/src/widgets/SimilarityMap/SimilarityMap.tsx +++ b/src/widgets/SimilarityMap/SimilarityMap.tsx @@ -1,4 +1,5 @@ import SimilaritiesIcon from '../../icons/Bubbles'; +import WarningIcon from '../../icons/Warning'; import LoadingIndicator from '../../components/LoadingIndicator'; import Plot, { MergeStrategy, @@ -23,6 +24,7 @@ import { IndexArray, isNumberColumn, NumberColumn, + Problem, TableData, } from '../../types'; import { createSizeTransferFunction } from '../../dataformat'; @@ -220,8 +222,8 @@ const SimilarityMap: Widget = () => { const placeByColumns = useDataset(placeByColumnsSelector, shallow); const [positions, setPositions] = useState<[number, number][]>([]); - const [visibleIndices, setVisibleIndices] = useState([]); + const [problem, setProblem] = useState(null); const selected = useMemo(() => { const selected: boolean[] = []; @@ -275,13 +277,14 @@ const SimilarityMap: Widget = () => { const widgetId = useMemo(() => uuidv4(), []); useEffect(() => { + setVisibleIndices([]); + setPositions([]); + setProblem(null); + if (!indices.length || !placeByColumnKeys.length) { - setVisibleIndices([]); - setPositions([]); setIsComputing(false); return; } - setIsComputing(true); const reductionPromise = @@ -302,12 +305,18 @@ const SimilarityMap: Widget = () => { ); let cancelled = false; - reductionPromise.then(({ points, indices }) => { - if (cancelled) return; - setVisibleIndices(indices); - setPositions(points); - setIsComputing(false); - }); + reductionPromise + .then(({ points, indices }) => { + if (cancelled) return; + setVisibleIndices(indices); + setPositions(points); + setIsComputing(false); + }) + .catch((problem: Problem) => { + if (cancelled) return; + setProblem(problem); + setIsComputing(false); + }); return () => { cancelled = true; @@ -443,23 +452,35 @@ const SimilarityMap: Widget = () => { const areColumnsSelected = !!placeByColumnKeys.length; const hasVisibleRows = !!visibleIndices.length; - return ( - - {isComputing && } - {!areColumnsSelected && !isComputing && ( - - - Select columns and show at least two rows to display similarity - map - - - )} - {areColumnsSelected && !hasVisibleRows && !isComputing && ( - - Not enough rows - - )} - {areColumnsSelected && hasVisibleRows && !isComputing && ( + let content: JSX.Element; + if (problem) { + content = ( +
+
+ {problem.title} +
+
{problem.detail}
+
+ ); + } else if (isComputing) { + content = ; + } else if (!areColumnsSelected) { + content = ( + + + Select columns and show at least two rows to display similarity map + + + ); + } else if (!hasVisibleRows) { + content = ( + + Not enough rows + + ); + } else { + content = ( + <> { /> )} - )} - {!isComputing && (
{visibleIndices.length} of {indices.length} rows
- )} + + ); + } + return ( + + {content} Date: Tue, 24 Oct 2023 07:54:04 +0200 Subject: [PATCH 12/31] Formal md files --- CONTRIBUTING.md | 8 ++++---- README.md | 15 +++++++-------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d6628c4f..e71d164e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -11,10 +11,10 @@ Technical details on how to contribute can be found in our [documentation](https There are several ways you can contribute to Spotlight: -* Fix outstanding issues. -* Implement new features. -* Submit issues related to bugs or desired new features. -* Share your use case +- Fix outstanding issues. +- Implement new features. +- Submit issues related to bugs or desired new features. +- Share your use case If you don't know where to start, you might want to have a look at [hacktoberfest issues](https://github.com/Renumics/spotlight/issues?q=is%3Aissue+is%3Aopen+label%3Ahacktoberfest) and our guide on how to create a [new Lens](https://renumics.com/docs/development/lenses). diff --git a/README.md b/README.md index 6c68d669..320f9e13 100644 --- a/README.md +++ b/README.md @@ -17,9 +17,10 @@

-Spotlight helps you to **understand unstructured datasets** fast. You can quickly create **interactive visualizations** and leverage data enrichments (e.g. embeddings, prediction, uncertainties) to **identify critical clusters** in your data. +Spotlight helps you to **understand unstructured datasets** fast. You can quickly create **interactive visualizations** and leverage data enrichments (e.g. embeddings, prediction, uncertainties) to **identify critical clusters** in your data. Spotlight supports most unstructured data types including **images, audio, text, videos, time-series and geometric data**. You can start from your existing dataframe: +

And start Spotlight with just a few lines of code: @@ -49,7 +50,7 @@ Machine learning and engineering teams use Spotlight to understand and communica [Classification] Find Issues in Any Image Classification Dataset 👨‍💻 📝 🕹️ - + Find data issues in the CIFAR-100 image dataset 🕹️ @@ -91,7 +92,6 @@ Machine learning and engineering teams use Spotlight to understand and communica - ## ⏱️ Quickstart Get started by installing Spotlight and loading your first dataset. @@ -132,12 +132,11 @@ ds = datasets.load_dataset('renumics/emodb-enriched', split='all') layout= spotlight.layouts.debug_classification(label='gender', prediction='m1_gender_prediction', embedding='m1_embedding', features=['age', 'emotion']) spotlight.show(ds, layout=layout) ``` + Here, the data types are discovered automatically from the dataset and we use a pre-defined layout for model debugging. Custom layouts can be built programmatically or via the UI. > The `datasets[audio]` package can be installed via pip. - - #### Usage Tracking We have added crash report and performance collection. We do NOT collect user data other than an anonymized Machine Id obtained by py-machineid, and only log our own actions. We do NOT collect folder names, dataset names, or row data of any kind only aggregate performance statistics like total time of a table_load, crash data, etc. Collecting Spotlight crashes will help us improve stability. To opt out of the crash report collection define an environment variable called `SPOTLIGHT_OPT_OUT` and set it to true. e.G.`export SPOTLIGHT_OPT_OUT=true` @@ -150,9 +149,9 @@ We have added crash report and performance collection. We do NOT collect user da ## Learn more about unstructured data workflows -- 🤗 [Huggingface](https://huggingface.co/renumics) example spaces and datasets -- 🏀 [Playbook](https://renumics.com/docs/playbook/) for data-centric AI workflows -- 🍰 [Sliceguard](https://github.com/Renumics/sliceguard) library for automatic slice detection +- 🤗 [Huggingface](https://huggingface.co/renumics) example spaces and datasets +- 🏀 [Playbook](https://renumics.com/docs/playbook/) for data-centric AI workflows +- 🍰 [Sliceguard](https://github.com/Renumics/sliceguard) library for automatic slice detection ## Contribute From 27ed3e76f9221cfa42fe77c24d48958d841ec7af Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Tue, 24 Oct 2023 08:30:39 +0200 Subject: [PATCH 13/31] Add deprecation warning to descriptors module --- renumics/spotlight/dataset/descriptors/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/renumics/spotlight/dataset/descriptors/__init__.py b/renumics/spotlight/dataset/descriptors/__init__.py index bb5cf837..3f52372d 100644 --- a/renumics/spotlight/dataset/descriptors/__init__.py +++ b/renumics/spotlight/dataset/descriptors/__init__.py @@ -1,5 +1,6 @@ """make descriptor methods more available """ +import warnings from typing import Optional, Tuple import numpy as np @@ -11,6 +12,13 @@ from renumics.spotlight.dataset.exceptions import ColumnExistsError, InvalidDTypeError from .data_alignment import align_column_data +warnings.warn( + "`renumics.spotlight.dataset.descriptors` module is deprecated and will " + "be removed in future versions.", + DeprecationWarning, + stacklevel=2, +) + def pca( dataset: Dataset, From 18522daeb39b3754fcc5a707f56fe51875df35bc Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Tue, 24 Oct 2023 10:08:29 +0200 Subject: [PATCH 14/31] feat: style error in simmap --- src/widgets/SimilarityMap/SimilarityMap.tsx | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/widgets/SimilarityMap/SimilarityMap.tsx b/src/widgets/SimilarityMap/SimilarityMap.tsx index df9e602b..5c42ebc9 100644 --- a/src/widgets/SimilarityMap/SimilarityMap.tsx +++ b/src/widgets/SimilarityMap/SimilarityMap.tsx @@ -456,10 +456,11 @@ const SimilarityMap: Widget = () => { if (problem) { content = (
-
- {problem.title} +
+
+
{problem.title}
+
{problem.detail}
-
{problem.detail}
); } else if (isComputing) { From 9225698eb0c57eb4254c2a19a8c2e1fb5005354c Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Tue, 24 Oct 2023 13:27:21 +0200 Subject: [PATCH 15/31] make embeddings to C layout after PCA --- renumics/spotlight/backend/tasks/reduction.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/renumics/spotlight/backend/tasks/reduction.py b/renumics/spotlight/backend/tasks/reduction.py index 9a8d71c1..7de7158f 100644 --- a/renumics/spotlight/backend/tasks/reduction.py +++ b/renumics/spotlight/backend/tasks/reduction.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd -from sklearn import preprocessing from renumics.spotlight.dataset.exceptions import ColumnNotExistsError from renumics.spotlight.data_store import DataStore @@ -27,6 +26,7 @@ def align_data( """ Align data from table's columns, remove `NaN`'s. """ + from sklearn import preprocessing if not column_names or not indices: return np.empty(0, np.float64), [] @@ -60,7 +60,7 @@ def align_data( aligned_values.append(np.array(column_values, dtype=float)) else: raise ColumnNotEmbeddable( - "Column '{column_name}' of type {dtype} is not embeddable." + f"Column '{column_name}' of type {dtype} is not embeddable." ) data = np.hstack([col.reshape((len(indices), -1)) for col in aligned_values]) @@ -120,7 +120,7 @@ def compute_pca( try: data, indices = align_data(data_store, column_names, indices) - except (ColumnNotExistsError, ValueError): + except (ColumnNotExistsError, ColumnNotEmbeddable): return np.empty(0, np.float64), [] if data.size == 0: return np.empty(0, np.float64), [] @@ -131,5 +131,6 @@ def compute_pca( elif normalization == "robust standardize": data = preprocessing.RobustScaler(copy=False).fit_transform(data) reducer = decomposition.PCA(n_components=2, copy=False, random_state=SEED) - embeddings = reducer.fit_transform(data) + # `fit_transform` returns Fortran-ordered array. + embeddings = np.ascontiguousarray(reducer.fit_transform(data)) return embeddings, indices From 36e522d47b229d4881a7c81416d4a66f8945851b Mon Sep 17 00:00:00 2001 From: Tarek Date: Tue, 24 Oct 2023 14:40:42 +0200 Subject: [PATCH 16/31] remove mel calculation for separate fix --- src/lenses/SpectrogramLens/index.tsx | 34 +++------------------------- 1 file changed, 3 insertions(+), 31 deletions(-) diff --git a/src/lenses/SpectrogramLens/index.tsx b/src/lenses/SpectrogramLens/index.tsx index 9643ce99..70e48ba1 100644 --- a/src/lenses/SpectrogramLens/index.tsx +++ b/src/lenses/SpectrogramLens/index.tsx @@ -10,7 +10,7 @@ import { ColorsState, useColors } from '../../stores/colors'; import { Lens } from '../../types'; import useSetting from '../useSetting'; import MenuBar from './MenuBar'; -import { fixWindow, freqType, unitType, amplitudeToDb, hzToMel } from './Spectrogram'; +import { fixWindow, freqType, unitType, amplitudeToDb } from './Spectrogram'; const Container = tw.div`flex flex-col w-full h-full items-stretch justify-center`; const EmptyNote = styled.p` @@ -166,23 +166,15 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => { return; } - //console.log(JSON.stringify((buffer.getChannelData(0).slice(start, end))) - //const slice = buffer.getChannelData(0).slice(start, end); - const arr = buffer.getChannelData(0).slice(start, end); - console.log(Math.max(...arr), Math.min(...arr)); const frequenciesData = await worker( FFT_SAMPLES, backend.windowFunc, backend.alpha, width, FFT_SAMPLES, - arr + buffer.getChannelData(0).slice(start, end) ); - console.log(frequenciesData); - console.log(frequenciesData.length); - console.log(frequenciesData[0].length); - setIsComputing(false); // Get the canvas render context 2D @@ -216,8 +208,6 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => { const widthScale = d3.scaleLinear([0, width], [0, frequenciesData.length]); let drawData = []; - //let colorScale: chroma.Scale; - let min = 0; let max = 0; @@ -250,24 +240,6 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => { } } - drawData[i] = col; - } - } else if (ampScale === 'mel') { - for (let i = 0; i < frequenciesData.length; i++) { - const col = []; - - for (let j = 0; j < frequenciesData[i].length; j++) { - const amplitude = frequenciesData[i][j]; - col[j] = hzToMel(amplitude ** 2); - - if (col[j] > max) { - max = col[j]; - } - - if (col[j] < min) { - min = col[j]; - } - } drawData[i] = col; } } else { @@ -476,7 +448,7 @@ const SpectrogramLens: Lens = ({ columns, urls, values }) => { availableFreqScales={['linear', 'logarithmic']} freqScale={freqScale} onChangeFreqScale={handleFreqScaleChange} - availableAmpScales={['decibel', 'linear', 'mel']} + availableAmpScales={['decibel', 'linear']} ampScale={ampScale} onChangeAmpScale={handleAmpScaleChange} /> From 4b46a758f74191521757ce81c0a18ab62d72121f Mon Sep 17 00:00:00 2001 From: Tarek Date: Tue, 24 Oct 2023 14:44:23 +0200 Subject: [PATCH 17/31] remove mel/hz calculation for separate fix --- src/lenses/SpectrogramLens/Spectrogram.tsx | 38 +--------------------- 1 file changed, 1 insertion(+), 37 deletions(-) diff --git a/src/lenses/SpectrogramLens/Spectrogram.tsx b/src/lenses/SpectrogramLens/Spectrogram.tsx index bfd77463..8dd7f5e0 100644 --- a/src/lenses/SpectrogramLens/Spectrogram.tsx +++ b/src/lenses/SpectrogramLens/Spectrogram.tsx @@ -51,40 +51,4 @@ const amplitudeToDb = (amplitude: number, ref: number, amin: number) => { return log_spec; }; -const hzToMel = (freq: number) => { - // Fill in the linear part - const f_min = 0.0; - const f_sp = 200.0 / 3; - - let mel = (freq - f_min) / f_sp; - - // Fill in the log-scale part - const min_log_hz = 1000.0; // beginning of log region (Hz) - const min_log_mel = (min_log_hz - f_min) / f_sp; // same (Mels) - const logstep = Math.log(6.4) / 27.0; // step size for log region - - if (freq >= min_log_hz) { - mel = min_log_mel + Math.log(freq / min_log_hz) / logstep; - } - - return mel; -}; - -const melToHz = (mel: number) => { - const f_min = 0.0; - const f_sp = 200.0 / 3; - let freq = f_min + f_sp * mel; - - // And now the nonlinear scale - const min_log_hz = 1000.0; // beginning of log region (Hz) - const min_log_mel = (min_log_hz - f_min) / f_sp; // same (Mels) - const logstep = Math.log(6.4) / 27.0; // step size for log region - - if (mel >= min_log_mel) { - // If we have scalar data, check directly - freq = min_log_hz * Math.exp(logstep * (mel - min_log_mel)); - } - return freq; -}; - -export { unitType, freqType, fixWindow, amplitudeToDb, hzToMel, melToHz }; +export { unitType, freqType, fixWindow, amplitudeToDb }; From 5c41864d3cf72945ee5e31b07f8d1c6d1028b2ca Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 12:49:39 +0200 Subject: [PATCH 18/31] refactor: setup explicit ws message handlers in backend --- renumics/spotlight/backend/tasks/reduction.py | 2 - .../spotlight/backend/tasks/task_manager.py | 11 +- renumics/spotlight/backend/websockets.py | 271 ++++++------------ src/services/data.ts | 162 +---------- src/services/task.ts | 46 +++ src/services/websocket/connection.ts | 61 ++++ src/services/websocket/index.ts | 49 ++++ src/services/websocket/types.ts | 6 + 8 files changed, 270 insertions(+), 338 deletions(-) create mode 100644 src/services/task.ts create mode 100644 src/services/websocket/connection.ts create mode 100644 src/services/websocket/index.ts create mode 100644 src/services/websocket/types.ts diff --git a/renumics/spotlight/backend/tasks/reduction.py b/renumics/spotlight/backend/tasks/reduction.py index bb14096a..00207c38 100644 --- a/renumics/spotlight/backend/tasks/reduction.py +++ b/renumics/spotlight/backend/tasks/reduction.py @@ -78,8 +78,6 @@ def compute_umap( Prepare data from table and compute U-Map on them. """ - raise Exception("Ooops") - try: data, indices = align_data(data_store, column_names, indices) except (ColumnNotExistsError, ColumnNotEmbeddable): diff --git a/renumics/spotlight/backend/tasks/task_manager.py b/renumics/spotlight/backend/tasks/task_manager.py index e120fa88..8c392a55 100644 --- a/renumics/spotlight/backend/tasks/task_manager.py +++ b/renumics/spotlight/backend/tasks/task_manager.py @@ -6,7 +6,7 @@ import multiprocessing from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool -from typing import Any, Callable, List, Optional, Sequence, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union from .exceptions import TaskCancelled from .task import Task @@ -30,16 +30,20 @@ def create_task( self, func: Callable, args: Sequence[Any], + kwargs: Optional[Dict[str, Any]] = None, name: Optional[str] = None, tag: Optional[Union[str, int]] = None, ) -> Task: """ create and launch a new task """ + if kwargs is None: + kwargs = {} + # cancel running task with same name self.cancel(name=name) - future = self.pool.submit(func, *args) + future = self.pool.submit(func, *args, **kwargs) task = Task(name, tag, future) self.tasks.append(task) @@ -59,6 +63,7 @@ async def run_async( self, func: Callable[..., T], args: Sequence[Any], + kwargs: Optional[Dict[str, Any]] = None, name: Optional[str] = None, tag: Optional[Union[str, int]] = None, ) -> T: @@ -66,7 +71,7 @@ async def run_async( Launch a new task. Await and return result. """ - task = self.create_task(func, args, name, tag) + task = self.create_task(func, args=args, kwargs=kwargs, name=name, tag=tag) try: return await asyncio.wrap_future(task.future) except BrokenProcessPool as e: diff --git a/renumics/spotlight/backend/websockets.py b/renumics/spotlight/backend/websockets.py index 9fa6e9d6..30d15652 100644 --- a/renumics/spotlight/backend/websockets.py +++ b/renumics/spotlight/backend/websockets.py @@ -3,11 +3,20 @@ """ import asyncio -import functools -import orjson -from typing import Any, List, Optional, Set, Callable +from typing import ( + Any, + Coroutine, + Dict, + Optional, + Set, + Callable, + Tuple, + Type, + cast, +) import numpy as np +import orjson from fastapi import WebSocket, WebSocketDisconnect from loguru import logger from pydantic import BaseModel @@ -17,7 +26,7 @@ from .tasks import TaskManager, TaskCancelled from .tasks.reduction import compute_umap, compute_pca -from .exceptions import GenerationIDMismatch +from .exceptions import GenerationIDMismatch, Problem class Message(BaseModel): @@ -47,99 +56,12 @@ class ResetLayoutMessage(Message): data: Any = None -class ReductionMessage(Message): - """ - Common data reduction message model. - """ - +class TaskData(BaseModel): + task: str widget_id: str - uid: str - generation_id: int - - -class TaskErrorMessage(Message): - """ - Common message model for task errors - """ - - widget_id: str - uid: str - error: str - title: str - detail: Optional[str] = None - data: None = None - - -class ReductionRequestData(BaseModel): - """ - Base data reduction request payload. - """ - - indices: List[int] - columns: List[str] - - -class UMapRequestData(ReductionRequestData): - """ - U-Map request payload. - """ - - n_neighbors: int - metric: str - min_dist: float - - -class PCARequestData(ReductionRequestData): - """ - PCA request payload. - """ - - normalization: str - - -class UMapRequest(ReductionMessage): - """ - U-Map request model. - """ - - data: UMapRequestData - - -class PCARequest(ReductionMessage): - """ - PCA request model. - """ - - data: PCARequestData - - -class ReductionResponseData(BaseModel): - """ - Data reduction response payload. - """ - - indices: List[int] - points: np.ndarray - - class Config: - arbitrary_types_allowed = True - - -class ReductionResponse(ReductionMessage): - """ - Data reduction response model. - """ - - data: ReductionResponseData - - -MESSAGE_BY_TYPE = { - "umap": UMapRequest, - "umap_result": ReductionResponse, - "pca": PCARequest, - "pca_result": ReductionResponse, - "refresh": RefreshMessage, -} + task_id: str + generation_id: Optional[int] + args: Any class UnknownMessageType(Exception): @@ -148,28 +70,26 @@ class UnknownMessageType(Exception): """ -def parse_message(raw_message: str) -> Message: - """ - Parse a websocket message from a raw text. - """ - json_message = orjson.loads(raw_message) - message_type = json_message["type"] - message_class = MESSAGE_BY_TYPE.get(message_type) - if message_class is None: - raise UnknownMessageType(f"Message type {message_type} is unknown.") - return message_class(**json_message) +PayloadType = Type[BaseModel] +MessageHandler = Callable[[Any, "WebsocketConnection"], Coroutine[Any, Any, Any]] +MessageHandlerSpec = Tuple[PayloadType, MessageHandler] +MESSAGE_HANDLERS: Dict[str, MessageHandlerSpec] = {} -@functools.singledispatch -async def handle_message(request: Message, connection: "WebsocketConnection") -> None: - """ - Handle incoming messages. +def register_message_handler( + message_type: str, handler_spec: MessageHandlerSpec +) -> None: + MESSAGE_HANDLERS[message_type] = handler_spec - New message types should be registered by decorating with `@handle_message.register`. - """ - # TODO: add generic error handler - raise NotImplementedError +def message_handler( + message_type: str, payload_type: Type[BaseModel] +) -> Callable[[MessageHandler], MessageHandler]: + def decorator(handler: MessageHandler) -> MessageHandler: + register_message_handler(message_type, (payload_type, handler)) + return handler + + return decorator class WebsocketConnection: @@ -218,12 +138,17 @@ async def listen(self) -> None: try: while True: try: - message = parse_message(await self.websocket.receive_text()) + raw_message = await self.websocket.receive_text() + message = Message(**orjson.loads(raw_message)) except UnknownMessageType as e: logger.warning(str(e)) else: logger.info(f"WS message with type {message.type} received.") - asyncio.create_task(handle_message(message, self)) + if handler_spec := MESSAGE_HANDLERS.get(message.type): + payload_type, handler = handler_spec + asyncio.create_task(handler(payload_type(**message.data), self)) + else: + logger.error(f"Unknown message received: {message.type}") except WebSocketDisconnect: self._on_disconnect() @@ -314,83 +239,63 @@ def on_disconnect(self, connection: WebsocketConnection) -> None: callback(len(self.connections)) -@handle_message.register -async def _(request: UMapRequest, connection: "WebsocketConnection") -> None: +TASK_FUNCS = {"umap": compute_umap, "pca": compute_pca} + + +@message_handler("task", TaskData) +async def _(data: TaskData, connection: WebsocketConnection) -> None: data_store: Optional[DataStore] = connection.websocket.app.data_store if data_store is None: return None - try: - data_store.check_generation_id(request.generation_id) - except GenerationIDMismatch: - return + + if data.generation_id: + try: + data_store.check_generation_id(data.generation_id) + except GenerationIDMismatch: + return try: - points, valid_indices = await connection.task_manager.run_async( - compute_umap, - ( - data_store, - request.data.columns, - request.data.indices, - request.data.n_neighbors, - request.data.metric, - request.data.min_dist, - ), - name=request.widget_id, + task_func = TASK_FUNCS[data.task] + print(data.args) + result = await connection.task_manager.run_async( + task_func, # type: ignore + args=(data_store,), + kwargs=data.args, + name=data.widget_id, tag=id(connection), ) + points = cast(np.ndarray, result[0]) + valid_indices = cast(np.ndarray, result[1]) except TaskCancelled: - ... - except Exception as e: - msg = TaskErrorMessage( - type="tasks.error", - widget_id=request.widget_id, - uid=request.uid, - error=type(e).__name__, - title=type(e).__name__, - detail=type(e).__doc__, + pass + except Problem as e: + msg = Message( + type="task.error", + data={ + "task_id": data.task_id, + "error": { + "type": type(e).__name__, + "title": type(e).__name__, + "detail": type(e).__doc__, + }, + }, ) await connection.send_async(msg) - else: - response = ReductionResponse( - type="umap_result", - widget_id=request.widget_id, - uid=request.uid, - generation_id=request.generation_id, - data=ReductionResponseData(indices=valid_indices, points=points), - ) - await connection.send_async(response) - - -@handle_message.register -async def _(request: PCARequest, connection: "WebsocketConnection") -> None: - data_store: Optional[DataStore] = connection.websocket.app.data_store - if data_store is None: - return None - try: - data_store.check_generation_id(request.generation_id) - except GenerationIDMismatch: - return - - try: - points, valid_indices = await connection.task_manager.run_async( - compute_pca, - ( - data_store, - request.data.columns, - request.data.indices, - request.data.normalization, - ), - name=request.widget_id, - tag=id(connection), + except Exception as e: + msg = Message( + type="task.error", + data={ + "task_id": data.task_id, + "error": { + "type": type(e).__name__, + "title": type(e).__name__, + "detail": type(e).__doc__, + }, + }, ) - except TaskCancelled: - ... + await connection.send_async(msg) else: - response = ReductionResponse( - type="pca_result", - widget_id=request.widget_id, - uid=request.uid, - generation_id=request.generation_id, - data=ReductionResponseData(indices=valid_indices, points=points), + msg = Message( + type="task.result", data={"points": points, "indices": valid_indices} ) - await connection.send_async(response) + await connection.send_async(msg) diff --git a/src/services/data.ts b/src/services/data.ts index 4884e365..069e01dc 100644 --- a/src/services/data.ts +++ b/src/services/data.ts @@ -1,8 +1,5 @@ -import { useDataset } from '../stores/dataset'; -import { useLayout } from '../stores/layout'; import { IndexArray } from '../types'; -import { v4 as uuidv4 } from 'uuid'; -import { Problem } from '../types'; +import taskService from './task'; export const umapMetricNames = [ 'euclidean', @@ -11,11 +8,8 @@ export const umapMetricNames = [ 'cosine', 'mahalanobis', ] as const; - export type UmapMetric = typeof umapMetricNames[number]; - export const pcaNormalizations = ['none', 'standardize', 'robust standardize'] as const; - export type PCANormalization = typeof pcaNormalizations[number]; interface ReductionResult { @@ -23,84 +17,7 @@ interface ReductionResult { indices: IndexArray; } -const MAX_QUEUED_MESSAGES = 16; - -class Connection { - url: string; - socket?: WebSocket; - messageQueue: string[]; - onmessage?: (data: unknown) => void; - - constructor(host: string, port: string) { - this.messageQueue = []; - if (globalThis.location.protocol === 'https:') { - this.url = `wss://${host}:${port}/api/ws`; - } else { - this.url = `ws://${host}:${port}/api/ws`; - } - this.#connect(); - } - - send(message: unknown): void { - const data = JSON.stringify(message); - if (this.socket) { - this.socket.send(data); - } else { - if (this.messageQueue.length > MAX_QUEUED_MESSAGES) { - this.messageQueue.shift(); - } - this.messageQueue.push(data); - } - } - - #connect() { - const webSocket = new WebSocket(this.url); - - webSocket.onopen = () => { - this.socket = webSocket; - this.messageQueue.forEach((message) => this.socket?.send(message)); - this.messageQueue.length = 0; - }; - webSocket.onmessage = (event) => { - const message = JSON.parse(event.data); - this.onmessage?.(message); - }; - webSocket.onerror = () => { - webSocket.close(); - }; - webSocket.onclose = () => { - this.socket = undefined; - setTimeout(() => { - this.#connect(); - }, 500); - }; - } -} - export class DataService { - connection: Connection; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - dispatchTable: Map void>; - - constructor(host: string, port: string) { - this.connection = new Connection(host, port); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - this.connection.onmessage = (message: any) => { - if (message.uid) { - this.dispatchTable.get(message.uid)?.(message); - return; - } - if (message.type === 'refresh') { - useDataset.getState().refresh(); - } else if (message.type === 'resetLayout') { - useLayout.getState().reset(); - } else if (message.type === 'issuesUpdated') { - useDataset.getState().fetchIssues(); - } - }; - this.dispatchTable = new Map(); - } - async computeUmap( widgetId: string, columnNames: string[], @@ -109,43 +26,13 @@ export class DataService { metric: UmapMetric, min_dist: number ): Promise { - const messageId = uuidv4(); - - const promise = new Promise((resolve, reject) => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - this.dispatchTable.set(messageId, (message: any) => { - if (message.type === 'tasks.error') { - const error: Problem = { - type: message.error, - title: message.title, - detail: message.detail, - }; - reject(error); - return; - } - const result = { - points: message.data.points, - indices: message.data.indices, - }; - resolve(result); - }); - }); - - this.connection.send({ - type: 'umap', - widget_id: widgetId, - uid: messageId, - generation_id: useDataset.getState().generationID, - data: { - indices: Array.from(indices), - columns: columnNames, - n_neighbors: n_neighbors, - metric: metric, - min_dist: min_dist, - }, + return await taskService.run('umap', widgetId, { + column_names: columnNames, + indices: Array.from(indices), + n_neighbors, + metric, + min_dist, }); - - return promise; } async computePCA( @@ -154,38 +41,13 @@ export class DataService { indices: IndexArray, pcaNormalization: PCANormalization ): Promise { - const messageId = uuidv4(); - - const promise = new Promise((resolve) => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - this.dispatchTable.set(messageId, (message: any) => { - const result = { - points: message.data.points, - indices: message.data.indices, - }; - resolve(result); - }); - }); - - this.connection.send({ - type: 'pca', - widget_id: widgetId, - uid: messageId, - generation_id: useDataset.getState().generationID, - data: { - indices: Array.from(indices), - columns: columnNames, - normalization: pcaNormalization, - }, + return await taskService.run('pca', widgetId, { + indices: Array.from(indices), + column_names: columnNames, + normalization: pcaNormalization, }); - - return promise; } } -const dataService = new DataService( - globalThis.location.hostname, - globalThis.location.port -); - +const dataService = new DataService(); export default dataService; diff --git a/src/services/task.ts b/src/services/task.ts new file mode 100644 index 00000000..0a74d58e --- /dev/null +++ b/src/services/task.ts @@ -0,0 +1,46 @@ +import { useDataset } from '../lib'; +import websocketService, { WebsocketService } from './websocket'; + +interface ResponseHandler { + resolve: (result: unknown) => void; + reject: (error: unknown) => void; +} + +// service handling the execution of remote tasks +class TaskService { + // dispatch table for task responses/errors + // eslint-disable-next-line @typescript-eslint/no-explicit-any + dispatchTable: Map; + websocketService: WebsocketService; + + constructor(websocketService: WebsocketService) { + this.dispatchTable = new Map(); + this.websocketService = websocketService; + this.websocketService.registerMessageHandler('task.result', (message: any) => { + this.dispatchTable.get(message.data.task_id)?.resolve(message.data.result); + }); + this.websocketService.registerMessageHandler('task.error', (message: any) => { + this.dispatchTable.get(message.data.task_id)?.reject(message.data.error); + }); + } + + async run(task: string, name: string, args: unknown) { + return new Promise((resolve, reject) => { + const task_id = crypto.randomUUID(); + this.dispatchTable.set(task_id, { resolve, reject }); + websocketService.send({ + type: 'task', + data: { + task: task, + widget_id: name, + task_id, + generation_id: useDataset.getState().generationID, + args, + }, + }); + }); + } +} + +const taskService = new TaskService(websocketService); +export default taskService; diff --git a/src/services/websocket/connection.ts b/src/services/websocket/connection.ts new file mode 100644 index 00000000..84452ca7 --- /dev/null +++ b/src/services/websocket/connection.ts @@ -0,0 +1,61 @@ +// the maximum number of outgoing messages that are queued, +import { Message } from './types'; + +// while the connection is down +const MAX_QUEUED_MESSAGES = 16; + +// a websocket connection to the spotlight backend +// automatically reconnects when necessary +class Connection { + url: string; + socket?: WebSocket; + messageQueue: string[]; + onmessage?: (data: unknown) => void; + + constructor(host: string, port: string) { + this.messageQueue = []; + if (globalThis.location.protocol === 'https:') { + this.url = `wss://${host}:${port}/api/ws`; + } else { + this.url = `ws://${host}:${port}/api/ws`; + } + this.#connect(); + } + + send(message: Message): void { + const data = JSON.stringify(message); + if (this.socket) { + this.socket.send(data); + } else { + if (this.messageQueue.length > MAX_QUEUED_MESSAGES) { + this.messageQueue.shift(); + } + this.messageQueue.push(data); + } + } + + #connect() { + const webSocket = new WebSocket(this.url); + + webSocket.onopen = () => { + this.socket = webSocket; + this.messageQueue.forEach((message) => this.socket?.send(message)); + this.messageQueue.length = 0; + }; + webSocket.onmessage = (event) => { + const message = JSON.parse(event.data); + this.onmessage?.(message); + }; + webSocket.onerror = () => { + webSocket.close(); + }; + webSocket.onclose = () => { + this.socket = undefined; + setTimeout(() => { + this.#connect(); + }, 500); + }; + } +} + +export default Connection; diff --git a/src/services/websocket/index.ts b/src/services/websocket/index.ts new file mode 100644 index 00000000..9183ded2 --- /dev/null +++ b/src/services/websocket/index.ts @@ -0,0 +1,49 @@ +import Connection from './connection'; +import { Message, MessageHandler } from './types'; + +// this service provides a websocket connection to the service +// and handles incoming and outgoing messages +export class WebsocketService { + connection: Connection; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + messageHandlers: Map; + + constructor(host: string, port: string) { + this.messageHandlers = new Map(); + this.connection = new Connection(host, port); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + this.connection.onmessage = (message: any) => { + const messageHandler = this.messageHandlers.get(message.type); + if (messageHandler) { + messageHandler(message); + } else { + console.error(`Unknown websocket message: ${message.type}`); + } + }; + } + + registerMessageHandler(messageType: string, handler: MessageHandler): void { + this.messageHandlers.set(messageType, handler); + } + + send(message: Message): void { + this.connection.send(message); + } +} + +const websocketService = new WebsocketService( + globalThis.location.hostname, + globalThis.location.port +); +export default websocketService; + +// TODO: reimplement/move +/* + if (message.type === 'refresh') { + useDataset.getState().refresh(); + } else if (message.type === 'resetLayout') { + useLayout.getState().reset(); + } else if (message.type === 'issuesUpdated') { + useDataset.getState().fetchIssues(); + } +*/ diff --git a/src/services/websocket/types.ts b/src/services/websocket/types.ts new file mode 100644 index 00000000..f33f9eee --- /dev/null +++ b/src/services/websocket/types.ts @@ -0,0 +1,6 @@ +export interface Message { + type: string; + data: any; +} + +export type MessageHandler = (message: Message) => void; From 6fd6f8cac0ee08c0deefc2cbf8b5e73d05387053 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 12:58:11 +0200 Subject: [PATCH 19/31] chore: reimplement missing message handlers in fe --- src/services/websocket/index.ts | 6 +----- src/stores/dataset/dataset.ts | 9 +++++++++ src/stores/layout.ts | 5 +++++ 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/services/websocket/index.ts b/src/services/websocket/index.ts index 9183ded2..efc344e1 100644 --- a/src/services/websocket/index.ts +++ b/src/services/websocket/index.ts @@ -37,13 +37,9 @@ const websocketService = new WebsocketService( ); export default websocketService; -// TODO: reimplement/move +// TODO: these functions should probably be somewhere else /* - if (message.type === 'refresh') { - useDataset.getState().refresh(); } else if (message.type === 'resetLayout') { useLayout.getState().reset(); - } else if (message.type === 'issuesUpdated') { - useDataset.getState().fetchIssues(); } */ diff --git a/src/stores/dataset/dataset.ts b/src/stores/dataset/dataset.ts index 0cd3e3c0..4916c015 100644 --- a/src/stores/dataset/dataset.ts +++ b/src/stores/dataset/dataset.ts @@ -22,6 +22,7 @@ import { notifyAPIError, notifyError } from '../../notify'; import { makeColumnsColorTransferFunctions } from './colorTransferFunctionFactory'; import { makeColumn } from './columnFactory'; import { makeColumnsStats } from './statisticsFactory'; +import websocketService from '../../services/websocket'; export type CallbackOrData = ((data: T) => T) | T; @@ -540,3 +541,11 @@ useDataset.subscribe( ); useColors.subscribe(() => useDataset.getState().recomputeColorTransferFunctions()); + +websocketService.registerMessageHandler('refresh', () => { + useDataset.getState().refresh(); +}); + +websocketService.registerMessageHandler('issuesUpdated', () => { + useDataset.getState().fetchIssues(); +}); diff --git a/src/stores/layout.ts b/src/stores/layout.ts index 8eb5f517..53d4a993 100644 --- a/src/stores/layout.ts +++ b/src/stores/layout.ts @@ -2,6 +2,7 @@ import { create } from 'zustand'; import { AppLayout } from '../types'; import api from '../api'; import { saveAs } from 'file-saver'; +import websocketService from '../services/websocket'; export interface State { layout: AppLayout; @@ -41,3 +42,7 @@ export const useLayout = create((set) => ({ reader.readAsText(file); }, })); + +websocketService.registerMessageHandler('resetLayout', () => { + useLayout.getState().reset(); +}); From a1d7cb919723be515a6c8429935510f92de3c7dd Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 13:23:12 +0200 Subject: [PATCH 20/31] feat: handle errors during message serialization --- renumics/spotlight/backend/tasks/reduction.py | 5 +--- renumics/spotlight/backend/websockets.py | 27 ++++++++++++++----- src/services/task.ts | 23 +++++++++++----- src/services/websocket/index.ts | 13 +++++---- 4 files changed, 45 insertions(+), 23 deletions(-) diff --git a/renumics/spotlight/backend/tasks/reduction.py b/renumics/spotlight/backend/tasks/reduction.py index 00207c38..2b6f3180 100644 --- a/renumics/spotlight/backend/tasks/reduction.py +++ b/renumics/spotlight/backend/tasks/reduction.py @@ -116,10 +116,7 @@ def compute_pca( from sklearn import preprocessing, decomposition - try: - data, indices = align_data(data_store, column_names, indices) - except (ColumnNotExistsError, ValueError): - return np.empty(0, np.float64), [] + data, indices = align_data(data_store, column_names, indices) if data.size == 0: return np.empty(0, np.float64), [] if data.shape[1] == 1: diff --git a/renumics/spotlight/backend/websockets.py b/renumics/spotlight/backend/websockets.py index 30d15652..99187dd8 100644 --- a/renumics/spotlight/backend/websockets.py +++ b/renumics/spotlight/backend/websockets.py @@ -108,11 +108,25 @@ async def send_async(self, message: Message) -> None: """ Send a message async. """ + try: - message_data = message.dict() - await self.websocket.send_text( - orjson.dumps(message_data, option=orjson.OPT_SERIALIZE_NUMPY).decode() + json_text = orjson.dumps( + message.dict(), option=orjson.OPT_SERIALIZE_NUMPY + ).decode() + except TypeError as e: + error_message = Message( + type="error", + data={ + "type": type(e).__name__, + "title": "Failed to serialize message", + "detail": str(e), + }, ) + json_text = orjson.dumps( + error_message.dict(), option=orjson.OPT_SERIALIZE_NUMPY + ).decode() + try: + await self.websocket.send_text(json_text) except WebSocketDisconnect: self._on_disconnect() except RuntimeError: @@ -256,7 +270,6 @@ async def _(data: TaskData, connection: WebsocketConnection) -> None: try: task_func = TASK_FUNCS[data.task] - print(data.args) result = await connection.task_manager.run_async( task_func, # type: ignore args=(data_store,), @@ -267,6 +280,7 @@ async def _(data: TaskData, connection: WebsocketConnection) -> None: points = cast(np.ndarray, result[0]) valid_indices = cast(np.ndarray, result[1]) except TaskCancelled: + print("task cancelled") pass except Problem as e: msg = Message( @@ -275,13 +289,14 @@ async def _(data: TaskData, connection: WebsocketConnection) -> None: "task_id": data.task_id, "error": { "type": type(e).__name__, - "title": type(e).__name__, - "detail": type(e).__doc__, + "title": e.title, + "detail": e.detail, }, }, ) await connection.send_async(msg) except Exception as e: + print("task failed") msg = Message( type="task.error", data={ diff --git a/src/services/task.ts b/src/services/task.ts index 0a74d58e..2d16c45f 100644 --- a/src/services/task.ts +++ b/src/services/task.ts @@ -1,5 +1,6 @@ import { useDataset } from '../lib'; import websocketService, { WebsocketService } from './websocket'; +import { Message } from './websocket/types'; interface ResponseHandler { resolve: (result: unknown) => void; @@ -16,12 +17,22 @@ class TaskService { constructor(websocketService: WebsocketService) { this.dispatchTable = new Map(); this.websocketService = websocketService; - this.websocketService.registerMessageHandler('task.result', (message: any) => { - this.dispatchTable.get(message.data.task_id)?.resolve(message.data.result); - }); - this.websocketService.registerMessageHandler('task.error', (message: any) => { - this.dispatchTable.get(message.data.task_id)?.reject(message.data.error); - }); + this.websocketService.registerMessageHandler( + 'task.result', + (message: Message) => { + this.dispatchTable + .get(message.data.task_id) + ?.resolve(message.data.result); + } + ); + this.websocketService.registerMessageHandler( + 'task.error', + (message: Message) => { + this.dispatchTable + .get(message.data.task_id) + ?.reject(message.data.error); + } + ); } async run(task: string, name: string, args: unknown) { diff --git a/src/services/websocket/index.ts b/src/services/websocket/index.ts index efc344e1..2d5e241a 100644 --- a/src/services/websocket/index.ts +++ b/src/services/websocket/index.ts @@ -1,3 +1,4 @@ +import { notifyProblem } from '../../notify'; import Connection from './connection'; import { Message, MessageHandler } from './types'; @@ -35,11 +36,9 @@ const websocketService = new WebsocketService( globalThis.location.hostname, globalThis.location.port ); -export default websocketService; -// TODO: these functions should probably be somewhere else -/* - } else if (message.type === 'resetLayout') { - useLayout.getState().reset(); - } -*/ +websocketService.registerMessageHandler('error', (message: Message) => { + notifyProblem(message.data); +}); + +export default websocketService; From 480c9009479a2d180a4e09c2346edb8424bc20c1 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 13:31:46 +0200 Subject: [PATCH 21/31] feat: raise serialization errors and handle at call site --- renumics/spotlight/backend/websockets.py | 36 +++++++++++++++--------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/renumics/spotlight/backend/websockets.py b/renumics/spotlight/backend/websockets.py index 99187dd8..34a39d77 100644 --- a/renumics/spotlight/backend/websockets.py +++ b/renumics/spotlight/backend/websockets.py @@ -70,6 +70,12 @@ class UnknownMessageType(Exception): """ +class SerializationError(Exception): + """ + Failed to serialize the WS message + """ + + PayloadType = Type[BaseModel] MessageHandler = Callable[[Any, "WebsocketConnection"], Coroutine[Any, Any, Any]] MessageHandlerSpec = Tuple[PayloadType, MessageHandler] @@ -114,17 +120,7 @@ async def send_async(self, message: Message) -> None: message.dict(), option=orjson.OPT_SERIALIZE_NUMPY ).decode() except TypeError as e: - error_message = Message( - type="error", - data={ - "type": type(e).__name__, - "title": "Failed to serialize message", - "detail": str(e), - }, - ) - json_text = orjson.dumps( - error_message.dict(), option=orjson.OPT_SERIALIZE_NUMPY - ).decode() + raise SerializationError(str(e)) try: await self.websocket.send_text(json_text) except WebSocketDisconnect: @@ -280,7 +276,6 @@ async def _(data: TaskData, connection: WebsocketConnection) -> None: points = cast(np.ndarray, result[0]) valid_indices = cast(np.ndarray, result[1]) except TaskCancelled: - print("task cancelled") pass except Problem as e: msg = Message( @@ -296,7 +291,6 @@ async def _(data: TaskData, connection: WebsocketConnection) -> None: ) await connection.send_async(msg) except Exception as e: - print("task failed") msg = Message( type="task.error", data={ @@ -313,4 +307,18 @@ async def _(data: TaskData, connection: WebsocketConnection) -> None: msg = Message( type="task.result", data={"points": points, "indices": valid_indices} ) - await connection.send_async(msg) + try: + await connection.send_async(msg) + except SerializationError as e: + error_msg = Message( + type="task.error", + data={ + "task_id": data.task_id, + "error": { + "type": type(e).__name__, + "title": "Serialization Error", + "detail": str(e), + }, + }, + ) + await connection.send_async(error_msg) From fd6c71e232aa59351f05392abb6d6cc308c2ff54 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 13:53:42 +0200 Subject: [PATCH 22/31] refactor: change format of task result --- renumics/spotlight/backend/websockets.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/renumics/spotlight/backend/websockets.py b/renumics/spotlight/backend/websockets.py index 34a39d77..aed9a1ef 100644 --- a/renumics/spotlight/backend/websockets.py +++ b/renumics/spotlight/backend/websockets.py @@ -305,7 +305,11 @@ async def _(data: TaskData, connection: WebsocketConnection) -> None: await connection.send_async(msg) else: msg = Message( - type="task.result", data={"points": points, "indices": valid_indices} + type="task.result", + data={ + "task_id": data.task_id, + "result": {"points": points, "indices": valid_indices}, + }, ) try: await connection.send_async(msg) From 2c0ef37abe2fda813f809c1ec9e1774df1c77061 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 15:03:00 +0200 Subject: [PATCH 23/31] chore: fix linter errors --- src/services/task.ts | 24 ++++++--------------- src/services/websocket/index.ts | 5 +++-- src/services/websocket/types.ts | 5 +++-- src/widgets/SimilarityMap/SimilarityMap.tsx | 1 - 4 files changed, 13 insertions(+), 22 deletions(-) diff --git a/src/services/task.ts b/src/services/task.ts index 2d16c45f..f95a5d5e 100644 --- a/src/services/task.ts +++ b/src/services/task.ts @@ -1,6 +1,5 @@ import { useDataset } from '../lib'; import websocketService, { WebsocketService } from './websocket'; -import { Message } from './websocket/types'; interface ResponseHandler { resolve: (result: unknown) => void; @@ -17,25 +16,16 @@ class TaskService { constructor(websocketService: WebsocketService) { this.dispatchTable = new Map(); this.websocketService = websocketService; - this.websocketService.registerMessageHandler( - 'task.result', - (message: Message) => { - this.dispatchTable - .get(message.data.task_id) - ?.resolve(message.data.result); - } - ); - this.websocketService.registerMessageHandler( - 'task.error', - (message: Message) => { - this.dispatchTable - .get(message.data.task_id) - ?.reject(message.data.error); - } - ); + this.websocketService.registerMessageHandler('task.result', (data) => { + this.dispatchTable.get(data.task_id)?.resolve(data.result); + }); + this.websocketService.registerMessageHandler('task.error', (data) => { + this.dispatchTable.get(data.task_id)?.reject(data.error); + }); } async run(task: string, name: string, args: unknown) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any return new Promise((resolve, reject) => { const task_id = crypto.randomUUID(); this.dispatchTable.set(task_id, { resolve, reject }); diff --git a/src/services/websocket/index.ts b/src/services/websocket/index.ts index 2d5e241a..e722e1de 100644 --- a/src/services/websocket/index.ts +++ b/src/services/websocket/index.ts @@ -1,4 +1,5 @@ import { notifyProblem } from '../../notify'; +import { Problem } from '../../types'; import Connection from './connection'; import { Message, MessageHandler } from './types'; @@ -37,8 +38,8 @@ const websocketService = new WebsocketService( globalThis.location.port ); -websocketService.registerMessageHandler('error', (message: Message) => { - notifyProblem(message.data); +websocketService.registerMessageHandler('error', (problem: Problem) => { + notifyProblem(problem); }); export default websocketService; diff --git a/src/services/websocket/types.ts b/src/services/websocket/types.ts index f33f9eee..18840150 100644 --- a/src/services/websocket/types.ts +++ b/src/services/websocket/types.ts @@ -1,6 +1,7 @@ export interface Message { type: string; - data: any; + data: unknown; } -export type MessageHandler = (message: Message) => void; +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type MessageHandler = (data: any) => void; diff --git a/src/widgets/SimilarityMap/SimilarityMap.tsx b/src/widgets/SimilarityMap/SimilarityMap.tsx index 5c42ebc9..023c4dd1 100644 --- a/src/widgets/SimilarityMap/SimilarityMap.tsx +++ b/src/widgets/SimilarityMap/SimilarityMap.tsx @@ -1,5 +1,4 @@ import SimilaritiesIcon from '../../icons/Bubbles'; -import WarningIcon from '../../icons/Warning'; import LoadingIndicator from '../../components/LoadingIndicator'; import Plot, { MergeStrategy, From a0f6fa41580e0d014597b76d3921faf82306ca4b Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 15:39:18 +0200 Subject: [PATCH 24/31] fix: always open browser on localhost --- renumics/spotlight/viewer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/renumics/spotlight/viewer.py b/renumics/spotlight/viewer.py index b7630b17..8f859ad1 100644 --- a/renumics/spotlight/viewer.py +++ b/renumics/spotlight/viewer.py @@ -220,7 +220,7 @@ def open_browser(self) -> None: """ if not self.port: return - launch_browser_in_thread(self.host, self.port) + launch_browser_in_thread("localhost", self.port) def refresh(self) -> None: """ From bec35978598332766942e5a4d2ed777455b51c5e Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Wed, 25 Oct 2023 15:48:57 +0200 Subject: [PATCH 25/31] Do not fail when simple value conversion fails in `get_table` request --- renumics/spotlight_plugins/core/api/table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/renumics/spotlight_plugins/core/api/table.py b/renumics/spotlight_plugins/core/api/table.py index 1446bb74..807fdc40 100644 --- a/renumics/spotlight_plugins/core/api/table.py +++ b/renumics/spotlight_plugins/core/api/table.py @@ -73,7 +73,7 @@ def get_table(request: Request) -> ORJSONResponse: columns = [] for column_name in data_store.column_names: dtype = data_store.dtypes[column_name] - values = data_store.get_converted_values(column_name, simple=True) + values = data_store.get_converted_values(column_name, simple=True, check=False) meta = data_store.get_column_metadata(column_name) column = Column( name=column_name, From 6d53da275c437ff59760aecc684484fc77cdd366 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Wed, 25 Oct 2023 15:49:38 +0200 Subject: [PATCH 26/31] Return `` for failed simple conversion of complex values --- renumics/spotlight/dtypes/conversion.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/renumics/spotlight/dtypes/conversion.py b/renumics/spotlight/dtypes/conversion.py index 48c90dde..21f129b4 100644 --- a/renumics/spotlight/dtypes/conversion.py +++ b/renumics/spotlight/dtypes/conversion.py @@ -37,8 +37,8 @@ import trimesh import PIL.Image import validators -from renumics.spotlight import dtypes +from renumics.spotlight import dtypes from renumics.spotlight.typing import PathOrUrlType, PathType from renumics.spotlight.cache import external_data_cache from renumics.spotlight.io import audio @@ -221,15 +221,19 @@ def convert_to_dtype( except (TypeError, ValueError) as e: if check: raise ConversionFailed(value, dtype) from e - else: - return None - - if check: - if last_conversion_error: - raise ConversionFailed(value, dtype, last_conversion_error.reason) - else: + else: + if check: + if last_conversion_error: + raise ConversionFailed(value, dtype, last_conversion_error.reason) raise NoConverterAvailable(value, dtype) + if simple and ( + dtypes.is_array_dtype(dtype) + or dtypes.is_embedding_dtype(dtype) + or dtypes.is_sequence_1d_dtype(dtype) + or dtypes.is_filebased_dtype(dtype) + ): + return "" return None From 723092a9f1ac0e8f1c53fc06e1b5faafed00b277 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Wed, 25 Oct 2023 15:53:27 +0200 Subject: [PATCH 27/31] Standardize imports --- renumics/spotlight/dtypes/conversion.py | 118 +++++++++++------------- 1 file changed, 56 insertions(+), 62 deletions(-) diff --git a/renumics/spotlight/dtypes/conversion.py b/renumics/spotlight/dtypes/conversion.py index 21f129b4..1488b79f 100644 --- a/renumics/spotlight/dtypes/conversion.py +++ b/renumics/spotlight/dtypes/conversion.py @@ -38,22 +38,13 @@ import PIL.Image import validators -from renumics.spotlight import dtypes +from renumics.spotlight import dtypes, media from renumics.spotlight.typing import PathOrUrlType, PathType from renumics.spotlight.cache import external_data_cache from renumics.spotlight.io import audio from renumics.spotlight.io.file import as_file from renumics.spotlight.media.exceptions import InvalidFile from renumics.spotlight.backend.exceptions import Problem -from renumics.spotlight.dtypes import ( - CategoryDType, - DType, - audio_dtype, - image_dtype, - mesh_dtype, - video_dtype, -) -from renumics.spotlight.media import Sequence1D, Image, Audio, Video, Mesh NormalizedValue = Union[ @@ -96,7 +87,7 @@ class ConversionFailed(Problem): def __init__( self, value: NormalizedValue, - dtype: DType, + dtype: dtypes.DType, reason: Optional[str] = None, ) -> None: super().__init__( @@ -116,14 +107,14 @@ class NoConverterAvailable(Problem): No matching converter could be applied """ - def __init__(self, value: NormalizedValue, dtype: DType) -> None: + def __init__(self, value: NormalizedValue, dtype: dtypes.DType) -> None: msg = f"No Converter for {type(value)} -> {dtype}" super().__init__(title="No matching converter", detail=msg, status_code=422) N = TypeVar("N", bound=NormalizedValue) -Converter = Callable[[N, DType], ConvertedValue] +Converter = Callable[[N, dtypes.DType], ConvertedValue] _converters_table: Dict[ Type[NormalizedValue], Dict[str, List[Converter]] ] = defaultdict(lambda: defaultdict(list)) @@ -176,7 +167,10 @@ def _decorate(func: Converter[N]) -> Converter[N]: def convert_to_dtype( - value: NormalizedValue, dtype: DType, simple: bool = False, check: bool = True + value: NormalizedValue, + dtype: dtypes.DType, + simple: bool = False, + check: bool = True, ) -> ConvertedValue: """ Convert normalized type from data source to internal Spotlight DType @@ -238,24 +232,24 @@ def convert_to_dtype( @convert("datetime") -def _(value: datetime.datetime, _: DType) -> datetime.datetime: +def _(value: datetime.datetime, _: dtypes.DType) -> datetime.datetime: return value @convert("datetime") -def _(value: Union[str, np.str_], _: DType) -> Optional[datetime.datetime]: +def _(value: Union[str, np.str_], _: dtypes.DType) -> Optional[datetime.datetime]: if value == "": return None return datetime.datetime.fromisoformat(value) @convert("datetime") -def _(value: np.datetime64, _: DType) -> Optional[datetime.datetime]: +def _(value: np.datetime64, _: dtypes.DType) -> Optional[datetime.datetime]: return value.tolist() @convert("Category") -def _(value: Union[str, np.str_], dtype: CategoryDType) -> int: +def _(value: Union[str, np.str_], dtype: dtypes.CategoryDType) -> int: categories = dtype.categories if not categories: return -1 @@ -263,7 +257,7 @@ def _(value: Union[str, np.str_], dtype: CategoryDType) -> int: @convert("Category") -def _(_: None, _dtype: CategoryDType) -> int: +def _(_: None, _dtype: dtypes.CategoryDType) -> int: return -1 @@ -280,23 +274,23 @@ def _( np.uint32, np.uint64, ], - _: CategoryDType, + _: dtypes.CategoryDType, ) -> int: return int(value) @convert("Window") -def _(value: list, _: DType) -> np.ndarray: +def _(value: list, _: dtypes.DType) -> np.ndarray: return np.array(value, dtype=np.float64) @convert("Window") -def _(value: np.ndarray, _: DType) -> np.ndarray: +def _(value: np.ndarray, _: dtypes.DType) -> np.ndarray: return value.astype(np.float64) @convert("Window") -def _(value: Union[str, np.str_], _: DType) -> np.ndarray: +def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray: try: obj = ast.literal_eval(value) array = np.array(obj, dtype=np.float64) @@ -308,17 +302,17 @@ def _(value: Union[str, np.str_], _: DType) -> np.ndarray: @convert("Embedding", simple=False) -def _(value: list, _: DType) -> np.ndarray: +def _(value: list, _: dtypes.DType) -> np.ndarray: return np.array(value, dtype=np.float64) @convert("Embedding", simple=False) -def _(value: np.ndarray, _: DType) -> np.ndarray: +def _(value: np.ndarray, _: dtypes.DType) -> np.ndarray: return value.astype(np.float64) @convert("Embedding", simple=False) -def _(value: Union[str, np.str_], _: DType) -> np.ndarray: +def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray: try: obj = ast.literal_eval(value) array = np.array(obj, dtype=np.float64) @@ -330,23 +324,23 @@ def _(value: Union[str, np.str_], _: DType) -> np.ndarray: @convert("Sequence1D", simple=False) -def _(value: Union[np.ndarray, list], _: DType) -> np.ndarray: - return Sequence1D(value).encode() +def _(value: Union[np.ndarray, list], _: dtypes.DType) -> np.ndarray: + return media.Sequence1D(value).encode() @convert("Sequence1D", simple=False) -def _(value: Union[str, np.str_], _: DType) -> np.ndarray: +def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray: try: obj = ast.literal_eval(value) - return Sequence1D(obj).encode() + return media.Sequence1D(obj).encode() except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): raise ConversionError("Cannot interpret string as a 1D sequence") @convert("Image", simple=False) -def _(value: Union[str, np.str_], _: DType) -> bytes: +def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes: try: - if data := read_external_value(value, image_dtype): + if data := read_external_value(value, dtypes.image_dtype): return data.tolist() except InvalidFile: raise ConversionError() @@ -354,26 +348,26 @@ def _(value: Union[str, np.str_], _: DType) -> bytes: @convert("Image", simple=False) -def _(value: Union[bytes, np.bytes_], _: DType) -> bytes: - return Image.from_bytes(value).encode().tolist() +def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes: + return media.Image.from_bytes(value).encode().tolist() @convert("Image", simple=False) -def _(value: np.ndarray, _: DType) -> bytes: - return Image(value).encode().tolist() +def _(value: np.ndarray, _: dtypes.DType) -> bytes: + return media.Image(value).encode().tolist() @convert("Image", simple=False) -def _(value: PIL.Image.Image, _: DType) -> bytes: +def _(value: PIL.Image.Image, _: dtypes.DType) -> bytes: buffer = io.BytesIO() value.save(buffer, format="PNG") return buffer.getvalue() @convert("Audio", simple=False) -def _(value: Union[str, np.str_], _: DType) -> bytes: +def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes: try: - if data := read_external_value(value, audio_dtype): + if data := read_external_value(value, dtypes.audio_dtype): return data.tolist() except (InvalidFile, IndexError, ValueError): raise ConversionError() @@ -381,14 +375,14 @@ def _(value: Union[str, np.str_], _: DType) -> bytes: @convert("Audio", simple=False) -def _(value: Union[bytes, np.bytes_], _: DType) -> bytes: - return Audio.from_bytes(value).encode().tolist() +def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes: + return media.Audio.from_bytes(value).encode().tolist() @convert("Video", simple=False) -def _(value: Union[str, np.str_], _: DType) -> bytes: +def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes: try: - if data := read_external_value(value, video_dtype): + if data := read_external_value(value, dtypes.video_dtype): return data.tolist() except InvalidFile: raise ConversionError() @@ -396,14 +390,14 @@ def _(value: Union[str, np.str_], _: DType) -> bytes: @convert("Video", simple=False) -def _(value: Union[bytes, np.bytes_], _: DType) -> bytes: - return Video.from_bytes(value).encode().tolist() +def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes: + return media.Video.from_bytes(value).encode().tolist() @convert("Mesh", simple=False) -def _(value: Union[str, np.str_], _: DType) -> bytes: +def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes: try: - if data := read_external_value(value, mesh_dtype): + if data := read_external_value(value, dtypes.mesh_dtype): return data.tolist() except InvalidFile: raise ConversionError() @@ -411,29 +405,29 @@ def _(value: Union[str, np.str_], _: DType) -> bytes: @convert("Mesh", simple=False) -def _(value: Union[bytes, np.bytes_], _: DType) -> bytes: +def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes: return value # this should not be necessary @convert("Mesh", simple=False) # type: ignore -def _(value: trimesh.Trimesh, _: DType) -> bytes: - return Mesh.from_trimesh(value).encode().tolist() +def _(value: trimesh.Trimesh, _: dtypes.DType) -> bytes: + return media.Mesh.from_trimesh(value).encode().tolist() @convert("Embedding", simple=True) @convert("Sequence1D", simple=True) -def _(_: Union[np.ndarray, list, str, np.str_], _dtype: DType) -> str: +def _(_: Union[np.ndarray, list, str, np.str_], _dtype: dtypes.DType) -> str: return "[...]" @convert("Image", simple=True) -def _(_: np.ndarray, _dtype: DType) -> str: +def _(_: np.ndarray, _dtype: dtypes.DType) -> str: return "[...]" @convert("Image", simple=True) -def _(_: PIL.Image.Image, _dtype: DType) -> str: +def _(_: PIL.Image.Image, _dtype: dtypes.DType) -> str: return "" @@ -441,7 +435,7 @@ def _(_: PIL.Image.Image, _dtype: DType) -> str: @convert("Audio", simple=True) @convert("Video", simple=True) @convert("Mesh", simple=True) -def _(value: Union[str, np.str_], _: DType) -> str: +def _(value: Union[str, np.str_], _: dtypes.DType) -> str: return str(value) @@ -449,19 +443,19 @@ def _(value: Union[str, np.str_], _: DType) -> str: @convert("Audio", simple=True) @convert("Video", simple=True) @convert("Mesh", simple=True) -def _(_: Union[bytes, np.bytes_], _dtype: DType) -> str: +def _(_: Union[bytes, np.bytes_], _dtype: dtypes.DType) -> str: return "" # this should not be necessary @convert("Mesh", simple=True) # type: ignore -def _(_: trimesh.Trimesh, _dtype: DType) -> str: +def _(_: trimesh.Trimesh, _dtype: dtypes.DType) -> str: return "" def read_external_value( path_or_url: Optional[str], - dtype: DType, + dtype: dtypes.DType, target_format: Optional[str] = None, workdir: PathType = ".", ) -> Optional[np.void]: @@ -498,7 +492,7 @@ def prepare_path_or_url(path_or_url: PathOrUrlType, workdir: PathType) -> str: def _decode_external_value( path_or_url: PathOrUrlType, - dtype: DType, + dtype: dtypes.DType, target_format: Optional[str] = None, workdir: PathType = ".", ) -> np.void: @@ -528,7 +522,7 @@ def _decode_external_value( # Convert all other formats/codecs to flac. output_format, output_codec = "flac", "flac" else: - output_format, output_codec = Audio.get_format_codec(target_format) + output_format, output_codec = media.Audio.get_format_codec(target_format) if output_format == input_format and output_codec == input_codec: # Nothing to transcode if isinstance(file, str): @@ -554,10 +548,10 @@ def _decode_external_value( ): return np.void(file.read()) # `image/tiff`s become blank in frontend, so convert them too. - return Image.from_file(file).encode(target_format) + return media.Image.from_file(file).encode(target_format) if dtypes.is_mesh_dtype(dtype): - return Mesh.from_file(path_or_url).encode(target_format) + return media.Mesh.from_file(path_or_url).encode(target_format) if dtypes.is_video_dtype(dtype): - return Video.from_file(path_or_url).encode(target_format) + return media.Video.from_file(path_or_url).encode(target_format) assert False From 63a28459cb6ff65b721e5db0b5a032b22fc81bb5 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Wed, 25 Oct 2023 16:03:05 +0200 Subject: [PATCH 28/31] Support array-likes and literal strings in image conversion --- renumics/spotlight/dtypes/conversion.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/renumics/spotlight/dtypes/conversion.py b/renumics/spotlight/dtypes/conversion.py index 1488b79f..c5b7cc12 100644 --- a/renumics/spotlight/dtypes/conversion.py +++ b/renumics/spotlight/dtypes/conversion.py @@ -337,13 +337,22 @@ def _(value: Union[str, np.str_], _: dtypes.DType) -> np.ndarray: raise ConversionError("Cannot interpret string as a 1D sequence") +@convert("Image", simple=False) +def _(value: Union[np.ndarray, list], _: dtypes.DType) -> bytes: + return media.Image(value).encode().tolist() + + @convert("Image", simple=False) def _(value: Union[str, np.str_], _: dtypes.DType) -> bytes: try: if data := read_external_value(value, dtypes.image_dtype): return data.tolist() except InvalidFile: - raise ConversionError() + try: + obj = ast.literal_eval(value) + return media.Image(obj).encode().tolist() + except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): + raise ConversionError() raise ConversionError() @@ -352,11 +361,6 @@ def _(value: Union[bytes, np.bytes_], _: dtypes.DType) -> bytes: return media.Image.from_bytes(value).encode().tolist() -@convert("Image", simple=False) -def _(value: np.ndarray, _: dtypes.DType) -> bytes: - return media.Image(value).encode().tolist() - - @convert("Image", simple=False) def _(value: PIL.Image.Image, _: dtypes.DType) -> bytes: buffer = io.BytesIO() @@ -422,7 +426,7 @@ def _(_: Union[np.ndarray, list, str, np.str_], _dtype: dtypes.DType) -> str: @convert("Image", simple=True) -def _(_: np.ndarray, _dtype: dtypes.DType) -> str: +def _(_: Union[np.ndarray, list], _dtype: dtypes.DType) -> str: return "[...]" From e36c79751803b5e0bff0981e0b0368fe548784a1 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 16:08:57 +0200 Subject: [PATCH 29/31] fix: never cache frontend files --- renumics/spotlight/app.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/renumics/spotlight/app.py b/renumics/spotlight/app.py index 62220d2f..fc6993d4 100644 --- a/renumics/spotlight/app.py +++ b/renumics/spotlight/app.py @@ -15,6 +15,7 @@ from typing_extensions import Annotated from fastapi import Cookie, FastAPI, Request, status +from fastapi.datastructures import Headers from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response from fastapi.staticfiles import StaticFiles @@ -56,10 +57,23 @@ from renumics.spotlight.dtypes import DTypeMap - CURRENT_LAYOUT_KEY = "layout.current" +class UncachedStaticFiles(StaticFiles): + """ + FastAPI StaticFiles but without caching + """ + + def is_not_modified( + self, response_headers: Headers, request_headers: Headers + ) -> bool: + """ + Never Cache + """ + return False + + class IssuesUpdatedMessage(Message): """ Notify about updated issues. @@ -207,9 +221,13 @@ async def _(_: Request, problem: Problem) -> JSONResponse: plugin.activate(self) try: + # Mount frontend files as uncached, + # so that we always get the new frontend after updating spotlight. + # NOTE: we might not need this if we added a version hash + # to our built js files self.mount( "/static", - StaticFiles(packages=["renumics.spotlight.backend"]), + UncachedStaticFiles(packages=["renumics.spotlight.backend"]), name="assets", ) except AssertionError: From 965c4716290e415e05fe961ca34856422c618973 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 25 Oct 2023 18:03:18 +0200 Subject: [PATCH 30/31] fix: pass correct part of message to handlers --- src/services/websocket/index.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/services/websocket/index.ts b/src/services/websocket/index.ts index e722e1de..135ea37c 100644 --- a/src/services/websocket/index.ts +++ b/src/services/websocket/index.ts @@ -17,7 +17,7 @@ export class WebsocketService { this.connection.onmessage = (message: any) => { const messageHandler = this.messageHandlers.get(message.type); if (messageHandler) { - messageHandler(message); + messageHandler(message.data); } else { console.error(`Unknown websocket message: ${message.type}`); } From eca76962e2feb4de8238be6ae1bfbd19cbca1fd4 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Thu, 26 Oct 2023 11:00:05 +0200 Subject: [PATCH 31/31] feat: send viewer update errors to and re-raise in parent process --- renumics/spotlight/app.py | 75 ++++++++++++++++++------------------ renumics/spotlight/server.py | 27 +++++++++---- renumics/spotlight/viewer.py | 6 ++- 3 files changed, 62 insertions(+), 46 deletions(-) diff --git a/renumics/spotlight/app.py b/renumics/spotlight/app.py index fc6993d4..5baa724d 100644 --- a/renumics/spotlight/app.py +++ b/renumics/spotlight/app.py @@ -89,7 +89,6 @@ class SpotlightApp(FastAPI): """ # lifecycle - _startup_complete: bool _loop: asyncio.AbstractEventLoop # connection @@ -120,7 +119,6 @@ class SpotlightApp(FastAPI): def __init__(self) -> None: super().__init__() - self._startup_complete = False self.task_manager = TaskManager() self.websocket_manager = None self.config = Config() @@ -313,44 +311,45 @@ def update(self, config: AppConfig) -> None: """ Update application config. """ - if config.project_root is not None: - self.project_root = config.project_root - if config.dtypes is not None: - self._user_dtypes = config.dtypes - if config.analyze is not None: - self.analyze_columns = config.analyze - if config.custom_issues is not None: - self.custom_issues = config.custom_issues - if config.dataset is not None: - self._dataset = config.dataset - self._data_source = create_datasource(self._dataset) - if config.layout is not None: - self._layout = config.layout or layouts.default() - if config.filebrowsing_allowed is not None: - self.filebrowsing_allowed = config.filebrowsing_allowed - - if config.dtypes is not None or config.dataset is not None: - data_source = self._data_source - assert data_source is not None - self._data_store = DataStore(data_source, self._user_dtypes) - self._broadcast(RefreshMessage()) - self._update_issues() - if config.layout is not None: - if self._data_store is not None: - dataset_uid = self._data_store.uid - future = asyncio.run_coroutine_threadsafe( - self.config.remove_all(CURRENT_LAYOUT_KEY, dataset=dataset_uid), - self._loop, - ) - future.result() - self._broadcast(ResetLayoutMessage()) + try: + if config.project_root is not None: + self.project_root = config.project_root + if config.dtypes is not None: + self._user_dtypes = config.dtypes + if config.analyze is not None: + self.analyze_columns = config.analyze + if config.custom_issues is not None: + self.custom_issues = config.custom_issues + if config.dataset is not None: + self._dataset = config.dataset + self._data_source = create_datasource(self._dataset) + if config.layout is not None: + self._layout = config.layout or layouts.default() + if config.filebrowsing_allowed is not None: + self.filebrowsing_allowed = config.filebrowsing_allowed + + if config.dtypes is not None or config.dataset is not None: + data_source = self._data_source + assert data_source is not None + self._data_store = DataStore(data_source, self._user_dtypes) + self._broadcast(RefreshMessage()) + self._update_issues() + if config.layout is not None: + if self._data_store is not None: + dataset_uid = self._data_store.uid + future = asyncio.run_coroutine_threadsafe( + self.config.remove_all(CURRENT_LAYOUT_KEY, dataset=dataset_uid), + self._loop, + ) + future.result() + self._broadcast(ResetLayoutMessage()) - for plugin in load_plugins(): - plugin.update(self, config) + for plugin in load_plugins(): + plugin.update(self, config) + except Exception as e: + self._connection.send({"kind": "update_complete", "error": e}) - if not self._startup_complete: - self._startup_complete = True - self._connection.send({"kind": "startup_complete"}) + self._connection.send({"kind": "update_complete"}) def _handle_message(self, message: Any) -> None: kind = message.get("kind") diff --git a/renumics/spotlight/server.py b/renumics/spotlight/server.py index 83d688db..bf56918d 100644 --- a/renumics/spotlight/server.py +++ b/renumics/spotlight/server.py @@ -41,6 +41,8 @@ class Server: process: Optional[subprocess.Popen] _startup_event: threading.Event + _update_complete_event: threading.Event + _update_error: Optional[Exception] connection: Optional[multiprocessing.connection.Connection] _connection_message_queue: Queue @@ -77,7 +79,7 @@ def __init__(self, host: str = "127.0.0.1", port: int = 8000) -> None: ) self._startup_event = threading.Event() - self._startup_complete_event = threading.Event() + self._update_complete_event = threading.Event() self._connection_thread_online = threading.Event() self._connection_thread = threading.Thread( @@ -151,7 +153,6 @@ def start(self, config: AppConfig) -> None: command.append("--reload") # start uvicorn - self.process = subprocess.Popen( command, env=env, @@ -164,7 +165,7 @@ def start(self, config: AppConfig) -> None: ) if platform.system() != "Windows": sock.close() - self._startup_complete_event.wait(timeout=120) + self._wait_for_update() def stop(self) -> None: """ @@ -197,7 +198,6 @@ def stop(self) -> None: self._port = self._requested_port self._startup_event.clear() - self._startup_complete_event.clear() @property def running(self) -> bool: @@ -217,9 +217,21 @@ def update(self, config: AppConfig) -> None: """ Update app config """ + self._update(config) + self._wait_for_update() + + def _update(self, config: AppConfig) -> None: self._app_config = config self.send({"kind": "update", "data": config}) + def _wait_for_update(self) -> None: + self._update_complete_event.wait(timeout=120) + self._update_complete_event.clear() + err = self._update_error + self._update_error = None + if err: + raise err + def get_df(self) -> Optional[pd.DataFrame]: """ Request and return the current DafaFrame from the server process (if possible) @@ -236,9 +248,10 @@ def _handle_message(self, message: Any) -> None: if kind == "startup": self._startup_event.set() - self.update(self._app_config) - elif kind == "startup_complete": - self._startup_complete_event.set() + self._update(self._app_config) + elif kind == "update_complete": + self._update_error = message.get("error") + self._update_complete_event.set() elif kind == "frontend_connected": self.connected_frontends = message["data"] self._all_frontends_disconnected.clear() diff --git a/renumics/spotlight/viewer.py b/renumics/spotlight/viewer.py index 8f859ad1..817f31b2 100644 --- a/renumics/spotlight/viewer.py +++ b/renumics/spotlight/viewer.py @@ -176,7 +176,11 @@ def show( if self not in _VIEWERS: _VIEWERS.append(self) else: - self._server.update(config) + try: + self._server.update(config) + except Exception as e: + self.close() + raise e if not no_browser and self._server.connected_frontends == 0: self.open_browser()