Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: optimizers #167

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open

refactor: optimizers #167

wants to merge 12 commits into from

Conversation

eddiebergman
Copy link
Contributor

WIP

Some notes:

  • NePS only cares that whatever you pass in as an optimizers, it has this signature, no need for a base class or anything. This is the only requirement for an optimizer and should make developing new optimizers much more straight-forward.
def __call__(
    self,
    trials: Mapping[str, Trial],
    budget_info: BudgetInfo | None,
    n: int | None = None,
) -> SampledConfig | list[SampledConfig]: ...
  • There are no longer any class heirarchies in the optimizers, no need to track down super() calls that call the base class, which calls the subclass, which calls the base class and so on. In fact most of the optimizers are now a dataclass which just looks like this.
@dataclass
class MyOpt:
    pipeline_space: SearchSpace
    prior: Prior | None
    # ... other attributes
    
    def __call__(
        self,
        trials: Mapping[str, Trial],
        budget_info: BudgetInfo | None,
        n: int | None = None,
    ) -> SampledConfig | list[SampledConfig]:
        # A high level overview of what the optimizer is actually doing logically,
        # calling out to functions where needed.
        # Calls out to other functions where needed
  • Replacing the actual heirarchies is just re-usable components. All of the combination of bracket types (sh, asha, hb, async-hb) and the various samplers (uniform, prior, priorband) are just supported with various ways to construct a BracketOptimizer
def bracket_optimizer(  # noqa: C901
    *,
    pipeline_space: SearchSpace,
    bracket_type: Literal["successive_halving", "hyperband", "asha", "async_hb"],
    eta: int = 3,
    early_stopping_rate: int = 0,
    sampler: Literal["uniform", "prior", "priorband"] = "uniform",
    sample_prior_first: bool | Literal["highest_fidelity"] = False,
) -> BracketOptimizer:
    """Initialise a bracket optimizer.

    Args:
        pipeline_space: Space in which to search
        bracket_type: The type of bracket to use. Can be one of:

            * "successive_halving": Successive Halving
            * "hyperband": HyperBand
            * "asha": ASHA
            * "async_hb": Async

        eta: The reduction factor used for building brackets
        early_stopping_rate: Determines the number of rungs in a bracket
            Choosing 0 creates maximal rungs given the fidelity bounds.

            !!! warning

                This is only used for Successive Halving and Asha.

        sampler: The type of sampling procedure to use:

            * If "uniform", samples uniformly from the space when it needs to sample
            * If "prior", samples from the prior distribution built from the prior
              and prior_confidence values in the pipeline space.
            * If "priorband", samples with weights according to the PriorBand
                algorithm. See: https://arxiv.org/abs/2306.12370

        sample_prior_first: Whether to sample the prior configuration first.
    """
    assert pipeline_space.fidelity is not None
    assert pipeline_space.fidelity_name is not None
    if len(pipeline_space.fidelities) != 1:
        raise ValueError(
            "Fidelity should be defined in the pipeline space."
            f"\nGot: {pipeline_space.fidelities}"
        )

    if sample_prior_first not in (True, False, "highest_fidelity"):
        raise ValueError(
            "sample_prior_first should be either True, False or 'highest_fidelity'"
        )

    from neps.optimizers.utils import brackets

    match bracket_type:
        case "successive_halving":
            rung_to_fidelity, rung_sizes = brackets.calculate_sh_rungs(
                bounds=(pipeline_space.fidelity.lower, pipeline_space.fidelity.upper),
                eta=eta,
                early_stopping_rate=early_stopping_rate,
            )
            create_brackets = partial(
                brackets.Sync.create_repeating, rung_sizes=rung_sizes
            )
        case "hyperband":
            rung_to_fidelity, bracket_layouts = brackets.calculate_hb_bracket_layouts(
                bounds=(pipeline_space.fidelity.lower, pipeline_space.fidelity.upper),
                eta=eta,
            )
            create_brackets = partial(
                brackets.Hyperband.create_repeating,
                bracket_layouts=bracket_layouts,
            )
        case "asha":
            rung_to_fidelity, _rung_sizes = brackets.calculate_sh_rungs(
                bounds=(pipeline_space.fidelity.lower, pipeline_space.fidelity.upper),
                eta=eta,
                early_stopping_rate=early_stopping_rate,
            )
            create_brackets = partial(
                brackets.Async.create, rungs=list(rung_to_fidelity), eta=eta
            )
        case "async_hb":
            rung_to_fidelity, bracket_layouts = brackets.calculate_hb_bracket_layouts(
                bounds=(pipeline_space.fidelity.lower, pipeline_space.fidelity.upper),
                eta=eta,
            )
            # We don't care about the capacity of each bracket, we need the rung layout
            bracket_rungs = [list(bracket.keys()) for bracket in bracket_layouts]
            create_brackets = partial(
                brackets.AsyncHyperband.create,
                bracket_rungs=bracket_rungs,
                eta=eta,
            )
        case _:
            raise ValueError(f"Unknown bracket type: {bracket_type}")

    encoder = ConfigEncoder.from_space(pipeline_space, include_fidelity=False)

    match sampler:
        case "uniform":
            _sampler = Sampler.uniform(ndim=encoder.ndim)
        case "prior":
            _sampler = Prior.from_config(
                pipeline_space.prior_config, space=pipeline_space
            )
        case "priorband":
            _sampler = PriorBandArgs(mutation_rate=0.5, mutation_std=0.25)
        case PriorBandArgs() | Sampler():
            _sampler = sampler
        case _:
            raise ValueError(f"Unknown sampler: {sampler}")

    return BracketOptimizer(
        pipeline_space=pipeline_space,
        encoder=encoder,
        eta=eta,
        rung_to_fid=rung_to_fidelity,
        fid_min=pipeline_space.fidelity.lower,
        fid_max=pipeline_space.fidelity.upper,
        fid_name=pipeline_space.fidelity_name,
        sampler=_sampler,
        sample_prior_first=sample_prior_first,
        create_brackets=create_brackets,
    )
  • A whole lot of optimizations in terms of speed from not having to go back and forth between SearchSpace as a config. This is probably only relevant for benchmarking

TODO

  • Model based sampling for the BracketOptimizer, previously a mixin class called MFBOBase. This can mostly just re-use the things from gp and is essentially just a Sampler that can be used. Here's a snippet from the BracketOptimizer. It we would essentially just use it as the sampler when the bracket says next_action=("new", rung), signalling that a new sample has to happen at a given rung.
        # Get and execute the next action from our brackets that are not pending or done
        brackets = self.create_brackets(table)
        if not isinstance(brackets, Sequence):
            brackets = [brackets]

        next_action = next(
            (
                action
                for bracket in brackets
                if (action := bracket.next()) not in ("done", "pending")
            ),
            None,
        )

        if next_action is None:
            raise RuntimeError(
                f"{self.__class__.__name__} never got a 'sample' or 'pending' action!"
            )

        match next_action:
            case ("promote", config, config_id, new_rung):
                config = {**config, self.fid_name: self.rung_to_fid[new_rung]}
                return SampledConfig(
                    id=f"{config_id}_{new_rung}",
                    config=config,
                    previous_config_id=f"{config_id}_{new_rung - 1}",
                )
            case ("new", rung):
                match self.sampler:
                    case Sampler():
                        config = self.sampler.sample_config(to=self.encoder)
                        config = {**config, self.fid_name: self.rung_to_fid[rung]}
                        return SampledConfig(id=f"{nxt_id}_{rung}", config=config)
                    case PriorBandArgs():
                        config = sample_with_priorband(
                            table=table,
                            space=space,
                            rung_to_sample_for=rung,
                            fid_bounds=(self.fid_min, self.fid_max),
                            encoder=self.encoder,
                            inc_mutation_rate=self.sampler.mutation_rate,
                            inc_mutation_std=self.sampler.mutation_std,
                            eta=self.eta,
                            seed=None,  # TODO
                        )
                        config = {**config, self.fid_name: self.rung_to_fid[rung]}
                        return SampledConfig(id=f"{nxt_id}_{rung}", config=config)
                    case _:
                        raise RuntimeError(f"Unknown sampler: {self.sampler}")
            case _:
                raise RuntimeError(f"Unknown bracket action: {next_action}")

Replaced with a partial of `SuccessiveHalvingBase`
The only thing inheriting from it was `PriorBand`, which ended up
replacing all the defaults of it's `__init__()`. The only other thing
the `HyperbandCustomDefault.__init__()` did was change the sampling args
of the SH brackets, which then `PriorBand` would overwrite it its own
`__init__()`.
The only user of the was `PriorBand`, in the only
thing it did was explicitly pass `use_priors=`True`.
Everything else it set was overwritten by `PriorBand`
passing down its args.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

1 participant