diff --git a/renumics/spotlight/dataset/__init__.py b/renumics/spotlight/dataset/__init__.py index e586499d..195d5153 100644 --- a/renumics/spotlight/dataset/__init__.py +++ b/renumics/spotlight/dataset/__init__.py @@ -31,8 +31,9 @@ from loguru import logger from typing_extensions import TypeGuard +from renumics.spotlight import dtypes as spotlight_dtypes from renumics.spotlight.__version__ import __version__ -from .pandas import create_typed_series, infer_dtypes, is_string_mask, prepare_column +from renumics.spotlight.dtypes.conversion import prepare_path_or_url from renumics.spotlight.typing import ( BoolType, IndexType, @@ -42,10 +43,8 @@ is_integer, is_iterable, ) -from renumics.spotlight.dtypes.conversion import prepare_path_or_url -from renumics.spotlight import dtypes as spotlight_dtypes - from . import exceptions +from .pandas import create_typed_series, infer_dtypes, is_string_mask, prepare_column from .typing import ( OutputType, ExternalOutputType, @@ -1908,6 +1907,49 @@ def rename_column(self, old_name: str, new_name: str) -> None: self._column_names.add(new_name) self._update_generation_id() + def rebuild(self) -> None: + """ + Update old-style columns in the dataset. + Be aware, that it can take some time and memory. It is useful to do + `prune` after `rebuild`. + """ + self._assert_is_writable() + + old_columns = [] + for name in self._column_names: + h5_dataset: h5py.Dataset = self._h5_file[name] + dtype = self._get_dtype(h5_dataset) + if spotlight_dtypes.is_embedding_dtype(dtype): + vlen_dtype = h5py.check_vlen_dtype(h5_dataset.dtype) + if ( + vlen_dtype is None + or not isinstance(vlen_dtype, np.dtype) + or vlen_dtype.kind not in "fiu" + ): + # Non-vlen embedding columns + old_columns.append(name) + elif ( + spotlight_dtypes.is_array_dtype(dtype) + or spotlight_dtypes.is_sequence_1d_dtype(dtype) + or spotlight_dtypes.is_filebased_dtype(dtype) + ) and h5py.check_string_dtype(h5_dataset.dtype) is None: + # Non-string complex dtype columns + old_columns.append(name) + + for name in old_columns: + new_name = name + while new_name in self._column_names: + new_name += "_" + self.append_column( + new_name, + self.get_dtype(name), + self[name], + **self.get_column_attributes(name), + ) + del self[name] + self.rename_column(new_name, name) + logger.info(f"Column {name} rebuilt") + def prune(self) -> None: """ Rebuild the whole dataset with the same content.