diff --git a/ax/core/generator_run.py b/ax/core/generator_run.py index 4ba6a3a33ef..9faf4f17f1d 100644 --- a/ax/core/generator_run.py +++ b/ax/core/generator_run.py @@ -11,6 +11,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum +from logging import Logger from typing import Any, Dict, List, MutableMapping, Optional, Set, Tuple import pandas as pd @@ -23,9 +24,13 @@ TModelPredict, TModelPredictArm, ) +from ax.exceptions.core import UnsupportedError from ax.utils.common.base import Base, SortableBase +from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import not_none +logger: Logger = get_logger(__name__) + class GeneratorRunType(Enum): """Class for enumerating generator run types.""" @@ -98,6 +103,7 @@ def __init__( candidate_metadata_by_arm_signature: Optional[ Dict[str, TCandidateMetadata] ] = None, + generation_node_name: Optional[str] = None, ) -> None: """ Inits GeneratorRun. @@ -133,14 +139,20 @@ def __init__( model when reinstantiating it to continue generation from it, rather than to reproduce the conditions, in which this generator run was created. - generation_step_index: Optional index of the generation step that produced - this generator run. Applicable only if the generator run was created - via a generation strategy (in which case this index should reflect the - index of generation step in a generation strategy) or a standalone - generation step (in which case this index should be ``-1``). + generation_step_index: Deprecated in favor of generation_node_name. + Optional index of the generation step that produced this generator run. + Applicable only if the generator run was created via a generation + strategy (in which case this index should reflect the index of + generation step in a generation strategy) or a standalone generation + step (in which case this index should be ``-1``). candidate_metadata_by_arm_signature: Optional dictionary of arm signatures to model-produced candidate metadata that corresponds to that arm in this generator run. + generaiton_node_name: Optional name of the generation node that produced + this generator run. Applicable only if the generator run was created + via a generation strategy (in which case this name should reflect the + name of the generation node in a generation strategy) or a standalone + generation node (in which case this name should be ``-1``). """ self._arm_weight_table: OrderedDict[str, ArmWeight] = OrderedDict() if weights is None: @@ -198,6 +210,11 @@ def __init__( ) self._candidate_metadata_by_arm_signature = candidate_metadata_by_arm_signature + if generation_step_index is not None: + logger.warn( + "The generation_step_index argument is deprecated.Please use the more" + "generalized generation_node_name argument instead.", + ) # Validate that generation step index is either not set (not from generation # strategy or ste), is non-negative (from generation step) or is -1 (from a # standalone generation step that was not a part of a generation strategy). @@ -207,6 +224,7 @@ def __init__( or generation_step_index >= 0 # Generation strategy ) self._generation_step_index = generation_step_index + self._generation_node_name = generation_node_name @property def arms(self) -> List[Arm]: @@ -249,8 +267,13 @@ def index(self) -> Optional[int]: @index.setter def index(self, index: int) -> None: + """Sets the index of this generator run within a trial's list of + generator runs. Cannot be changed after being set. + """ if self._index is not None and self._index != index: - raise ValueError("Cannot change the index of a generator run once set.") + raise UnsupportedError( + "Cannot change the index of a generator run once set." + ) self._index = index @property @@ -265,18 +288,26 @@ def search_space(self) -> Optional[SearchSpace]: @property def model_predictions(self) -> Optional[TModelPredict]: + """Means and covariances for the arms in this run recorded at + the time the run was executed. + """ return self._model_predictions @property def fit_time(self) -> Optional[float]: + """Time taken to fit the model in seconds.""" return self._fit_time @property def gen_time(self) -> Optional[float]: + """Time taken to generate in seconds.""" return self._gen_time @property def model_predictions_by_arm(self) -> Optional[Dict[str, TModelPredictArm]]: + """Model predictions for each arm in this run, at the time the run was + executed. + """ if self._model_predictions is None: return None @@ -289,6 +320,9 @@ def model_predictions_by_arm(self) -> Optional[Dict[str, TModelPredictArm]]: @property def best_arm_predictions(self) -> Optional[Tuple[Arm, Optional[TModelPredictArm]]]: + """Best arm in this run (according to the optimization config) and its + optional respective model predictions. + """ return self._best_arm_predictions @property @@ -343,6 +377,7 @@ def clone(self) -> GeneratorRun: model_state_after_gen=self._model_state_after_gen, generation_step_index=self._generation_step_index, candidate_metadata_by_arm_signature=cand_metadata, + generation_node_name=self._generation_node_name, ) generator_run._time_created = self._time_created generator_run._index = self._index @@ -361,6 +396,7 @@ def clone(self) -> GeneratorRun: return generator_run def __repr__(self) -> str: + """String representation of a GeneratorRun.""" class_name = self.__class__.__name__ num_arms = len(self.arms) total_weight = sum(self.weights) @@ -368,12 +404,8 @@ def __repr__(self) -> str: @property def _unique_id(self) -> str: + """Unique (within a given experiment) identifier for a GeneratorRun.""" if self.index is not None: - return str(self.index) - elif self._generation_step_index is not None: - return str(self._generation_step_index) + return str(self.index) + str(self.time_created) else: - raise ValueError( - "GeneratorRuns only have a unique id if attached " - "to a Trial or GenerationStrategy." - ) + return str(self) + str(self.time_created) diff --git a/ax/core/tests/test_generator_run.py b/ax/core/tests/test_generator_run.py index 07aae3bbdf6..4b948e96cf0 100644 --- a/ax/core/tests/test_generator_run.py +++ b/ax/core/tests/test_generator_run.py @@ -6,6 +6,7 @@ from ax.core.arm import Arm from ax.core.generator_run import GeneratorRun +from ax.exceptions.core import UnsupportedError from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_arms, @@ -107,7 +108,7 @@ def test_MergeDuplicateArm(self) -> None: def test_Index(self) -> None: self.assertIsNone(self.unweighted_run.index) self.unweighted_run.index = 1 - with self.assertRaises(ValueError): + with self.assertRaises(UnsupportedError): self.unweighted_run.index = 2 def test_ModelPredictions(self) -> None: diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 95fa965c6fe..557a200afa2 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -307,8 +307,12 @@ def gen( "generation step in an attempt to deduplicate. Candidates " f"produced in the last generator run: {generator_run.arms}." ) - - return not_none(generator_run) + assert generator_run is not None, ( + "The GeneratorRun is None which is an unexpected state of this" + " GenerationStrategy. This occured on GenerationNode: {self.node_name}." + ) + generator_run._generation_node_name = self.node_name + return generator_run # ------------------------- Model selection logic helpers. ------------------------- diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index d6a75a3da01..30cc6b534ca 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -415,6 +415,7 @@ def generator_run_to_dict(generator_run: GeneratorRun) -> Dict[str, Any]: "model_state_after_gen": gr._model_state_after_gen, "generation_step_index": gr._generation_step_index, "candidate_metadata_by_arm_signature": cand_metadata, + "generation_node_name": gr._generation_node_name, } diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index f5fd00a1bed..1984983e533 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -717,6 +717,7 @@ def generator_run_from_sqa( decoder_registry=self.config.json_decoder_registry, class_decoder_registry=self.config.json_class_decoder_registry, ), + generation_node_name=generator_run_sqa.generation_node_name, ) generator_run._time_created = generator_run_sqa.time_created generator_run._generator_run_type = self.get_enum_name( diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index ab5c48ccd9d..e137b539128 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -821,6 +821,7 @@ def generator_run_to_sqa( encoder_registry=self.config.json_encoder_registry, class_encoder_registry=self.config.json_class_encoder_registry, ), + generation_node_name=generator_run._generation_node_name, ) return gr_sqa diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index d1c35582c2a..05e67c00e99 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -287,6 +287,8 @@ class SQAGeneratorRun(Base): candidate_metadata_by_arm_signature: Optional[Dict[str, Any]] = Column( JSONEncodedTextDict ) + # pyre-fixme[8]: Attribute has type `Optional[str]`; used as `Column[str]`. + generation_node_name: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) # relationships # Use selectin loading for collections to prevent idle timeout errors