Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow plugin modules to register new webviz store return types #387

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions webviz_config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
try:
# Python 3.8+
from importlib.metadata import version, PackageNotFoundError # type: ignore
from importlib.metadata import version, PackageNotFoundError, entry_points # type: ignore
except ModuleNotFoundError:
# Python < 3.8
from importlib_metadata import version, PackageNotFoundError # type: ignore
from importlib_metadata import version, PackageNotFoundError, entry_points # type: ignore


from ._theme_class import WebvizConfigTheme
from ._webviz_settings_class import WebvizSettings
Expand All @@ -12,7 +13,11 @@
from ._plugin_abc import WebvizPluginABC, EncodedFile, ZipFileMember
from ._shared_settings_subscriptions import SHARED_SETTINGS_SUBSCRIPTIONS
from ._oauth2 import Oauth2
from .webviz_storage import WebvizStorageTypeRegistry, WebvizStorageType

for entry_point in entry_points().get("webviz_config_return_types", []):
theme = entry_point.load()
globals()[entry_point.name] = theme
try:
__version__ = version("webviz-config")
except PackageNotFoundError:
Expand Down
128 changes: 128 additions & 0 deletions webviz_config/webviz_storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import io
import abc
import glob
import shutil
import functools
import hashlib
import inspect
import pathlib
import warnings
from collections import defaultdict
from typing import Callable, List, Union, Any, Type, Dict
import weakref

import numpy as np
import pandas as pd
from tqdm import tqdm


class ClassProperty:
def __init__(self, fget: Callable):
self.fget = fget

def __get__(self, owner_self: Type[Any], owner_cls: Type[Any]) -> Any:
print(type(owner_self), type(owner_cls))
print(type(self.fget(owner_cls)))
return self.fget(owner_cls)


class WebvizStorageType(abc.ABC):
""" Base class for a webviz storage type """

@staticmethod
@abc.abstractmethod
def get_data(path: str, **kwargs: Dict) -> None:
""" Abstract method to retrieve stored data """

@staticmethod
@abc.abstractmethod
def save_data(data: Any, path: str) -> Any:
""" Abstract method to save data to store """


class WebvizStorageTypeRegistry:
""" Registry of allowed webviz storage types """

registry: Dict = {}

@classmethod
def register(cls, return_type: Type) -> Callable:
def inner_wrapper(wrapped_class: Type) -> Type:
if return_type in cls.registry:
print(f"Storage type {return_type} already exists. Will replace it")
cls.registry[return_type] = wrapped_class
return wrapped_class

return inner_wrapper

@classmethod
def create_storage_type(
cls, return_type: str, **kwargs: Dict
) -> Union[None, WebvizStorageType]:
"""Factory command to create the storage type.
This method gets the appropriate WebvizStorageType class from the registry
and creates an instance of it, while passing in the parameters
given in ``kwargs``.
Args:
return_type (str): The type of the storage type to create.
Returns:
An instance of the storage type that is created.
"""

if return_type not in cls.registry:
print(f"Storage type {return_type} does not exist in the registry")
return None

exec_class = cls.registry[return_type]
return exec_class(**kwargs)

# pylint: disable=no-self-argument
@ClassProperty
def return_types(cls) -> List:
return list(cls.registry.keys())


@WebvizStorageTypeRegistry.register(pd.DataFrame)
class TypeDataFrame(WebvizStorageType):
@staticmethod
def get_data(path: str, **kwargs: Dict) -> Any:
return pd.read_parquet(f"{path}.parquet")

@staticmethod
def save_data(data: Any, path: str) -> None:
data.to_parquet(f"{path}.parquet")


@WebvizStorageTypeRegistry.register(pathlib.Path)
@WebvizStorageTypeRegistry.register(pathlib.PosixPath)
class TypePath(WebvizStorageType):
@staticmethod
def get_data(path: str, **kwargs: Dict) -> Any:
return pathlib.Path(glob.glob(f"{path}*")[0])

@staticmethod
def save_data(data: Any, path: str) -> None:
shutil.copy(data, f"{path}{data.suffix}")


@WebvizStorageTypeRegistry.register(list)
@WebvizStorageTypeRegistry.register(List)
class TypeList(WebvizStorageType):
@staticmethod
def get_data(path: str, **kwargs: Dict) -> Any:
return np.load(f"{path}.npy").tolist()

@staticmethod
def save_data(data: Any, path: str) -> None:
np.save(f"{path}.npy", data)


@WebvizStorageTypeRegistry.register(io.BytesIO)
class TypeBytesIO(WebvizStorageType):
@staticmethod
def get_data(path: str, **kwargs: Dict) -> Any:
return np.load(f"{path}.npy").tolist()

@staticmethod
def save_data(data: Any, path: str) -> None:
pathlib.Path(path).write_bytes(data.getvalue())
42 changes: 15 additions & 27 deletions webviz_config/webviz_store.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import io
import glob
import shutil
import functools
import hashlib
import inspect
import pathlib
import warnings
from collections import defaultdict
from typing import Callable, List, Union, Any
from typing import Callable, List, Any

import pandas as pd
from tqdm import tqdm

from .webviz_storage import WebvizStorageTypeRegistry, WebvizStorageType

class WebvizStorage:

RETURN_TYPES = [pd.DataFrame, pathlib.Path, io.BytesIO]
class WebvizStorage:
# pylint: disable=unsupported-membership-test
RETURN_TYPES = WebvizStorageTypeRegistry.return_types

def __init__(self) -> None:
self._use_storage = False
Expand All @@ -27,7 +26,6 @@ def register_function(self, func: Callable) -> None:
decorator @webvizstore, registering the function it decorates.
"""
return_type = inspect.getfullargspec(func).annotations["return"]

if return_type not in WebvizStorage.RETURN_TYPES:
raise NotImplementedError(
f"Webviz storage type must be one of {WebvizStorage.RETURN_TYPES}"
Expand Down Expand Up @@ -145,26 +143,21 @@ def complete_kwargs(func: Callable, kwargs: dict) -> dict:
return kwargs

def get_stored_data(
self, func: Callable, *args: Any, **kwargs: Any
) -> Union[pd.DataFrame, pathlib.Path, io.BytesIO]:

self, func: Callable, *args: Any, webviz_load=None, **kwargs: Any
) -> WebvizStorageType:
load_args = webviz_load if webviz_load is not None else {}
argspec = inspect.getfullargspec(func)
for arg_name, arg in zip(argspec.args, args):
kwargs[arg_name] = arg

WebvizStorage.complete_kwargs(func, kwargs)
return_type = inspect.getfullargspec(func).annotations["return"]

path = self._unique_path(func, WebvizStorage._dict_to_tuples(kwargs))
storagetype = WebvizStorageTypeRegistry.create_storage_type(return_type)

try:
if return_type == pd.DataFrame:
return pd.read_parquet(f"{path}.parquet")
if return_type == pathlib.Path:
return pathlib.Path(glob.glob(f"{path}*")[0])
if return_type == io.BytesIO:
return io.BytesIO(pathlib.Path(path).read_bytes())
raise ValueError(f"Unknown return type {return_type}")
return storagetype.get_data(path=path, load_args=load_args)
# raise ValueError(f"Unknown return type {return_type}")

except OSError as exc:
raise OSError(
Expand All @@ -190,15 +183,10 @@ def build_store(self) -> None:
for argtuples in self.storage_function_argvalues[func].values():
output = func(**dict(argtuples))
path = self._unique_path(func, argtuples)

if isinstance(output, pd.DataFrame):
output.to_parquet(f"{path}.parquet")
elif isinstance(output, pathlib.Path):
shutil.copy(output, f"{path}{output.suffix}")
elif isinstance(output, io.BytesIO):
pathlib.Path(path).write_bytes(output.getvalue())
else:
raise ValueError(f"Unknown return type {type(output)}")
storagetype = WebvizStorageTypeRegistry.create_storage_type(
type(output)
)
storagetype.save_data(data=output, path=path)

progress_bar.update()

Expand Down