Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docstrings for MB._convert_observations & _extract_observation_data #1908

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,28 @@ def _convert_observations(
]:
"""Converts observations to a dictionary of `Dataset` containers and (optional)
candidate metadata.

Args:
observation_data: A list of `ObservationData` from which to extract
mean `Y` and variance `Yvar` observations. Must correspond 1:1 to
the `observation_features`.
observation_features: A list of `ObservationFeatures` from which to extract
parameter values. Must correspond 1:1 to the `observation_data`.
outcomes: The names of the outcomes to extract observations for.
parameters: The names of the parameters to extract. Any observation features
that are not included in `parameters` will be ignored.
search_space_digest: An optional `SearchSpaceDigest` containing information
about the search space. This is used to convert datasets into a
`MultiTaskDataset` where applicable.

Returns:
- A list of `Dataset` objects corresponding to each outcome. Each element
in the list corresponds to one outcome. If the outcome does not have
any observations, then the corresponding element in the list will be
`None`.
- An optional list of lists of candidate metadata. Each inner list
corresponds to one outcome. Each element in the inner list corresponds
to one observation.
"""
(
Xs,
Expand Down Expand Up @@ -944,9 +966,34 @@ def _extract_observation_data(
observation_data: List[ObservationData],
observation_features: List[ObservationFeatures],
parameters: List[str],
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict` to avoid runtime subscripting errors.
) -> Tuple[Dict, Dict, Dict, Dict, bool]:
) -> Tuple[
Dict[str, List[Tensor]],
Dict[str, List[Tensor]],
Dict[str, List[Tensor]],
Dict[str, List[TCandidateMetadata]],
bool,
]:
"""Extract observation features & data into tensors and metadata.

Args:
observation_data: A list of `ObservationData` from which to extract
mean `Y` and variance `Yvar` observations. Must correspond 1:1 to
the `observation_features`.
observation_features: A list of `ObservationFeatures` from which to extract
parameter values. Must correspond 1:1 to the `observation_data`.
parameters: The names of the parameters to extract. Any observation features
that are not included in `parameters` will be ignored.

Returns:
- A dictionary mapping metric names to lists of corresponding feature
tensors `X`.
- A dictionary mapping metric names to lists of corresponding mean
observation tensors `Y`.
- A dictionary mapping metric names to lists of corresponding variance
observation tensors `Yvar`.
- A dictionary mapping metric names to lists of corresponding metadata.
- A boolean denoting whether any candidate metadata is not none.
"""
Xs: Dict[str, List[Tensor]] = defaultdict(list)
Ys: Dict[str, List[Tensor]] = defaultdict(list)
Yvars: Dict[str, List[Tensor]] = defaultdict(list)
Expand Down