diff --git a/src/orion/client/experiment.py b/src/orion/client/experiment.py index 8e2180d11..d8724e60a 100644 --- a/src/orion/client/experiment.py +++ b/src/orion/client/experiment.py @@ -28,6 +28,7 @@ from orion.executor.base import Executor from orion.plotting.base import PlotAccessor from orion.storage.base import FailedUpdate +from orion.ext.extensions import OrionExtensionManager log = logging.getLogger(__name__) @@ -87,6 +88,7 @@ def __init__(self, experiment, producer, executor=None, heartbeat=None): **orion.core.config.worker.executor_configuration, ) self.plot = PlotAccessor(self) + self.extensions = OrionExtensionManager() ### # Attributes @@ -753,22 +755,54 @@ def workon( self._experiment.max_trials = max_trials self._experiment.algorithms.algorithm.max_trials = max_trials - trials = self.executor.wait( - self.executor.submit( - self._optimize, - fct, - pool_size, - max_trials_per_worker, - max_broken, - trial_arg, - on_error, - **kwargs, + with self.extensions.experiment(self._experiment): + trials = self.executor.wait( + self.executor.submit( + self._optimize, + fct, + pool_size, + max_trials_per_worker, + max_broken, + trial_arg, + on_error, + **kwargs, + ) + for _ in range(n_workers) ) - for _ in range(n_workers) - ) return sum(trials) + def _optimize_trial(self, fct, trial, trial_arg, kwargs, worker_broken_trials, max_broken, on_error): + kwargs.update(flatten(trial.params)) + + if trial_arg: + kwargs[trial_arg] = trial + + try: + with self.extensions.trial(trial): + results = self.executor.wait( + [self.executor.submit(fct, **unflatten(kwargs))] + )[0] + self.observe(trial, results=results) + except (KeyboardInterrupt, InvalidResult): + raise + except BaseException as e: + if on_error is None or on_error(self, trial, e, worker_broken_trials): + log.error(traceback.format_exc()) + worker_broken_trials += 1 + else: + log.error(str(e)) + log.debug(traceback.format_exc()) + + if worker_broken_trials >= max_broken: + raise BrokenExperiment( + "Worker has reached broken trials threshold" + ) + else: + self.release(trial, status="broken") + + return worker_broken_trials + def _optimize( self, fct, pool_size, max_trials, max_broken, trial_arg, on_error, **kwargs ): @@ -776,43 +810,26 @@ def _optimize( trials = 0 kwargs = flatten(kwargs) max_trials = min(max_trials, self.max_trials) + while not self.is_done and trials - worker_broken_trials < max_trials: - try: - with self.suggest(pool_size=pool_size) as trial: - - kwargs.update(flatten(trial.params)) - - if trial_arg: - kwargs[trial_arg] = trial - - try: - results = self.executor.wait( - [self.executor.submit(fct, **unflatten(kwargs))] - )[0] - self.observe(trial, results=results) - except (KeyboardInterrupt, InvalidResult): - raise - except BaseException as e: - if on_error is None or on_error( - self, trial, e, worker_broken_trials - ): - log.error(traceback.format_exc()) - worker_broken_trials += 1 - else: - log.error(str(e)) - log.debug(traceback.format_exc()) - - if worker_broken_trials >= max_broken: - raise BrokenExperiment( - "Worker has reached broken trials threshold" - ) - else: - self.release(trial, status="broken") - except CompletedExperiment as e: - log.warning(e) - break - - trials += 1 + try: + with self.suggest(pool_size=pool_size) as trial: + + worker_broken_trials = self._optimize_trial( + fct, + trial, + trial_arg, + kwargs, + worker_broken_trials, + max_broken, + on_error + ) + + except CompletedExperiment as e: + log.warning(e) + break + + trials += 1 return trials diff --git a/src/orion/ext/extensions.py b/src/orion/ext/extensions.py new file mode 100644 index 000000000..05abbd92d --- /dev/null +++ b/src/orion/ext/extensions.py @@ -0,0 +1,209 @@ +"""Defines extension mechanism for third party to hook into Orion""" + + +class EventDelegate: + """Allow extensions to listen to incoming events from Orion. + Orion broadcasts events which trigger extensions callbacks. + + Parameters + ---------- + name: str + name of the event we are creating, this is useful for error reporting + + deferred: bool + if false events are triggered as soon as broadcast is called + if true the events will need to be triggered manually + """ + def __init__(self, name, deferred=False) -> None: + self.handlers = [] + self.deferred_calls = [] + self.name = name + self.deferred = deferred + self.bad_handlers = [] + self.manager = None + + def remove(self, function) -> bool: + """Remove an event handler from the handler list""" + try: + self.handlers.remove(function) + return True + except ValueError: + return False + + def add(self, function): + """Add an event handler to our handler list""" + self.handlers.append(function) + + def broadcast(self, *args, **kwargs): + """Broadcast and event to all our handlers""" + if not self.deferred: + self._execute(args, kwargs) + return + + self.deferred_calls.append((args, kwargs)) + + def _execute(self, args, kwargs): + for fun in self.handlers: + try: + fun(*args, **kwargs) + except Exception as err: + if self.manager: + self.manager.on_extension_error.broadcast(self.name, fun, err, args=(args, kwargs)) + + def execute(self): + """Execute all our deferred handlers if any""" + for args, kwargs in self.deferred_calls: + self._execute(args, kwargs) + + +class _DelegateStartEnd: + def __init__(self, start, error, end, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.start = start + self.end = end + self.error = error + + def __enter__(self): + self.start.broadcast(*self.args, **self.kwargs) + return self + + def __exit__(self, exception_type, exception_value, exception_traceback): + self.end.broadcast(*self.args, **self.kwargs) + + if exception_value is not None: + self.error.broadcast( + *self.args, + exception_type, + exception_value, + exception_traceback, + **self.kwargs + ) + + +class OrionExtensionManager: + """Manages third party extensions for Orion""" + + def __init__(self): + self._events = {} + self._get_event('on_extension_error') + + # -- Trials + self._get_event('new_trial') + self._get_event('on_trial_error') + self._get_event('end_trial') + + # -- Experiments + self._get_event('start_experiment') + self._get_event('on_experiment_error') + self._get_event('end_experiment') + + def experiment(self, *args, **kwargs): + """Initialize a context manager that will call start/error/end events automatically""" + return _DelegateStartEnd( + self.start_experiment, + self.on_experiment_error, + self.end_experiment, + *args, + **kwargs + ) + + def trial(self, *args, **kwargs): + """Initialize a context manager that will call start/error/end events automatically""" + return _DelegateStartEnd( + self.new_trial, + self.on_trial_error, + self.end_trial, + *args, + **kwargs + ) + + def __getattr__(self, name): + if name in self._events: + return self._get_event(name) + + def _get_event(self, key): + """Retrieve or generate a new event delegate""" + delegate = self._events.get(key) + + if delegate is None: + delegate = EventDelegate(key) + delegate.manager = self + self._events[key] = delegate + + return delegate + + def register(self, ext): + """Register a new extensions + + Parameters + ---------- + ext + object implementing :class`OrionExtension` methods + + Returns + ------- + the number of calls that was registered + """ + registered_callbacks = 0 + for name, delegate in self._events.items(): + if hasattr(ext, name): + delegate.add(getattr(ext, name)) + registered_callbacks += 1 + + return registered_callbacks + + def unregister(self, ext): + """Remove an extensions if it was already registered""" + unregistered_callbacks = 0 + for name, delegate in self._events.items(): + if hasattr(ext, name): + delegate.remove(getattr(ext, name)) + unregistered_callbacks += 1 + + return unregistered_callbacks + + +class OrionExtension: + """Base orion extension interface you need to implement""" + + def on_extension_error(self, name, fun, exception, args): + """Called when an extension callbakc raise an exception + + Parameters + ---------- + fun: callable + handler that raised the error + + exception: + raised exception + + args: tuple + tuple of the arguments that were used + """ + return + + def on_trial_error(self, trial, exception_type, exception_value, exception_traceback): + """Called when a error occur during the optimization process""" + return + + def new_trial(self, trial): + """Called when the trial starts with a new configuration""" + return + + def end_trial(self, trial): + """Called when the trial finished""" + return + + def on_experiment_error(self, experiment, exception_type, exception_value, exception_traceback): + """Called when a error occur during the optimization process""" + return + + def start_experiment(self, experiment): + """Called at the begin of the optimization process before the worker starts""" + return + + def end_experiment(self, experiment): + """Called at the end of the optimization process after the worker exits""" + return + diff --git a/tests/unittests/ext/test_extension.py b/tests/unittests/ext/test_extension.py new file mode 100644 index 000000000..7d205b942 --- /dev/null +++ b/tests/unittests/ext/test_extension.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +"""Example usage and tests for :mod:`orion.client.experiment`.""" +from collections import defaultdict + +import pytest + +from orion.core.utils.exceptions import BrokenExperiment +from orion.testing import create_experiment + +config = dict( + name="supernaekei", + space={"x": "uniform(0, 200)"}, + metadata={ + "user": "tsirif", + "orion_version": "XYZ", + "VCS": { + "type": "git", + "is_dirty": False, + "HEAD_sha": "test", + "active_branch": None, + "diff_sha": "diff", + }, + }, + version=1, + max_trials=10, + max_broken=5, + working_dir="", + algorithms={"random": {"seed": 1}}, + producer={"strategy": "NoParallelStrategy"}, + refers=dict(root_id="supernaekei", parent_id=None, adapter=[]), +) + +base_trial = { + "experiment": 0, + "status": "new", # new, reserved, suspended, completed, broken + "worker": None, + "start_time": None, + "end_time": None, + "heartbeat": None, + "results": [], + "params": [], +} + +class OrionExtensionTest: + """Base orion extension interface you need to implement""" + def __init__(self) -> None: + self.calls = defaultdict(int) + + def on_experiment_error(self, *args, **kwargs): + self.calls['on_experiment_error'] += 1 + + def on_trial_error(self, *args, **kwargs): + self.calls['on_trial_error'] += 1 + + def start_experiment(self, *args, **kwargs): + self.calls['start_experiment'] += 1 + + def new_trial(self, *args, **kwargs): + self.calls['new_trial'] += 1 + + def end_trial(self, *args, **kwargs): + self.calls['end_trial'] += 1 + + def end_experiment(self, *args, **kwargs): + self.calls['end_experiment'] += 1 + + +def test_client_extension(): + ext = OrionExtensionTest() + with create_experiment(config, base_trial) as (cfg, experiment, client): + registered_callback = client.extensions.register(ext) + assert registered_callback == 6, "All ext callbacks got registered" + + def foo(x): + if len(client.fetch_trials()) > 5: + raise RuntimeError() + return [dict(name="result", type="objective", value=x * 2)] + + MAX_TRIALS = 10 + MAX_BROKEN = 5 + assert client.max_trials == MAX_TRIALS + + with pytest.raises(BrokenExperiment): + client.workon(foo, max_trials=MAX_TRIALS, max_broken=MAX_BROKEN) + + n_trials = len(experiment.fetch_trials_by_status("completed")) + n_broken = len(experiment.fetch_trials_by_status("broken")) + n_reserved = len(experiment.fetch_trials_by_status("reserved")) + + assert ext.calls['new_trial'] == n_trials + n_broken - n_reserved, 'all trials should have triggered callbacks' + assert ext.calls['end_trial'] == n_trials + n_broken - n_reserved, 'all trials should have triggered callbacks' + assert ext.calls['on_trial_error'] == n_broken, 'failed trial should be reported ' + + assert ext.calls['start_experiment'] == 1, 'experiment should have started' + assert ext.calls['end_experiment'] == 1, 'experiment should have ended' + assert ext.calls['on_experiment_error'] == 1, 'failed experiment ' + + unregistered_callback = client.extensions.unregister(ext) + assert unregistered_callback == 6, "All ext callbacks got unregistered" + + +class BadOrionExtensionTest: + """Base orion extension interface you need to implement""" + def __init__(self) -> None: + self.calls = defaultdict(int) + + def on_extension_error(self, name, fun, exception, args): + self.calls['on_extension_error'] += 1 + + def on_experiment_error(self, *args, **kwargs): + self.calls['on_experiment_error'] += 1 + + def on_trial_error(self, *args, **kwargs): + self.calls['on_trial_error'] += 1 + + def new_trial(self, *args, **kwargs): + raise RuntimeError() + + +def test_client_bad_extension(): + ext = BadOrionExtensionTest() + with create_experiment(config, base_trial) as (cfg, experiment, client): + registered_callback = client.extensions.register(ext) + assert registered_callback == 4, "All ext callbacks got registered" + + def foo(x): + return [dict(name="result", type="objective", value=x * 2)] + + MAX_TRIALS = 10 + MAX_BROKEN = 5 + assert client.max_trials == MAX_TRIALS + client.workon(foo, max_trials=MAX_TRIALS, max_broken=MAX_BROKEN) + + assert ext.calls['on_trial_error'] == 0, 'Orion worked as expected' + assert ext.calls['on_experiment_error'] == 0, 'Orion worked as expected' + assert ext.calls['on_extension_error'] == 9, 'Extension error got reported' + + unregistered_callback = client.extensions.unregister(ext) + assert unregistered_callback == 4, "All ext callbacks got unregistered"