Skip to content

Commit

Permalink
Update Ax tutorials to stop using legacy Ax models (facebook#1982)
Browse files Browse the repository at this point in the history
Summary:
- The only legacy model remaining is a call to `get_MOO_PAREGO` in the MOO tutorial.
- I also cleaned up some imports and got rid of a few `from ax import *`.

Pull Request resolved: facebook#1982

Differential Revision: D51214898

fbshipit-source-id: 9d937e2795dc7046adb22b23ae3e3f7c5dbe64fd
  • Loading branch information
dme65 authored and facebook-github-bot committed Nov 10, 2023
1 parent 1186380 commit f55e089
Show file tree
Hide file tree
Showing 18 changed files with 10,741 additions and 21,940 deletions.
2,220 changes: 1,116 additions & 1,104 deletions tutorials/Setup_and_Usage_of_BoTorch_Models_in_Ax.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions tutorials/ax_client_snapshot.json

Large diffs are not rendered by default.

1,028 changes: 452 additions & 576 deletions tutorials/generation_strategy.ipynb

Large diffs are not rendered by default.

1,418 changes: 710 additions & 708 deletions tutorials/gpei_hartmann_developer.ipynb

Large diffs are not rendered by default.

1,262 changes: 233 additions & 1,029 deletions tutorials/gpei_hartmann_loop.ipynb

Large diffs are not rendered by default.

962 changes: 481 additions & 481 deletions tutorials/gpei_hartmann_service.ipynb

Large diffs are not rendered by default.

1,100 changes: 540 additions & 560 deletions tutorials/gss.ipynb

Large diffs are not rendered by default.

2,351 changes: 1,176 additions & 1,175 deletions tutorials/modular_botax.ipynb

Large diffs are not rendered by default.

1,164 changes: 582 additions & 582 deletions tutorials/multi_task.ipynb

Large diffs are not rendered by default.

2,039 changes: 1,021 additions & 1,018 deletions tutorials/multiobjective_optimization.ipynb

Large diffs are not rendered by default.

641 changes: 317 additions & 324 deletions tutorials/raytune_pytorch_cnn.ipynb

Large diffs are not rendered by default.

757 changes: 369 additions & 388 deletions tutorials/saasbo.ipynb

Large diffs are not rendered by default.

1,417 changes: 692 additions & 725 deletions tutorials/saasbo_nehvi.ipynb

Large diffs are not rendered by default.

1,811 changes: 912 additions & 899 deletions tutorials/scheduler.ipynb

Large diffs are not rendered by default.

2,243 changes: 609 additions & 1,634 deletions tutorials/sebo.ipynb

Large diffs are not rendered by default.

286 changes: 286 additions & 0 deletions tutorials/tune_cnn.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tune a CNN on MNIST\n",
"\n",
"This tutorial walks through using Ax to tune two hyperparameters (learning rate and momentum) for a PyTorch CNN on the MNIST dataset trained using SGD with momentum.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"\n",
"from ax.plot.contour import plot_contour\n",
"from ax.plot.trace import optimization_trace_single_method\n",
"from ax.service.managed_loop import optimize\n",
"from ax.utils.notebook.plotting import init_notebook_plotting, render\n",
"from ax.utils.tutorials.cnn_utils import CNN, evaluate, load_mnist, train\n",
"\n",
"init_notebook_plotting()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(12345)\n",
"dtype = torch.float\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Load MNIST data\n",
"First, we need to load the MNIST data and partition it into training, validation, and test sets.\n",
"\n",
"Note: this will download the dataset if necessary."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"BATCH_SIZE = 512\n",
"train_loader, valid_loader, test_loader = load_mnist(batch_size=BATCH_SIZE)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Define function to optimize\n",
"In this tutorial, we want to optimize classification accuracy on the validation set as a function of the learning rate and momentum. The function takes in a parameterization (set of parameter values), computes the classification accuracy, and returns a dictionary of metric name ('accuracy') to a tuple with the mean and standard error."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train_evaluate(parameterization):\n",
" net = CNN()\n",
" net = train(\n",
" net=net,\n",
" train_loader=train_loader,\n",
" parameters=parameterization,\n",
" dtype=dtype,\n",
" device=device,\n",
" )\n",
" return evaluate(\n",
" net=net,\n",
" data_loader=valid_loader,\n",
" dtype=dtype,\n",
" device=device,\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Run the optimization loop\n",
"Here, we set the bounds on the learning rate and momentum and set the parameter space for the learning rate to be on a log scale. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"best_parameters, values, experiment, model = optimize(\n",
" parameters=[\n",
" {\"name\": \"lr\", \"type\": \"range\", \"bounds\": [1e-6, 0.4], \"log_scale\": True},\n",
" {\"name\": \"momentum\", \"type\": \"range\", \"bounds\": [0.0, 1.0]},\n",
" ],\n",
" evaluation_function=train_evaluate,\n",
" objective_name=\"accuracy\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can introspect the optimal parameters and their outcomes:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"best_parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"means, covariances = values\n",
"means, covariances"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Plot response surface\n",
"\n",
"Contour plot showing classification accuracy as a function of the two hyperparameters.\n",
"\n",
"The black squares show points that we have actually run, notice how they are clustered in the optimal region."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"render(\n",
" plot_contour(model=model, param_x=\"lr\", param_y=\"momentum\", metric_name=\"accuracy\")\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Plot best objective as function of the iteration\n",
"\n",
"Show the model accuracy improving as we identify better hyperparameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple\n",
"# optimization runs, so we wrap out best objectives array in another array.\n",
"best_objectives = np.array(\n",
" [[trial.objective_mean * 100 for trial in experiment.trials.values()]]\n",
")\n",
"best_objective_plot = optimization_trace_single_method(\n",
" y=np.maximum.accumulate(best_objectives, axis=1),\n",
" title=\"Model performance vs. # of iterations\",\n",
" ylabel=\"Classification Accuracy, %\",\n",
")\n",
"render(best_objective_plot)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Train CNN with best hyperparameters and evaluate on test set\n",
"Note that the resulting accuracy on the test set might not be exactly the same as the maximum accuracy achieved on the evaluation set throughout optimization. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = experiment.fetch_data()\n",
"df = data.df\n",
"best_arm_name = df.arm_name[df[\"mean\"] == df[\"mean\"].max()].values[0]\n",
"best_arm = experiment.arms_by_name[best_arm_name]\n",
"best_arm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"combined_train_valid_set = torch.utils.data.ConcatDataset(\n",
" [\n",
" train_loader.dataset.dataset,\n",
" valid_loader.dataset.dataset,\n",
" ]\n",
")\n",
"combined_train_valid_loader = torch.utils.data.DataLoader(\n",
" combined_train_valid_set,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"net = train(\n",
" net=CNN(),\n",
" train_loader=combined_train_valid_loader,\n",
" parameters=best_arm.parameters,\n",
" dtype=dtype,\n",
" device=device,\n",
")\n",
"test_accuracy = evaluate(\n",
" net=net,\n",
" data_loader=test_loader,\n",
" dtype=dtype,\n",
" device=device,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(f\"Classification Accuracy (test set): {round(test_accuracy*100, 2)}%\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit f55e089

Please sign in to comment.