From bb7d77080667151e51284ec52c53c2641d748e53 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Tue, 13 Feb 2024 20:41:27 +0100 Subject: [PATCH] Use `isinstance(dtype, pd.CategoricalDtype)` to check if a Pandas column is categorical because of deprecation in late versions --- renumics/spotlight/dataset/pandas.py | 2 +- renumics/spotlight_plugins/core/pandas_data_source.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/renumics/spotlight/dataset/pandas.py b/renumics/spotlight/dataset/pandas.py index 222ee66f..47701c9b 100644 --- a/renumics/spotlight/dataset/pandas.py +++ b/renumics/spotlight/dataset/pandas.py @@ -132,7 +132,7 @@ def infer_dtype(column: pd.Series) -> dtypes.DType: if pd.api.types.is_bool_dtype(column): return dtypes.bool_dtype - if pd.api.types.is_categorical_dtype(column): + if isinstance(column.dtype, pd.CategoricalDtype): return dtypes.CategoryDType( {category: code for code, category in enumerate(column.cat.categories)} ) diff --git a/renumics/spotlight_plugins/core/pandas_data_source.py b/renumics/spotlight_plugins/core/pandas_data_source.py index d11c4128..d5b727e2 100644 --- a/renumics/spotlight_plugins/core/pandas_data_source.py +++ b/renumics/spotlight_plugins/core/pandas_data_source.py @@ -160,7 +160,7 @@ def get_column_values( return values if pd.api.types.is_datetime64_any_dtype(column): return column.dt.tz_localize(None).to_numpy() - if pd.api.types.is_categorical_dtype(column): + if isinstance(column.dtype, pd.CategoricalDtype): return column.cat.codes.to_numpy() if pd.api.types.is_string_dtype(column): column = column.astype(object).mask(column.isna(), None) @@ -200,7 +200,7 @@ def _get_column_index(self, column_name: str) -> int: def _determine_intermediate_dtype(column: pd.Series) -> dtypes.DType: if pd.api.types.is_bool_dtype(column): return dtypes.bool_dtype - if pd.api.types.is_categorical_dtype(column): + if isinstance(column.dtype, pd.CategoricalDtype): return dtypes.CategoryDType( {category: code for code, category in enumerate(column.cat.categories)} )