From 8c2d28767a09a866336a6d8bb9bec6c952e89c1a Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Mon, 16 Oct 2023 11:30:26 -0700 Subject: [PATCH] Docstrings for MB._convert_observations & _extract_observation_data Summary: --- Differential Revision: D50333012 --- ax/modelbridge/torch.py | 53 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/ax/modelbridge/torch.py b/ax/modelbridge/torch.py index 4b2b346d15e..430a38f900f 100644 --- a/ax/modelbridge/torch.py +++ b/ax/modelbridge/torch.py @@ -303,6 +303,28 @@ def _convert_observations( ]: """Converts observations to a dictionary of `Dataset` containers and (optional) candidate metadata. + + Args: + observation_data: A list of `ObservationData` from which to extract + mean `Y` and variance `Yvar` observations. Must correspond 1:1 to + the `observation_features`. + observation_features: A list of `ObservationFeatures` from which to extract + parameter values. Must correspond 1:1 to the `observation_data`. + outcomes: The names of the outcomes to extract observations for. + parameters: The names of the parameters to extract. Any observation features + that are not included in `parameters` will be ignored. + search_space_digest: An optional `SearchSpaceDigest` containing information + about the search space. This is used to convert datasets into a + `MultiTaskDataset` where applicable. + + Returns: + - A list of `Dataset` objects corresponding to each outcome. Each element + in the list corresponds to one outcome. If the outcome does not have + any observations, then the corresponding element in the list will be + `None`. + - An optional list of lists of candidate metadata. Each inner list + corresponds to one outcome. Each element in the inner list corresponds + to one observation. """ ( Xs, @@ -944,9 +966,34 @@ def _extract_observation_data( observation_data: List[ObservationData], observation_features: List[ObservationFeatures], parameters: List[str], - # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use - # `typing.Dict` to avoid runtime subscripting errors. - ) -> Tuple[Dict, Dict, Dict, Dict, bool]: + ) -> Tuple[ + Dict[str, List[Tensor]], + Dict[str, List[Tensor]], + Dict[str, List[Tensor]], + Dict[str, List[TCandidateMetadata]], + bool, + ]: + """Extract observation features & data into tensors and metadata. + + Args: + observation_data: A list of `ObservationData` from which to extract + mean `Y` and variance `Yvar` observations. Must correspond 1:1 to + the `observation_features`. + observation_features: A list of `ObservationFeatures` from which to extract + parameter values. Must correspond 1:1 to the `observation_data`. + parameters: The names of the parameters to extract. Any observation features + that are not included in `parameters` will be ignored. + + Returns: + - A dictionary mapping metric names to lists of corresponding feature + tensors `X`. + - A dictionary mapping metric names to lists of corresponding mean + observation tensors `Y`. + - A dictionary mapping metric names to lists of corresponding variance + observation tensors `Yvar`. + - A dictionary mapping metric names to lists of corresponding metadata. + - A boolean denoting whether any candidate metadata is not none. + """ Xs: Dict[str, List[Tensor]] = defaultdict(list) Ys: Dict[str, List[Tensor]] = defaultdict(list) Yvars: Dict[str, List[Tensor]] = defaultdict(list)