Skip to content

Commit

Permalink
Merge pull request #298 from Renumics/fix/234-scalar-columns-not-embe…
Browse files Browse the repository at this point in the history
…ddable

Fix checking scalar dtypes for data alignment before embedding
  • Loading branch information
druzsan authored Oct 24, 2023
2 parents 6780ff9 + 9225698 commit 0bb6ce5
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions renumics/spotlight/backend/tasks/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@

import numpy as np
import pandas as pd
from sklearn import preprocessing

from renumics.spotlight.dataset.exceptions import ColumnNotExistsError
from renumics.spotlight.data_store import DataStore
from renumics.spotlight.dtypes import is_category_dtype, is_embedding_dtype
from renumics.spotlight import dtypes

SEED = 42

Expand All @@ -27,6 +26,7 @@ def align_data(
"""
Align data from table's columns, remove `NaN`'s.
"""
from sklearn import preprocessing

if not column_names or not indices:
return np.empty(0, np.float64), []
Expand All @@ -35,7 +35,7 @@ def align_data(
for column_name in column_names:
dtype = data_store.dtypes[column_name]
column_values = data_store.get_converted_values(column_name, indices)
if is_embedding_dtype(dtype):
if dtypes.is_embedding_dtype(dtype):
embedding_length = max(
0 if x is None else len(cast(np.ndarray, x)) for x in column_values
)
Expand All @@ -49,17 +49,19 @@ def align_data(
]
)
)
elif is_category_dtype(dtype):
elif dtypes.is_category_dtype(dtype):
na_mask = np.array(column_values) == -1
one_hot_values = preprocessing.label_binarize(
column_values, classes=sorted(set(column_values).difference({-1})) # type: ignore
).astype(float)
one_hot_values[na_mask] = np.nan
aligned_values.append(one_hot_values)
elif dtype in (int, bool, float):
elif dtypes.is_scalar_dtype(dtype):
aligned_values.append(np.array(column_values, dtype=float))
else:
raise ColumnNotEmbeddable
raise ColumnNotEmbeddable(
f"Column '{column_name}' of type {dtype} is not embeddable."
)

data = np.hstack([col.reshape((len(indices), -1)) for col in aligned_values])
mask = ~pd.isna(data).any(axis=1)
Expand Down Expand Up @@ -118,7 +120,7 @@ def compute_pca(

try:
data, indices = align_data(data_store, column_names, indices)
except (ColumnNotExistsError, ValueError):
except (ColumnNotExistsError, ColumnNotEmbeddable):
return np.empty(0, np.float64), []
if data.size == 0:
return np.empty(0, np.float64), []
Expand All @@ -129,5 +131,6 @@ def compute_pca(
elif normalization == "robust standardize":
data = preprocessing.RobustScaler(copy=False).fit_transform(data)
reducer = decomposition.PCA(n_components=2, copy=False, random_state=SEED)
embeddings = reducer.fit_transform(data)
# `fit_transform` returns Fortran-ordered array.
embeddings = np.ascontiguousarray(reducer.fit_transform(data))
return embeddings, indices

0 comments on commit 0bb6ce5

Please sign in to comment.