From 84b789db7f09d2d095ec34e19974fc3bc900fe2a Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Mon, 18 Sep 2023 17:00:46 +0200 Subject: [PATCH] Fix pandas infer dtype for ints and bools --- renumics/spotlight/dataset/__init__.py | 24 ++++++----------- renumics/spotlight/io/pandas.py | 37 +++++++++++++++++++++++--- 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/renumics/spotlight/dataset/__init__.py b/renumics/spotlight/dataset/__init__.py index e6e5a866..fea5415b 100644 --- a/renumics/spotlight/dataset/__init__.py +++ b/renumics/spotlight/dataset/__init__.py @@ -57,9 +57,8 @@ Video, Window, ) - +from renumics.spotlight.io.pandas import create_typed_series from renumics.spotlight.dtypes.conversion import prepare_path_or_url - from renumics.spotlight.dtypes import ( CategoryDType, Sequence1DDType, @@ -887,23 +886,16 @@ def to_pandas(self) -> pd.DataFrame: df = pd.DataFrame() for column_name in self._column_names: dtype = self.get_dtype(column_name) - if ( + if is_datetime_dtype(dtype): + df[column_name] = create_typed_series(dtype, self[column_name]) + elif ( is_scalar_dtype(dtype) or is_str_dtype(dtype) - or is_datetime_dtype(dtype) + or is_category_dtype(dtype) ): - df[column_name] = self[column_name] - elif is_category_dtype(dtype): - if dtype.inverted_categories is None: - values: List[Optional[str]] = [None] * len( - self._h5_file[column_name] - ) - else: - values = [ - dtype.inverted_categories.get(code) - for code in self._h5_file[column_name] - ] - df[column_name] = pd.Categorical(values) + df[column_name] = create_typed_series( + dtype, self._h5_file[column_name][:] + ) not_exported_columns = self._column_names.difference(df.columns) if len(not_exported_columns) > 0: diff --git a/renumics/spotlight/io/pandas.py b/renumics/spotlight/io/pandas.py index c09f32f3..00466ecd 100644 --- a/renumics/spotlight/io/pandas.py +++ b/renumics/spotlight/io/pandas.py @@ -6,7 +6,7 @@ import os.path import statistics from contextlib import suppress -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Sequence, Union import PIL.Image import filetype @@ -27,6 +27,37 @@ from renumics.spotlight import dtypes +def create_typed_series( + dtype: dtypes.DType, values: Optional[Union[Sequence, np.ndarray]] = None +) -> pd.Series: + if dtypes.is_category_dtype(dtype): + if values is None or len(values) == 0: + return pd.Series( + dtype=pd.CategoricalDtype( + [] if not dtype.categories else list(dtype.categories.keys()) + ) + ) + if dtype.inverted_categories is None: + return pd.Series([None] * len(values), dtype=pd.CategoricalDtype()) + return pd.Series( + [dtype.inverted_categories.get(code) for code in values], + dtype=pd.CategoricalDtype(), + ) + if dtypes.is_bool_dtype(dtype): + pandas_dtype = "boolean" + elif dtypes.is_int_dtype(dtype): + pandas_dtype = "Int64" + elif dtypes.is_float_dtype(dtype): + pandas_dtype = "float" + elif dtypes.is_str_dtype(dtype): + pandas_dtype = "string" + elif dtypes.is_datetime_dtype(dtype): + pandas_dtype = "datetime64[ns]" + else: + pandas_dtype = "object" + return pd.Series([] if values is None else values, dtype=pandas_dtype) + + def is_empty(value: Any) -> bool: """ Check if value is `NA` or an empty string. @@ -76,7 +107,7 @@ def infer_dtype(column: pd.Series) -> dtypes.DType: ValueError: If dtype cannot be inferred automatically. """ - if pd.api.types.is_bool_dtype(column) and not column.hasnans: + if pd.api.types.is_bool_dtype(column): return dtypes.bool_dtype if pd.api.types.is_categorical_dtype(column): return dtypes.CategoryDType( @@ -85,7 +116,7 @@ def infer_dtype(column: pd.Series) -> dtypes.DType: for code, category in zip(column.cat.codes, column.cat.categories) } ) - if pd.api.types.is_integer_dtype(column) and not column.hasnans: + if pd.api.types.is_integer_dtype(column): return dtypes.int_dtype if pd.api.types.is_float_dtype(column): return dtypes.float_dtype