diff --git a/ax/core/observation.py b/ax/core/observation.py index a28be42520e..1fb7d811348 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -417,7 +417,7 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]: feature_cols = OBS_COLS.intersection(data.df.columns) # note we use this check, rather than isinstance, since # only some Modelbridges (e.g. MapTorchModelBridge) - # use observations_from_map_data, which is required + # use observations_from_data, which is required # to properly handle MapData features (e.g. fidelity). if is_map_data: data = checked_cast(MapData, data) @@ -437,174 +437,103 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]: def observations_from_data( experiment: experiment.Experiment, - data: Data, - statuses_to_include: set[TrialStatus] | None = None, - statuses_to_include_map_metric: set[TrialStatus] | None = None, -) -> list[Observation]: - """Convert Data to observations. - - Converts a Data object to a list of Observation objects. Pulls arm parameters from - from experiment. Overrides fidelity parameters in the arm with those found in the - Data object. - - Uses a diagonal covariance matrix across metric_names. - - Args: - experiment: Experiment with arm parameters. - data: Data of observations. - statuses_to_include: data from non-MapMetrics will only be included for trials - with statuses in this set. Defaults to all statuses except abandoned. - statuses_to_include_map_metric: data from MapMetrics will only be included for - trials with statuses in this set. Defaults to completed status only. - - Returns: - List of Observation objects. - """ - if statuses_to_include is None: - statuses_to_include = NON_ABANDONED_STATUSES - if statuses_to_include_map_metric is None: - statuses_to_include_map_metric = {TrialStatus.COMPLETED} - feature_cols = get_feature_cols(data) - observations = [] - arm_name_only = len(feature_cols) == 1 # there will always be an arm name - # One DataFrame where all rows have all features. - isnull = data.df[feature_cols].isnull() - isnull_any = isnull.any(axis=1) - incomplete_df_cols = isnull[isnull_any].any() - - # Get the incomplete_df columns that are complete, and usable as groupby keys. - complete_feature_cols = list( - OBS_COLS.intersection(incomplete_df_cols.index[~incomplete_df_cols]) - ) - - if set(feature_cols) == set(complete_feature_cols): - complete_df = data.df - incomplete_df = None - else: - # The groupby and filter is expensive, so do it only if we have to. - grouped = data.df.groupby(by=complete_feature_cols) - complete_df = grouped.filter(lambda r: ~r[feature_cols].isnull().any().any()) - incomplete_df = grouped.filter(lambda r: r[feature_cols].isnull().any().any()) - - # Get Observations from complete_df - observations.extend( - _observations_from_dataframe( - experiment=experiment, - df=complete_df, - cols=feature_cols, - arm_name_only=arm_name_only, - statuses_to_include=statuses_to_include, - statuses_to_include_map_metric=statuses_to_include_map_metric, - map_keys=[], - ) - ) - if incomplete_df is not None: - # Get Observations from incomplete_df - observations.extend( - _observations_from_dataframe( - experiment=experiment, - df=incomplete_df, - cols=complete_feature_cols, - arm_name_only=arm_name_only, - statuses_to_include=statuses_to_include, - statuses_to_include_map_metric=statuses_to_include_map_metric, - map_keys=[], - ) - ) - return observations - - -def observations_from_map_data( - experiment: experiment.Experiment, - map_data: MapData, + data: Data | MapData, statuses_to_include: set[TrialStatus] | None = None, statuses_to_include_map_metric: set[TrialStatus] | None = None, map_keys_as_parameters: bool = False, limit_rows_per_metric: int | None = None, limit_rows_per_group: int | None = None, ) -> list[Observation]: - """Convert MapData to observations. + """Convert Data (or MapData) to observations. - Converts a MapData object to a list of Observation objects. Pulls arm parameters - from experiment. Overrides fidelity parameters in the arm with those found in the - Data object. + Converts a Data (or MapData) object to a list of Observation objects. + Pulls arm parameters from from experiment. Overrides fidelity parameters + in the arm with those found in the Data object. Uses a diagonal covariance matrix across metric_names. Args: experiment: Experiment with arm parameters. - map_data: MapData of observations. + data: Data (or MapData) of observations. statuses_to_include: data from non-MapMetrics will only be included for trials with statuses in this set. Defaults to all statuses except abandoned. statuses_to_include_map_metric: data from MapMetrics will only be included for trials with statuses in this set. Defaults to all statuses except abandoned. map_keys_as_parameters: Whether map_keys should be returned as part of the parameters of the Observation objects. - limit_rows_per_metric: If specified, uses MapData.subsample() with + limit_rows_per_metric: If specified, and if data is an instance of MapData, + uses MapData.subsample() with `limit_rows_per_metric` equal to the specified value on the first map_key (map_data.map_keys[0]) to subsample the MapData. This is useful in, e.g., cases where learning curves are frequently updated, leading to an intractable number of Observation objects created. - limit_rows_per_group: If specified, uses MapData.subsample() with + limit_rows_per_group: If specified, and if data is an instance of MapData, + uses MapData.subsample() with `limit_rows_per_group` equal to the specified value on the first map_key (map_data.map_keys[0]) to subsample the MapData. Returns: List of Observation objects. """ + is_map_data = isinstance(data, MapData) + if statuses_to_include is None: statuses_to_include = NON_ABANDONED_STATUSES if statuses_to_include_map_metric is None: statuses_to_include_map_metric = NON_ABANDONED_STATUSES - if limit_rows_per_metric is not None or limit_rows_per_group is not None: - map_data = map_data.subsample( - map_key=map_data.map_keys[0], - limit_rows_per_metric=limit_rows_per_metric, - limit_rows_per_group=limit_rows_per_group, - include_first_last=True, - ) - feature_cols = get_feature_cols(map_data, is_map_data=True) - observations = [] + + map_keys = [] + obs_cols = OBS_COLS + if is_map_data: + data = checked_cast(MapData, data) + + if limit_rows_per_metric is not None or limit_rows_per_group is not None: + data = data.subsample( + map_key=data.map_keys[0], + limit_rows_per_metric=limit_rows_per_metric, + limit_rows_per_group=limit_rows_per_group, + include_first_last=True, + ) + + map_keys.extend(data.map_keys) + obs_cols = obs_cols.union(data.map_keys) + df = data.map_df + else: + df = data.df + + feature_cols = get_feature_cols(data, is_map_data=is_map_data) + arm_name_only = len(feature_cols) == 1 # there will always be an arm name # One DataFrame where all rows have all features. - isnull = map_data.map_df[feature_cols].isnull() + isnull = df[feature_cols].isnull() isnull_any = isnull.any(axis=1) incomplete_df_cols = isnull[isnull_any].any() # Get the incomplete_df columns that are complete, and usable as groupby keys. - obs_cols_and_map = OBS_COLS.union(map_data.map_keys) complete_feature_cols = list( - obs_cols_and_map.intersection(incomplete_df_cols.index[~incomplete_df_cols]) + obs_cols.intersection(incomplete_df_cols.index[~incomplete_df_cols]) ) if set(feature_cols) == set(complete_feature_cols): - complete_df = map_data.map_df + complete_df = df incomplete_df = None else: # The groupby and filter is expensive, so do it only if we have to. - grouped = map_data.map_df.groupby( - by=( - complete_feature_cols - if len(complete_feature_cols) > 1 - else complete_feature_cols[0] - ) - ) + grouped = df.groupby(by=complete_feature_cols) complete_df = grouped.filter(lambda r: ~r[feature_cols].isnull().any().any()) incomplete_df = grouped.filter(lambda r: r[feature_cols].isnull().any().any()) # Get Observations from complete_df - observations.extend( - _observations_from_dataframe( - experiment=experiment, - df=complete_df, - cols=feature_cols, - arm_name_only=arm_name_only, - map_keys=map_data.map_keys, - statuses_to_include=statuses_to_include, - statuses_to_include_map_metric=statuses_to_include_map_metric, - map_keys_as_parameters=map_keys_as_parameters, - ) + observations = _observations_from_dataframe( + experiment=experiment, + df=complete_df, + cols=feature_cols, + arm_name_only=arm_name_only, + map_keys=map_keys, + statuses_to_include=statuses_to_include, + statuses_to_include_map_metric=statuses_to_include_map_metric, + map_keys_as_parameters=map_keys_as_parameters, ) if incomplete_df is not None: # Get Observations from incomplete_df @@ -614,7 +543,7 @@ def observations_from_map_data( df=incomplete_df, cols=complete_feature_cols, arm_name_only=arm_name_only, - map_keys=map_data.map_keys, + map_keys=map_keys, statuses_to_include=statuses_to_include, statuses_to_include_map_metric=statuses_to_include_map_metric, map_keys_as_parameters=map_keys_as_parameters, diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index 2c304353502..849cc69ab2f 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -24,7 +24,6 @@ ObservationData, ObservationFeatures, observations_from_data, - observations_from_map_data, recombine_observations, separate_observations, ) @@ -475,7 +474,7 @@ def test_ObservationsFromMapData(self) -> None: MapKeyInfo(key="timestamp", default_value=0.0), ], ) - observations = observations_from_map_data(experiment, data) + observations = observations_from_data(experiment, data) self.assertEqual(len(observations), 3) diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index e60193b1e75..b54be842852 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -297,6 +297,7 @@ def _prepare_observations( data=data, statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, + map_keys_as_parameters=False, ) def _transform_data( diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index 04ca083d476..ef2fea3a958 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -19,7 +19,7 @@ Observation, ObservationData, ObservationFeatures, - observations_from_map_data, + observations_from_data, separate_observations, ) from ax.core.optimization_config import OptimizationConfig @@ -242,19 +242,16 @@ def _array_to_observation_features( def _prepare_observations( self, experiment: Experiment | None, data: Data | None ) -> list[Observation]: - """The difference b/t this method and ModelBridge._prepare_observations(...) - is that this one uses `observations_from_map_data`. - """ if experiment is None or data is None: return [] - return observations_from_map_data( + return observations_from_data( experiment=experiment, - map_data=data, # pyre-ignore[6]: Checked in __init__. - map_keys_as_parameters=True, + data=data, limit_rows_per_metric=self._map_data_limit_rows_per_metric, limit_rows_per_group=self._map_data_limit_rows_per_group, statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, + map_keys_as_parameters=True, ) def _compute_in_design(