From 1e96293e8352067e6bdf3e22f6293070a0414ab4 Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Fri, 6 Oct 2023 09:23:41 -0700 Subject: [PATCH] Deprecate `use_update` as discussed (#1895) Summary: The main use case for `use_update` in generation strategy is a stateful model, where we really don't want to restore the state on each call to `fit` and instead want to expect a call to `update` with only the new data. I think at this point supporting `use_update` just adds lots of complexity to generation strategy, and it'll be great to see that complexity gone (this diff shows it in the deleted code). Reviewed By: saitcakmak Differential Revision: D49465655 --- ax/modelbridge/dispatch_utils.py | 28 +--- ax/modelbridge/generation_node.py | 106 +----------- ax/modelbridge/generation_strategy.py | 154 +++++------------- ax/modelbridge/tests/test_dispatch_utils.py | 29 ---- .../tests/test_generation_strategy.py | 66 -------- ax/service/ax_client.py | 12 +- ax/service/scheduler.py | 2 +- ax/service/tests/test_ax_client.py | 4 +- ax/service/tests/test_scheduler.py | 4 +- ax/service/utils/best_point.py | 2 +- ax/storage/json_store/decoder.py | 1 - 11 files changed, 56 insertions(+), 352 deletions(-) diff --git a/ax/modelbridge/dispatch_utils.py b/ax/modelbridge/dispatch_utils.py index 48e60e6ba28..059527e7189 100644 --- a/ax/modelbridge/dispatch_utils.py +++ b/ax/modelbridge/dispatch_utils.py @@ -14,7 +14,6 @@ from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.search_space import SearchSpace -from ax.exceptions.core import UnsupportedError from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import MODEL_KEY_TO_MODEL_SETUP, Models from ax.modelbridge.transforms.base import Transform @@ -80,7 +79,6 @@ def _make_botorch_step( disable_progbar: Optional[bool] = None, jit_compile: Optional[bool] = None, derelativize_with_raw_status_quo: bool = False, - use_update: bool = False, ) -> GenerationStep: """Shortcut for creating a BayesOpt generation step.""" model_kwargs = model_kwargs or {} @@ -130,7 +128,6 @@ def _make_botorch_step( min_trials_observed=min_trials_observed or ceil(num_trials / 2), enforce_num_trials=enforce_num_trials, max_parallelism=max_parallelism, - use_update=use_update, # `model_kwargs` should default to `None` if empty model_kwargs=model_kwargs if len(model_kwargs) > 0 else None, should_deduplicate=should_deduplicate, @@ -306,7 +303,6 @@ def choose_generation_strategy( disable_progbar: Optional[bool] = None, jit_compile: Optional[bool] = None, experiment: Optional[Experiment] = None, - use_update: bool = False, ) -> GenerationStrategy: """Select an appropriate generation strategy based on the properties of the search space and expected settings of the experiment, such as number of @@ -402,27 +398,10 @@ def choose_generation_strategy( strategy with a given experiment before it's first used to ``gen`` with that experiment). Can also provide `optimization_config` if it is not provided as an arg to this function. - use_update: Whether to use ``ModelBridge.update`` to update the model with - new data rather than fitting it from scratch. - This changes the behavior of how the model is updated to incorporate the - new data before candidate generation. When ``use_update=False``, we fit a - new model from scratch. When ``use_update=True``, we update the training - data of the model and re-use the hyper-parameters from the previously - fitted model. Depending on the ``refit_model`` flag (defaults to True, - not exposed in this API), we may further train the hyper-parameters while - using the previous values as the starting conditions for optimization. + use_update: DEPRECATED. """ - if experiment is not None: - if optimization_config is None: - optimization_config = experiment.optimization_config - metrics_available_while_running = any( - m.is_available_while_running() for m in experiment.metrics.values() - ) - if metrics_available_while_running and use_update is True: - raise UnsupportedError( - "Got `use_update=True` but the experiment has metrics that are " - "available while running. Set `use_update=False`." - ) + if experiment is not None and optimization_config is None: + optimization_config = experiment.optimization_config suggested_model = _suggest_gp_model( search_space=search_space, @@ -539,7 +518,6 @@ def choose_generation_strategy( verbose=verbose, disable_progbar=disable_progbar, jit_compile=jit_compile, - use_update=use_update, ), ) gs = GenerationStrategy(steps=steps) diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index d649b40a806..95fa965c6fe 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -95,7 +95,6 @@ class GenerationNode: # Optional specifications _model_spec_to_gen_from: Optional[ModelSpec] = None - use_update: bool = False _transition_criteria: Optional[Sequence[TransitionCriterion]] # [TODO] Handle experiment passing more eloquently by enforcing experiment @@ -110,7 +109,6 @@ def __init__( model_specs: List[ModelSpec], best_model_selector: Optional[BestModelSelector] = None, should_deduplicate: bool = False, - use_update: bool = False, transition_criteria: Optional[Sequence[TransitionCriterion]] = None, ) -> None: self._node_name = node_name @@ -122,7 +120,6 @@ def __init__( self.model_specs = model_specs self.best_model_selector = best_model_selector self.should_deduplicate = should_deduplicate - self.use_update = use_update self._transition_criteria = transition_criteria @property @@ -240,10 +237,6 @@ def fit( **kwargs, ) - def update(self, experiment: Experiment, new_data: Data) -> None: - """Updates the specified models on the given experiment + new data.""" - raise NotImplementedError("`update` is not supported yet.") - def gen( self, n: Optional[int] = None, @@ -343,96 +336,6 @@ def _pick_fitted_model_to_gen_from(self) -> ModelSpec: ) return self.model_specs[best_model_index] - def get_data_for_update( - self, passed_in_data: Optional[Data], newly_completed_trials: Set[int] - ) -> Optional[Data]: - """ - Get the data that will be used to update the model. This is used if - `use_update=True` for this GenerationNode. Only the new data since the - last model update / gen call should be used for the update. - - Args: - passed_in_data: An optional data object for fitting the model - for this GenerationNode. When omitted, data will be retrieved - using `experiment.lookup_data`. - newly_completed_trials: Indices of trials that have been completed or - updated with data since the last call to `GenerationStrategy.gen`. - Only the data for these trials are used when updating the model. - - Returns: - Data: Data for updating the fitted model for this GenerationNode. - """ - if len(newly_completed_trials) == 0: - logger.debug( - "There were no newly completed trials since last model update." - ) - return None - - if passed_in_data is None: - new_data = self.experiment.lookup_data(trial_indices=newly_completed_trials) - if new_data.df.empty: - logger.info( - "No new data is attached to experiment; no need for model update." - ) - return None - return new_data - - elif passed_in_data.df.empty: - logger.info("Manually supplied data is empty; no need for model update.") - return None - - return Data( - df=passed_in_data.df.loc[ - passed_in_data.df.trial_index.isin(newly_completed_trials) - ] - ) - - def get_data_for_fit( - self, - passed_in_data: Optional[Data], - ) -> Data: - """ - Fetches data given this generation node configuration, and checks for invalid - data states before returning it. - - Args: - passed_in_data: An optional provided Data object for fitting the model for - this generation node - - Returns: - Data: Data for fitting a model to generate this generation node - """ - if passed_in_data is None: - if self.use_update: - # If this node is using `update`, it's important to instantiate - # the model with data for completed trials only, so later we can - # update it with data for new trials as they become completed. - # `experiment.lookup_data` can lookup all available data, including - # for non-completed trials (depending on how the experiment's metrics - # implement `fetch_experiment_data`). We avoid fetching data for - # trials with statuses other than `COMPLETED`, by fetching specifically - # for `COMPLETED` trials. - avail_while_running_metrics = { - m.name - for m in self.experiment.metrics.values() - if m.is_available_while_running() - } - if avail_while_running_metrics: - raise NotImplementedError( - f"Metrics {avail_while_running_metrics} are available while " - "trial is running, but use of `update` functionality in " - "generation node relies on new data being available upon " - "trial completion." - ) - return self.experiment.lookup_data( - trial_indices=self.experiment.trial_indices_by_status[ - TrialStatus.COMPLETED - ] - ) - else: - return self.experiment.lookup_data() - return passed_in_data - def __repr__(self) -> str: "String representation of this GenerationNode" # add model specs @@ -475,11 +378,7 @@ class GenerationStep(GenerationNode, SortableBase): to `generation_strategy.gen` will fail with a `MaxParallelismReached Exception`, indicating that more trials need to be completed before generating and running next trials. - use_update: Whether to use `model_bridge.update` instead or reinstantiating - model + bridge on every call to `gen` within a single generation step. - NOTE: use of `update` on stateful models that do not implement `_get_state` - may result in inability to correctly resume a generation strategy from - a serialized state. + use_update: DEPRECATED. enforce_num_trials: Whether to enforce that only `num_trials` are generated from the given step. If False and `num_trials` have been generated, but `min_trials_observed` have not been completed, `generation_strategy.gen` @@ -545,6 +444,8 @@ class GenerationStep(GenerationNode, SortableBase): model_name: str = field(default_factory=str) def __post_init__(self) -> None: + if self.use_update: + raise DeprecationWarning("`GenerationStep.use_update` is deprecated.") if ( self.enforce_num_trials and (self.num_trials >= 0) @@ -603,7 +504,6 @@ def __post_init__(self) -> None: node_name=f"GenerationStep_{str(self.index)}", model_specs=[model_spec], should_deduplicate=self.should_deduplicate, - use_update=self.use_update, transition_criteria=transition_criteria, ) diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 1a3441e403b..902211d164d 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -9,10 +9,9 @@ from copy import deepcopy from logging import Logger -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Tuple import pandas as pd -from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun @@ -258,9 +257,7 @@ def gen( will be used to produce the generator run returned from this method. data: Optional data to be passed to the underlying model's `gen`, which is called within this method and actually produces the resulting - generator run. By default, data is all data on the `experiment` if - `use_update` is False and only the new data since the last call to - this method if `use_update` is True. + generator run. By default, data is all data on the `experiment`. n: Integer representing how many arms should be in the generator run produced by this method. NOTE: Some underlying models may ignore the `n` and produce a model-determined number of arms. In that @@ -304,8 +301,12 @@ def current_generator_run_limit( def clone_reset(self) -> GenerationStrategy: """Copy this generation strategy without it's state.""" - # TODO[drfreund]: unset `_generation_strategy`` from steps - return GenerationStrategy(name=self.name, steps=deepcopy(self._steps)) + steps = deepcopy(self._steps) + for s in steps: + # Unset the generation strategy back-pointer, so the steps are not + # associated with any generation strategy. + s._generation_strategy = None + return GenerationStrategy(name=self.name, steps=steps) def _unset_non_persistent_state_fields(self) -> None: """Utility for testing convenience: unset fields of generation strategy @@ -367,9 +368,7 @@ def _gen_multiple( will be used to produce the generator run returned from this method. data: Optional data to be passed to the underlying model's `gen`, which is called within this method and actually produces the resulting - generator run. By default, data is all data on the `experiment` if - `use_update` is False and only the new data since the last call to - this method if `use_update` is True. + generator run. By default, data is all data on the `experiment`. n: Integer representing how many arms should be in the generator run produced by this method. NOTE: Some underlying models may ignore the ``n`` and produce a model-determined number of arms. In that @@ -384,7 +383,7 @@ def _gen_multiple( """ self.experiment = experiment self._maybe_move_to_next_step() - self._fit_or_update_current_model(data=data) + self._fit_current_model(data=data) # Make sure to not make too many generator runs and # exceed maximum allowed paralellism for the step. @@ -435,7 +434,7 @@ def _gen_multiple( # ------------------------- Model selection logic helpers. ------------------------- - def _fit_or_update_current_model(self, data: Optional[Data]) -> None: + def _fit_current_model(self, data: Optional[Data]) -> None: """Fits or update the model on the current generation step (does not move between generation steps). @@ -443,36 +442,19 @@ def _fit_or_update_current_model(self, data: Optional[Data]) -> None: data: Optional ``Data`` to fit or update with; if not specified, generation strategy will obtain the data via ``experiment.lookup_data``. """ - if self._model is not None and self._curr.use_update: - newly_completed_trials = self._find_trials_completed_since_last_gen() - new_data = self._curr.get_data_for_update( - passed_in_data=data, newly_completed_trials=newly_completed_trials - ) - if new_data is not None: - self._update_current_model(new_data=new_data) - else: - processed_data = self._curr.get_data_for_fit( - passed_in_data=data, - ) - self._fit_current_model(data=processed_data) - previous_step_req_observations = ( - self._curr.index > 0 - and self._steps[self._curr.index - 1].min_trials_observed > 0 - ) - # If previous step required observed data, we should raise an error even if - # enough trials were completed. Such an empty data case does indicate an - # invalid state; this check is to improve the experience of detecting and - # debugging the invalid state that led to this. - if processed_data.df.empty and previous_step_req_observations: - raise NoDataError( - f"Observed data is required for GenerationNode {self._curr.index}," - f"(model {self._curr.model_to_gen_from_name}), but fetched data" - "was empty. Something is wrong with experiment setup -- likely " - "metrics do not implement fetching logic (check your metrics) or" - "no data was attached to experiment for completed trials." - ) + data = self.experiment.lookup_data() if data is None else data + # If last generator run's index matches the current step, extract + # model state from last generator run and pass it to the model + # being instantiated in this function. + model_state_on_lgr = self._get_model_state_from_last_generator_run() - self._save_seen_trial_indices() + if not data.df.empty: + trial_indices_in_data = sorted(data.df["trial_index"].unique()) + logger.debug(f"Fitting model with data for trials: {trial_indices_in_data}") + + self._curr.fit(experiment=self.experiment, data=data, **model_state_on_lgr) + self._model = self._curr.model_spec.fitted_model + self._check_previous_required_observation(data=data) def _maybe_move_to_next_step(self, raise_data_required_error: bool = True) -> bool: """Moves this generation strategy to next step if the current step is completed, @@ -505,21 +487,17 @@ def _maybe_move_to_next_step(self, raise_data_required_error: bool = True) -> bo # new step's model will be initialized for the first time, so we don't # try to `update` it but rather initialize with all the data even if # `use_update` is true for the new generation step; this is done in - # `self._fit_or_update_current_model). + # `self._fit_current_model). self._model = None return move_to_next_step - def _fit_current_model(self, data: Data) -> None: - """Instantiate the current model with all available data.""" - # If last generator run's index matches the current step, extract - # model state from last generator run and pass it to the model - # being instantiated in this function. + def _get_model_state_from_last_generator_run(self) -> Dict[str, Any]: lgr = self.last_generator_run # NOTE: This will not be easily compatible with `GenerationNode`; # will likely need to find last generator run per model. Not a problem # for now though as GS only allows `GenerationStep`-s for now. - # Potential solution: store generator runs on `GenerationStep`-s and + # Potential solution: store generator runs on `GenerationNode`-s and # split them per-model there. model_state_on_lgr = {} model_on_curr = self._curr.model @@ -542,71 +520,23 @@ def _fit_current_model(self, data: Data) -> None: generator_run=lgr, model_class=model_cls, ) - else: - logger.warning( - "While model state after last call to `gen` was recorded on the " - "las generator run produced by this generation strategy, it could" - " not be applied because model for this generation step is defined" - f" via factory function: {self._curr.model}. Generation strategies" - " with factory functions do not support reloading from a stored " - "state." - ) - - if not data.df.empty: - trial_indices_in_data = sorted(data.df["trial_index"].unique()) - logger.debug(f"Fitting model with data for trials: {trial_indices_in_data}") - self._curr.fit(experiment=self.experiment, data=data, **model_state_on_lgr) - self._model = self._curr.model_spec.fitted_model + return model_state_on_lgr - def _update_current_model(self, new_data: Data) -> None: - """Update the current model with new data (data for trials that have been - completed since the last call to `GenerationStrategy.gen`). - """ - if self._model is None: # Should not be reachable. - raise ValueError("Cannot update if no model instantiated.") - trial_indices_in_new_data = sorted(new_data.df["trial_index"].unique()) - logger.info(f"Updating model with data for trials: {trial_indices_in_new_data}") - # TODO[drfreund]: Switch to `self._curr.update` once `GenerationNode` supports - not_none(self._model).update(experiment=self.experiment, new_data=new_data) - - # ------------------------- State-tracking helpers. ------------------------- - - def _save_seen_trial_indices(self) -> None: - """Saves Experiment's `trial_indices_by_status` at the time of the model's - last `gen` (so these `trial_indices_by_status` reflect which trials model - has seen the data for). Useful when `use_update=True` for a given - generation step. - """ - self._seen_trial_indices_by_status = deepcopy( - self.experiment.trial_indices_by_status + def _check_previous_required_observation(self, data: Data) -> None: + previous_step_req_observations = ( + self._curr.index > 0 + and self._steps[self._curr.index - 1].min_trials_observed > 0 ) - - def _find_trials_completed_since_last_gen(self) -> Set[int]: - """Retrieves indices of trials that have been completed or updated with data - since the last call to `GenerationStrategy.gen`. - """ - completed_now = self.experiment.trial_indices_by_status[TrialStatus.COMPLETED] - if self._seen_trial_indices_by_status is None: - return completed_now - - completed_before = not_none(self._seen_trial_indices_by_status)[ - TrialStatus.COMPLETED - ] - return completed_now.difference(completed_before) - - def _register_trial_data_update(self, trial: BaseTrial) -> None: - """Registers that a given trial has new data even though it's a trial that has - been completed before. Useful only for generation steps that have `use_update= - True`, as the information registered by this function is used for identifying - new data since last call to `GenerationStrategy.gen`. - """ - # TODO[T65857344]: store information about trial update to pass with `new_data` - # to `model_update`. This information does not need to be stored, since when - # restoring generation strategy from serialized form, all data will is - # refetched and the underlying model is re-fit. - if any(s.use_update for s in self._steps): - raise NotImplementedError( - "Updating completed trials with new data is not yet supported for " - "generation strategies that leverage `model.update` functionality." + # If previous step required observed data, we should raise an error even if + # enough trials were completed. Such an empty data case does indicate an + # invalid state; this check is to improve the experience of detecting and + # debugging the invalid state that led to this. + if data.df.empty and previous_step_req_observations: + raise NoDataError( + f"Observed data is required for generation node {self._curr.node_name}," + f"(model {self._curr.model_to_gen_from_name}), but fetched data was " + "empty. Something is wrong with experiment setup -- likely metrics " + "do not implement fetching logic (check your metrics) or no data " + "was attached to experiment for completed trials." ) diff --git a/ax/modelbridge/tests/test_dispatch_utils.py b/ax/modelbridge/tests/test_dispatch_utils.py index 565a1f347af..e16b74aed05 100644 --- a/ax/modelbridge/tests/test_dispatch_utils.py +++ b/ax/modelbridge/tests/test_dispatch_utils.py @@ -11,7 +11,6 @@ import torch from ax.core.objective import MultiObjective from ax.core.optimization_config import MultiObjectiveOptimizationConfig -from ax.exceptions.core import UnsupportedError from ax.modelbridge.dispatch_utils import ( _make_botorch_step, calculate_num_initialization_trials, @@ -26,7 +25,6 @@ from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import not_none from ax.utils.testing.core_stubs import ( - get_branin_experiment, get_branin_search_space, get_discrete_search_space, get_experiment, @@ -692,33 +690,6 @@ def test_calculate_num_initialization_trials(self) -> None: 5, ) - @fast_botorch_optimize - def test_use_update(self) -> None: - search_space = get_branin_search_space() - # Defaults to False. - gs = choose_generation_strategy(search_space=search_space) - self.assertFalse(gs._steps[1].use_update) - run_branin_experiment_with_generation_strategy(generation_strategy=gs) - # Pass in True. - gs = choose_generation_strategy(search_space=search_space, use_update=True) - self.assertTrue(gs._steps[1].use_update) - with self.assertRaisesRegex( - NotImplementedError, "use of `update` functionality" - ): - run_branin_experiment_with_generation_strategy(generation_strategy=gs) - # Metrics available while running. - experiment = get_branin_experiment() - gs = choose_generation_strategy( - search_space=search_space, experiment=experiment, use_saasbo=True - ) - # Default to False. - self.assertFalse(gs._steps[1].use_update) - # Error with True. - with self.assertRaisesRegex(UnsupportedError, "use_update"): - choose_generation_strategy( - search_space=search_space, experiment=experiment, use_update=True - ) - @fast_botorch_optimize def test_jit_compile(self) -> None: for jit_compile in (True, False): diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index aedf7cef919..f0eb7dac71f 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -367,7 +367,6 @@ def test_sobol_strategy(self) -> None: model=Models.SOBOL, num_trials=5, max_parallelism=10, - use_update=False, enforce_num_trials=False, ) ] @@ -547,71 +546,6 @@ def test_max_parallelism_reached(self) -> None: with self.assertRaises(MaxParallelismReachedException): sobol_generation_strategy.gen(experiment=exp) - @patch(f"{RandomModelBridge.__module__}.RandomModelBridge.update") - @patch(f"{Experiment.__module__}.Experiment.lookup_data") - def test_use_update( - self, mock_lookup_data: MagicMock, mock_update: MagicMock - ) -> None: - exp = get_branin_experiment() - sobol_gs_with_update = GenerationStrategy( - steps=[GenerationStep(model=Models.SOBOL, num_trials=-1, use_update=True)] - ) - sobol_gs_with_update._experiment = exp - self.assertEqual( - sobol_gs_with_update._find_trials_completed_since_last_gen(), - set(), - ) - with self.assertRaises(NotImplementedError): - # `BraninMetric` is available while running by default, which should - # raise an error when use with `use_update=True` on a generation step, as we - # have not yet properly addressed that edge case (for lack of use case). - sobol_gs_with_update.gen(experiment=exp) - - core_stubs_module = get_branin_experiment.__module__ - with patch( - f"{core_stubs_module}.BraninMetric.is_available_while_running", - return_value=False, - ): - # Try without passing data (GS looks up data on experiment). - trial = exp.new_trial( - generator_run=sobol_gs_with_update.gen(experiment=exp) - ) - mock_update.assert_not_called() - trial._status = TrialStatus.COMPLETED - for i in range(3): - gr = sobol_gs_with_update.gen(experiment=exp) - self.assertEqual( - mock_lookup_data.call_args[1].get("trial_indices"), {i} - ) - trial = exp.new_trial(generator_run=gr) - trial._status = TrialStatus.COMPLETED - # `_seen_trial_indices_by_status` is set during `gen`, to the experiment's - # `trial_indices_by_Status` at the time of candidate generation. - self.assertNotEqual( - sobol_gs_with_update._seen_trial_indices_by_status, - exp.trial_indices_by_status, - ) - # Try with passing data. - sobol_gs_with_update.gen( - experiment=exp, data=get_branin_data(trial_indices=range(4)) - ) - # Now `_seen_trial_indices_by_status` should be set to experiment's, - self.assertEqual( - sobol_gs_with_update._seen_trial_indices_by_status, - exp.trial_indices_by_status, - ) - # Only the data for the last completed trial should be considered new and passed - # to `update`. - self.assertEqual( - set(mock_update.call_args[1].get("new_data").df["trial_index"].values), {3} - ) - # Try with passing same data as before; no update should be performed. - with patch.object(sobol_gs_with_update, "_update_current_model") as mock_update: - sobol_gs_with_update.gen( - experiment=exp, data=get_branin_data(trial_indices=range(4)) - ) - mock_update.assert_not_called() - def test_deduplication(self) -> None: tiny_parameters = [ FixedParameter( diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index c93c7331409..81e1c750702 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -818,14 +818,6 @@ def update_trial_data( sample_size=sample_size, combine_with_last_data=True, ) - # Registering trial data update is needed for generation strategies that - # leverage the `update` functionality of model and bridge setup and therefore - # need to be aware of new data added to experiment. Usually this happends - # seamlessly, by looking at newly completed trials, but in this case trial - # status does not change, so we manually register the new data. - # Currently this call will only result in a `NotImplementedError` if generation - # strategy uses `update` (`GenerationStep.use_update` is False by default). - self.generation_strategy._register_trial_data_update(trial=trial) logger.info(f"Added data: {data_update_repr} to trial {trial.index}.") def log_trial_failure( @@ -1273,7 +1265,7 @@ def fit_model(self) -> None: "At least one trial must be completed with data to instantiate " "a model." ) - self.generation_strategy._fit_or_update_current_model(data=None) + self.generation_strategy._fit_current_model(data=None) logger.info("Successfully instantiated a model for the first time.") # Model update is normally tied to the GenerationStrategy.gen() call, @@ -1281,7 +1273,7 @@ def fit_model(self) -> None: # can be performed without the need to call get_next_trial(), we update the # model with all attached data. Note that this method keeps track of previously # seen trials and will update the model if there is newly attached data. - self.generation_strategy._fit_or_update_current_model(data=None) + self.generation_strategy._fit_current_model(data=None) def verify_trial_parameterization( self, trial_index: int, parameterization: TParameterization diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 5366e9f0435..e265474d556 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -1850,6 +1850,6 @@ def get_fitted_model_bridge(scheduler: Scheduler) -> ModelBridge: model_bridge = gs.model # Optional[ModelBridge] if model_bridge is None: # Need to re-fit the model. data = scheduler.experiment.fetch_data() - gs._fit_or_update_current_model(data=data) + gs._fit_current_model(data=data) model_bridge = cast(ModelBridge, gs.model) return model_bridge diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index aba1ac0db28..1eb6ed4b093 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -2339,7 +2339,7 @@ def helper_test_get_pareto_optimal_points( num_trials=20, outcome_constraints=outcome_constraints ) ax_client.generation_strategy._maybe_move_to_next_step() - ax_client.generation_strategy._fit_or_update_current_model( + ax_client.generation_strategy._fit_current_model( data=ax_client.experiment.lookup_data() ) self.assertEqual(ax_client.generation_strategy._curr.model_name, "BoTorch") @@ -2519,7 +2519,7 @@ def test_get_pareto_optimal_points_objective_threshold_inference( num_trials=20, include_objective_thresholds=False ) ax_client.generation_strategy._maybe_move_to_next_step() - ax_client.generation_strategy._fit_or_update_current_model( + ax_client.generation_strategy._fit_current_model( data=ax_client.experiment.lookup_data() ) diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index 6596340ed03..01e8f2121ff 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -1262,8 +1262,8 @@ def test_get_fitted_model_bridge(self) -> None: ): with patch.object( GenerationStrategy, - "_fit_or_update_current_model", - wraps=scheduler.generation_strategy._fit_or_update_current_model, + "_fit_current_model", + wraps=scheduler.generation_strategy._fit_current_model, ) as fit_model: get_fitted_model_bridge(scheduler) fit_model.assert_called_once() diff --git a/ax/service/utils/best_point.py b/ax/service/utils/best_point.py index 46dd02fd1dc..21810f08dda 100644 --- a/ax/service/utils/best_point.py +++ b/ax/service/utils/best_point.py @@ -553,7 +553,7 @@ def get_pareto_optimal_parameters( and checked_cast(TorchModelBridge, modelbridge).is_moo_problem ) if is_moo_modelbridge: - generation_strategy._fit_or_update_current_model(data=None) + generation_strategy._fit_current_model(data=None) else: modelbridge = Models.MOO( experiment=experiment, diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index c267dd74685..2fc89c570fd 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -709,7 +709,6 @@ def generation_step_from_json( if completion_criteria is not None else [], max_parallelism=(generation_step_json.pop("max_parallelism", None)), - use_update=generation_step_json.pop("use_update", False), enforce_num_trials=generation_step_json.pop("enforce_num_trials", True), model_kwargs=_decode_callables_from_references( object_from_json(