Skip to content

Commit

Permalink
Put model fit data in gen_metadata (facebook#2511)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebook#2511

Reviewed By: saitcakmak

Differential Revision: D58261582

fbshipit-source-id: a29600fb48d3a825d2c648646a12c88702ca94b3
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Jun 13, 2024
1 parent 0dce67c commit 1139ea0
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 78 deletions.
73 changes: 73 additions & 0 deletions ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
Expand All @@ -16,6 +17,7 @@
from logging import Logger
from numbers import Number
from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple
from warnings import warn

import numpy as np
from ax.core.observation import Observation, ObservationData, recombine_observations
Expand Down Expand Up @@ -492,6 +494,45 @@ def best_diagnostic(self, diagnostics: List[CVDiagnostics]) -> int:
"""


def get_fit_and_std_quality_and_generalization_dict(
fitted_model_bridge: ModelBridge,
) -> Dict[str, Optional[float]]:
"""
Get stats and gen from a fitted ModelBridge for analytics purposes.
"""
try:
model_fit_dict = compute_model_fit_metrics_from_modelbridge(
model_bridge=fitted_model_bridge,
generalization=False,
untransform=False,
)
# similar for uncertainty quantification, but distance from 1 matters
std = list(model_fit_dict["std_of_the_standardized_error"].values())

# generalization metrics
model_gen_dict = compute_model_fit_metrics_from_modelbridge(
model_bridge=fitted_model_bridge,
generalization=True,
untransform=False,
)
gen_std = list(model_gen_dict["std_of_the_standardized_error"].values())
return {
"model_fit_quality": _model_fit_metric(model_fit_dict),
"model_std_quality": _model_std_quality(np.array(std)),
"model_fit_generalization": _model_fit_metric(model_gen_dict),
"model_std_generalization": _model_std_quality(np.array(gen_std)),
}

except Exception as e:
warn("Encountered exception in computing model fit quality: " + str(e))
return {
"model_fit_quality": None,
"model_std_quality": None,
"model_fit_generalization": None,
"model_std_generalization": None,
}


def compute_model_fit_metrics_from_modelbridge(
model_bridge: ModelBridge,
fit_metrics_dict: Optional[Dict[str, ModelFitMetricProtocol]] = None,
Expand Down Expand Up @@ -550,6 +591,38 @@ def compute_model_fit_metrics_from_modelbridge(
)


def _model_fit_metric(metric_dict: Dict[str, Dict[str, float]]) -> float:
# We'd ideally log the entire `model_fit_dict` as a single model fit metric
# can't capture the nuances of multiple experimental metrics, but this might
# lead to database performance issues. So instead, we take the worst
# coefficient of determination as model fit quality and store the full data
# in Manifold (TODO).
return min(metric_dict["coefficient_of_determination"].values())


def _model_std_quality(std: np.ndarray) -> float:
"""Quantifies quality of the model uncertainty. A value of one means the
uncertainty is perfectly predictive of the true standard deviation of the error.
Values larger than one indicate over-estimation and negative values indicate
under-estimation of the true standard deviation of the error. In particular, a value
of 2 (resp. 1 / 2) represents an over-estimation (resp. under-estimation) of the
true standard deviation of the error by a factor of 2.
Args:
std: The standard deviation of the standardized error.
Returns:
The factor corresponding to the worst over- or under-estimation factor of the
standard deviation of the error among all experimentally observed metrics.
"""
max_std, min_std = np.max(std), np.min(std)
# comparing worst over-estimation factor with worst under-estimation factor
inv_model_std_quality = max_std if max_std > 1 / min_std else min_std
# reciprocal so that values greater than one indicate over-estimation and
# values smaller than indicate underestimation of the uncertainty.
return 1 / inv_model_std_quality


def _predict_on_training_data(
model_bridge: ModelBridge,
untransform: bool = False,
Expand Down
17 changes: 16 additions & 1 deletion ax/modelbridge/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
cross_validate,
CVDiagnostics,
CVResult,
get_fit_and_std_quality_and_generalization_dict,
)
from ax.modelbridge.registry import ModelRegistryBase
from ax.utils.common.base import SortableBase
Expand Down Expand Up @@ -218,7 +219,21 @@ def gen(self, **model_gen_kwargs: Any) -> GeneratorRun:
],
keywords=get_function_argument_names(fitted_model.gen),
)
return fitted_model.gen(**model_gen_kwargs)
generator_run = fitted_model.gen(
**model_gen_kwargs,
)
fit_and_std_quality_and_generalization_dict = (
get_fit_and_std_quality_and_generalization_dict(
fitted_model_bridge=self.fitted_model,
)
)
generator_run._gen_metadata = (
{} if generator_run.gen_metadata is None else generator_run.gen_metadata
)
generator_run._gen_metadata.update(
**fit_and_std_quality_and_generalization_dict
)
return generator_run

def copy(self) -> ModelSpec:
"""`ModelSpec` is both a spec and an object that performs actions.
Expand Down
78 changes: 77 additions & 1 deletion ax/modelbridge/tests/test_model_fit_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
_predict_on_cross_validation_data,
_predict_on_training_data,
compute_model_fit_metrics_from_modelbridge,
get_fit_and_std_quality_and_generalization_dict,
)
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
Expand All @@ -27,7 +28,7 @@
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
from ax.utils.stats.model_fit_stats import _entropy_via_kde, entropy_of_observations
from ax.utils.testing.core_stubs import get_branin_search_space
from ax.utils.testing.core_stubs import get_branin_experiment, get_branin_search_space

NUM_SOBOL = 5

Expand Down Expand Up @@ -141,3 +142,78 @@ def test_model_fit_metrics(self) -> None:
self.assertFalse(
any("Input data is not standardized" in str(w.message) for w in ws)
)


class TestGetFitAndStdQualityAndGeneralizationDict(TestCase):
def setUp(self) -> None:
super().setUp()
self.experiment = get_branin_experiment()
self.sobol = Models.SOBOL(search_space=self.experiment.search_space)

def test_it_returns_empty_data_for_sobol(self) -> None:
results = get_fit_and_std_quality_and_generalization_dict(
fitted_model_bridge=self.sobol,
)
expected = {
"model_fit_quality": None,
"model_std_quality": None,
"model_fit_generalization": None,
"model_std_generalization": None,
}
self.assertDictEqual(results, expected)

def test_it_returns_float_values_when_fit_can_be_evaluated(self) -> None:
# GIVEN we have a model whose CV can be evaluated
sobol_run = self.sobol.gen(n=20)
self.experiment.new_batch_trial().add_generator_run(
sobol_run
).run().mark_completed()
data = self.experiment.fetch_data()
botorch_modelbridge = Models.BOTORCH_MODULAR(
experiment=self.experiment, data=data
)

# WHEN we call get_fit_and_std_quality_and_generalization_dict
results = get_fit_and_std_quality_and_generalization_dict(
fitted_model_bridge=botorch_modelbridge,
)

# THEN we get expected results
# CALCULATE EXPECTED RESULTS
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=botorch_modelbridge,
generalization=False,
untransform=False,
)
# checking fit metrics
r2 = fit_metrics.get("coefficient_of_determination")
r2 = cast(Dict[str, float], r2)

std = fit_metrics.get("std_of_the_standardized_error")
std = cast(Dict[str, float], std)
std_branin = std["branin"]

model_std_quality = 1 / std_branin

# check generalization metrics
gen_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=botorch_modelbridge,
generalization=True,
untransform=False,
)
r2_gen = gen_metrics.get("coefficient_of_determination")
r2_gen = cast(Dict[str, float], r2_gen)
gen_std = gen_metrics.get("std_of_the_standardized_error")
gen_std = cast(Dict[str, float], gen_std)
gen_std_branin = gen_std["branin"]
model_std_generalization = 1 / gen_std_branin

expected = {
"model_fit_quality": min(r2.values()),
"model_std_quality": model_std_quality,
"model_fit_generalization": min(r2_gen.values()),
"model_std_generalization": model_std_generalization,
}
# END CALCULATE EXPECTED RESULTS

self.assertDictsAlmostEqual(results, expected)
21 changes: 21 additions & 0 deletions ax/modelbridge/tests/test_model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ax.modelbridge.modelbridge_utils import extract_search_space_digest
from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import not_none
from ax.utils.testing.core_stubs import get_branin_experiment
from ax.utils.testing.mock import fast_botorch_optimize

Expand Down Expand Up @@ -148,6 +149,26 @@ def test_fixed_features(self) -> None:
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
self.assertEqual(ms.model_gen_kwargs["fixed_features"], new_features)

def test_gen_attaches_empty_model_fit_metadata_if_fit_not_applicable(self) -> None:
ms = ModelSpec(model_enum=Models.SOBOL)
ms.fit(experiment=self.experiment, data=self.data)
gr = ms.gen(n=1)
gen_metadata = not_none(gr.gen_metadata)
self.assertEqual(gen_metadata["model_fit_quality"], None)
self.assertEqual(gen_metadata["model_std_quality"], None)
self.assertEqual(gen_metadata["model_fit_generalization"], None)
self.assertEqual(gen_metadata["model_std_generalization"], None)

def test_gen_attaches_model_fit_metadata_if_applicable(self) -> None:
ms = ModelSpec(model_enum=Models.GPEI)
ms.fit(experiment=self.experiment, data=self.data)
gr = ms.gen(n=1)
gen_metadata = not_none(gr.gen_metadata)
self.assertIsInstance(gen_metadata["model_fit_quality"], float)
self.assertIsInstance(gen_metadata["model_std_quality"], float)
self.assertIsInstance(gen_metadata["model_fit_generalization"], float)
self.assertIsInstance(gen_metadata["model_std_generalization"], float)


class FactoryFunctionModelSpecTest(BaseModelSpecTest):
def test_construct(self) -> None:
Expand Down
83 changes: 18 additions & 65 deletions ax/telemetry/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from typing import Any, Dict, Optional
from warnings import warn

import numpy as np
from ax.modelbridge.cross_validation import compute_model_fit_metrics_from_modelbridge
from ax.modelbridge.cross_validation import (
get_fit_and_std_quality_and_generalization_dict,
)

from ax.service.scheduler import get_fitted_model_bridge, Scheduler
from ax.telemetry.common import _get_max_transformed_dimensionality
Expand Down Expand Up @@ -103,10 +104,10 @@ class SchedulerCompletedRecord:
experiment_completed_record: ExperimentCompletedRecord

best_point_quality: float
model_fit_quality: float
model_std_quality: float
model_fit_generalization: float
model_std_generalization: float
model_fit_quality: Optional[float]
model_std_quality: Optional[float]
model_fit_generalization: Optional[float]
model_std_generalization: Optional[float]

improvement_over_baseline: float

Expand All @@ -117,32 +118,19 @@ class SchedulerCompletedRecord:
def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCompletedRecord:
try:
model_bridge = get_fitted_model_bridge(scheduler)
model_fit_dict = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
generalization=False,
untransform=False,
)
model_fit_quality = _model_fit_metric(model_fit_dict)
# similar for uncertainty quantification, but distance from 1 matters
std = list(model_fit_dict["std_of_the_standardized_error"].values())
model_std_quality = _model_std_quality(np.array(std))

# generalization metrics
model_gen_dict = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
generalization=True,
untransform=False,
quality_and_generalizations_dict = (
get_fit_and_std_quality_and_generalization_dict(
fitted_model_bridge=model_bridge,
)
)
model_fit_generalization = _model_fit_metric(model_gen_dict)
gen_std = list(model_gen_dict["std_of_the_standardized_error"].values())
model_std_generalization = _model_std_quality(np.array(gen_std))

except Exception as e:
warn("Encountered exception in computing model fit quality: " + str(e))
model_fit_quality = float("nan")
model_std_quality = float("nan")
model_fit_generalization = float("nan")
model_std_generalization = float("nan")
quality_and_generalizations_dict = {
"model_fit_quality": None,
"model_std_quality": None,
"model_fit_generalization": None,
"model_std_generalization": None,
}

try:
improvement_over_baseline = scheduler.get_improvement_over_baseline()
Expand All @@ -158,13 +146,10 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCompletedRecord:
experiment=scheduler.experiment
),
best_point_quality=float("nan"), # TODO[T147907632]
model_fit_quality=model_fit_quality,
model_std_quality=model_std_quality,
model_fit_generalization=model_fit_generalization,
model_std_generalization=model_std_generalization,
improvement_over_baseline=improvement_over_baseline,
num_metric_fetch_e_encountered=scheduler._num_metric_fetch_e_encountered,
num_trials_bad_due_to_err=scheduler._num_trials_bad_due_to_err,
**quality_and_generalizations_dict,
)

def flatten(self) -> Dict[str, Any]:
Expand All @@ -179,35 +164,3 @@ def flatten(self) -> Dict[str, Any]:
**self_dict,
**experiment_completed_record_dict,
}


def _model_fit_metric(metric_dict: Dict[str, Dict[str, float]]) -> float:
# We'd ideally log the entire `model_fit_dict` as a single model fit metric
# can't capture the nuances of multiple experimental metrics, but this might
# lead to database performance issues. So instead, we take the worst
# coefficient of determination as model fit quality and store the full data
# in Manifold (TODO).
return min(metric_dict["coefficient_of_determination"].values())


def _model_std_quality(std: np.ndarray) -> float:
"""Quantifies quality of the model uncertainty. A value of one means the
uncertainty is perfectly predictive of the true standard deviation of the error.
Values larger than one indicate over-estimation and negative values indicate
under-estimation of the true standard deviation of the error. In particular, a value
of 2 (resp. 1 / 2) represents an over-estimation (resp. under-estimation) of the
true standard deviation of the error by a factor of 2.
Args:
std: The standard deviation of the standardized error.
Returns:
The factor corresponding to the worst over- or under-estimation factor of the
standard deviation of the error among all experimentally observed metrics.
"""
max_std, min_std = np.max(std), np.min(std)
# comparing worst over-estimation factor with worst under-estimation factor
inv_model_std_quality = max_std if max_std > 1 / min_std else min_std
# reciprocal so that values greater than one indicate over-estimation and
# values smaller than indicate underestimation of the uncertainty.
return 1 / inv_model_std_quality
Loading

0 comments on commit 1139ea0

Please sign in to comment.