-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
Summary: - The only legacy model remaining is a call to `get_MOO_PAREGO` in the MOO tutorial. - I also cleaned up some imports and got rid of a few `from ax import *`. Pull Request resolved: facebook#1982 Differential Revision: D51214898 fbshipit-source-id: 9d937e2795dc7046adb22b23ae3e3f7c5dbe64fd
- Loading branch information
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Tune a CNN on MNIST\n", | ||
"\n", | ||
"This tutorial walks through using Ax to tune two hyperparameters (learning rate and momentum) for a PyTorch CNN on the MNIST dataset trained using SGD with momentum.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import torch\n", | ||
"\n", | ||
"from ax.plot.contour import plot_contour\n", | ||
"from ax.plot.trace import optimization_trace_single_method\n", | ||
"from ax.service.managed_loop import optimize\n", | ||
"from ax.utils.notebook.plotting import init_notebook_plotting, render\n", | ||
"from ax.utils.tutorials.cnn_utils import CNN, evaluate, load_mnist, train\n", | ||
"\n", | ||
"init_notebook_plotting()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"torch.manual_seed(12345)\n", | ||
"dtype = torch.float\n", | ||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 1. Load MNIST data\n", | ||
"First, we need to load the MNIST data and partition it into training, validation, and test sets.\n", | ||
"\n", | ||
"Note: this will download the dataset if necessary." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"BATCH_SIZE = 512\n", | ||
"train_loader, valid_loader, test_loader = load_mnist(batch_size=BATCH_SIZE)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 2. Define function to optimize\n", | ||
"In this tutorial, we want to optimize classification accuracy on the validation set as a function of the learning rate and momentum. The function takes in a parameterization (set of parameter values), computes the classification accuracy, and returns a dictionary of metric name ('accuracy') to a tuple with the mean and standard error." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def train_evaluate(parameterization):\n", | ||
" net = CNN()\n", | ||
" net = train(\n", | ||
" net=net,\n", | ||
" train_loader=train_loader,\n", | ||
" parameters=parameterization,\n", | ||
" dtype=dtype,\n", | ||
" device=device,\n", | ||
" )\n", | ||
" return evaluate(\n", | ||
" net=net,\n", | ||
" data_loader=valid_loader,\n", | ||
" dtype=dtype,\n", | ||
" device=device,\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 3. Run the optimization loop\n", | ||
"Here, we set the bounds on the learning rate and momentum and set the parameter space for the learning rate to be on a log scale. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"best_parameters, values, experiment, model = optimize(\n", | ||
" parameters=[\n", | ||
" {\"name\": \"lr\", \"type\": \"range\", \"bounds\": [1e-6, 0.4], \"log_scale\": True},\n", | ||
" {\"name\": \"momentum\", \"type\": \"range\", \"bounds\": [0.0, 1.0]},\n", | ||
" ],\n", | ||
" evaluation_function=train_evaluate,\n", | ||
" objective_name=\"accuracy\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We can introspect the optimal parameters and their outcomes:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"best_parameters" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"means, covariances = values\n", | ||
"means, covariances" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 4. Plot response surface\n", | ||
"\n", | ||
"Contour plot showing classification accuracy as a function of the two hyperparameters.\n", | ||
"\n", | ||
"The black squares show points that we have actually run, notice how they are clustered in the optimal region." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"render(\n", | ||
" plot_contour(model=model, param_x=\"lr\", param_y=\"momentum\", metric_name=\"accuracy\")\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 5. Plot best objective as function of the iteration\n", | ||
"\n", | ||
"Show the model accuracy improving as we identify better hyperparameters." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple\n", | ||
"# optimization runs, so we wrap out best objectives array in another array.\n", | ||
"best_objectives = np.array(\n", | ||
" [[trial.objective_mean * 100 for trial in experiment.trials.values()]]\n", | ||
")\n", | ||
"best_objective_plot = optimization_trace_single_method(\n", | ||
" y=np.maximum.accumulate(best_objectives, axis=1),\n", | ||
" title=\"Model performance vs. # of iterations\",\n", | ||
" ylabel=\"Classification Accuracy, %\",\n", | ||
")\n", | ||
"render(best_objective_plot)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 6. Train CNN with best hyperparameters and evaluate on test set\n", | ||
"Note that the resulting accuracy on the test set might not be exactly the same as the maximum accuracy achieved on the evaluation set throughout optimization. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"data = experiment.fetch_data()\n", | ||
"df = data.df\n", | ||
"best_arm_name = df.arm_name[df[\"mean\"] == df[\"mean\"].max()].values[0]\n", | ||
"best_arm = experiment.arms_by_name[best_arm_name]\n", | ||
"best_arm" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"combined_train_valid_set = torch.utils.data.ConcatDataset(\n", | ||
" [\n", | ||
" train_loader.dataset.dataset,\n", | ||
" valid_loader.dataset.dataset,\n", | ||
" ]\n", | ||
")\n", | ||
"combined_train_valid_loader = torch.utils.data.DataLoader(\n", | ||
" combined_train_valid_set,\n", | ||
" batch_size=BATCH_SIZE,\n", | ||
" shuffle=True,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"net = train(\n", | ||
" net=CNN(),\n", | ||
" train_loader=combined_train_valid_loader,\n", | ||
" parameters=best_arm.parameters,\n", | ||
" dtype=dtype,\n", | ||
" device=device,\n", | ||
")\n", | ||
"test_accuracy = evaluate(\n", | ||
" net=net,\n", | ||
" data_loader=test_loader,\n", | ||
" dtype=dtype,\n", | ||
" device=device,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"print(f\"Classification Accuracy (test set): {round(test_accuracy*100, 2)}%\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |