diff --git a/ax/benchmark/problems/hpo/torchvision.py b/ax/benchmark/problems/hpo/torchvision.py index 0269ea30640..564db2039fa 100644 --- a/ax/benchmark/problems/hpo/torchvision.py +++ b/ax/benchmark/problems/hpo/torchvision.py @@ -4,13 +4,14 @@ # LICENSE file in the root directory of this source tree. import os -from typing import Any, Dict +from typing import Any, Dict, Optional from ax.benchmark.problems.hpo.pytorch_cnn import ( PyTorchCNNBenchmarkProblem, PyTorchCNNRunner, ) from ax.exceptions.core import UserInputError +from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry from ax.utils.common.typeutils import checked_cast from torch.utils.data import TensorDataset @@ -103,7 +104,12 @@ def serialize_init_args(cls, obj: Any) -> Dict[str, Any]: return {"name": pytorch_cnn_runner.name} @classmethod - def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]: + def deserialize_init_args( + cls, + args: Dict[str, Any], + decoder_registry: Optional[TDecoderRegistry] = None, + class_decoder_registry: Optional[TClassDecoderRegistry] = None, + ) -> Dict[str, Any]: name = args["name"] dataset_fn = _REGISTRY[name] diff --git a/ax/benchmark/problems/surrogate.py b/ax/benchmark/problems/surrogate.py index 923ac673ac7..915ddb1b2db 100644 --- a/ax/benchmark/problems/surrogate.py +++ b/ax/benchmark/problems/surrogate.py @@ -27,6 +27,7 @@ from ax.utils.common.base import Base from ax.utils.common.equality import equality_typechecker from ax.utils.common.result import Err, Ok +from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry from ax.utils.common.typeutils import not_none from botorch.utils.datasets import SupervisedDataset @@ -311,5 +312,10 @@ def serialize_init_args(cls, obj: Any) -> Dict[str, Any]: return {} @classmethod - def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]: + def deserialize_init_args( + cls, + args: Dict[str, Any], + decoder_registry: Optional[TDecoderRegistry] = None, + class_decoder_registry: Optional[TClassDecoderRegistry] = None, + ) -> Dict[str, Any]: return {} diff --git a/ax/core/data.py b/ax/core/data.py index 833dda06153..3a3c843487d 100644 --- a/ax/core/data.py +++ b/ax/core/data.py @@ -20,6 +20,8 @@ extract_init_args, SerializationMixin, serialize_init_args, + TClassDecoderRegistry, + TDecoderRegistry, ) from ax.utils.common.typeutils import checked_cast, not_none @@ -184,7 +186,12 @@ def serialize_init_args(cls, obj: Any) -> Dict[str, Any]: return serialize_init_args(obj=data) @classmethod - def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]: + def deserialize_init_args( + cls, + args: Dict[str, Any], + decoder_registry: Optional[TDecoderRegistry] = None, + class_decoder_registry: Optional[TClassDecoderRegistry] = None, + ) -> Dict[str, Any]: """Given a dictionary, extract the properties needed to initialize the object. Used for storage. """ diff --git a/ax/core/map_data.py b/ax/core/map_data.py index d144da4bcbc..430270c92d1 100644 --- a/ax/core/map_data.py +++ b/ax/core/map_data.py @@ -18,7 +18,11 @@ from ax.utils.common.docutils import copy_doc from ax.utils.common.equality import dataframe_equals from ax.utils.common.logger import get_logger -from ax.utils.common.serialization import serialize_init_args +from ax.utils.common.serialization import ( + serialize_init_args, + TClassDecoderRegistry, + TDecoderRegistry, +) from ax.utils.common.typeutils import checked_cast logger: Logger = get_logger(__name__) @@ -304,7 +308,12 @@ def serialize_init_args(cls, obj: Any) -> Dict[str, Any]: return properties @classmethod - def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]: + def deserialize_init_args( + cls, + args: Dict[str, Any], + decoder_registry: Optional[TDecoderRegistry] = None, + class_decoder_registry: Optional[TClassDecoderRegistry] = None, + ) -> Dict[str, Any]: """Given a dictionary, extract the properties needed to initialize the metric. Used for storage. """ diff --git a/ax/runners/botorch_test_problem.py b/ax/runners/botorch_test_problem.py index cdff8c36561..f6553a7933c 100644 --- a/ax/runners/botorch_test_problem.py +++ b/ax/runners/botorch_test_problem.py @@ -11,6 +11,7 @@ from ax.core.runner import Runner from ax.utils.common.base import Base from ax.utils.common.equality import equality_typechecker +from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry from ax.utils.common.typeutils import checked_cast from botorch.test_functions.base import BaseTestProblem, ConstrainedBaseTestProblem from botorch.utils.transforms import normalize, unnormalize @@ -133,7 +134,12 @@ def serialize_init_args(cls, obj: Any) -> Dict[str, Any]: } @classmethod - def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]: + def deserialize_init_args( + cls, + args: Dict[str, Any], + decoder_registry: Optional[TDecoderRegistry] = None, + class_decoder_registry: Optional[TClassDecoderRegistry] = None, + ) -> Dict[str, Any]: """Given a dictionary, deserialize the properties needed to initialize the runner. Used for storage. """ diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 2fc89c570fd..1de3b72ed28 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -238,7 +238,13 @@ def object_from_json( "outcome_transform_options" ] = outcome_transform_options_json elif isclass(_class) and issubclass(_class, SerializationMixin): - return _class(**_class.deserialize_init_args(args=object_json)) + return _class( + **_class.deserialize_init_args( + args=object_json, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + ) return ax_class_from_json_dict( _class=_class, diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 1984983e533..adeccd4c94b 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -801,7 +801,9 @@ def runner_from_sqa( try: args = runner_class.deserialize_init_args( - args=dict(runner_sqa.properties or {}) + args=dict(runner_sqa.properties or {}), + decoder_registry=self.config.json_decoder_registry, + class_decoder_registry=self.config.json_class_decoder_registry, ) args.update(runner_kwargs or {}) # pyre-ignore[45]: Cannot instantiate abstract class `Runner`. diff --git a/ax/utils/common/serialization.py b/ax/utils/common/serialization.py index fb902fbab12..b66a19c0712 100644 --- a/ax/utils/common/serialization.py +++ b/ax/utils/common/serialization.py @@ -8,11 +8,17 @@ import inspect import pydoc -from abc import ABC from types import FunctionType from typing import Any, Callable, Dict, List, Optional, Type +# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to +# avoid runtime subscripting errors. +TDecoderRegistry = Dict[str, Type] +# pyre-fixme[33]: `TClassDecoderRegistry` cannot alias to a type containing `Any`. +TClassDecoderRegistry = Dict[str, Callable[[Dict[str, Any]], Any]] + + # https://stackoverflow.com/a/39235373 # pyre-fixme[3]: Return annotation cannot be `Any`. # pyre-fixme[2]: Parameter annotation cannot be `Any`. @@ -121,7 +127,7 @@ def extract_init_args(args: Dict[str, Any], class_: Type) -> Dict[str, Any]: return init_args -class SerializationMixin(ABC): +class SerializationMixin: @classmethod def serialize_init_args(cls, obj: SerializationMixin) -> Dict[str, Any]: """Serialize the properties needed to initialize the object. @@ -130,7 +136,12 @@ def serialize_init_args(cls, obj: SerializationMixin) -> Dict[str, Any]: return serialize_init_args(obj=obj) @classmethod - def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]: + def deserialize_init_args( + cls, + args: Dict[str, Any], + decoder_registry: Optional[TDecoderRegistry] = None, + class_decoder_registry: Optional[TClassDecoderRegistry] = None, + ) -> Dict[str, Any]: """Given a dictionary, deserialize the properties needed to initialize the object. Used for storage. """