Skip to content

Commit

Permalink
Deprecate use_update as discussed (#1895)
Browse files Browse the repository at this point in the history
Summary:

The main use case for `use_update` in generation strategy is a stateful model, where we really don't want to restore the state on each call to `fit` and instead want to expect a call to `update` with only the new data. I think at this point supporting `use_update` just adds lots of complexity to generation strategy, and it'll be great to see that complexity gone (this diff shows it in the deleted code).

Reviewed By: saitcakmak

Differential Revision: D49465655
  • Loading branch information
lena-kashtelyan authored and facebook-github-bot committed Oct 6, 2023
1 parent dd10bee commit 1e96293
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 352 deletions.
28 changes: 3 additions & 25 deletions ax/modelbridge/dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import MODEL_KEY_TO_MODEL_SETUP, Models
from ax.modelbridge.transforms.base import Transform
Expand Down Expand Up @@ -80,7 +79,6 @@ def _make_botorch_step(
disable_progbar: Optional[bool] = None,
jit_compile: Optional[bool] = None,
derelativize_with_raw_status_quo: bool = False,
use_update: bool = False,
) -> GenerationStep:
"""Shortcut for creating a BayesOpt generation step."""
model_kwargs = model_kwargs or {}
Expand Down Expand Up @@ -130,7 +128,6 @@ def _make_botorch_step(
min_trials_observed=min_trials_observed or ceil(num_trials / 2),
enforce_num_trials=enforce_num_trials,
max_parallelism=max_parallelism,
use_update=use_update,
# `model_kwargs` should default to `None` if empty
model_kwargs=model_kwargs if len(model_kwargs) > 0 else None,
should_deduplicate=should_deduplicate,
Expand Down Expand Up @@ -306,7 +303,6 @@ def choose_generation_strategy(
disable_progbar: Optional[bool] = None,
jit_compile: Optional[bool] = None,
experiment: Optional[Experiment] = None,
use_update: bool = False,
) -> GenerationStrategy:
"""Select an appropriate generation strategy based on the properties of
the search space and expected settings of the experiment, such as number of
Expand Down Expand Up @@ -402,27 +398,10 @@ def choose_generation_strategy(
strategy with a given experiment before it's first used to ``gen`` with
that experiment). Can also provide `optimization_config` if it is not
provided as an arg to this function.
use_update: Whether to use ``ModelBridge.update`` to update the model with
new data rather than fitting it from scratch.
This changes the behavior of how the model is updated to incorporate the
new data before candidate generation. When ``use_update=False``, we fit a
new model from scratch. When ``use_update=True``, we update the training
data of the model and re-use the hyper-parameters from the previously
fitted model. Depending on the ``refit_model`` flag (defaults to True,
not exposed in this API), we may further train the hyper-parameters while
using the previous values as the starting conditions for optimization.
use_update: DEPRECATED.
"""
if experiment is not None:
if optimization_config is None:
optimization_config = experiment.optimization_config
metrics_available_while_running = any(
m.is_available_while_running() for m in experiment.metrics.values()
)
if metrics_available_while_running and use_update is True:
raise UnsupportedError(
"Got `use_update=True` but the experiment has metrics that are "
"available while running. Set `use_update=False`."
)
if experiment is not None and optimization_config is None:
optimization_config = experiment.optimization_config

suggested_model = _suggest_gp_model(
search_space=search_space,
Expand Down Expand Up @@ -539,7 +518,6 @@ def choose_generation_strategy(
verbose=verbose,
disable_progbar=disable_progbar,
jit_compile=jit_compile,
use_update=use_update,
),
)
gs = GenerationStrategy(steps=steps)
Expand Down
106 changes: 3 additions & 103 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ class GenerationNode:

# Optional specifications
_model_spec_to_gen_from: Optional[ModelSpec] = None
use_update: bool = False
_transition_criteria: Optional[Sequence[TransitionCriterion]]

# [TODO] Handle experiment passing more eloquently by enforcing experiment
Expand All @@ -110,7 +109,6 @@ def __init__(
model_specs: List[ModelSpec],
best_model_selector: Optional[BestModelSelector] = None,
should_deduplicate: bool = False,
use_update: bool = False,
transition_criteria: Optional[Sequence[TransitionCriterion]] = None,
) -> None:
self._node_name = node_name
Expand All @@ -122,7 +120,6 @@ def __init__(
self.model_specs = model_specs
self.best_model_selector = best_model_selector
self.should_deduplicate = should_deduplicate
self.use_update = use_update
self._transition_criteria = transition_criteria

@property
Expand Down Expand Up @@ -240,10 +237,6 @@ def fit(
**kwargs,
)

def update(self, experiment: Experiment, new_data: Data) -> None:
"""Updates the specified models on the given experiment + new data."""
raise NotImplementedError("`update` is not supported yet.")

def gen(
self,
n: Optional[int] = None,
Expand Down Expand Up @@ -343,96 +336,6 @@ def _pick_fitted_model_to_gen_from(self) -> ModelSpec:
)
return self.model_specs[best_model_index]

def get_data_for_update(
self, passed_in_data: Optional[Data], newly_completed_trials: Set[int]
) -> Optional[Data]:
"""
Get the data that will be used to update the model. This is used if
`use_update=True` for this GenerationNode. Only the new data since the
last model update / gen call should be used for the update.
Args:
passed_in_data: An optional data object for fitting the model
for this GenerationNode. When omitted, data will be retrieved
using `experiment.lookup_data`.
newly_completed_trials: Indices of trials that have been completed or
updated with data since the last call to `GenerationStrategy.gen`.
Only the data for these trials are used when updating the model.
Returns:
Data: Data for updating the fitted model for this GenerationNode.
"""
if len(newly_completed_trials) == 0:
logger.debug(
"There were no newly completed trials since last model update."
)
return None

if passed_in_data is None:
new_data = self.experiment.lookup_data(trial_indices=newly_completed_trials)
if new_data.df.empty:
logger.info(
"No new data is attached to experiment; no need for model update."
)
return None
return new_data

elif passed_in_data.df.empty:
logger.info("Manually supplied data is empty; no need for model update.")
return None

return Data(
df=passed_in_data.df.loc[
passed_in_data.df.trial_index.isin(newly_completed_trials)
]
)

def get_data_for_fit(
self,
passed_in_data: Optional[Data],
) -> Data:
"""
Fetches data given this generation node configuration, and checks for invalid
data states before returning it.
Args:
passed_in_data: An optional provided Data object for fitting the model for
this generation node
Returns:
Data: Data for fitting a model to generate this generation node
"""
if passed_in_data is None:
if self.use_update:
# If this node is using `update`, it's important to instantiate
# the model with data for completed trials only, so later we can
# update it with data for new trials as they become completed.
# `experiment.lookup_data` can lookup all available data, including
# for non-completed trials (depending on how the experiment's metrics
# implement `fetch_experiment_data`). We avoid fetching data for
# trials with statuses other than `COMPLETED`, by fetching specifically
# for `COMPLETED` trials.
avail_while_running_metrics = {
m.name
for m in self.experiment.metrics.values()
if m.is_available_while_running()
}
if avail_while_running_metrics:
raise NotImplementedError(
f"Metrics {avail_while_running_metrics} are available while "
"trial is running, but use of `update` functionality in "
"generation node relies on new data being available upon "
"trial completion."
)
return self.experiment.lookup_data(
trial_indices=self.experiment.trial_indices_by_status[
TrialStatus.COMPLETED
]
)
else:
return self.experiment.lookup_data()
return passed_in_data

def __repr__(self) -> str:
"String representation of this GenerationNode"
# add model specs
Expand Down Expand Up @@ -475,11 +378,7 @@ class GenerationStep(GenerationNode, SortableBase):
to `generation_strategy.gen` will fail with a `MaxParallelismReached
Exception`, indicating that more trials need to be completed before
generating and running next trials.
use_update: Whether to use `model_bridge.update` instead or reinstantiating
model + bridge on every call to `gen` within a single generation step.
NOTE: use of `update` on stateful models that do not implement `_get_state`
may result in inability to correctly resume a generation strategy from
a serialized state.
use_update: DEPRECATED.
enforce_num_trials: Whether to enforce that only `num_trials` are generated
from the given step. If False and `num_trials` have been generated, but
`min_trials_observed` have not been completed, `generation_strategy.gen`
Expand Down Expand Up @@ -545,6 +444,8 @@ class GenerationStep(GenerationNode, SortableBase):
model_name: str = field(default_factory=str)

def __post_init__(self) -> None:
if self.use_update:
raise DeprecationWarning("`GenerationStep.use_update` is deprecated.")
if (
self.enforce_num_trials
and (self.num_trials >= 0)
Expand Down Expand Up @@ -603,7 +504,6 @@ def __post_init__(self) -> None:
node_name=f"GenerationStep_{str(self.index)}",
model_specs=[model_spec],
should_deduplicate=self.should_deduplicate,
use_update=self.use_update,
transition_criteria=transition_criteria,
)

Expand Down
Loading

0 comments on commit 1e96293

Please sign in to comment.