From 42930139268258cf9cbaa52013cd2bee2b7600cb Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Mon, 18 Sep 2023 16:39:25 +0200 Subject: [PATCH] correctly convert pd.CategoricalDType --- renumics/spotlight/io/pandas.py | 5 +---- renumics/spotlight_plugins/core/pandas_data_source.py | 6 +----- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/renumics/spotlight/io/pandas.py b/renumics/spotlight/io/pandas.py index c09f32f3..b16101b1 100644 --- a/renumics/spotlight/io/pandas.py +++ b/renumics/spotlight/io/pandas.py @@ -80,10 +80,7 @@ def infer_dtype(column: pd.Series) -> dtypes.DType: return dtypes.bool_dtype if pd.api.types.is_categorical_dtype(column): return dtypes.CategoryDType( - { - category: code - for code, category in zip(column.cat.codes, column.cat.categories) - } + {category: code for code, category in enumerate(column.cat.categories)} ) if pd.api.types.is_integer_dtype(column) and not column.hasnans: return dtypes.int_dtype diff --git a/renumics/spotlight_plugins/core/pandas_data_source.py b/renumics/spotlight_plugins/core/pandas_data_source.py index 6ee5e38e..4d2bf172 100644 --- a/renumics/spotlight_plugins/core/pandas_data_source.py +++ b/renumics/spotlight_plugins/core/pandas_data_source.py @@ -66,7 +66,6 @@ def __init__(self, source: Union[Path, pd.DataFrame]): for feature_name, feature_type in hf_dataset[ splits[0] ].features.items(): - print(feature_type) if isinstance(feature_type, datasets.ClassLabel): try: df[feature_name] = pd.Categorical.from_codes( @@ -212,10 +211,7 @@ def _determine_intermediate_dtype(column: pd.Series) -> dtypes.DType: return dtypes.bool_dtype if pd.api.types.is_categorical_dtype(column): return dtypes.CategoryDType( - { - category: code - for code, category in zip(column.cat.codes, column.cat.categories) - } + {category: code for code, category in enumerate(column.cat.categories)} ) if pd.api.types.is_integer_dtype(column): return dtypes.int_dtype