Skip to content

Commit

Permalink
Add Orion extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
bouthilx authored and Delaunay committed Oct 5, 2021
1 parent da00153 commit 7e8bd78
Show file tree
Hide file tree
Showing 3 changed files with 414 additions and 48 deletions.
113 changes: 65 additions & 48 deletions src/orion/client/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from orion.executor.base import Executor
from orion.plotting.base import PlotAccessor
from orion.storage.base import FailedUpdate
from orion.ext.extensions import OrionExtensionManager

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(self, experiment, producer, executor=None, heartbeat=None):
**orion.core.config.worker.executor_configuration,
)
self.plot = PlotAccessor(self)
self.extensions = OrionExtensionManager()

###
# Attributes
Expand Down Expand Up @@ -753,66 +755,81 @@ def workon(
self._experiment.max_trials = max_trials
self._experiment.algorithms.algorithm.max_trials = max_trials

trials = self.executor.wait(
self.executor.submit(
self._optimize,
fct,
pool_size,
max_trials_per_worker,
max_broken,
trial_arg,
on_error,
**kwargs,
with self.extensions.experiment(self._experiment):
trials = self.executor.wait(
self.executor.submit(
self._optimize,
fct,
pool_size,
max_trials_per_worker,
max_broken,
trial_arg,
on_error,
**kwargs,
)
for _ in range(n_workers)
)
for _ in range(n_workers)
)

return sum(trials)

def _optimize_trial(self, fct, trial, trial_arg, kwargs, worker_broken_trials, max_broken, on_error):
kwargs.update(flatten(trial.params))

if trial_arg:
kwargs[trial_arg] = trial

try:
with self.extensions.trial(trial):
results = self.executor.wait(
[self.executor.submit(fct, **unflatten(kwargs))]
)[0]
self.observe(trial, results=results)
except (KeyboardInterrupt, InvalidResult):
raise
except BaseException as e:
if on_error is None or on_error(self, trial, e, worker_broken_trials):
log.error(traceback.format_exc())
worker_broken_trials += 1
else:
log.error(str(e))
log.debug(traceback.format_exc())

if worker_broken_trials >= max_broken:
raise BrokenExperiment(
"Worker has reached broken trials threshold"
)
else:
self.release(trial, status="broken")

return worker_broken_trials

def _optimize(
self, fct, pool_size, max_trials, max_broken, trial_arg, on_error, **kwargs
):
worker_broken_trials = 0
trials = 0
kwargs = flatten(kwargs)
max_trials = min(max_trials, self.max_trials)

while not self.is_done and trials - worker_broken_trials < max_trials:
try:
with self.suggest(pool_size=pool_size) as trial:

kwargs.update(flatten(trial.params))

if trial_arg:
kwargs[trial_arg] = trial

try:
results = self.executor.wait(
[self.executor.submit(fct, **unflatten(kwargs))]
)[0]
self.observe(trial, results=results)
except (KeyboardInterrupt, InvalidResult):
raise
except BaseException as e:
if on_error is None or on_error(
self, trial, e, worker_broken_trials
):
log.error(traceback.format_exc())
worker_broken_trials += 1
else:
log.error(str(e))
log.debug(traceback.format_exc())

if worker_broken_trials >= max_broken:
raise BrokenExperiment(
"Worker has reached broken trials threshold"
)
else:
self.release(trial, status="broken")
except CompletedExperiment as e:
log.warning(e)
break

trials += 1
try:
with self.suggest(pool_size=pool_size) as trial:

worker_broken_trials = self._optimize_trial(
fct,
trial,
trial_arg,
kwargs,
worker_broken_trials,
max_broken,
on_error
)

except CompletedExperiment as e:
log.warning(e)
break

trials += 1

return trials

Expand Down
209 changes: 209 additions & 0 deletions src/orion/ext/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""Defines extension mechanism for third party to hook into Orion"""


class EventDelegate:
"""Allow extensions to listen to incoming events from Orion.
Orion broadcasts events which trigger extensions callbacks.
Parameters
----------
name: str
name of the event we are creating, this is useful for error reporting
deferred: bool
if false events are triggered as soon as broadcast is called
if true the events will need to be triggered manually
"""
def __init__(self, name, deferred=False) -> None:
self.handlers = []
self.deferred_calls = []
self.name = name
self.deferred = deferred
self.bad_handlers = []
self.manager = None

def remove(self, function) -> bool:
"""Remove an event handler from the handler list"""
try:
self.handlers.remove(function)
return True
except ValueError:
return False

def add(self, function):
"""Add an event handler to our handler list"""
self.handlers.append(function)

def broadcast(self, *args, **kwargs):
"""Broadcast and event to all our handlers"""
if not self.deferred:
self._execute(args, kwargs)
return

self.deferred_calls.append((args, kwargs))

def _execute(self, args, kwargs):
for fun in self.handlers:
try:
fun(*args, **kwargs)
except Exception as err:
if self.manager:
self.manager.on_extension_error.broadcast(self.name, fun, err, args=(args, kwargs))

def execute(self):
"""Execute all our deferred handlers if any"""
for args, kwargs in self.deferred_calls:
self._execute(args, kwargs)


class _DelegateStartEnd:
def __init__(self, start, error, end, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.start = start
self.end = end
self.error = error

def __enter__(self):
self.start.broadcast(*self.args, **self.kwargs)
return self

def __exit__(self, exception_type, exception_value, exception_traceback):
self.end.broadcast(*self.args, **self.kwargs)

if exception_value is not None:
self.error.broadcast(
*self.args,
exception_type,
exception_value,
exception_traceback,
**self.kwargs
)


class OrionExtensionManager:
"""Manages third party extensions for Orion"""

def __init__(self):
self._events = {}
self._get_event('on_extension_error')

# -- Trials
self._get_event('new_trial')
self._get_event('on_trial_error')
self._get_event('end_trial')

# -- Experiments
self._get_event('start_experiment')
self._get_event('on_experiment_error')
self._get_event('end_experiment')

def experiment(self, *args, **kwargs):
"""Initialize a context manager that will call start/error/end events automatically"""
return _DelegateStartEnd(
self.start_experiment,
self.on_experiment_error,
self.end_experiment,
*args,
**kwargs
)

def trial(self, *args, **kwargs):
"""Initialize a context manager that will call start/error/end events automatically"""
return _DelegateStartEnd(
self.new_trial,
self.on_trial_error,
self.end_trial,
*args,
**kwargs
)

def __getattr__(self, name):
if name in self._events:
return self._get_event(name)

def _get_event(self, key):
"""Retrieve or generate a new event delegate"""
delegate = self._events.get(key)

if delegate is None:
delegate = EventDelegate(key)
delegate.manager = self
self._events[key] = delegate

return delegate

def register(self, ext):
"""Register a new extensions
Parameters
----------
ext
object implementing :class`OrionExtension` methods
Returns
-------
the number of calls that was registered
"""
registered_callbacks = 0
for name, delegate in self._events.items():
if hasattr(ext, name):
delegate.add(getattr(ext, name))
registered_callbacks += 1

return registered_callbacks

def unregister(self, ext):
"""Remove an extensions if it was already registered"""
unregistered_callbacks = 0
for name, delegate in self._events.items():
if hasattr(ext, name):
delegate.remove(getattr(ext, name))
unregistered_callbacks += 1

return unregistered_callbacks


class OrionExtension:
"""Base orion extension interface you need to implement"""

def on_extension_error(self, name, fun, exception, args):
"""Called when an extension callbakc raise an exception
Parameters
----------
fun: callable
handler that raised the error
exception:
raised exception
args: tuple
tuple of the arguments that were used
"""
return

def on_trial_error(self, trial, exception_type, exception_value, exception_traceback):
"""Called when a error occur during the optimization process"""
return

def new_trial(self, trial):
"""Called when the trial starts with a new configuration"""
return

def end_trial(self, trial):
"""Called when the trial finished"""
return

def on_experiment_error(self, experiment, exception_type, exception_value, exception_traceback):
"""Called when a error occur during the optimization process"""
return

def start_experiment(self, experiment):
"""Called at the begin of the optimization process before the worker starts"""
return

def end_experiment(self, experiment):
"""Called at the end of the optimization process after the worker exits"""
return

Loading

0 comments on commit 7e8bd78

Please sign in to comment.