Skip to content

Commit

Permalink
Merge pull request #222 from Renumics/feature/158-dtype-classes-with-…
Browse files Browse the repository at this point in the history
…options

Feature/158 dtype classes with options
  • Loading branch information
neindochoh authored Sep 8, 2023
2 parents dacbb8b + 21a40ce commit 435282f
Show file tree
Hide file tree
Showing 68 changed files with 3,754 additions and 3,549 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ typecheck: ## Typecheck all source files
poetry run mypy -p renumics.spotlight
poetry run mypy -p renumics.spotlight_plugins.core
poetry run mypy scripts
poetry run mypy tests
pnpm run typecheck

.PHONY: lint
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ files = [

[tool.ruff]
line-length = 100
ignore = [
"E501"
]

[tool.mypy]
ignore_missing_imports = false
Expand Down Expand Up @@ -168,7 +171,8 @@ module = [
"cleanlab.*",
"machineid",
"filetype",
"datasets"
"datasets",
"diffimg"
]
ignore_missing_imports = true

Expand Down
5 changes: 2 additions & 3 deletions renumics/spotlight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@

from .__version__ import __version__ # noqa: F401
from .dataset import Dataset # noqa: F401
from .dtypes import (
from .media import (
Audio, # noqa: F401
Category, # noqa: F401
Embedding, # noqa: F401
Image, # noqa: F401
Mesh, # noqa: F401
Sequence1D, # noqa: F401
Video, # noqa: F401
Window, # noqa: F401
)
from .dtypes.legacy import Category, Window # noqa: F401
from .viewer import Viewer, close, viewers, show
from .plugin_loader import load_plugins
from .settings import settings
Expand Down
8 changes: 5 additions & 3 deletions renumics/spotlight/analysis/analyzers/cleanlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import numpy as np
import cleanlab.outlier
from renumics.spotlight.dtypes import Embedding
from renumics.spotlight.data_store import DataStore

from renumics.spotlight.data_store import DataStore
from renumics.spotlight.dtypes import is_embedding_dtype
from ..decorator import data_analyzer
from ..typing import DataIssue

Expand All @@ -23,7 +23,9 @@ def analyze_with_cleanlab(
"""

embedding_columns = (
col for col in columns if data_store.dtypes.get(col) == Embedding
col
for col in columns
if col in data_store.dtypes and is_embedding_dtype(data_store.dtypes[col])
)

for column_name in embedding_columns:
Expand Down
5 changes: 3 additions & 2 deletions renumics/spotlight/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@
from renumics.spotlight.plugin_loader import load_plugins
from renumics.spotlight.develop.project import get_project_info
from renumics.spotlight.backend.middlewares.timing import add_timing_middleware
from renumics.spotlight.dtypes.typing import ColumnTypeMapping
from renumics.spotlight.app_config import AppConfig
from renumics.spotlight.data_source import DataSource, create_datasource
from renumics.spotlight.layout.default import DEFAULT_LAYOUT

from renumics.spotlight.data_store import DataStore

from renumics.spotlight.dtypes import DTypeMap


class IssuesUpdatedMessage(Message):
"""
Expand All @@ -79,7 +80,7 @@ class SpotlightApp(FastAPI):

# datasource
_dataset: Optional[Union[PathType, pd.DataFrame]]
_user_dtypes: ColumnTypeMapping
_user_dtypes: DTypeMap
_data_source: Optional[DataSource]
_data_store: Optional[DataStore]

Expand Down
5 changes: 2 additions & 3 deletions renumics/spotlight/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

from renumics.spotlight.layout.nodes import Layout
from renumics.spotlight.analysis.typing import DataIssue

from renumics.spotlight.dtypes.typing import ColumnTypeMapping
from renumics.spotlight.dtypes import DTypeMap


@dataclass(frozen=True)
Expand All @@ -22,7 +21,7 @@ class AppConfig:

# dataset
dataset: Optional[Union[Path, pd.DataFrame]] = None
dtypes: Optional[ColumnTypeMapping] = None
dtypes: Optional[DTypeMap] = None
project_root: Optional[Path] = None

# data analysis
Expand Down
7 changes: 4 additions & 3 deletions renumics/spotlight/backend/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
Exceptions to be raised from backend.
"""

from typing import Any, Optional, Type
from typing import Any, Optional
from fastapi import status

from renumics.spotlight.dtypes.typing import ColumnType
from renumics.spotlight.typing import IndexType, PathOrUrlType, PathType

from renumics.spotlight.dtypes import DType


class Problem(Exception):
"""
Expand Down Expand Up @@ -127,7 +128,7 @@ class ConversionFailed(Problem):
Value cannot be converted to the desired dtype.
"""

def __init__(self, dtype: Type[ColumnType], value: Any) -> None:
def __init__(self, dtype: DType, value: Any) -> None:
self.dtype = dtype
self.value = value
super().__init__(
Expand Down
7 changes: 3 additions & 4 deletions renumics/spotlight/backend/tasks/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from sklearn import preprocessing

from renumics.spotlight.dataset.exceptions import ColumnNotExistsError
from renumics.spotlight.dtypes import Category, Embedding

from renumics.spotlight.data_store import DataStore
from renumics.spotlight.dtypes import is_category_dtype, is_embedding_dtype

SEED = 42

Expand All @@ -36,7 +35,7 @@ def align_data(
for column_name in column_names:
dtype = data_store.dtypes[column_name]
column_values = data_store.get_converted_values(column_name, indices)
if dtype is Embedding:
if is_embedding_dtype(dtype):
embedding_length = max(
0 if x is None else len(cast(np.ndarray, x)) for x in column_values
)
Expand All @@ -50,7 +49,7 @@ def align_data(
]
)
)
elif dtype is Category:
elif is_category_dtype(dtype):
na_mask = np.array(column_values) == -1
one_hot_values = preprocessing.label_binarize(
column_values, classes=sorted(set(column_values).difference({-1})) # type: ignore
Expand Down
26 changes: 8 additions & 18 deletions renumics/spotlight/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,36 @@
import platform
import signal
import sys
from typing import Optional, Tuple, Union, List
from typing import Dict, Optional, Tuple, Union, List
from pathlib import Path

import click

from renumics import spotlight
from renumics.spotlight.dtypes.typing import COLUMN_TYPES_BY_NAME, ColumnTypeMapping

from renumics.spotlight import logging


def cli_dtype_callback(
_ctx: click.Context, _param: click.Option, value: Tuple[str, ...]
) -> Optional[ColumnTypeMapping]:
) -> Optional[Dict[str, str]]:
"""
Parse column types from multiple strings in format
`COLUMN_NAME=DTYPE` to a dict.
"""
if not value:
return None
dtype = {}
dtypes: Dict[str, str] = {}
for mapping in value:
try:
column_name, dtype_name = mapping.split("=")
column_name, dtype = mapping.split("=")
except ValueError as e:
raise click.BadParameter(
"Column type setting separator '=' not specified or specified "
"more than once."
) from e
try:
column_type = COLUMN_TYPES_BY_NAME[dtype_name]
except KeyError as e:
raise click.BadParameter(
f"Column types from {list(COLUMN_TYPES_BY_NAME.keys())} "
f"expected, but value '{dtype_name}' recived."
) from e
dtype[column_name] = column_type
return dtype
dtypes[column_name] = dtype
return dtypes


@click.command() # type: ignore
Expand Down Expand Up @@ -84,9 +76,7 @@ def cli_dtype_callback(
type=click.UNPROCESSED,
callback=cli_dtype_callback,
multiple=True,
help="Custom column types setting (use COLUMN_NAME={"
+ "|".join(sorted(COLUMN_TYPES_BY_NAME.keys()))
+ "} notation). Multiple settings allowed.",
help="Custom column types setting (use COLUMN_NAME=DTYPE notation). Multiple settings allowed.",
)
@click.option(
"--no-browser",
Expand Down Expand Up @@ -119,7 +109,7 @@ def main(
host: str,
port: Union[int, str],
layout: Optional[str],
dtype: Optional[ColumnTypeMapping],
dtype: Optional[Dict[str, str]],
no_browser: bool,
filebrowsing: bool,
analyze: List[str],
Expand Down
15 changes: 4 additions & 11 deletions renumics/spotlight/data_source/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses
from abc import ABC, abstractmethod
from typing import Dict, Optional, List, Any, Union
from typing import Optional, List, Any, Union

import pandas as pd
import numpy as np
Expand All @@ -12,11 +12,10 @@
ColumnExistsError,
ColumnNotExistsError,
)
from renumics.spotlight.dtypes.typing import (
ColumnTypeMapping,
)
from renumics.spotlight.backend.exceptions import GenerationIDMismatch, NoRowFound

from renumics.spotlight.dtypes import DTypeMap


@dataclasses.dataclass
class ColumnMetadata:
Expand Down Expand Up @@ -84,7 +83,7 @@ def check_generation_id(self, generation_id: int) -> None:
raise GenerationIDMismatch()

@abstractmethod
def guess_dtypes(self) -> ColumnTypeMapping:
def guess_dtypes(self) -> DTypeMap:
"""
Guess data source's dtypes.
"""
Expand Down Expand Up @@ -117,12 +116,6 @@ def get_column_metadata(self, column_name: str) -> ColumnMetadata:
Get extra info of a column.
"""

@abstractmethod
def get_column_categories(self, column_name: str) -> Dict[str, int]:
"""
Get column categories (for categorical dtype)
"""

def _assert_index_exists(self, index: int) -> None:
if index < -len(self) or index >= len(self):
raise NoRowFound(index)
Expand Down
67 changes: 37 additions & 30 deletions renumics/spotlight/data_store.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,51 @@
import hashlib
import io
from typing import List, Optional, Union
from typing import List, Optional, Set, Union, cast
import numpy as np

from renumics.spotlight.cache import external_data_cache
from renumics.spotlight.dtypes.typing import ColumnTypeMapping
from renumics.spotlight.data_source import DataSource
from renumics.spotlight.dtypes import Audio, Category
from renumics.spotlight.dtypes.conversion import (
ConvertedValue,
DTypeOptions,
convert_to_dtype,
)
from renumics.spotlight.dtypes.conversion import ConvertedValue, convert_to_dtype
from renumics.spotlight.data_source.data_source import ColumnMetadata

from renumics.spotlight.io import audio
from renumics.spotlight.dtypes import (
CategoryDType,
DTypeMap,
is_audio_dtype,
is_category_dtype,
is_str_dtype,
str_dtype,
)


class DataStore:
_dtypes: ColumnTypeMapping
_dtypes: DTypeMap
_data_source: DataSource

def __init__(self, data_source: DataSource, user_dtypes: ColumnTypeMapping) -> None:
def __init__(self, data_source: DataSource, user_dtypes: DTypeMap) -> None:
self._data_source = data_source
dtypes = self._data_source.guess_dtypes()
dtypes.update(
{
column_name: column_type
for column_name, column_type in user_dtypes.items()
if column_name in dtypes
}
)
guessed_dtypes = self._data_source.guess_dtypes()
dtypes = {
**guessed_dtypes,
**{
column_name: dtype
for column_name, dtype in user_dtypes.items()
if column_name in guessed_dtypes
},
}
for column_name, dtype in dtypes.items():
if (
is_category_dtype(dtype)
and dtype.categories is None
and is_str_dtype(guessed_dtypes[column_name])
):
normalized_values = self._data_source.get_column_values(column_name)
converted_values = [
convert_to_dtype(value, str_dtype, simple=True, check=True)
for value in normalized_values
]
category_names = sorted(cast(Set[str], set(converted_values)))
dtypes[column_name] = CategoryDType(category_names)
self._dtypes = dtypes

def __len__(self) -> int:
Expand All @@ -53,7 +68,7 @@ def column_names(self) -> List[str]:
return self._data_source.column_names

@property
def dtypes(self) -> ColumnTypeMapping:
def dtypes(self) -> DTypeMap:
return self._dtypes

def check_generation_id(self, generation_id: int) -> None:
Expand All @@ -71,16 +86,8 @@ def get_converted_values(
) -> List[ConvertedValue]:
dtype = self._dtypes[column_name]
normalized_values = self._data_source.get_column_values(column_name, indices)
if dtype is Category:
dtype_options = DTypeOptions(
categories=self._data_source.get_column_categories(column_name)
)
else:
dtype_options = DTypeOptions()
converted_values = [
convert_to_dtype(
value, dtype, dtype_options=dtype_options, simple=simple, check=check
)
convert_to_dtype(value, dtype, simple=simple, check=check)
for value in normalized_values
]
return converted_values
Expand All @@ -94,7 +101,7 @@ def get_waveform(self, column_name: str, index: int) -> Optional[np.ndarray]:
"""
return the waveform of an audio cell
"""
assert self._dtypes[column_name] is Audio
assert is_audio_dtype(self._dtypes[column_name])

blob = self.get_converted_value(column_name, index, simple=False)
if blob is None:
Expand Down
Loading

0 comments on commit 435282f

Please sign in to comment.