diff --git a/webviz_config/__init__.py b/webviz_config/__init__.py index aae7b59d..fe028d84 100644 --- a/webviz_config/__init__.py +++ b/webviz_config/__init__.py @@ -1,9 +1,10 @@ try: # Python 3.8+ - from importlib.metadata import version, PackageNotFoundError # type: ignore + from importlib.metadata import version, PackageNotFoundError, entry_points # type: ignore except ModuleNotFoundError: # Python < 3.8 - from importlib_metadata import version, PackageNotFoundError # type: ignore + from importlib_metadata import version, PackageNotFoundError, entry_points # type: ignore + from ._theme_class import WebvizConfigTheme from ._webviz_settings_class import WebvizSettings @@ -12,7 +13,11 @@ from ._plugin_abc import WebvizPluginABC, EncodedFile, ZipFileMember from ._shared_settings_subscriptions import SHARED_SETTINGS_SUBSCRIPTIONS from ._oauth2 import Oauth2 +from .webviz_storage import WebvizStorageTypeRegistry, WebvizStorageType +for entry_point in entry_points().get("webviz_config_return_types", []): + theme = entry_point.load() + globals()[entry_point.name] = theme try: __version__ = version("webviz-config") except PackageNotFoundError: diff --git a/webviz_config/webviz_storage/__init__.py b/webviz_config/webviz_storage/__init__.py new file mode 100644 index 00000000..4fd681e8 --- /dev/null +++ b/webviz_config/webviz_storage/__init__.py @@ -0,0 +1,128 @@ +import io +import abc +import glob +import shutil +import functools +import hashlib +import inspect +import pathlib +import warnings +from collections import defaultdict +from typing import Callable, List, Union, Any, Type, Dict +import weakref + +import numpy as np +import pandas as pd +from tqdm import tqdm + + +class ClassProperty: + def __init__(self, fget: Callable): + self.fget = fget + + def __get__(self, owner_self: Type[Any], owner_cls: Type[Any]) -> Any: + print(type(owner_self), type(owner_cls)) + print(type(self.fget(owner_cls))) + return self.fget(owner_cls) + + +class WebvizStorageType(abc.ABC): + """ Base class for a webviz storage type """ + + @staticmethod + @abc.abstractmethod + def get_data(path: str, **kwargs: Dict) -> None: + """ Abstract method to retrieve stored data """ + + @staticmethod + @abc.abstractmethod + def save_data(data: Any, path: str) -> Any: + """ Abstract method to save data to store """ + + +class WebvizStorageTypeRegistry: + """ Registry of allowed webviz storage types """ + + registry: Dict = {} + + @classmethod + def register(cls, return_type: Type) -> Callable: + def inner_wrapper(wrapped_class: Type) -> Type: + if return_type in cls.registry: + print(f"Storage type {return_type} already exists. Will replace it") + cls.registry[return_type] = wrapped_class + return wrapped_class + + return inner_wrapper + + @classmethod + def create_storage_type( + cls, return_type: str, **kwargs: Dict + ) -> Union[None, WebvizStorageType]: + """Factory command to create the storage type. + This method gets the appropriate WebvizStorageType class from the registry + and creates an instance of it, while passing in the parameters + given in ``kwargs``. + Args: + return_type (str): The type of the storage type to create. + Returns: + An instance of the storage type that is created. + """ + + if return_type not in cls.registry: + print(f"Storage type {return_type} does not exist in the registry") + return None + + exec_class = cls.registry[return_type] + return exec_class(**kwargs) + + # pylint: disable=no-self-argument + @ClassProperty + def return_types(cls) -> List: + return list(cls.registry.keys()) + + +@WebvizStorageTypeRegistry.register(pd.DataFrame) +class TypeDataFrame(WebvizStorageType): + @staticmethod + def get_data(path: str, **kwargs: Dict) -> Any: + return pd.read_parquet(f"{path}.parquet") + + @staticmethod + def save_data(data: Any, path: str) -> None: + data.to_parquet(f"{path}.parquet") + + +@WebvizStorageTypeRegistry.register(pathlib.Path) +@WebvizStorageTypeRegistry.register(pathlib.PosixPath) +class TypePath(WebvizStorageType): + @staticmethod + def get_data(path: str, **kwargs: Dict) -> Any: + return pathlib.Path(glob.glob(f"{path}*")[0]) + + @staticmethod + def save_data(data: Any, path: str) -> None: + shutil.copy(data, f"{path}{data.suffix}") + + +@WebvizStorageTypeRegistry.register(list) +@WebvizStorageTypeRegistry.register(List) +class TypeList(WebvizStorageType): + @staticmethod + def get_data(path: str, **kwargs: Dict) -> Any: + return np.load(f"{path}.npy").tolist() + + @staticmethod + def save_data(data: Any, path: str) -> None: + np.save(f"{path}.npy", data) + + +@WebvizStorageTypeRegistry.register(io.BytesIO) +class TypeBytesIO(WebvizStorageType): + @staticmethod + def get_data(path: str, **kwargs: Dict) -> Any: + return np.load(f"{path}.npy").tolist() + + @staticmethod + def save_data(data: Any, path: str) -> None: + pathlib.Path(path).write_bytes(data.getvalue()) diff --git a/webviz_config/webviz_store.py b/webviz_config/webviz_store.py index 940871bb..9bee375b 100644 --- a/webviz_config/webviz_store.py +++ b/webviz_config/webviz_store.py @@ -1,21 +1,20 @@ -import io -import glob -import shutil import functools import hashlib import inspect import pathlib import warnings from collections import defaultdict -from typing import Callable, List, Union, Any +from typing import Callable, List, Any import pandas as pd from tqdm import tqdm +from .webviz_storage import WebvizStorageTypeRegistry, WebvizStorageType -class WebvizStorage: - RETURN_TYPES = [pd.DataFrame, pathlib.Path, io.BytesIO] +class WebvizStorage: + # pylint: disable=unsupported-membership-test + RETURN_TYPES = WebvizStorageTypeRegistry.return_types def __init__(self) -> None: self._use_storage = False @@ -27,7 +26,6 @@ def register_function(self, func: Callable) -> None: decorator @webvizstore, registering the function it decorates. """ return_type = inspect.getfullargspec(func).annotations["return"] - if return_type not in WebvizStorage.RETURN_TYPES: raise NotImplementedError( f"Webviz storage type must be one of {WebvizStorage.RETURN_TYPES}" @@ -145,26 +143,21 @@ def complete_kwargs(func: Callable, kwargs: dict) -> dict: return kwargs def get_stored_data( - self, func: Callable, *args: Any, **kwargs: Any - ) -> Union[pd.DataFrame, pathlib.Path, io.BytesIO]: - + self, func: Callable, *args: Any, webviz_load=None, **kwargs: Any + ) -> WebvizStorageType: + load_args = webviz_load if webviz_load is not None else {} argspec = inspect.getfullargspec(func) for arg_name, arg in zip(argspec.args, args): kwargs[arg_name] = arg WebvizStorage.complete_kwargs(func, kwargs) return_type = inspect.getfullargspec(func).annotations["return"] - path = self._unique_path(func, WebvizStorage._dict_to_tuples(kwargs)) + storagetype = WebvizStorageTypeRegistry.create_storage_type(return_type) try: - if return_type == pd.DataFrame: - return pd.read_parquet(f"{path}.parquet") - if return_type == pathlib.Path: - return pathlib.Path(glob.glob(f"{path}*")[0]) - if return_type == io.BytesIO: - return io.BytesIO(pathlib.Path(path).read_bytes()) - raise ValueError(f"Unknown return type {return_type}") + return storagetype.get_data(path=path, load_args=load_args) + # raise ValueError(f"Unknown return type {return_type}") except OSError as exc: raise OSError( @@ -190,15 +183,10 @@ def build_store(self) -> None: for argtuples in self.storage_function_argvalues[func].values(): output = func(**dict(argtuples)) path = self._unique_path(func, argtuples) - - if isinstance(output, pd.DataFrame): - output.to_parquet(f"{path}.parquet") - elif isinstance(output, pathlib.Path): - shutil.copy(output, f"{path}{output.suffix}") - elif isinstance(output, io.BytesIO): - pathlib.Path(path).write_bytes(output.getvalue()) - else: - raise ValueError(f"Unknown return type {type(output)}") + storagetype = WebvizStorageTypeRegistry.create_storage_type( + type(output) + ) + storagetype.save_data(data=output, path=path) progress_bar.update()