Skip to content

Commit

Permalink
SerializationMixin.deserialize_init_args signature to take decoder_re…
Browse files Browse the repository at this point in the history
…gistry and class_decoder_registry (#1939)

Summary:
Pull Request resolved: #1939

Descendents of SerializationMixin (Metrics, Runners, and Data objects) should be able to deserialize init args that are themselves Ax types. This is enabled by passing in a decoding registries, which individual deserialization functions can optionally employ. This is necessary because importing object_from_json into a Runner can create a circular dependency.

In this diff, registries are passed in but never used, so this is a no-op.

Reviewed By: lena-kashtelyan

Differential Revision: D50665111

fbshipit-source-id: ac7f5aff70e59de438b5cc953266d53d54f318dd
  • Loading branch information
bernardbeckerman authored and facebook-github-bot committed Oct 30, 2023
1 parent 94c1cbd commit 80eb384
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 12 deletions.
10 changes: 8 additions & 2 deletions ax/benchmark/problems/hpo/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# LICENSE file in the root directory of this source tree.

import os
from typing import Any, Dict
from typing import Any, Dict, Optional

from ax.benchmark.problems.hpo.pytorch_cnn import (
PyTorchCNNBenchmarkProblem,
PyTorchCNNRunner,
)
from ax.exceptions.core import UserInputError
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry
from ax.utils.common.typeutils import checked_cast
from torch.utils.data import TensorDataset

Expand Down Expand Up @@ -103,7 +104,12 @@ def serialize_init_args(cls, obj: Any) -> Dict[str, Any]:
return {"name": pytorch_cnn_runner.name}

@classmethod
def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]:
def deserialize_init_args(
cls,
args: Dict[str, Any],
decoder_registry: Optional[TDecoderRegistry] = None,
class_decoder_registry: Optional[TClassDecoderRegistry] = None,
) -> Dict[str, Any]:
name = args["name"]

dataset_fn = _REGISTRY[name]
Expand Down
8 changes: 7 additions & 1 deletion ax/benchmark/problems/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from ax.utils.common.result import Err, Ok
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry
from ax.utils.common.typeutils import not_none
from botorch.utils.datasets import SupervisedDataset

Expand Down Expand Up @@ -311,5 +312,10 @@ def serialize_init_args(cls, obj: Any) -> Dict[str, Any]:
return {}

@classmethod
def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]:
def deserialize_init_args(
cls,
args: Dict[str, Any],
decoder_registry: Optional[TDecoderRegistry] = None,
class_decoder_registry: Optional[TClassDecoderRegistry] = None,
) -> Dict[str, Any]:
return {}
9 changes: 8 additions & 1 deletion ax/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
extract_init_args,
SerializationMixin,
serialize_init_args,
TClassDecoderRegistry,
TDecoderRegistry,
)
from ax.utils.common.typeutils import checked_cast, not_none

Expand Down Expand Up @@ -184,7 +186,12 @@ def serialize_init_args(cls, obj: Any) -> Dict[str, Any]:
return serialize_init_args(obj=data)

@classmethod
def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]:
def deserialize_init_args(
cls,
args: Dict[str, Any],
decoder_registry: Optional[TDecoderRegistry] = None,
class_decoder_registry: Optional[TClassDecoderRegistry] = None,
) -> Dict[str, Any]:
"""Given a dictionary, extract the properties needed to initialize the object.
Used for storage.
"""
Expand Down
13 changes: 11 additions & 2 deletions ax/core/map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from ax.utils.common.docutils import copy_doc
from ax.utils.common.equality import dataframe_equals
from ax.utils.common.logger import get_logger
from ax.utils.common.serialization import serialize_init_args
from ax.utils.common.serialization import (
serialize_init_args,
TClassDecoderRegistry,
TDecoderRegistry,
)
from ax.utils.common.typeutils import checked_cast

logger: Logger = get_logger(__name__)
Expand Down Expand Up @@ -304,7 +308,12 @@ def serialize_init_args(cls, obj: Any) -> Dict[str, Any]:
return properties

@classmethod
def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]:
def deserialize_init_args(
cls,
args: Dict[str, Any],
decoder_registry: Optional[TDecoderRegistry] = None,
class_decoder_registry: Optional[TClassDecoderRegistry] = None,
) -> Dict[str, Any]:
"""Given a dictionary, extract the properties needed to initialize the metric.
Used for storage.
"""
Expand Down
8 changes: 7 additions & 1 deletion ax/runners/botorch_test_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ax.core.runner import Runner
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry
from ax.utils.common.typeutils import checked_cast
from botorch.test_functions.base import BaseTestProblem, ConstrainedBaseTestProblem
from botorch.utils.transforms import normalize, unnormalize
Expand Down Expand Up @@ -133,7 +134,12 @@ def serialize_init_args(cls, obj: Any) -> Dict[str, Any]:
}

@classmethod
def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]:
def deserialize_init_args(
cls,
args: Dict[str, Any],
decoder_registry: Optional[TDecoderRegistry] = None,
class_decoder_registry: Optional[TClassDecoderRegistry] = None,
) -> Dict[str, Any]:
"""Given a dictionary, deserialize the properties needed to initialize the
runner. Used for storage.
"""
Expand Down
8 changes: 7 additions & 1 deletion ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,13 @@ def object_from_json(
"outcome_transform_options"
] = outcome_transform_options_json
elif isclass(_class) and issubclass(_class, SerializationMixin):
return _class(**_class.deserialize_init_args(args=object_json))
return _class(
**_class.deserialize_init_args(
args=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
)

return ax_class_from_json_dict(
_class=_class,
Expand Down
4 changes: 3 additions & 1 deletion ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,9 @@ def runner_from_sqa(

try:
args = runner_class.deserialize_init_args(
args=dict(runner_sqa.properties or {})
args=dict(runner_sqa.properties or {}),
decoder_registry=self.config.json_decoder_registry,
class_decoder_registry=self.config.json_class_decoder_registry,
)
args.update(runner_kwargs or {})
# pyre-ignore[45]: Cannot instantiate abstract class `Runner`.
Expand Down
17 changes: 14 additions & 3 deletions ax/utils/common/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@

import inspect
import pydoc
from abc import ABC
from types import FunctionType
from typing import Any, Callable, Dict, List, Optional, Type


# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
# avoid runtime subscripting errors.
TDecoderRegistry = Dict[str, Type]
# pyre-fixme[33]: `TClassDecoderRegistry` cannot alias to a type containing `Any`.
TClassDecoderRegistry = Dict[str, Callable[[Dict[str, Any]], Any]]


# https://stackoverflow.com/a/39235373
# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
Expand Down Expand Up @@ -121,7 +127,7 @@ def extract_init_args(args: Dict[str, Any], class_: Type) -> Dict[str, Any]:
return init_args


class SerializationMixin(ABC):
class SerializationMixin:
@classmethod
def serialize_init_args(cls, obj: SerializationMixin) -> Dict[str, Any]:
"""Serialize the properties needed to initialize the object.
Expand All @@ -130,7 +136,12 @@ def serialize_init_args(cls, obj: SerializationMixin) -> Dict[str, Any]:
return serialize_init_args(obj=obj)

@classmethod
def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]:
def deserialize_init_args(
cls,
args: Dict[str, Any],
decoder_registry: Optional[TDecoderRegistry] = None,
class_decoder_registry: Optional[TClassDecoderRegistry] = None,
) -> Dict[str, Any]:
"""Given a dictionary, deserialize the properties needed to initialize the
object. Used for storage.
"""
Expand Down

0 comments on commit 80eb384

Please sign in to comment.