Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/test hf serving #251

Merged
merged 8 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
install-dependencies: false
- name: Cache pip cache folder
- name: '♻️ Cache pip cache folder'
uses: actions/cache@v3
with:
path: ${{ steps.setup-poetry.outputs.pip-cache-dir }}
Expand Down
50 changes: 27 additions & 23 deletions renumics/spotlight/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
from renumics.spotlight.data_source.data_source import ColumnMetadata
from renumics.spotlight.io import audio
from renumics.spotlight.dtypes import (
ArrayDType,
CategoryDType,
DType,
DTypeMap,
EmbeddingDType,
is_array_dtype,
is_audio_dtype,
is_category_dtype,
is_file_dtype,
Expand All @@ -29,7 +32,6 @@
video_dtype,
mesh_dtype,
embedding_dtype,
array_dtype,
window_dtype,
sequence_1d_dtype,
)
Expand Down Expand Up @@ -163,20 +165,39 @@ def _update_dtypes(self) -> None:

def _guess_dtype(self, col: str) -> DType:
intermediate_dtype = self._data_source.intermediate_dtypes[col]
fallback_dtype = _intermediate_to_semantic_dtype(intermediate_dtype)
semantic_dtype = _intermediate_to_semantic_dtype(intermediate_dtype)

if is_array_dtype(intermediate_dtype):
return semantic_dtype

sample_values = self._data_source.get_column_values(col, slice(10))
sample_dtypes = [_guess_value_dtype(value) for value in sample_values]

try:
mode_dtype = statistics.mode(sample_dtypes)
except statistics.StatisticsError:
return fallback_dtype
return semantic_dtype

return mode_dtype or fallback_dtype
return mode_dtype or semantic_dtype


def _intermediate_to_semantic_dtype(intermediate_dtype: DType) -> DType:
if is_array_dtype(intermediate_dtype):
if intermediate_dtype.shape is None:
return intermediate_dtype
if intermediate_dtype.shape == (2,):
return window_dtype
if intermediate_dtype.ndim == 1 and intermediate_dtype.shape[0] is not None:
return EmbeddingDType(intermediate_dtype.shape[0])
if intermediate_dtype.ndim == 1 and intermediate_dtype.shape[0] is None:
return sequence_1d_dtype
if intermediate_dtype.ndim == 2 and (
intermediate_dtype.shape[0] == 2 or intermediate_dtype.shape[1] == 2
):
return sequence_1d_dtype
if intermediate_dtype.ndim == 3 and intermediate_dtype.shape[-1] in (1, 3, 4):
return image_dtype
return intermediate_dtype
if is_file_dtype(intermediate_dtype):
return str_dtype
if is_mixed_dtype(intermediate_dtype):
Expand Down Expand Up @@ -208,7 +229,7 @@ def _guess_value_dtype(value: Any) -> Optional[DType]:
if isinstance(value, trimesh.Trimesh):
return mesh_dtype
if isinstance(value, np.ndarray):
return _infer_array_dtype(value)
return ArrayDType(value.shape)

if isinstance(value, bytes) or (is_pathtype(value) and os.path.isfile(value)):
kind = filetype.guess(value)
Expand All @@ -227,22 +248,5 @@ def _guess_value_dtype(value: Any) -> Optional[DType]:
except (TypeError, ValueError):
pass
else:
return _infer_array_dtype(value)
return ArrayDType(value.shape)
return None


def _infer_array_dtype(value: np.ndarray) -> DType:
"""
Infer dtype of a numpy array
"""
if value.ndim == 3:
if value.shape[-1] in (1, 3, 4):
return image_dtype
elif value.ndim == 2:
if value.shape[0] == 2 or value.shape[1] == 2:
return sequence_1d_dtype
elif value.ndim == 1:
if len(value) == 2:
return window_dtype
return embedding_dtype
return array_dtype
42 changes: 37 additions & 5 deletions renumics/spotlight/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, Dict, Iterable, Optional, Union
from typing import Any, Dict, Iterable, Optional, Tuple, Union

import numpy as np
from typing_extensions import TypeGuard
Expand Down Expand Up @@ -80,6 +80,38 @@ def inverted_categories(self) -> Optional[Dict[int, str]]:
return self._inverted_categories


class ArrayDType(DType):
"""
Array dtype with optional shape.
"""

shape: Optional[Tuple[Optional[int], ...]]

def __init__(self, shape: Optional[Tuple[Optional[int], ...]] = None):
super().__init__("array")
self.shape = shape

@property
def ndim(self) -> int:
if self.shape is None:
return 0
return len(self.shape)


class EmbeddingDType(DType):
"""
Embedding dtype with optional length.
"""

length: Optional[int]

def __init__(self, length: Optional[int] = None):
super().__init__("Embedding")
if length is not None and length < 0:
raise ValueError(f"Length must be non-negative, but {length} received.")
self.length = length


class Sequence1DDType(DType):
"""
1D-sequence dtype with predefined axis labels.
Expand Down Expand Up @@ -131,10 +163,10 @@ def register_dtype(dtype: DType, aliases: Optional[list] = None) -> None:
window_dtype = DType("Window")
"""Window dtype"""
register_dtype(window_dtype, [Window])
embedding_dtype = DType("Embedding")
embedding_dtype = EmbeddingDType()
"""Embedding dtype"""
register_dtype(embedding_dtype, [Embedding])
array_dtype = DType("array")
array_dtype = ArrayDType()
"""numpy array dtype"""
register_dtype(array_dtype, [np.ndarray])
image_dtype = DType("Image")
Expand Down Expand Up @@ -195,15 +227,15 @@ def is_category_dtype(dtype: DType) -> TypeGuard[CategoryDType]:
return dtype.name == "Category"


def is_array_dtype(dtype: DType) -> bool:
def is_array_dtype(dtype: DType) -> TypeGuard[ArrayDType]:
return dtype.name == "array"


def is_window_dtype(dtype: DType) -> bool:
return dtype.name == "Window"


def is_embedding_dtype(dtype: DType) -> bool:
def is_embedding_dtype(dtype: DType) -> TypeGuard[EmbeddingDType]:
return dtype.name == "Embedding"


Expand Down
Loading