Skip to content

Commit

Permalink
Improve performance in parallel case (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens authored Jan 19, 2024
1 parent a94c8b5 commit 0f3d567
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 11 deletions.
30 changes: 22 additions & 8 deletions src/tranquilo/acceptance_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from tranquilo.options import AcceptanceOptions


def get_acceptance_decider(acceptance_decider, acceptance_options):
def get_acceptance_decider(
acceptance_decider,
acceptance_options,
):
func_dict = {
"classic": _accept_classic,
"naive_noisy": accept_naive_noisy,
Expand Down Expand Up @@ -92,11 +95,11 @@ def accept_classic_line_search(
state,
history,
*,
speculative_sampling_radius_factor,
wrapped_criterion,
min_improvement,
batch_size,
sample_points,
search_radius_factor,
rng,
):
# ==================================================================================
Expand Down Expand Up @@ -144,11 +147,12 @@ def accept_classic_line_search(
if n_unallocated_evals > 0:
speculative_xs = _generate_speculative_sample(
new_center=candidate_x,
search_radius_factor=search_radius_factor,
radius_factor=speculative_sampling_radius_factor,
trustregion=state.trustregion,
sample_points=sample_points,
n_points=n_unallocated_evals,
history=history,
line_search_xs=line_search_xs,
rng=rng,
)
else:
Expand Down Expand Up @@ -427,7 +431,14 @@ def calculate_rho(actual_improvement, expected_improvement):


def _generate_speculative_sample(
new_center, trustregion, sample_points, n_points, history, search_radius_factor, rng
new_center,
trustregion,
sample_points,
n_points,
history,
line_search_xs,
radius_factor,
rng,
):
"""Generative a speculative sample.
Expand All @@ -437,23 +448,26 @@ def _generate_speculative_sample(
sample_points (callable): Function to sample points.
n_points (int): Number of points to sample.
history (History): Tranquilo history.
search_radius_factor (float): Factor to multiply the trust region radius by to
get the search radius.
radius_factor (float): Factor to multiply the trust region radius by to get the
radius of the region from which to draw the speculative sample.
rng (np.random.Generator): Random number generator.
Returns:
np.ndarray: Speculative sample.
"""
search_region = trustregion._replace(
center=new_center, radius=search_radius_factor * trustregion.radius
center=new_center, radius=radius_factor * trustregion.radius
)

old_indices = history.get_x_indices_in_region(search_region)

old_xs = history.get_xs(old_indices)

model_xs = old_xs
if line_search_xs is not None:
model_xs = np.row_stack([old_xs, line_search_xs])
else:
model_xs = old_xs

new_xs = sample_points(
search_region,
Expand Down
6 changes: 4 additions & 2 deletions src/tranquilo/filter_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def keep_all(xs, indices):
return xs, indices


def drop_excess(xs, indices, state, target_size):
n_to_drop = max(0, len(xs) - target_size)
def drop_excess(xs, indices, state, target_size, n_max_factor):
filter_target_size = int(np.floor(target_size * n_max_factor))

n_to_drop = max(0, len(xs) - filter_target_size)

if n_to_drop:
xs, indices = drop_worst_points(xs, indices, state, n_to_drop)
Expand Down
2 changes: 2 additions & 0 deletions src/tranquilo/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class AcceptanceOptions(NamedTuple):
n_min: int = 4
n_max: int = 50
min_improvement: float = 0.0
speculative_sampling_radius_factor: float = 0.75


class StagnationOptions(NamedTuple):
Expand Down Expand Up @@ -179,6 +180,7 @@ class VarianceEstimatorOptions(NamedTuple):
class FilterOptions(NamedTuple):
strictness: float = 1e-10
shape: str = "sphere"
n_max_factor: int = 3


class SamplerOptions(NamedTuple):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_acceptance_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def test_generate_speculative_sample():
sample_points=get_sampler("random_hull"),
n_points=3,
history=history,
search_radius_factor=1.0,
radius_factor=1.0,
line_search_xs=None,
rng=np.random.default_rng(1234),
)

Expand Down
53 changes: 53 additions & 0 deletions tests/test_filter_points.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from tranquilo.filter_points import get_sample_filter
from tranquilo.filter_points import drop_worst_points
from tranquilo.tranquilo import State
from tranquilo.region import Region
from numpy.testing import assert_array_equal as aae
Expand Down Expand Up @@ -46,3 +47,55 @@ def test_keep_all():
got_xs, got_idxs = filter(xs=xs, indices=indices, state=None)
aae(got_xs, xs)
aae(got_idxs, indices)


def test_drop_worst_point(state):
xs = np.array(
[
[1, 1.1], # should be dropped
[1, 1.2],
[1, 1], # center (needs to have index=2)
[3, 3], # should be dropped
]
)

got_xs, got_indices = drop_worst_points(
xs, indices=np.arange(4), state=state, n_to_drop=2
)

expected_xs = np.array(
[
[1, 1.2],
[1, 1],
]
)
expected_indices = np.array([1, 2])

aae(got_xs, expected_xs)
aae(got_indices, expected_indices)


def test_drop_excess(state):
filter = get_sample_filter("drop_excess", user_options={"n_max_factor": 1.0})

xs = np.array(
[
[1, 1.1], # should be dropped
[1, 1.2],
[1, 1], # center (needs to have index=2)
[3, 3], # should be dropped
]
)

got_xs, got_indices = filter(xs, indices=np.arange(4), state=state, target_size=2)

expected_xs = np.array(
[
[1, 1.2],
[1, 1],
]
)
expected_indices = np.array([1, 2])

aae(got_xs, expected_xs)
aae(got_indices, expected_indices)

0 comments on commit 0f3d567

Please sign in to comment.