Skip to content

Commit

Permalink
Fix checking scalar dtypes for data alignment before embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
druzsan committed Oct 23, 2023
1 parent 1ef542e commit 4729378
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions renumics/spotlight/backend/tasks/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from renumics.spotlight.dataset.exceptions import ColumnNotExistsError
from renumics.spotlight.data_store import DataStore
from renumics.spotlight.dtypes import is_category_dtype, is_embedding_dtype
from renumics.spotlight import dtypes

SEED = 42

Expand All @@ -35,7 +35,7 @@ def align_data(
for column_name in column_names:
dtype = data_store.dtypes[column_name]
column_values = data_store.get_converted_values(column_name, indices)
if is_embedding_dtype(dtype):
if dtypes.is_embedding_dtype(dtype):
embedding_length = max(
0 if x is None else len(cast(np.ndarray, x)) for x in column_values
)
Expand All @@ -49,17 +49,19 @@ def align_data(
]
)
)
elif is_category_dtype(dtype):
elif dtypes.is_category_dtype(dtype):
na_mask = np.array(column_values) == -1
one_hot_values = preprocessing.label_binarize(
column_values, classes=sorted(set(column_values).difference({-1})) # type: ignore
).astype(float)
one_hot_values[na_mask] = np.nan
aligned_values.append(one_hot_values)
elif dtype in (int, bool, float):
elif dtypes.is_scalar_dtype(dtype):
aligned_values.append(np.array(column_values, dtype=float))
else:
raise ColumnNotEmbeddable
raise ColumnNotEmbeddable(
"Column '{column_name}' of type {dtype} is not embeddable."
)

data = np.hstack([col.reshape((len(indices), -1)) for col in aligned_values])
mask = ~pd.isna(data).any(axis=1)
Expand Down

0 comments on commit 4729378

Please sign in to comment.