Skip to content

Commit

Permalink
Add generation_node_name arg to GeneratorRun (facebook#1918)
Browse files Browse the repository at this point in the history
Summary:

As we transition to using GenerationNodes by default in GenerationStrategy, we want to add in a geneartion_node_name argument to GeneratorRun. This arg is analogous to the generation_step_index, but more flexibile.

For backwards compatibility purposes, we aren't replacing generation_step_index at this time. I anticipate, over time we will be able to do so.

While I was here, I also updated the functions with docstrings to abide by our standards :)

Things in the pipeline:
(1) Update other code impacted by this generator run change
(2) Add the transition_to field to transition criterion class
(3) Update the transition criterion class to check on a per node basis, instead of per experiment
(4) add is_complete to generationNode and then use that in generation Strategy for moving forward
(5) [Mby] skip max trial criterion addition if numtrials == -1
(6) add transition criterion to the repr string + some of the other fields that havent made it yet

Reviewed By: ItsMrLin

Differential Revision: D50276770
  • Loading branch information
Mia Garrard authored and facebook-github-bot committed Oct 23, 2023
1 parent f7d01cf commit 536b3f1
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 16 deletions.
58 changes: 45 additions & 13 deletions ax/core/generator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from logging import Logger
from typing import Any, Dict, List, MutableMapping, Optional, Set, Tuple

import pandas as pd
Expand All @@ -23,9 +24,13 @@
TModelPredict,
TModelPredictArm,
)
from ax.exceptions.core import UnsupportedError
from ax.utils.common.base import Base, SortableBase
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none

logger: Logger = get_logger(__name__)


class GeneratorRunType(Enum):
"""Class for enumerating generator run types."""
Expand Down Expand Up @@ -98,6 +103,7 @@ def __init__(
candidate_metadata_by_arm_signature: Optional[
Dict[str, TCandidateMetadata]
] = None,
generation_node_name: Optional[str] = None,
) -> None:
"""
Inits GeneratorRun.
Expand Down Expand Up @@ -133,14 +139,20 @@ def __init__(
model when reinstantiating it to continue generation from it,
rather than to reproduce the conditions, in which this generator
run was created.
generation_step_index: Optional index of the generation step that produced
this generator run. Applicable only if the generator run was created
via a generation strategy (in which case this index should reflect the
index of generation step in a generation strategy) or a standalone
generation step (in which case this index should be ``-1``).
generation_step_index: Deprecated in favor of generation_node_name.
Optional index of the generation step that produced this generator run.
Applicable only if the generator run was created via a generation
strategy (in which case this index should reflect the index of
generation step in a generation strategy) or a standalone generation
step (in which case this index should be ``-1``).
candidate_metadata_by_arm_signature: Optional dictionary of arm signatures
to model-produced candidate metadata that corresponds to that arm in
this generator run.
generaiton_node_name: Optional name of the generation node that produced
this generator run. Applicable only if the generator run was created
via a generation strategy (in which case this name should reflect the
name of the generation node in a generation strategy) or a standalone
generation node (in which case this name should be ``-1``).
"""
self._arm_weight_table: OrderedDict[str, ArmWeight] = OrderedDict()
if weights is None:
Expand Down Expand Up @@ -198,6 +210,11 @@ def __init__(
)
self._candidate_metadata_by_arm_signature = candidate_metadata_by_arm_signature

if generation_step_index is not None:
logger.warn(
"The generation_step_index argument is deprecated.Please use the more"
"generalized generation_node_name argument instead.",
)
# Validate that generation step index is either not set (not from generation
# strategy or ste), is non-negative (from generation step) or is -1 (from a
# standalone generation step that was not a part of a generation strategy).
Expand All @@ -207,6 +224,7 @@ def __init__(
or generation_step_index >= 0 # Generation strategy
)
self._generation_step_index = generation_step_index
self._generation_node_name = generation_node_name

@property
def arms(self) -> List[Arm]:
Expand Down Expand Up @@ -249,8 +267,13 @@ def index(self) -> Optional[int]:

@index.setter
def index(self, index: int) -> None:
"""Sets the index of this generator run within a trial's list of
generator runs. Cannot be changed after being set.
"""
if self._index is not None and self._index != index:
raise ValueError("Cannot change the index of a generator run once set.")
raise UnsupportedError(
"Cannot change the index of a generator run once set."
)
self._index = index

@property
Expand All @@ -265,18 +288,26 @@ def search_space(self) -> Optional[SearchSpace]:

@property
def model_predictions(self) -> Optional[TModelPredict]:
"""Means and covariances for the arms in this run recorded at
the time the run was executed.
"""
return self._model_predictions

@property
def fit_time(self) -> Optional[float]:
"""Time taken to fit the model in seconds."""
return self._fit_time

@property
def gen_time(self) -> Optional[float]:
"""Time taken to generate in seconds."""
return self._gen_time

@property
def model_predictions_by_arm(self) -> Optional[Dict[str, TModelPredictArm]]:
"""Model predictions for each arm in this run, at the time the run was
executed.
"""
if self._model_predictions is None:
return None

Expand All @@ -289,6 +320,9 @@ def model_predictions_by_arm(self) -> Optional[Dict[str, TModelPredictArm]]:

@property
def best_arm_predictions(self) -> Optional[Tuple[Arm, Optional[TModelPredictArm]]]:
"""Best arm in this run (according to the optimization config) and its
optional respective model predictions.
"""
return self._best_arm_predictions

@property
Expand Down Expand Up @@ -343,6 +377,7 @@ def clone(self) -> GeneratorRun:
model_state_after_gen=self._model_state_after_gen,
generation_step_index=self._generation_step_index,
candidate_metadata_by_arm_signature=cand_metadata,
generation_node_name=self._generation_node_name,
)
generator_run._time_created = self._time_created
generator_run._index = self._index
Expand All @@ -361,19 +396,16 @@ def clone(self) -> GeneratorRun:
return generator_run

def __repr__(self) -> str:
"""String representation of a GeneratorRun."""
class_name = self.__class__.__name__
num_arms = len(self.arms)
total_weight = sum(self.weights)
return f"{class_name}({num_arms} arms, total weight {total_weight})"

@property
def _unique_id(self) -> str:
"""Unique (within a given experiment) identifier for a GeneratorRun."""
if self.index is not None:
return str(self.index)
elif self._generation_step_index is not None:
return str(self._generation_step_index)
return str(self.index) + str(self.time_created)
else:
raise ValueError(
"GeneratorRuns only have a unique id if attached "
"to a Trial or GenerationStrategy."
)
return str(self) + str(self.time_created)
3 changes: 2 additions & 1 deletion ax/core/tests/test_generator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ax.core.arm import Arm
from ax.core.generator_run import GeneratorRun
from ax.exceptions.core import UnsupportedError
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_arms,
Expand Down Expand Up @@ -107,7 +108,7 @@ def test_MergeDuplicateArm(self) -> None:
def test_Index(self) -> None:
self.assertIsNone(self.unweighted_run.index)
self.unweighted_run.index = 1
with self.assertRaises(ValueError):
with self.assertRaises(UnsupportedError):
self.unweighted_run.index = 2

def test_ModelPredictions(self) -> None:
Expand Down
8 changes: 6 additions & 2 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,12 @@ def gen(
"generation step in an attempt to deduplicate. Candidates "
f"produced in the last generator run: {generator_run.arms}."
)

return not_none(generator_run)
assert generator_run is not None, (
"The GeneratorRun is None which is an unexpected state of this"
" GenerationStrategy. This occured on GenerationNode: {self.node_name}."
)
generator_run._generation_node_name = self.node_name
return generator_run

# ------------------------- Model selection logic helpers. -------------------------

Expand Down
1 change: 1 addition & 0 deletions ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def generator_run_to_dict(generator_run: GeneratorRun) -> Dict[str, Any]:
"model_state_after_gen": gr._model_state_after_gen,
"generation_step_index": gr._generation_step_index,
"candidate_metadata_by_arm_signature": cand_metadata,
"generation_node_name": gr._generation_node_name,
}


Expand Down
1 change: 1 addition & 0 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,7 @@ def generator_run_from_sqa(
decoder_registry=self.config.json_decoder_registry,
class_decoder_registry=self.config.json_class_decoder_registry,
),
generation_node_name=generator_run_sqa.generation_node_name,
)
generator_run._time_created = generator_run_sqa.time_created
generator_run._generator_run_type = self.get_enum_name(
Expand Down
1 change: 1 addition & 0 deletions ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ def generator_run_to_sqa(
encoder_registry=self.config.json_encoder_registry,
class_encoder_registry=self.config.json_class_encoder_registry,
),
generation_node_name=generator_run._generation_node_name,
)
return gr_sqa

Expand Down
2 changes: 2 additions & 0 deletions ax/storage/sqa_store/sqa_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ class SQAGeneratorRun(Base):
candidate_metadata_by_arm_signature: Optional[Dict[str, Any]] = Column(
JSONEncodedTextDict
)
# pyre-fixme[8]: Attribute has type `Optional[str]`; used as `Column[str]`.
generation_node_name: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))

# relationships
# Use selectin loading for collections to prevent idle timeout errors
Expand Down

0 comments on commit 536b3f1

Please sign in to comment.