From 3ebc24caa69ccf89e2aafec4f0b5d62a8c73ca7d Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Fri, 17 May 2024 11:37:14 -0700 Subject: [PATCH] Update MBM tutorial with modern, non-deprecated functionality and accurate description of defaults (#2466) Summary: * Use log acquisition functions * Remove references to deprecated fixed-noise acquisition functions (now absorbed into more general acqfs) * Correct inaccuracies in description of MBM defaults * Removed unused imports Reviewed By: mgarrard Differential Revision: D57443923 --- tutorials/modular_botax.ipynb | 1739 ++++++++++++++++----------------- 1 file changed, 869 insertions(+), 870 deletions(-) diff --git a/tutorials/modular_botax.ipynb b/tutorials/modular_botax.ipynb index b9dd8778699..58af1347b8b 100644 --- a/tutorials/modular_botax.ipynb +++ b/tutorials/modular_botax.ipynb @@ -1,872 +1,871 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "about-preview", - "metadata": { - "originalKey": "cca773d8-5e94-4b5a-ae54-22295be8936a" - }, - "outputs": [], - "source": [ - "from typing import Any, Dict, Optional, Tuple, Type\n", - "\n", - "from ax.modelbridge.registry import Cont_X_trans, Models, Y_trans\n", - "\n", - "# Ax data tranformation layer\n", - "from ax.modelbridge.torch import TorchModelBridge\n", - "from ax.models.torch.botorch_modular.acquisition import Acquisition\n", - "\n", - "# Ax wrappers for BoTorch components\n", - "from ax.models.torch.botorch_modular.model import BoTorchModel\n", - "from ax.models.torch.botorch_modular.surrogate import Surrogate\n", - "\n", - "# Experiment examination utilities\n", - "from ax.service.utils.report_utils import exp_to_df\n", - "\n", - "# Test Ax objects\n", - "from ax.utils.testing.core_stubs import (\n", - " get_branin_data,\n", - " get_branin_data_multi_objective,\n", - " get_branin_experiment,\n", - " get_branin_experiment_with_multi_objective,\n", - ")\n", - "from botorch.acquisition.monte_carlo import (\n", - " qExpectedImprovement,\n", - " qNoisyExpectedImprovement,\n", - ")\n", - "from botorch.models.gp_regression import FixedNoiseGP\n", - "\n", - "# BoTorch components\n", - "from botorch.models.model import Model\n", - "from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood" - ] - }, - { - "cell_type": "markdown", - "id": "northern-affairs", - "metadata": { - "originalKey": "58ea5ebf-ff3a-40b4-8be3-1b85c99d1c4a" - }, - "source": [ - "# Setup and Usage of BoTorch Models in Ax\n", - "\n", - "Ax provides a set of flexible wrapper abstractions to mix-and-match BoTorch components like `Model` and `AcquisitionFunction` and combine them into a single `Model` object in Ax. The wrapper abstractions: `Surrogate`, `Acquisition`, and `BoTorchModel` – are located in `ax/models/torch/botorch_modular` directory and aim to encapsulate boilerplate code that interfaces between Ax and BoTorch. This functionality is in beta-release and still evolving.\n", - "\n", - "This tutorial walks through setting up a custom combination of BoTorch components in Ax in following steps:\n", - "\n", - "1. **Quick-start example of `BoTorchModel` use**\n", - "1. **`BoTorchModel` = `Surrogate` + `Acquisition` (overview)**\n", - " 1. Example with minimal options that uses the defaults\n", - " 2. Example showing all possible options\n", - " 3. Surrogate and Acquisition Q&A\n", - "2. **I know which Botorch Model and AcquisitionFunction I'd like to combine in Ax. How do set this up?**\n", - " 1. Making a `Surrogate` from BoTorch `Model`\n", - " 2. Using an arbitrary BoTorch `AcquisitionFunction` in Ax\n", - "3. **Using `Models.BOTORCH_MODULAR`** (convenience wrapper that enables storage and resumability)\n", - "4. **Utilizing `BoTorchModel` in generation strategies** (abstraction that allows to chain models together and use them in Ax Service API etc.)\n", - " 1. Specifying `pending_observations` to avoid the model re-suggesting points that are part of `RUNNING` or `ABANDONED` trials.\n", - "5. **Customizing a `Surrogate` or `Acquisition`** (for cases where existing subcomponent classes are not sufficient)" - ] - }, - { - "cell_type": "markdown", - "id": "pending-support", - "metadata": { - "originalKey": "c06d1b5c-067d-4618-977e-c8269a98bd0a" - }, - "source": [ - "## 1. Quick-start example\n", - "\n", - "Here we set up a `BoTorchModel` with `FixedNoiseGP` with `qNoisyExpectedImprovement`, one of the most popular combinations in Ax:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "parental-sending", - "metadata": { - "originalKey": "72934cf2-4ecf-483a-93bd-4df88b19a7b8" - }, - "outputs": [], - "source": [ - "experiment = get_branin_experiment(with_trial=True)\n", - "data = get_branin_data(trials=[experiment.trials[0]])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "rough-somerset", - "metadata": { - "originalKey": "e571212c-7872-4ebc-b646-8dad8d4266fd" - }, - "outputs": [], - "source": [ - "# `Models` automatically selects a model + model bridge combination.\n", - "# For `BOTORCH_MODULAR`, it will select `BoTorchModel` and `TorchModelBridge`.\n", - "model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n", - " experiment=experiment,\n", - " data=data,\n", - " surrogate=Surrogate(FixedNoiseGP), # Optional, will use default if unspecified\n", - " botorch_acqf_class=qNoisyExpectedImprovement, # Optional, will use default if unspecified\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "hairy-wiring", - "metadata": { - "originalKey": "fba91372-7aa6-456d-a22b-78ab30c26cd8" - }, - "source": [ - "Now we can use this model to generate candidates (`gen`), predict outcome at a point (`predict`), or evaluate acquisition function value at a given point (`evaluate_acquisition_function`)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "consecutive-summary", - "metadata": { - "originalKey": "59582fc6-8089-4320-864e-d98ee271d4f7" - }, - "outputs": [], - "source": [ - "generator_run = model_bridge_with_GPEI.gen(n=1)\n", - "generator_run.arms[0]" - ] - }, - { - "cell_type": "markdown", - "id": "diverse-richards", - "metadata": { - "originalKey": "8cfe0fa9-8cce-4718-ba43-e8a63744d626" - }, - "source": [ - "-----\n", - "Before you read the rest of this tutorial:\n", - "\n", - "- Note that the concept of ‘model’ is Ax is somewhat a misnomer; we use ['model'](https://ax.dev/docs/glossary.html#model) to refer to an optimization setup capable of producing candidate points for optimization (and often capable of being fit to data, with exception for quasi-random generators). See [Models documentation page](https://ax.dev/docs/models.html) for more information.\n", - "- Learn about `ModelBridge` in Ax, as users should rarely be interacting with a `Model` object directly (more about ModelBridge, a data transformation layer in Ax, [here](https://ax.dev/docs/models.html#deeper-dive-organization-of-the-modeling-stack))." - ] - }, - { - "cell_type": "markdown", - "id": "grand-committee", - "metadata": { - "originalKey": "7037fd14-bcfe-44f9-b915-c23915d2bda9" - }, - "source": [ - "## 2. BoTorchModel = Surrogate + Acquisition\n", - "\n", - "A `BoTorchModel` in Ax consists of two main subcomponents: a surrogate model and an acquisition function. A surrogate model is represented as an instance of Ax’s `Surrogate` class, which is a wrapper around BoTorch's `Model` class. The acquisition function is represented as an instance of Ax’s `Acquisition` class, a wrapper around BoTorch's `AcquisitionFunction` class." - ] - }, - { - "cell_type": "markdown", - "id": "thousand-blanket", - "metadata": { - "originalKey": "08b12c6c-14da-4342-95bd-f607a131ce9d" - }, - "source": [ - "### 2A. Example that uses defaults and requires no options\n", - "\n", - "BoTorchModel does not always require surrogate and acquisition specification. If instantiated without one or both components specified, defaults are selected based on properties of experiment and data (see Appendix 2 for auto-selection logic)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "changing-xerox", - "metadata": { - "originalKey": "b1bca702-07b2-4818-b2b9-2107268c383c" - }, - "outputs": [], - "source": [ - "# The surrogate is not specified, so it will be auto-selected\n", - "# during `model.fit`.\n", - "GPEI_model = BoTorchModel(botorch_acqf_class=qExpectedImprovement)\n", - "\n", - "# The acquisition class is not specified, so it will be\n", - "# auto-selected during `model.gen` or `model.evaluate_acquisition`\n", - "GPEI_model = BoTorchModel(surrogate=Surrogate(FixedNoiseGP))\n", - "\n", - "# Both the surrogate and acquisition class will be auto-selected.\n", - "GPEI_model = BoTorchModel()" - ] - }, - { - "cell_type": "markdown", - "id": "lovely-mechanics", - "metadata": { - "originalKey": "5cec0f06-ae2c-47d3-bd95-441c45762e38" - }, - "source": [ - "### 2B. Example with all the options\n", - "Below are the full set of configurable settings of a `BoTorchModel` with their descriptions:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "twenty-greek", - "metadata": { - "originalKey": "25b13c48-edb0-4b3f-ba34-4f4a4176162a" - }, - "outputs": [], - "source": [ - "model = BoTorchModel(\n", - " # Optional `Surrogate` specification to use instead of default\n", - " surrogate=Surrogate(\n", - " # BoTorch `Model` type\n", - " botorch_model_class=FixedNoiseGP,\n", - " # Optional, MLL class with which to optimize model parameters\n", - " mll_class=ExactMarginalLogLikelihood,\n", - " # Optional, dictionary of keyword arguments to underlying\n", - " # BoTorch `Model` constructor\n", - " model_options={},\n", - " ),\n", - " # Optional BoTorch `AcquisitionFunction` to use instead of default\n", - " botorch_acqf_class=qExpectedImprovement,\n", - " # Optional dict of keyword arguments, passed to the input\n", - " # constructor for the given BoTorch `AcquisitionFunction`\n", - " acquisition_options={},\n", - " # Optional Ax `Acquisition` subclass (if the given BoTorch\n", - " # `AcquisitionFunction` requires one, which is rare)\n", - " acquisition_class=None,\n", - " # Less common model settings shown with default values, refer\n", - " # to `BoTorchModel` documentation for detail\n", - " refit_on_cv=False,\n", - " warm_start_refit=True,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "fourth-material", - "metadata": { - "originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e" - }, - "source": [ - "## 2C. `Surrogate` and `Acquisition` Q&A\n", - "\n", - "**Why is the `surrogate` argument expected to be an instance, but `botorch_acqf_class` –– a class?** Because a BoTorch `AcquisitionFunction` object (and therefore its Ax wrapper, `Acquisition`) is ephemeral: it is constructed, immediately used, and destroyed during `BoTorchModel.gen`, so there is no reason to keep around an `Acquisition` instance. A `Surrogate`, on another hand, is kept in memory as long as its parent `BoTorchModel` is.\n", - "\n", - "**How to know when to use specify acquisition_class (and thereby a non-default Acquisition type) instead of just passing in botorch_acqf_class?** In short, custom `Acquisition` subclasses are needed when a given `AcquisitionFunction` in BoTorch needs some non-standard subcomponents or inputs (e.g. a custom BoTorch `MCAcquisitionObjective`). \n", - "\n", - "**Please post any other questions you have to our dedicated issue on Github: https://github.com/facebook/Ax/issues/363.** This functionality is in beta-release and your feedback will be of great help to us!" - ] - }, - { - "cell_type": "markdown", - "id": "violent-course", - "metadata": { - "originalKey": "86018ee5-f7b8-41ae-8e2d-460fe5f0c15b" - }, - "source": [ - "## 3. I know which Botorch `Model` and `AcquisitionFunction` I'd like to combine in Ax. How do set this up?" - ] - }, - { - "cell_type": "markdown", - "id": "unlike-football", - "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "b29a846d-d7bc-4143-8318-10170c9b4298", - "showInput": false - }, - "source": [ - "### 3a. Making a `Surrogate` from BoTorch `Model`:\n", - "Most models should work with base `Surrogate` in Ax, except for BoTorch `ModelListGP`. `ModelListGP` is a special case because its purpose is to combine multiple sub-models into a single `Model` in BoTorch. It is most commonly used for multi-objective and constrained optimization. Whether or not `ModelListGP` is used is determined automatically based on the `Model` class and the data being used via the `ax.models.torch.botorch_modular.utils.use_model_list` function.\n", - "\n", - "If your `Model` is not a `ModelListGP`, the steps to set it up as a `Surrogate` are:\n", - "1. Implement a [`construct_inputs` class method](https://github.com/pytorch/botorch/blob/main/botorch/models/model.py#L143). The purpose of this method is to produce arguments to a particular model from a standardized set of inputs passed to BoTorch `Model`-s from [`Surrogate.construct`](https://github.com/facebook/Ax/blob/main/ax/models/torch/botorch_modular/surrogate.py#L148) in Ax. It should accept training data in form of a `SupervisedDataset` container and optionally other keyword arguments and produce a dictionary of arguments to `__init__` of the `Model`. See [`SingleTaskMultiFidelityGP.construct_inputs`](https://github.com/pytorch/botorch/blob/5b3172f3daa22f6ea2f6f4d1d0a378a9518dcd8d/botorch/models/gp_regression_fidelity.py#L131) for an example.\n", - "2. Pass any additional needed keyword arguments for the `Model` constructor (that cannot be constructed from the training data and other arguments to `construct_inputs`) via `model_options` argument to `Surrogate`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dynamic-university", - "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "6c2ea955-c7a4-42ff-a4d7-f787113d4d53" - }, - "outputs": [], - "source": [ - "from botorch.models.model import Model\n", - "from botorch.utils.datasets import SupervisedDataset\n", - "\n", - "\n", - "class MyModelClass(Model):\n", - "\n", - " ... # Implementation of `MyModelClass`\n", - "\n", - " @classmethod\n", - " def construct_inputs(\n", - " cls, training_data: SupervisedDataset, **kwargs\n", - " ) -> Dict[str, Any]:\n", - " fidelity_features = kwargs.get(\"fidelity_features\")\n", - " if fidelity_features is None:\n", - " raise ValueError(f\"Fidelity features required for {cls.__name__}.\")\n", - "\n", - " return {\n", - " **super().construct_inputs(training_data=training_data, **kwargs),\n", - " \"fidelity_features\": fidelity_features,\n", - " }\n", - "\n", - "\n", - "surrogate = Surrogate(\n", - " botorch_model_class=MyModelClass, # Must implement `construct_inputs`\n", - " # Optional dict of additional keyword arguments to `MyModelClass`\n", - " model_options={},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "otherwise-context", - "metadata": { - "originalKey": "b9072296-956d-4add-b1f6-e7e0415ba65c" - }, - "source": [ - "NOTE: if you run into a case where base `Surrogate` does not work with your BoTorch `Model`, please let us know in this Github issue: https://github.com/facebook/Ax/issues/363, so we can find the right solution and augment this tutorial." - ] - }, - { - "cell_type": "markdown", - "id": "northern-invite", - "metadata": { - "originalKey": "335cabdf-2bf6-48e8-ba0c-1404a8ef47f9" - }, - "source": [ - "### 3B. Using an arbitrary BoTorch `AcquisitionFunction` in Ax" - ] - }, - { - "cell_type": "markdown", - "id": "surrounded-denial", - "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "e3f0c788-2131-4116-9518-4ae7daeb991f", - "showInput": false - }, - "source": [ - "Steps to set up any `AcquisitionFunction` in Ax are:\n", - "1. Define an input constructor function. The purpose of this method is to produce arguments to a acquisition function from a standardized set of inputs passed to BoTorch `AcquisitionFunction`-s from `Acquisition.__init__` in Ax. For example, see [`construct_inputs_qEHVI`](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L477), which creates a fairly complex set of arguments needed by `qExpectedHypervolumeImprovement` –– a popular multi-objective optimization acquisition function offered in Ax and BoTorch. For more examples, see this collection in BoTorch: [botorch/acquisition/input_constructors.py](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py) \n", - " 1. Note that the new input constructor needs to be decorated with `@acqf_input_constructor(AcquisitionFunctionClass)` to register it.\n", - "2. (Optional) If a given `AcquisitionFunction` requires specific options passed to the BoTorch `optimize_acqf`, it's possible to add default optimizer options for a given `AcquisitionFunction` to avoid always manually passing them via `acquisition_options`.\n", - "3. Specify the BoTorch `AcquisitionFunction` class as `botorch_acqf_class` to `BoTorchModel`\n", - "4. (Optional) Pass any additional keyword arguments to acquisition function constructor or to the optimizer function via `acquisition_options` argument to `BoTorchModel`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "interested-search", - "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "6967ce3e-929b-4d9a-8cd1-72bf94f0be3a" - }, - "outputs": [], - "source": [ - "from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse\n", - "from botorch.acquisition.acquisition import AcquisitionFunction\n", - "from botorch.acquisition.input_constructors import MaybeDict, acqf_input_constructor\n", - "from botorch.utils.datasets import SupervisedDataset\n", - "from torch import Tensor\n", - "\n", - "\n", - "class MyAcquisitionFunctionClass(AcquisitionFunction):\n", - " ... # Actual contents of the acquisition function class.\n", - "\n", - "\n", - "# 1. Add input constructor\n", - "@acqf_input_constructor(MyAcquisitionFunctionClass)\n", - "def construct_inputs_my_acqf(\n", - " model: Model,\n", - " training_data: MaybeDict[SupervisedDataset],\n", - " objective_thresholds: Tensor,\n", - " **kwargs: Any,\n", - ") -> Dict[str, Any]:\n", - " pass\n", - "\n", - "\n", - "# 2. Register default optimizer options\n", - "@optimizer_argparse.register(MyAcquisitionFunctionClass)\n", - "def _argparse_my_acqf(\n", - " acqf: MyAcquisitionFunctionClass, sequential: bool = True\n", - ") -> dict:\n", - " return {\n", - " \"sequential\": sequential\n", - " } # default to sequentially optimizing batches of queries\n", - "\n", - "\n", - "# 3-4. Specifying `botorch_acqf_class` and `acquisition_options`\n", - "BoTorchModel(\n", - " botorch_acqf_class=MyAcquisitionFunctionClass,\n", - " acquisition_options={\n", - " \"alpha\": 10**-6,\n", - " # The sub-dict by the key \"optimizer_options\" can be passed\n", - " # to propagate options to `optimize_acqf`, used in\n", - " # `Acquisition.optimize`, to add/override the default\n", - " # optimizer options registered above.\n", - " \"optimizer_options\": {\"sequential\": False},\n", - " },\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "metallic-imaging", - "metadata": { - "originalKey": "29256ab1-f214-4604-a423-4c7b4b36baa0" - }, - "source": [ - "See section 2A for combining the resulting `Surrogate` instance and `Acquisition` type into a `BoTorchModel`. You can also leverage `Models.BOTORCH_MODULAR` for ease of use; more on it in section 4 below or in section 1 quick-start example." - ] - }, - { - "cell_type": "markdown", - "id": "descending-australian", - "metadata": { - "originalKey": "1d15082f-1df7-4cdb-958b-300483eb7808" - }, - "source": [ - "## 4. Using `Models.BOTORCH_MODULAR` \n", - "\n", - "To simplify the instantiation of an Ax ModelBridge and its undelying Model, Ax provides a [`Models` registry enum](https://github.com/facebook/Ax/blob/main/ax/modelbridge/registry.py#L355). When calling entries of that enum (e.g. `Models.BOTORCH_MODULAR(experiment, data)`), the inputs are automatically distributed between a `Model` and a `ModelBridge` for a given setup. A call to a `Model` enum member yields a model bridge with an underlying model, ready for use to generate candidates.\n", - "\n", - "Here we use `Models.BOTORCH_MODULAR` to set up a model with all-default subcomponents:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "attached-border", - "metadata": { - "originalKey": "385b2f30-fd86-4d88-8784-f238ea8a6abb" - }, - "outputs": [], - "source": [ - "model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n", - " experiment=experiment,\n", - " data=data,\n", - ")\n", - "model_bridge_with_GPEI.gen(1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "powerful-gamma", - "metadata": { - "originalKey": "89930a31-e058-434b-b587-181931e247b6" - }, - "outputs": [], - "source": [ - "model_bridge_with_GPEI.model.botorch_acqf_class" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "improved-replication", - "metadata": { - "originalKey": "f9a9cb14-20c3-4e1d-93a3-6a35c281ae01" - }, - "outputs": [], - "source": [ - "model_bridge_with_GPEI.model.surrogate.botorch_model_class" - ] - }, - { - "cell_type": "markdown", - "id": "connected-sheet", - "metadata": { - "originalKey": "8b6a9ddc-d2d2-4cd5-a6a8-820113f78262" - }, - "source": [ - "We can use the same `Models.BOTORCH_MODULAR` to set up a model for multi-objective optimization:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "documentary-jurisdiction", - "metadata": { - "originalKey": "8001de33-d9d9-4888-a5d1-7a59ebeccfd5" - }, - "outputs": [], - "source": [ - "model_bridge_with_EHVI = Models.BOTORCH_MODULAR(\n", - " experiment=get_branin_experiment_with_multi_objective(\n", - " has_objective_thresholds=True, with_batch=True\n", - " ),\n", - " data=get_branin_data_multi_objective(),\n", - ")\n", - "model_bridge_with_EHVI.gen(1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "changed-maintenance", - "metadata": { - "originalKey": "dcfdbecc-4a9a-49ac-ad55-0bc04b2ec566" - }, - "outputs": [], - "source": [ - "model_bridge_with_EHVI.model.botorch_acqf_class" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "operating-shelf", - "metadata": { - "originalKey": "16727a51-337d-4715-bf51-9cb6637a950f" - }, - "outputs": [], - "source": [ - "model_bridge_with_EHVI.model.surrogate.botorch_model_class" - ] - }, - { - "cell_type": "markdown", - "id": "fatal-butterfly", - "metadata": { - "originalKey": "5c64eecc-5ce5-4907-bbcc-5b3cbf4358ae" - }, - "source": [ - "Furthermore, the quick-start example at the top of this tutorial shows how to specify surrogate and acquisition subcomponents to `Models.BOTORCH_MODULAR`. " - ] - }, - { - "cell_type": "markdown", - "id": "hearing-interface", - "metadata": { - "originalKey": "a0163432-f0ca-4582-ad84-16c77c99f20b" - }, - "source": [ - "## 5. Utilizing `BoTorchModel` in generation strategies\n", - "\n", - "Generation strategy is a key concept in Ax, enabling use of Service API (a.k.a. `AxClient`) and many other higher-level abstractions. A `GenerationStrategy` allows to chain multiple models in Ax and thereby automate candidate generation. Refer to the \"Generation Strategy\" tutorial for more detail in generation strategies.\n", - "\n", - "An example generation stategy with the modular `BoTorchModel` would look like this:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "received-registration", - "metadata": { - "originalKey": "f7eabbcf-607c-4bed-9a0e-6ac6e8b04350" - }, - "outputs": [], - "source": [ - "from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy\n", - "from botorch.acquisition import UpperConfidenceBound\n", - "from ax.modelbridge.modelbridge_utils import get_pending_observation_features\n", - "\n", - "gs = GenerationStrategy(\n", - " steps=[\n", - " GenerationStep( # Initialization step\n", - " # Which model to use for this step\n", - " model=Models.SOBOL,\n", - " # How many generator runs (each of which is then made a trial)\n", - " # to produce with this step\n", - " num_trials=5,\n", - " # How many trials generated from this step must be `COMPLETED`\n", - " # before the next one\n", - " min_trials_observed=5,\n", - " ),\n", - " GenerationStep( # BayesOpt step\n", - " model=Models.BOTORCH_MODULAR,\n", - " # No limit on how many generator runs will be produced\n", - " num_trials=-1,\n", - " model_kwargs={ # Kwargs to pass to `BoTorchModel.__init__`\n", - " \"surrogate\": Surrogate(FixedNoiseGP),\n", - " \"botorch_acqf_class\": qNoisyExpectedImprovement,\n", - " },\n", - " ),\n", - " ]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "logical-windsor", - "metadata": { - "originalKey": "212c4543-220e-4605-8f72-5f86cf52f722" - }, - "source": [ - "Set up an experiment and generate 10 trials in it, adding synthetic data to experiment after each one:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "viral-cheese", - "metadata": { - "originalKey": "30cfcdd7-721d-4f89-b851-7a94140dfad6" - }, - "outputs": [], - "source": [ - "experiment = get_branin_experiment(minimize=True)\n", - "\n", - "assert len(experiment.trials) == 0\n", - "experiment.search_space" - ] - }, - { - "cell_type": "markdown", - "id": "incident-newspaper", - "metadata": { - "originalKey": "2807d7ce-8a6b-423c-b5f5-32edba09c78e" - }, - "source": [ - "## 5a. Specifying `pending_observations`\n", - "Note that it's important to **specify pending observations** to the call to `gen` to avoid getting the same points re-suggested. Without `pending_observations` argument, Ax models are not aware of points that should be excluded from generation. Points are considered \"pending\" when they belong to `STAGED`, `RUNNING`, or `ABANDONED` trials (with the latter included so model does not re-suggest points that are considered \"bad\" and should not be re-suggested).\n", - "\n", - "If the call to `get_pending_observation_features` becomes slow in your setup (since it performs data-fetching etc.), you can opt for `get_pending_observation_features_based_on_trial_status` (also from `ax.modelbridge.modelbridge_utils`), but note the limitations of that utility (detailed in its docstring)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "casual-spread", - "metadata": { - "originalKey": "58aafd65-a366-4b66-a1b1-31b207037a2e" - }, - "outputs": [], - "source": [ - "for _ in range(10):\n", - " # Produce a new generator run and attach it to experiment as a trial\n", - " generator_run = gs.gen(\n", - " experiment=experiment,\n", - " n=1,\n", - " pending_observations=get_pending_observation_features(experiment=experiment),\n", - " )\n", - " trial = experiment.new_trial(generator_run)\n", - "\n", - " # Mark the trial as 'RUNNING' so we can mark it 'COMPLETED' later\n", - " trial.mark_running(no_runner_required=True)\n", - "\n", - " # Attach data for the new trial and mark it 'COMPLETED'\n", - " experiment.attach_data(get_branin_data(trials=[trial]))\n", - " trial.mark_completed()\n", - "\n", - " print(f\"Completed trial #{trial.index}, suggested by {generator_run._model_key}.\")" - ] - }, - { - "cell_type": "markdown", - "id": "circular-vermont", - "metadata": { - "originalKey": "9d3b86bf-b691-4315-8b8f-60504b37818c" - }, - "source": [ - "Now we examine the experiment and observe the trials that were added to it and produced by the generation strategy:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "significant-particular", - "metadata": { - "originalKey": "ca12913d-e3fd-4617-a247-e3432665bac1" - }, - "outputs": [], - "source": [ - "exp_to_df(experiment)" - ] - }, - { - "cell_type": "markdown", - "id": "obvious-transparency", - "metadata": { - "originalKey": "c25da720-6d3d-4f16-b878-24f2d2755783" - }, - "source": [ - "## 6. Customizing a `Surrogate` or `Acquisition`\n", - "\n", - "We expect the base `Surrogate` and `Acquisition` classes to work with most BoTorch components, but there could be a case where you would need to subclass one of aforementioned abstractions to handle a given BoTorch component. If you run into a case like this, feel free to open an issue on our [Github issues page](https://github.com/facebook/Ax/issues) –– it would be very useful for us to know \n", - "\n", - "One such example would be a need for a custom `MCAcquisitionObjective` or posterior transform. To subclass `Acquisition` accordingly, one would override the `get_botorch_objective_and_transform` method:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "organizational-balance", - "metadata": { - "originalKey": "e7f8e413-f01e-4f9d-82c1-4912097637af" - }, - "outputs": [], - "source": [ - "from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform\n", - "from botorch.acquisition.risk_measures import RiskMeasureMCObjective\n", - "\n", - "class CustomObjectiveAcquisition(Acquisition):\n", - " def get_botorch_objective_and_transform(\n", - " self,\n", - " botorch_acqf_class: Type[AcquisitionFunction],\n", - " model: Model,\n", - " objective_weights: Tensor,\n", - " objective_thresholds: Optional[Tensor] = None,\n", - " outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None,\n", - " X_observed: Optional[Tensor] = None,\n", - " risk_measure: Optional[RiskMeasureMCObjective] = None,\n", - " ) -> Tuple[Optional[MCAcquisitionObjective], Optional[PosteriorTransform]]:\n", - " ... # Produce the desired `MCAcquisitionObjective` and `PosteriorTransform` instead of the default" - ] - }, - { - "cell_type": "markdown", - "id": "theoretical-horizon", - "metadata": { - "originalKey": "7299f0fc-e19e-4383-99de-ef7a9a987fe9" - }, - "source": [ - "Then to use the new subclass in `BoTorchModel`, just specify `acquisition_class` argument along with `botorch_acqf_class` (to `BoTorchModel` directly or to `Models.BOTORCH_MODULAR`, which just passes the relevant arguments to `BoTorchModel` under the hood, as discussed in section 4):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "approximate-rolling", - "metadata": { - "originalKey": "07fe169a-78de-437e-9857-7c99cc48eedc" - }, - "outputs": [], - "source": [ - "Models.BOTORCH_MODULAR(\n", - " experiment=experiment,\n", - " data=data,\n", - " acquisition_class=CustomObjectiveAcquisition,\n", - " botorch_acqf_class=MyAcquisitionFunctionClass,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "representative-implement", - "metadata": { - "originalKey": "608d5f0d-4528-4aa6-869d-db38fcbfb256" - }, - "source": [ - "To use a custom `Surrogate` subclass, pass the `surrogate` argument of that type:\n", - "```\n", - "Models.BOTORCH_MODULAR(\n", - " experiment=experiment, \n", - " data=data,\n", - " surrogate=CustomSurrogate(botorch_model_class=MyModelClass),\n", - ")\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "framed-intermediate", - "metadata": { - "originalKey": "64f1289e-73c7-4cc5-96ee-5091286a8361" - }, - "source": [ - "------" - ] - }, - { - "cell_type": "markdown", - "id": "metropolitan-feedback", - "metadata": { - "originalKey": "d1e37569-dd0d-4561-b890-2f0097a345e0" - }, - "source": [ - "## Appendix 1: Methods available on `BoTorchModel`\n", - "\n", - "Note that usually all these methods are used through `ModelBridge` –– a convertion and transformation layer that adapts Ax abstractions to inputs required by the given model.\n", - "\n", - "**Core methods on `BoTorchModel`:**\n", - "* `fit` selects a surrogate if needed and fits the surrogate model to data via `Surrogate.fit`,\n", - "* `predict` estimates metric values at a given point via `Surrogate.predict`,\n", - "* `gen` instantiates an acquisition function via `Acquisition.__init__` and optimizes it to generate candidates.\n", - "\n", - "**Other methods on `BoTorchModel`:**\n", - "* `update` updates surrogate model with training data and optionally reoptimizes model parameters via `Surrogate.update`,\n", - "* `cross_validate` re-fits the surrogate model to subset of training data and makes predictions for test data,\n", - "* `evaluate_acquisition_function` instantiates an acquisition function and evaluates it for a given point.\n", - "------\n" - ] - }, - { - "cell_type": "markdown", - "id": "possible-transsexual", - "metadata": { - "originalKey": "b02f928c-57d9-4b2a-b4fe-c6d28d368b12" - }, - "source": [ - "## Appendix 2: Default surrogate models and acquisition functions\n", - "\n", - "By default, the chosen surrogate model will be:\n", - "* if fidelity parameters are present in search space: `FixedNoiseMultiFidelityGP` (if [SEM](https://ax.dev/docs/glossary.html#sem)s are known on observations) and `SingleTaskMultiFidelityGP` (if variance unknown and needs to be inferred),\n", - "* if task parameters are present: a set of `FixedNoiseMultiTaskGP` (if known variance) or `MultiTaskGP` (if unknown variance), wrapped in a `ModelListGP` and each modeling one task,\n", - "* `FixedNoiseGP` (known variance) and `SingleTaskGP` (unknown variance) otherwise.\n", - "\n", - "The chosen acquisition function will be:\n", - "* for multi-objective settings: `qExpectedHypervolumeImprovement`,\n", - "* `qExpectedImprovement` (known variance) and `qNoisyExpectedImprovement` (unknown variance) otherwise.\n", - "----" - ] - }, - { - "cell_type": "markdown", - "id": "continuous-strain", - "metadata": { - "originalKey": "76ae9852-9d21-43d6-bf75-bb087a474dd6" - }, - "source": [ - "## Appendix 3: Handling storage errors that arise from objects that don't have serialization logic in A\n", - "\n", - "Attempting to store a generator run produced via `Models.BOTORCH_MODULAR` instance that included options without serization logic with will produce an error like: `\"Object passed to 'object_to_json' (of type ) is not registered with a corresponding encoder in ENCODER_REGISTRY.\"`" - ] - }, - { - "cell_type": "markdown", - "id": "broadband-voice", - "metadata": { - "originalKey": "6487b68e-b808-4372-b6ba-ab02ce4826bc" - }, - "source": [ - "The two options for handling this error are:\n", - "1. disabling storage of `BoTorchModel`'s options by passing `no_model_options_storage=True` to `Models.BOTORCH_MODULAR(...)` call –– this will prevent model options from being stored on the generator run, so a generator run can be saved but cannot be used to restore the model that produced it,\n", - "2. specifying serialization logic for a given object that needs to occur among the `Model` or `AcquisitionFunction` options. Tutorial for this is in the works, but in the meantime you can [post an issue on the Ax GitHub](https://github.com/facebook/Ax/issues) to get help with this." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "about-preview", + "metadata": { + "originalKey": "cca773d8-5e94-4b5a-ae54-22295be8936a" + }, + "outputs": [], + "source": [ + "from typing import Any, Dict, Optional, Tuple, Type\n", + "\n", + "from ax.modelbridge.registry import Models\n", + "\n", + "# Ax data tranformation layer\n", + "from ax.models.torch.botorch_modular.acquisition import Acquisition\n", + "\n", + "# Ax wrappers for BoTorch components\n", + "from ax.models.torch.botorch_modular.model import BoTorchModel\n", + "from ax.models.torch.botorch_modular.surrogate import Surrogate\n", + "\n", + "# Experiment examination utilities\n", + "from ax.service.utils.report_utils import exp_to_df\n", + "\n", + "# Test Ax objects\n", + "from ax.utils.testing.core_stubs import (\n", + " get_branin_data,\n", + " get_branin_data_multi_objective,\n", + " get_branin_experiment,\n", + " get_branin_experiment_with_multi_objective,\n", + ")\n", + "from botorch.acquisition.logei import (\n", + " qLogExpectedImprovement,\n", + " qLogNoisyExpectedImprovement,\n", + ")\n", + "from botorch.models.gp_regression import SingleTaskGP\n", + "\n", + "# BoTorch components\n", + "from botorch.models.model import Model\n", + "from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood" + ] + }, + { + "cell_type": "markdown", + "id": "northern-affairs", + "metadata": { + "originalKey": "58ea5ebf-ff3a-40b4-8be3-1b85c99d1c4a" + }, + "source": [ + "# Setup and Usage of BoTorch Models in Ax\n", + "\n", + "Ax provides a set of flexible wrapper abstractions to mix-and-match BoTorch components like `Model` and `AcquisitionFunction` and combine them into a single `Model` object in Ax. The wrapper abstractions: `Surrogate`, `Acquisition`, and `BoTorchModel` – are located in `ax/models/torch/botorch_modular` directory and aim to encapsulate boilerplate code that interfaces between Ax and BoTorch. This functionality is in beta-release and still evolving.\n", + "\n", + "This tutorial walks through setting up a custom combination of BoTorch components in Ax in following steps:\n", + "\n", + "1. **Quick-start example of `BoTorchModel` use**\n", + "1. **`BoTorchModel` = `Surrogate` + `Acquisition` (overview)**\n", + " 1. Example with minimal options that uses the defaults\n", + " 2. Example showing all possible options\n", + " 3. Surrogate and Acquisition Q&A\n", + "2. **I know which Botorch Model and AcquisitionFunction I'd like to combine in Ax. How do set this up?**\n", + " 1. Making a `Surrogate` from BoTorch `Model`\n", + " 2. Using an arbitrary BoTorch `AcquisitionFunction` in Ax\n", + "3. **Using `Models.BOTORCH_MODULAR`** (convenience wrapper that enables storage and resumability)\n", + "4. **Utilizing `BoTorchModel` in generation strategies** (abstraction that allows to chain models together and use them in Ax Service API etc.)\n", + " 1. Specifying `pending_observations` to avoid the model re-suggesting points that are part of `RUNNING` or `ABANDONED` trials.\n", + "5. **Customizing a `Surrogate` or `Acquisition`** (for cases where existing subcomponent classes are not sufficient)" + ] + }, + { + "cell_type": "markdown", + "id": "pending-support", + "metadata": { + "originalKey": "c06d1b5c-067d-4618-977e-c8269a98bd0a" + }, + "source": [ + "## 1. Quick-start example\n", + "\n", + "Here we set up a `BoTorchModel` with `SingleTaskGP` with `qLogNoisyExpectedImprovement`, one of the most popular combinations in Ax:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "parental-sending", + "metadata": { + "originalKey": "72934cf2-4ecf-483a-93bd-4df88b19a7b8" + }, + "outputs": [], + "source": [ + "experiment = get_branin_experiment(with_trial=True)\n", + "data = get_branin_data(trials=[experiment.trials[0]])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "rough-somerset", + "metadata": { + "originalKey": "e571212c-7872-4ebc-b646-8dad8d4266fd" + }, + "outputs": [], + "source": [ + "# `Models` automatically selects a model + model bridge combination.\n", + "# For `BOTORCH_MODULAR`, it will select `BoTorchModel` and `TorchModelBridge`.\n", + "model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n", + " experiment=experiment,\n", + " data=data,\n", + " surrogate=Surrogate(SingleTaskGP), # Optional, will use default if unspecified\n", + " botorch_acqf_class=qLogNoisyExpectedImprovement, # Optional, will use default if unspecified\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "hairy-wiring", + "metadata": { + "originalKey": "fba91372-7aa6-456d-a22b-78ab30c26cd8" + }, + "source": [ + "Now we can use this model to generate candidates (`gen`), predict outcome at a point (`predict`), or evaluate acquisition function value at a given point (`evaluate_acquisition_function`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "consecutive-summary", + "metadata": { + "originalKey": "59582fc6-8089-4320-864e-d98ee271d4f7" + }, + "outputs": [], + "source": [ + "generator_run = model_bridge_with_GPEI.gen(n=1)\n", + "generator_run.arms[0]" + ] + }, + { + "cell_type": "markdown", + "id": "diverse-richards", + "metadata": { + "originalKey": "8cfe0fa9-8cce-4718-ba43-e8a63744d626" + }, + "source": [ + "-----\n", + "Before you read the rest of this tutorial:\n", + "\n", + "- Note that the concept of ‘model’ is Ax is somewhat a misnomer; we use ['model'](https://ax.dev/docs/glossary.html#model) to refer to an optimization setup capable of producing candidate points for optimization (and often capable of being fit to data, with exception for quasi-random generators). See [Models documentation page](https://ax.dev/docs/models.html) for more information.\n", + "- Learn about `ModelBridge` in Ax, as users should rarely be interacting with a `Model` object directly (more about ModelBridge, a data transformation layer in Ax, [here](https://ax.dev/docs/models.html#deeper-dive-organization-of-the-modeling-stack))." + ] + }, + { + "cell_type": "markdown", + "id": "grand-committee", + "metadata": { + "originalKey": "7037fd14-bcfe-44f9-b915-c23915d2bda9" + }, + "source": [ + "## 2. BoTorchModel = Surrogate + Acquisition\n", + "\n", + "A `BoTorchModel` in Ax consists of two main subcomponents: a surrogate model and an acquisition function. A surrogate model is represented as an instance of Ax’s `Surrogate` class, which is a wrapper around BoTorch's `Model` class. The acquisition function is represented as an instance of Ax’s `Acquisition` class, a wrapper around BoTorch's `AcquisitionFunction` class." + ] + }, + { + "cell_type": "markdown", + "id": "thousand-blanket", + "metadata": { + "originalKey": "08b12c6c-14da-4342-95bd-f607a131ce9d" + }, + "source": [ + "### 2A. Example that uses defaults and requires no options\n", + "\n", + "BoTorchModel does not always require surrogate and acquisition specification. If instantiated without one or both components specified, defaults are selected based on properties of experiment and data (see Appendix 2 for auto-selection logic)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "changing-xerox", + "metadata": { + "originalKey": "b1bca702-07b2-4818-b2b9-2107268c383c" + }, + "outputs": [], + "source": [ + "# The surrogate is not specified, so it will be auto-selected\n", + "# during `model.fit`.\n", + "GPEI_model = BoTorchModel(botorch_acqf_class=qLogExpectedImprovement)\n", + "\n", + "# The acquisition class is not specified, so it will be\n", + "# auto-selected during `model.gen` or `model.evaluate_acquisition`\n", + "GPEI_model = BoTorchModel(surrogate=Surrogate(SingleTaskGP))\n", + "\n", + "# Both the surrogate and acquisition class will be auto-selected.\n", + "GPEI_model = BoTorchModel()" + ] + }, + { + "cell_type": "markdown", + "id": "lovely-mechanics", + "metadata": { + "originalKey": "5cec0f06-ae2c-47d3-bd95-441c45762e38" + }, + "source": [ + "### 2B. Example with all the options\n", + "Below are the full set of configurable settings of a `BoTorchModel` with their descriptions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "twenty-greek", + "metadata": { + "originalKey": "25b13c48-edb0-4b3f-ba34-4f4a4176162a" + }, + "outputs": [], + "source": [ + "model = BoTorchModel(\n", + " # Optional `Surrogate` specification to use instead of default\n", + " surrogate=Surrogate(\n", + " # BoTorch `Model` type\n", + " botorch_model_class=SingleTaskGP,\n", + " # Optional, MLL class with which to optimize model parameters\n", + " mll_class=ExactMarginalLogLikelihood,\n", + " # Optional, dictionary of keyword arguments to underlying\n", + " # BoTorch `Model` constructor\n", + " model_options={},\n", + " ),\n", + " # Optional BoTorch `AcquisitionFunction` to use instead of default\n", + " botorch_acqf_class=qLogExpectedImprovement,\n", + " # Optional dict of keyword arguments, passed to the input\n", + " # constructor for the given BoTorch `AcquisitionFunction`\n", + " acquisition_options={},\n", + " # Optional Ax `Acquisition` subclass (if the given BoTorch\n", + " # `AcquisitionFunction` requires one, which is rare)\n", + " acquisition_class=None,\n", + " # Less common model settings shown with default values, refer\n", + " # to `BoTorchModel` documentation for detail\n", + " refit_on_cv=False,\n", + " warm_start_refit=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "fourth-material", + "metadata": { + "originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e" + }, + "source": [ + "## 2C. `Surrogate` and `Acquisition` Q&A\n", + "\n", + "**Why is the `surrogate` argument expected to be an instance, but `botorch_acqf_class` –– a class?** Because a BoTorch `AcquisitionFunction` object (and therefore its Ax wrapper, `Acquisition`) is ephemeral: it is constructed, immediately used, and destroyed during `BoTorchModel.gen`, so there is no reason to keep around an `Acquisition` instance. A `Surrogate`, on another hand, is kept in memory as long as its parent `BoTorchModel` is.\n", + "\n", + "**How to know when to use specify acquisition_class (and thereby a non-default Acquisition type) instead of just passing in botorch_acqf_class?** In short, custom `Acquisition` subclasses are needed when a given `AcquisitionFunction` in BoTorch needs some non-standard subcomponents or inputs (e.g. a custom BoTorch `MCAcquisitionObjective`). \n", + "\n", + "**Please post any other questions you have to our dedicated issue on Github: https://github.com/facebook/Ax/issues/363.** This functionality is in beta-release and your feedback will be of great help to us!" + ] + }, + { + "cell_type": "markdown", + "id": "violent-course", + "metadata": { + "originalKey": "86018ee5-f7b8-41ae-8e2d-460fe5f0c15b" + }, + "source": [ + "## 3. I know which Botorch `Model` and `AcquisitionFunction` I'd like to combine in Ax. How do set this up?" + ] + }, + { + "cell_type": "markdown", + "id": "unlike-football", + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "b29a846d-d7bc-4143-8318-10170c9b4298", + "showInput": false + }, + "source": [ + "### 3a. Making a `Surrogate` from BoTorch `Model`:\n", + "Most models should work with base `Surrogate` in Ax, except for BoTorch `ModelListGP`. `ModelListGP` is a special case because its purpose is to combine multiple sub-models into a single `Model` in BoTorch. It is most commonly used for multi-objective and constrained optimization. Whether or not `ModelListGP` is used is determined automatically based on the `Model` class and the data being used via the `ax.models.torch.botorch_modular.utils.use_model_list` function.\n", + "\n", + "If your `Model` is not a `ModelListGP`, the steps to set it up as a `Surrogate` are:\n", + "1. Implement a [`construct_inputs` class method](https://github.com/pytorch/botorch/blob/main/botorch/models/model.py#L143). The purpose of this method is to produce arguments to a particular model from a standardized set of inputs passed to BoTorch `Model`-s from [`Surrogate.construct`](https://github.com/facebook/Ax/blob/main/ax/models/torch/botorch_modular/surrogate.py#L148) in Ax. It should accept training data in form of a `SupervisedDataset` container and optionally other keyword arguments and produce a dictionary of arguments to `__init__` of the `Model`. See [`SingleTaskMultiFidelityGP.construct_inputs`](https://github.com/pytorch/botorch/blob/5b3172f3daa22f6ea2f6f4d1d0a378a9518dcd8d/botorch/models/gp_regression_fidelity.py#L131) for an example.\n", + "2. Pass any additional needed keyword arguments for the `Model` constructor (that cannot be constructed from the training data and other arguments to `construct_inputs`) via `model_options` argument to `Surrogate`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dynamic-university", + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "6c2ea955-c7a4-42ff-a4d7-f787113d4d53" + }, + "outputs": [], + "source": [ + "from botorch.models.model import Model\n", + "from botorch.utils.datasets import SupervisedDataset\n", + "\n", + "\n", + "class MyModelClass(Model):\n", + "\n", + " ... # Implementation of `MyModelClass`\n", + "\n", + " @classmethod\n", + " def construct_inputs(\n", + " cls, training_data: SupervisedDataset, **kwargs\n", + " ) -> Dict[str, Any]:\n", + " fidelity_features = kwargs.get(\"fidelity_features\")\n", + " if fidelity_features is None:\n", + " raise ValueError(f\"Fidelity features required for {cls.__name__}.\")\n", + "\n", + " return {\n", + " **super().construct_inputs(training_data=training_data, **kwargs),\n", + " \"fidelity_features\": fidelity_features,\n", + " }\n", + "\n", + "\n", + "surrogate = Surrogate(\n", + " botorch_model_class=MyModelClass, # Must implement `construct_inputs`\n", + " # Optional dict of additional keyword arguments to `MyModelClass`\n", + " model_options={},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "otherwise-context", + "metadata": { + "originalKey": "b9072296-956d-4add-b1f6-e7e0415ba65c" + }, + "source": [ + "NOTE: if you run into a case where base `Surrogate` does not work with your BoTorch `Model`, please let us know in this Github issue: https://github.com/facebook/Ax/issues/363, so we can find the right solution and augment this tutorial." + ] + }, + { + "cell_type": "markdown", + "id": "northern-invite", + "metadata": { + "originalKey": "335cabdf-2bf6-48e8-ba0c-1404a8ef47f9" + }, + "source": [ + "### 3B. Using an arbitrary BoTorch `AcquisitionFunction` in Ax" + ] + }, + { + "cell_type": "markdown", + "id": "surrounded-denial", + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "e3f0c788-2131-4116-9518-4ae7daeb991f", + "showInput": false + }, + "source": [ + "Steps to set up any `AcquisitionFunction` in Ax are:\n", + "1. Define an input constructor function. The purpose of this method is to produce arguments to a acquisition function from a standardized set of inputs passed to BoTorch `AcquisitionFunction`-s from `Acquisition.__init__` in Ax. For example, see [`construct_inputs_qEHVI`](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L477), which creates a fairly complex set of arguments needed by `qExpectedHypervolumeImprovement` –– a popular multi-objective optimization acquisition function offered in Ax and BoTorch. For more examples, see this collection in BoTorch: [botorch/acquisition/input_constructors.py](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py) \n", + " 1. Note that the new input constructor needs to be decorated with `@acqf_input_constructor(AcquisitionFunctionClass)` to register it.\n", + "2. (Optional) If a given `AcquisitionFunction` requires specific options passed to the BoTorch `optimize_acqf`, it's possible to add default optimizer options for a given `AcquisitionFunction` to avoid always manually passing them via `acquisition_options`.\n", + "3. Specify the BoTorch `AcquisitionFunction` class as `botorch_acqf_class` to `BoTorchModel`\n", + "4. (Optional) Pass any additional keyword arguments to acquisition function constructor or to the optimizer function via `acquisition_options` argument to `BoTorchModel`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "interested-search", + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "6967ce3e-929b-4d9a-8cd1-72bf94f0be3a" + }, + "outputs": [], + "source": [ + "from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse\n", + "from botorch.acquisition.acquisition import AcquisitionFunction\n", + "from botorch.acquisition.input_constructors import acqf_input_constructor, MaybeDict\n", + "from botorch.utils.datasets import SupervisedDataset\n", + "from torch import Tensor\n", + "\n", + "\n", + "class MyAcquisitionFunctionClass(AcquisitionFunction):\n", + " ... # Actual contents of the acquisition function class.\n", + "\n", + "\n", + "# 1. Add input constructor\n", + "@acqf_input_constructor(MyAcquisitionFunctionClass)\n", + "def construct_inputs_my_acqf(\n", + " model: Model,\n", + " training_data: MaybeDict[SupervisedDataset],\n", + " objective_thresholds: Tensor,\n", + " **kwargs: Any,\n", + ") -> Dict[str, Any]:\n", + " pass\n", + "\n", + "\n", + "# 2. Register default optimizer options\n", + "@optimizer_argparse.register(MyAcquisitionFunctionClass)\n", + "def _argparse_my_acqf(\n", + " acqf: MyAcquisitionFunctionClass, sequential: bool = True\n", + ") -> dict:\n", + " return {\n", + " \"sequential\": sequential\n", + " } # default to sequentially optimizing batches of queries\n", + "\n", + "\n", + "# 3-4. Specifying `botorch_acqf_class` and `acquisition_options`\n", + "BoTorchModel(\n", + " botorch_acqf_class=MyAcquisitionFunctionClass,\n", + " acquisition_options={\n", + " \"alpha\": 10**-6,\n", + " # The sub-dict by the key \"optimizer_options\" can be passed\n", + " # to propagate options to `optimize_acqf`, used in\n", + " # `Acquisition.optimize`, to add/override the default\n", + " # optimizer options registered above.\n", + " \"optimizer_options\": {\"sequential\": False},\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "metallic-imaging", + "metadata": { + "originalKey": "29256ab1-f214-4604-a423-4c7b4b36baa0" + }, + "source": [ + "See section 2A for combining the resulting `Surrogate` instance and `Acquisition` type into a `BoTorchModel`. You can also leverage `Models.BOTORCH_MODULAR` for ease of use; more on it in section 4 below or in section 1 quick-start example." + ] + }, + { + "cell_type": "markdown", + "id": "descending-australian", + "metadata": { + "originalKey": "1d15082f-1df7-4cdb-958b-300483eb7808" + }, + "source": [ + "## 4. Using `Models.BOTORCH_MODULAR` \n", + "\n", + "To simplify the instantiation of an Ax ModelBridge and its undelying Model, Ax provides a [`Models` registry enum](https://github.com/facebook/Ax/blob/main/ax/modelbridge/registry.py#L355). When calling entries of that enum (e.g. `Models.BOTORCH_MODULAR(experiment, data)`), the inputs are automatically distributed between a `Model` and a `ModelBridge` for a given setup. A call to a `Model` enum member yields a model bridge with an underlying model, ready for use to generate candidates.\n", + "\n", + "Here we use `Models.BOTORCH_MODULAR` to set up a model with all-default subcomponents:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "attached-border", + "metadata": { + "originalKey": "385b2f30-fd86-4d88-8784-f238ea8a6abb" + }, + "outputs": [], + "source": [ + "model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n", + " experiment=experiment,\n", + " data=data,\n", + ")\n", + "model_bridge_with_GPEI.gen(1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "powerful-gamma", + "metadata": { + "originalKey": "89930a31-e058-434b-b587-181931e247b6" + }, + "outputs": [], + "source": [ + "model_bridge_with_GPEI.model.botorch_acqf_class" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "improved-replication", + "metadata": { + "originalKey": "f9a9cb14-20c3-4e1d-93a3-6a35c281ae01" + }, + "outputs": [], + "source": [ + "model_bridge_with_GPEI.model.surrogate.botorch_model_class" + ] + }, + { + "cell_type": "markdown", + "id": "connected-sheet", + "metadata": { + "originalKey": "8b6a9ddc-d2d2-4cd5-a6a8-820113f78262" + }, + "source": [ + "We can use the same `Models.BOTORCH_MODULAR` to set up a model for multi-objective optimization:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "documentary-jurisdiction", + "metadata": { + "originalKey": "8001de33-d9d9-4888-a5d1-7a59ebeccfd5" + }, + "outputs": [], + "source": [ + "model_bridge_with_EHVI = Models.BOTORCH_MODULAR(\n", + " experiment=get_branin_experiment_with_multi_objective(\n", + " has_objective_thresholds=True, with_batch=True\n", + " ),\n", + " data=get_branin_data_multi_objective(),\n", + ")\n", + "model_bridge_with_EHVI.gen(1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "changed-maintenance", + "metadata": { + "originalKey": "dcfdbecc-4a9a-49ac-ad55-0bc04b2ec566" + }, + "outputs": [], + "source": [ + "model_bridge_with_EHVI.model.botorch_acqf_class" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "operating-shelf", + "metadata": { + "originalKey": "16727a51-337d-4715-bf51-9cb6637a950f" + }, + "outputs": [], + "source": [ + "model_bridge_with_EHVI.model.surrogate.botorch_model_class" + ] + }, + { + "cell_type": "markdown", + "id": "fatal-butterfly", + "metadata": { + "originalKey": "5c64eecc-5ce5-4907-bbcc-5b3cbf4358ae" + }, + "source": [ + "Furthermore, the quick-start example at the top of this tutorial shows how to specify surrogate and acquisition subcomponents to `Models.BOTORCH_MODULAR`. " + ] + }, + { + "cell_type": "markdown", + "id": "hearing-interface", + "metadata": { + "originalKey": "a0163432-f0ca-4582-ad84-16c77c99f20b" + }, + "source": [ + "## 5. Utilizing `BoTorchModel` in generation strategies\n", + "\n", + "Generation strategy is a key concept in Ax, enabling use of Service API (a.k.a. `AxClient`) and many other higher-level abstractions. A `GenerationStrategy` allows to chain multiple models in Ax and thereby automate candidate generation. Refer to the \"Generation Strategy\" tutorial for more detail in generation strategies.\n", + "\n", + "An example generation stategy with the modular `BoTorchModel` would look like this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "received-registration", + "metadata": { + "originalKey": "f7eabbcf-607c-4bed-9a0e-6ac6e8b04350" + }, + "outputs": [], + "source": [ + "from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy\n", + "from ax.modelbridge.modelbridge_utils import get_pending_observation_features\n", + "\n", + "gs = GenerationStrategy(\n", + " steps=[\n", + " GenerationStep( # Initialization step\n", + " # Which model to use for this step\n", + " model=Models.SOBOL,\n", + " # How many generator runs (each of which is then made a trial)\n", + " # to produce with this step\n", + " num_trials=5,\n", + " # How many trials generated from this step must be `COMPLETED`\n", + " # before the next one\n", + " min_trials_observed=5,\n", + " ),\n", + " GenerationStep( # BayesOpt step\n", + " model=Models.BOTORCH_MODULAR,\n", + " # No limit on how many generator runs will be produced\n", + " num_trials=-1,\n", + " model_kwargs={ # Kwargs to pass to `BoTorchModel.__init__`\n", + " \"surrogate\": Surrogate(SingleTaskGP),\n", + " \"botorch_acqf_class\": qLogNoisyExpectedImprovement,\n", + " },\n", + " ),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "logical-windsor", + "metadata": { + "originalKey": "212c4543-220e-4605-8f72-5f86cf52f722" + }, + "source": [ + "Set up an experiment and generate 10 trials in it, adding synthetic data to experiment after each one:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "viral-cheese", + "metadata": { + "originalKey": "30cfcdd7-721d-4f89-b851-7a94140dfad6" + }, + "outputs": [], + "source": [ + "experiment = get_branin_experiment(minimize=True)\n", + "\n", + "assert len(experiment.trials) == 0\n", + "experiment.search_space" + ] + }, + { + "cell_type": "markdown", + "id": "incident-newspaper", + "metadata": { + "originalKey": "2807d7ce-8a6b-423c-b5f5-32edba09c78e" + }, + "source": [ + "## 5a. Specifying `pending_observations`\n", + "Note that it's important to **specify pending observations** to the call to `gen` to avoid getting the same points re-suggested. Without `pending_observations` argument, Ax models are not aware of points that should be excluded from generation. Points are considered \"pending\" when they belong to `STAGED`, `RUNNING`, or `ABANDONED` trials (with the latter included so model does not re-suggest points that are considered \"bad\" and should not be re-suggested).\n", + "\n", + "If the call to `get_pending_observation_features` becomes slow in your setup (since it performs data-fetching etc.), you can opt for `get_pending_observation_features_based_on_trial_status` (also from `ax.modelbridge.modelbridge_utils`), but note the limitations of that utility (detailed in its docstring)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "casual-spread", + "metadata": { + "originalKey": "58aafd65-a366-4b66-a1b1-31b207037a2e" + }, + "outputs": [], + "source": [ + "for _ in range(10):\n", + " # Produce a new generator run and attach it to experiment as a trial\n", + " generator_run = gs.gen(\n", + " experiment=experiment,\n", + " n=1,\n", + " pending_observations=get_pending_observation_features(experiment=experiment),\n", + " )\n", + " trial = experiment.new_trial(generator_run)\n", + "\n", + " # Mark the trial as 'RUNNING' so we can mark it 'COMPLETED' later\n", + " trial.mark_running(no_runner_required=True)\n", + "\n", + " # Attach data for the new trial and mark it 'COMPLETED'\n", + " experiment.attach_data(get_branin_data(trials=[trial]))\n", + " trial.mark_completed()\n", + "\n", + " print(f\"Completed trial #{trial.index}, suggested by {generator_run._model_key}.\")" + ] + }, + { + "cell_type": "markdown", + "id": "circular-vermont", + "metadata": { + "originalKey": "9d3b86bf-b691-4315-8b8f-60504b37818c" + }, + "source": [ + "Now we examine the experiment and observe the trials that were added to it and produced by the generation strategy:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "significant-particular", + "metadata": { + "originalKey": "ca12913d-e3fd-4617-a247-e3432665bac1" + }, + "outputs": [], + "source": [ + "exp_to_df(experiment)" + ] + }, + { + "cell_type": "markdown", + "id": "obvious-transparency", + "metadata": { + "originalKey": "c25da720-6d3d-4f16-b878-24f2d2755783" + }, + "source": [ + "## 6. Customizing a `Surrogate` or `Acquisition`\n", + "\n", + "We expect the base `Surrogate` and `Acquisition` classes to work with most BoTorch components, but there could be a case where you would need to subclass one of aforementioned abstractions to handle a given BoTorch component. If you run into a case like this, feel free to open an issue on our [Github issues page](https://github.com/facebook/Ax/issues) –– it would be very useful for us to know \n", + "\n", + "One such example would be a need for a custom `MCAcquisitionObjective` or posterior transform. To subclass `Acquisition` accordingly, one would override the `get_botorch_objective_and_transform` method:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "organizational-balance", + "metadata": { + "originalKey": "e7f8e413-f01e-4f9d-82c1-4912097637af" + }, + "outputs": [], + "source": [ + "from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform\n", + "from botorch.acquisition.risk_measures import RiskMeasureMCObjective\n", + "\n", + "\n", + "class CustomObjectiveAcquisition(Acquisition):\n", + " def get_botorch_objective_and_transform(\n", + " self,\n", + " botorch_acqf_class: Type[AcquisitionFunction],\n", + " model: Model,\n", + " objective_weights: Tensor,\n", + " objective_thresholds: Optional[Tensor] = None,\n", + " outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None,\n", + " X_observed: Optional[Tensor] = None,\n", + " risk_measure: Optional[RiskMeasureMCObjective] = None,\n", + " ) -> Tuple[Optional[MCAcquisitionObjective], Optional[PosteriorTransform]]:\n", + " ... # Produce the desired `MCAcquisitionObjective` and `PosteriorTransform` instead of the default" + ] + }, + { + "cell_type": "markdown", + "id": "theoretical-horizon", + "metadata": { + "originalKey": "7299f0fc-e19e-4383-99de-ef7a9a987fe9" + }, + "source": [ + "Then to use the new subclass in `BoTorchModel`, just specify `acquisition_class` argument along with `botorch_acqf_class` (to `BoTorchModel` directly or to `Models.BOTORCH_MODULAR`, which just passes the relevant arguments to `BoTorchModel` under the hood, as discussed in section 4):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "approximate-rolling", + "metadata": { + "originalKey": "07fe169a-78de-437e-9857-7c99cc48eedc" + }, + "outputs": [], + "source": [ + "Models.BOTORCH_MODULAR(\n", + " experiment=experiment,\n", + " data=data,\n", + " acquisition_class=CustomObjectiveAcquisition,\n", + " botorch_acqf_class=MyAcquisitionFunctionClass,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "representative-implement", + "metadata": { + "originalKey": "608d5f0d-4528-4aa6-869d-db38fcbfb256" + }, + "source": [ + "To use a custom `Surrogate` subclass, pass the `surrogate` argument of that type:\n", + "```\n", + "Models.BOTORCH_MODULAR(\n", + " experiment=experiment, \n", + " data=data,\n", + " surrogate=CustomSurrogate(botorch_model_class=MyModelClass),\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "framed-intermediate", + "metadata": { + "originalKey": "64f1289e-73c7-4cc5-96ee-5091286a8361" + }, + "source": [ + "------" + ] + }, + { + "cell_type": "markdown", + "id": "metropolitan-feedback", + "metadata": { + "originalKey": "d1e37569-dd0d-4561-b890-2f0097a345e0" + }, + "source": [ + "## Appendix 1: Methods available on `BoTorchModel`\n", + "\n", + "Note that usually all these methods are used through `ModelBridge` –– a convertion and transformation layer that adapts Ax abstractions to inputs required by the given model.\n", + "\n", + "**Core methods on `BoTorchModel`:**\n", + "* `fit` selects a surrogate if needed and fits the surrogate model to data via `Surrogate.fit`,\n", + "* `predict` estimates metric values at a given point via `Surrogate.predict`,\n", + "* `gen` instantiates an acquisition function via `Acquisition.__init__` and optimizes it to generate candidates.\n", + "\n", + "**Other methods on `BoTorchModel`:**\n", + "* `update` updates surrogate model with training data and optionally reoptimizes model parameters via `Surrogate.update`,\n", + "* `cross_validate` re-fits the surrogate model to subset of training data and makes predictions for test data,\n", + "* `evaluate_acquisition_function` instantiates an acquisition function and evaluates it for a given point.\n", + "------\n" + ] + }, + { + "cell_type": "markdown", + "id": "possible-transsexual", + "metadata": { + "originalKey": "b02f928c-57d9-4b2a-b4fe-c6d28d368b12" + }, + "source": [ + "## Appendix 2: Default surrogate models and acquisition functions\n", + "\n", + "By default, the chosen surrogate model will be:\n", + "* if fidelity parameters are present in search space: `SingleTaskMultiFidelityGP`,\n", + "* if task parameters are present: a set of `MultiTaskGP` wrapped in a `ModelListGP` and each modeling one task,\n", + "* `SingleTaskGP` otherwise.\n", + "\n", + "The chosen acquisition function will be:\n", + "* for multi-objective settings: `qLogExpectedHypervolumeImprovement`,\n", + "* for single-objective settings: `qLogNoisyExpectedImprovement`.\n", + "----" + ] + }, + { + "cell_type": "markdown", + "id": "continuous-strain", + "metadata": { + "originalKey": "76ae9852-9d21-43d6-bf75-bb087a474dd6" + }, + "source": [ + "## Appendix 3: Handling storage errors that arise from objects that don't have serialization logic in A\n", + "\n", + "Attempting to store a generator run produced via `Models.BOTORCH_MODULAR` instance that included options without serization logic with will produce an error like: `\"Object passed to 'object_to_json' (of type ) is not registered with a corresponding encoder in ENCODER_REGISTRY.\"`" + ] + }, + { + "cell_type": "markdown", + "id": "broadband-voice", + "metadata": { + "originalKey": "6487b68e-b808-4372-b6ba-ab02ce4826bc" + }, + "source": [ + "The two options for handling this error are:\n", + "1. disabling storage of `BoTorchModel`'s options by passing `no_model_options_storage=True` to `Models.BOTORCH_MODULAR(...)` call –– this will prevent model options from being stored on the generator run, so a generator run can be saved but cannot be used to restore the model that produced it,\n", + "2. specifying serialization logic for a given object that needs to occur among the `Model` or `AcquisitionFunction` options. Tutorial for this is in the works, but in the meantime you can [post an issue on the Ax GitHub](https://github.com/facebook/Ax/issues) to get help with this." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 }