diff --git a/ax/core/utils.py b/ax/core/utils.py index 33fa0d027fe..6b4cc4fd174 100644 --- a/ax/core/utils.py +++ b/ax/core/utils.py @@ -414,18 +414,16 @@ def extend_pending_observations( Returns: A new dictionary of pending observations to avoid in-place modification """ - # TODO: T203665729 @mgarrard add arm signature to ObservationFeatures and then use - # that to compare to arm signature in GR to speed up this method - extended_obs = deepcopy(pending_observations) + pending_observations = deepcopy(pending_observations) + extended_observations: dict[str, list[ObservationFeatures]] = {} for m in experiment.metrics: - if m not in extended_obs: - extended_obs[m] = [] + extended_obs_set = set(pending_observations.get(m, [])) for generator_run in generator_runs: for a in generator_run.arms: ob_ft = ObservationFeatures.from_arm(a) - if ob_ft not in extended_obs[m]: - extended_obs[m].append(ob_ft) - return extended_obs + extended_obs_set.add(ob_ft) + extended_observations[m] = list(extended_obs_set) + return extended_observations # -------------------- Get target trial utils. ---------------------