Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CMA-MAE Sampler (CmaMaeSampler) #173

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open

Conversation

btjanaka
Copy link

@btjanaka btjanaka commented Nov 7, 2024

Contributor Agreements

Please read the contributor agreements and if you agree, please click the checkbox below.

  • I agree to the contributor agreements.

Tip

Please follow the Quick TODO list to smoothly merge your PR.

Motivation

This PR creates the CmaMaeSampler, which provides the CMA-MAE algorithm as implemented in pyribs.

CMA-MAE is a quality diversity algorithm that has demonstrated state-of-the-art performance in a variety of domains. pyribs is a bare-bones Python library for quality diversity optimization algorithms. For a primer on CMA-MAE and pyribs, we recommend referring to the series of pyribs tutorials.

For simplicity, this implementation provides a default instantiation of CMA-MAE with a GridArchive and EvolutionStrategyEmitter with improvement ranking, all wrapped up in a Scheduler.

Description of the changes

  • Added CmaMaeSampler in sampler.py
  • Added an example for using CmaMaeSampler in example.py

TODO List towards PR Merge

Please remove this section if this PR is not an addition of a new package.
Otherwise, please check the following TODO list:

  • Copy ./template/ to create your package
  • Replace <COPYRIGHT HOLDER> in LICENSE of your package with your name
  • Fill out README.md in your package
  • Add import statements of your function or class names to be used in __init__.py
  • (Optional) Add from __future__ import annotations at the head of any Python files that include typing to support older Python versions
  • Apply the formatter based on the tips in README.md
  • Check whether your module works as intended based on the tips in README.md

@btjanaka
Copy link
Author

btjanaka commented Nov 7, 2024

Hi @nabenabe0928, I just started this PR for implementing the CMA-MAE algorithm in Optuna. I am wondering if I am on the right track. So far, I have set up the initialization function. I am wondering how I should handle sample_relative -- I believe there was a part we wanted to check due to the enqueuing of trials.

Currently, I am actually getting an error that I believe is related to sample_relative; roughly, it says

TypeError: argument of type 'FrozenTrial' is not iterable

I'm also available to meet and work through these errors together. Thanks!

@btjanaka btjanaka marked this pull request as draft November 7, 2024 20:13
@nabenabe0928
Copy link
Contributor

@btjanaka
Thank you for the PR!
I will first look into your problem and will get you back on how we should work on this problem!

@nabenabe0928
Copy link
Contributor

nabenabe0928 commented Nov 8, 2024

@btjanaka

I created an example code for batch optimization libraries!
Could you have a look and if you have any difficulties, please tell me!

Note

For the first version, it is probably fine not to support parallel optimization and any function failure.
We can work on the followup in the future:)

Dummy Sampler Code to Be Wrapped
from __future__ import  annotations

import numpy as np


class SamplerToWrap:
    def __init__(self, dim: int, batch_size: int, seed: int | None) -> None:
        self._batch_size = batch_size
        self._dim = dim
        self._rng = np.random.RandomState(seed)
        self._mean = np.zeros(dim, dtype=float)
        self._cov = np.identity(dim, dtype=float)

    def ask(self) -> np.ndarray:
        return self._rng.multivariate_normal(mean=self._mean, cov=self._cov, size=self._batch_size)

    def tell(self, params: np.ndarray, values: np.ndarray) -> None:
        assert len(values.shape) == 1
        assert len(params) == len(values)
        assert params.shape[-1] == self._dim
        # Take quantile so that at least two solutions will be considered as good.
        quantile = max(2 / len(values), 0.1)
        good_value = np.quantile(values, quantile)
        good_params = params[values <= good_value]
        # Take the statistics of good parameters.
        self._mean = np.mean(good_params, axis=0)
        self._cov = np.cov(good_params, rowvar=False)
Code Example
from __future__ import annotations

from typing import Iterable

import numpy as np
import optuna
import optunahub


_logger = optuna.logging.get_logger(f"optuna.{__name__}")
SimpleBaseSampler = optunahub.load_module("samplers/simple", force_reload=False).SimpleBaseSampler


class BatchOptSampler(SimpleBaseSampler):
    def __init__(self, dim: int, batch_size: int = 4, seed: int | None = None) -> None:
        _logger.warning("This sampler does not support parallel optimization now.")
        self._external_sampler = SamplerToWrap(batch_size=batch_size, dim=dim, seed=seed)
        self._batch_size = batch_size
        self._param_names = set([f"x{d}" for d in range(dim)])
        # NOTE: SimpleBaseSampler must know Optuna search_space information.
        search_space = {
            name: optuna.distributions.FloatDistribution(-1e9, 1e9) for name in self._param_names
        }
        super().__init__(search_space=search_space)

        # Store the batch results.
        self._params_to_tell: list[np.ndarray] = []
        self._values_to_tell: list[np.ndarray] = []

    def _validate_param_names(self, given_param_names: Iterable[str]) -> None:
        if any(param_name not in self._param_names for param_name in given_param_names):
            raise ValueError("All param_name must be like ``x0``, ``x1``, ..., and so on.")

    def sample_relative(
        self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
    ) -> dict[str, float]:
        self._validate_param_names(search_space.keys())
        # Ask the next batch.
        params_array = self._external_sampler.ask()

        # Use the first entry as the next parameter.
        next_params = params_array[0]

        # Enqueue the parameters except for the first one.
        for params in params_array[1:]:
            study.enqueue_trial({name: params[d] for d, name in enumerate(self._param_names)})

        # Convert the first entry into the Optuna format and return it.
        return {name: next_params[d] for d, name in enumerate(self._param_names)}

    def after_trial(
        self,
        study: Study,
        trial: FrozenTrial,
        state: TrialState,
        values: Sequence[float] | None,
    ) -> None:
        assert len(self._params_to_tell) == len(self._values_to_tell)
        assert len(values) == 1
        self._validate_param_names(trial.params.keys())

        # Store the trial result.
        self._params_to_tell.append([trial.params[name] for name in self._param_names])
        self._values_to_tell.append(values[0])

        if len(self._values_to_tell) != self._batch_size:
            return

        # Tell the batch results to external sampler once the batch is ready.
        self._external_sampler.tell(
            params=np.asarray(self._params_to_tell), values=np.asarray(self._values_to_tell)
        )
        # Empty the results.
        self._params_to_tell = []
        self._values_to_tell = []
Verification of Example
import optuna

def objective(trial: optuna.Trial) -> float:
    x0 = trial.suggest_float("x0", -5, 5)
    x1 = trial.suggest_float("x1", -5, 5)
    return x0**2 + x1**2

sampler = BatchOptSampler(dim=2, batch_size=5)
study = optuna.create_study(sampler=sampler)
study.optimize(objective, n_trials=50)

@btjanaka
Copy link
Author

btjanaka commented Nov 8, 2024

@nabenabe0928 Thank you for sending the example code. I got it to work with pyribs, but I ran into a couple of small questions. I left them as comments on the PR; would you mind taking a look? Thank you!

Also, how strict are the mypy checks? Do I need to pass them all?

@nabenabe0928
Copy link
Contributor

Some updates will come after the PR below is merged:
btjanaka#1

@y0z y0z added the new-package New packages label Nov 18, 2024
@btjanaka btjanaka marked this pull request as ready for review November 21, 2024 23:40
@btjanaka
Copy link
Author

Hi @nabenabe0928, I think the PR is now ready for review! I will work on the visualization in a separate PR, but I did check that I could visualize the archive using the pyribs grid_archive_heatmap function and the results look good.

@nabenabe0928
Copy link
Contributor

@btjanaka
Hey, thank you for the update!
I will check your changes asap and will merge it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new-package New packages
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants