Skip to content

Commit

Permalink
Speed up extend_pending_observations (#3148)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3148

This diff updates `extend_pending_observations` to utilize sets for deduplicated updates, which leads to ~200x speed up over existing implementation.

Reviewed By: dme65

Differential Revision: D66780561

fbshipit-source-id: 4e265b6e1584fffd88e7ba26554d227880899d02
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Dec 4, 2024
1 parent 3847400 commit a71448c
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions ax/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,18 +414,16 @@ def extend_pending_observations(
Returns:
A new dictionary of pending observations to avoid in-place modification
"""
# TODO: T203665729 @mgarrard add arm signature to ObservationFeatures and then use
# that to compare to arm signature in GR to speed up this method
extended_obs = deepcopy(pending_observations)
pending_observations = deepcopy(pending_observations)
extended_observations: dict[str, list[ObservationFeatures]] = {}
for m in experiment.metrics:
if m not in extended_obs:
extended_obs[m] = []
extended_obs_set = set(pending_observations.get(m, []))
for generator_run in generator_runs:
for a in generator_run.arms:
ob_ft = ObservationFeatures.from_arm(a)
if ob_ft not in extended_obs[m]:
extended_obs[m].append(ob_ft)
return extended_obs
extended_obs_set.add(ob_ft)
extended_observations[m] = list(extended_obs_set)
return extended_observations


# -------------------- Get target trial utils. ---------------------
Expand Down

0 comments on commit a71448c

Please sign in to comment.