From 472937840913ef27a11a0a1a66544e8d52477b1b Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Mon, 23 Oct 2023 14:02:26 +0200 Subject: [PATCH] Fix checking scalar dtypes for data alignment before embedding --- renumics/spotlight/backend/tasks/reduction.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/renumics/spotlight/backend/tasks/reduction.py b/renumics/spotlight/backend/tasks/reduction.py index 00207c38..9a8d71c1 100644 --- a/renumics/spotlight/backend/tasks/reduction.py +++ b/renumics/spotlight/backend/tasks/reduction.py @@ -10,7 +10,7 @@ from renumics.spotlight.dataset.exceptions import ColumnNotExistsError from renumics.spotlight.data_store import DataStore -from renumics.spotlight.dtypes import is_category_dtype, is_embedding_dtype +from renumics.spotlight import dtypes SEED = 42 @@ -35,7 +35,7 @@ def align_data( for column_name in column_names: dtype = data_store.dtypes[column_name] column_values = data_store.get_converted_values(column_name, indices) - if is_embedding_dtype(dtype): + if dtypes.is_embedding_dtype(dtype): embedding_length = max( 0 if x is None else len(cast(np.ndarray, x)) for x in column_values ) @@ -49,17 +49,19 @@ def align_data( ] ) ) - elif is_category_dtype(dtype): + elif dtypes.is_category_dtype(dtype): na_mask = np.array(column_values) == -1 one_hot_values = preprocessing.label_binarize( column_values, classes=sorted(set(column_values).difference({-1})) # type: ignore ).astype(float) one_hot_values[na_mask] = np.nan aligned_values.append(one_hot_values) - elif dtype in (int, bool, float): + elif dtypes.is_scalar_dtype(dtype): aligned_values.append(np.array(column_values, dtype=float)) else: - raise ColumnNotEmbeddable + raise ColumnNotEmbeddable( + "Column '{column_name}' of type {dtype} is not embeddable." + ) data = np.hstack([col.reshape((len(indices), -1)) for col in aligned_values]) mask = ~pd.isna(data).any(axis=1)