Skip to content

Commit

Permalink
Refactor _get_trial_indices_to_fetch (#3086)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3086

This diff refactors the _get_trial_indices_to_fetch method to improve its readability and maintainability. The changes include:

1. Extracting a new method _identify_trial_indices_to_fetch that takes both old and new trial statuses as input
2. Simplifying the logic for identifying newly completed, running, and previously completed trials with new data after completion.
3. Improving code organization and reducing duplication

Reviewed By: lena-kashtelyan

Differential Revision: D66045355

fbshipit-source-id: 70e8b22701dbbc6c9a915a0a14ba172b18857cee
  • Loading branch information
paschai authored and facebook-github-bot committed Dec 9, 2024
1 parent 5f3e8e2 commit ec08fe5
Showing 1 changed file with 72 additions and 55 deletions.
127 changes: 72 additions & 55 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1491,85 +1491,102 @@ def _apply_new_trial_statuses(
trial.mark_as(status=status, unsafe=True)
return updated_trial_indices

def _get_trial_indices_to_fetch(
self, new_status_to_trial_idcs: Mapping[TrialStatus, set[int]]
def _identify_trial_indices_to_fetch(
self,
old_status_to_trial_idcs: Mapping[TrialStatus, set[int]],
new_status_to_trial_idcs: Mapping[TrialStatus, set[int]],
) -> set[int]:
"""Get trial indices to fetch data for the experiment given
`new_status_to_trial_idcs` and metric properties. This should include:
- newly completed trials (about to be completed)
- running trials if the experiment has metrics available while running
- previously completed (or early stopped) trials if the experiment
has metrics with new data after completion which finished recently
"""
Identify trial indices to fetch data for based on changes in trial statuses.
Args:
new_status_to_trial_idcs: Changes about to be applied to trial statuses.
old_status_to_trial_idcs: Mapping of old trial statuses
to their corresponding trial indices.
new_status_to_trial_idcs: Mapping of new trial statuses
to their corresponding trial indices.
Returns:
Set of trial indices to fetch data for.
"""
terminated_trial_idcs = {
index
for status, indices in new_status_to_trial_idcs.items()
if status.is_terminal
for index in indices
}
running_trial_indices = {
trial.index
for trial in self.running_trials
if trial.index not in terminated_trial_idcs
}
# add in any trials that will be marked running
running_trial_indices.update(
new_status_to_trial_idcs.get(TrialStatus.RUNNING, set())
)

# includes completed and early stopped trials
prev_completed_trial_idcs = {
t.index for t in self.trials_expecting_data
} - self.running_trial_indices
trial_indices_to_fetch = set()
# Get newly completed trials
prev_completed_trial_idcs = old_status_to_trial_idcs.get(
TrialStatus.COMPLETED, set()
) | old_status_to_trial_idcs.get(TrialStatus.EARLY_STOPPED, set())

# Fetch data for newly completed trials
newly_completed = (
new_status_to_trial_idcs.get(TrialStatus.COMPLETED, set())
- prev_completed_trial_idcs
)

idcs = make_indices_str(indices=newly_completed)
if newly_completed:
self.logger.info(f"Fetching data for newly completed trials: {idcs}.")
trial_indices_to_fetch.update(newly_completed)
self.logger.debug(f"Will fetch data for newly completed trials: {idcs}.")
else:
self.logger.info("No newly completed trials; not fetching data for any.")
self.logger.debug("No newly completed trials; not fetching data for any.")

# Fetch data for running trials that have metrics available while running
if (
any(
m.is_available_while_running() for m in self.experiment.metrics.values()
)
and len(running_trial_indices) > 0
# Get running trials with metrics available while running
running_trial_indices_with_metrics = set()
if any(
m.is_available_while_running() for m in self.experiment.metrics.values()
):
# NOTE: Metrics that are *not* available_while_running will be skipped
# in fetch_trials_data
idcs = make_indices_str(indices=running_trial_indices)
self.logger.info(
f"Fetching data for trials: {idcs} because some metrics "
"on experiment are available while trials are running."
)
trial_indices_to_fetch.update(running_trial_indices)
running_trial_indices_with_metrics = new_status_to_trial_idcs.get(
TrialStatus.RUNNING, set()
) | old_status_to_trial_idcs.get(TrialStatus.RUNNING, set())

for status, indices in new_status_to_trial_idcs.items():
if status.is_terminal and indices:
running_trial_indices_with_metrics -= indices

if running_trial_indices_with_metrics:
idcs = make_indices_str(indices=running_trial_indices_with_metrics)
self.logger.debug(
f"Will fetch data for trials: {idcs} because some metrics "
"on experiment are available while trials are running."
)

# Fetch data for previously completed trials that have metrics available
# after trial completion that were completed within the max of the period
# specified by metrics
# Get previously completed trials with new data after completion
recently_completed_trial_indices = self._get_recently_completed_trial_indices()
if len(recently_completed_trial_indices) > 0:
idcs = make_indices_str(indices=recently_completed_trial_indices)
self.logger.info(
f"Fetching data for trials: {idcs} because some metrics "
self.logger.debug(
f"Will fetch data for trials: {idcs} because some metrics "
"on experiment have new data after completion."
)
trial_indices_to_fetch.update(recently_completed_trial_indices)

# Combine all trial indices to fetch data for
trial_indices_to_fetch = (
newly_completed
| running_trial_indices_with_metrics
| recently_completed_trial_indices
)

return trial_indices_to_fetch

def _get_trial_indices_to_fetch(
self, new_status_to_trial_idcs: Mapping[TrialStatus, set[int]]
) -> set[int]:
"""Get trial indices to fetch data for the experiment given
`new_status_to_trial_idcs` and metric properties. This should include:
- newly completed trials
- running trials if the experiment has metrics available while running
- previously completed (or early stopped) trials if the experiment
has metrics with new data after completion which finished recently
Args:
new_status_to_trial_idcs: Changes about to be applied to trial statuses.
Returns:
Set of trial indices to fetch data for.
"""
old_status_to_trial_idcs = {status: set() for status in TrialStatus}

for trial in self.trials:
old_status_to_trial_idcs[trial.status].add(trial.index)

return self._identify_trial_indices_to_fetch(
old_status_to_trial_idcs=old_status_to_trial_idcs,
new_status_to_trial_idcs=new_status_to_trial_idcs,
)

def _get_recently_completed_trial_indices(self) -> set[int]:
"""Get trials that have completed within the max period specified by metrics."""
if len(self.experiment.metrics) == 0:
Expand Down

0 comments on commit ec08fe5

Please sign in to comment.