Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Orion Extension concept [OC-343] #673

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,5 @@ target/

# Notebooks
tests/**.ipynb
dask-worker-space/
tests/functional/commands/*.json
1 change: 1 addition & 0 deletions docs/src/code/client.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Client helper functions
client/cli
client/experiment
client/manual
client/extensions

.. automodule:: orion.client
:members:
7 changes: 7 additions & 0 deletions docs/src/code/client/extensions.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Extensions
==========

.. automodule:: orion.ext.extensions
:members:


1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"orion.client",
"orion.core",
"orion.executor",
"orion.ext",
"orion.plotting",
"orion.serving",
"orion.storage",
Expand Down
53 changes: 37 additions & 16 deletions src/orion/client/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from orion.core.worker.trial import Trial, TrialCM
from orion.core.worker.trial_pacemaker import TrialPacemaker
from orion.executor.base import Executor
from orion.ext.extensions import OrionExtensionManager
from orion.plotting.base import PlotAccessor
from orion.storage.base import FailedUpdate

Expand Down Expand Up @@ -72,6 +73,11 @@ class ExperimentClient:
producer: `orion.core.worker.producer.Producer`
Producer object used to produce new trials.

Notes
-----

Users can write generic extensions to ExperimentClient through
`orion.client.experiment.OrionExtension`.
"""

def __init__(self, experiment, producer, executor=None, heartbeat=None):
Expand All @@ -87,6 +93,7 @@ def __init__(self, experiment, producer, executor=None, heartbeat=None):
**orion.core.config.worker.executor_configuration,
)
self.plot = PlotAccessor(self)
self.extensions = OrionExtensionManager()

###
# Attributes
Expand Down Expand Up @@ -320,6 +327,16 @@ def fetch_noncompleted_trials(self, with_evc_tree=False):
###
# Actions
###
def register_extension(self, ext):
"""Register a third party extension

Parameters
----------
ext: OrionExtension
object that implements the OrionExtension interface

"""
return self.extensions.register(ext)

# pylint: disable=unused-argument
def insert(self, params, results=None, reserve=False):
Expand Down Expand Up @@ -753,19 +770,20 @@ def workon(
self._experiment.max_trials = max_trials
self._experiment.algorithms.algorithm.max_trials = max_trials

trials = self.executor.wait(
self.executor.submit(
self._optimize,
fct,
pool_size,
max_trials_per_worker,
max_broken,
trial_arg,
on_error,
**kwargs,
with self.extensions.experiment(self._experiment):
trials = self.executor.wait(
self.executor.submit(
self._optimize,
fct,
pool_size,
max_trials_per_worker,
max_broken,
trial_arg,
on_error,
**kwargs,
)
for _ in range(n_workers)
)
for _ in range(n_workers)
)

return sum(trials)

Expand All @@ -776,6 +794,7 @@ def _optimize(
trials = 0
kwargs = flatten(kwargs)
max_trials = min(max_trials, self.max_trials)

while not self.is_done and trials - worker_broken_trials < max_trials:
try:
with self.suggest(pool_size=pool_size) as trial:
Expand All @@ -786,10 +805,11 @@ def _optimize(
kwargs[trial_arg] = trial

try:
results = self.executor.wait(
[self.executor.submit(fct, **unflatten(kwargs))]
)[0]
self.observe(trial, results=results)
with self.extensions.trial(trial):
results = self.executor.wait(
[self.executor.submit(fct, **unflatten(kwargs))]
)[0]
self.observe(trial, results=results)
except (KeyboardInterrupt, InvalidResult):
raise
except BaseException as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The line below executing on_error should be replaced by the Extension, with an event on_trial_fail or something like that.

Expand All @@ -808,6 +828,7 @@ def _optimize(
)
else:
self.release(trial, status="broken")

except CompletedExperiment as e:
log.warning(e)
break
Expand Down
222 changes: 222 additions & 0 deletions src/orion/ext/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
"""Defines extension mechanism for third party to hook into Orion"""


class EventDelegate:
"""Allow extensions to listen to incoming events from Orion.
Orion broadcasts events which trigger extensions callbacks.

Parameters
----------
name: str
name of the event we are creating, this is useful for error reporting

deferred: bool
if false events are triggered as soon as broadcast is called
Delaunay marked this conversation as resolved.
Show resolved Hide resolved
If true, the events will need to be triggered manually.
"""

def __init__(self, name, deferred=False) -> None:
self.handlers = []
self.deferred_calls = []
self.name = name
self.deferred = deferred
self.bad_handlers = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for?

self.manager = None

def remove(self, function) -> bool:
"""Remove an event handler from the handler list"""
try:
self.handlers.remove(function)
return True
except ValueError:
return False

def add(self, function):
"""Add an event handler to our handler list"""
self.handlers.append(function)

def broadcast(self, *args, **kwargs):
"""Broadcast and event to all our handlers"""
if not self.deferred:
self._execute(args, kwargs)
return

self.deferred_calls.append((args, kwargs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be best to leave the implementation of deferred calls for later unless we test it now.


def _execute(self, args, kwargs):
for fun in self.handlers:
try:
fun(*args, **kwargs)
except Exception as err:
if self.manager:
self.manager.on_extension_error.broadcast(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this silence any extension error by default? Perhaps the default should be to raise if there are no callbacks registered for on_extension_error.

self.name, fun, err, args=(args, kwargs)
)

def execute(self):
"""Execute all our deferred handlers if any"""
for args, kwargs in self.deferred_calls:
self._execute(args, kwargs)


class _DelegateStartEnd:
def __init__(self, start, error, end, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.start = start
self.end = end
self.error = error

def __enter__(self):
self.start.broadcast(*self.args, **self.kwargs)
return self

def __exit__(self, exception_type, exception_value, exception_traceback):
self.end.broadcast(*self.args, **self.kwargs)

if exception_value is not None:
self.error.broadcast(
*self.args,
exception_type,
exception_value,
exception_traceback,
**self.kwargs
)


class OrionExtensionManager:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think extension is a bit too general. It could be an extension of anything. Here we are specifically talking of callbacks for specific events. What would you think of a name like OrionCallbackManager instead? Or something related to events.

"""Manages third party extensions for Orion"""
bouthilx marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
self._events = {}
self._get_event("on_extension_error")

# -- Trials
self._get_event("new_trial")
self._get_event("on_trial_error")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be coherent I think there should be no on_ or it should be for all.

self._get_event("end_trial")

# -- Experiments
self._get_event("start_experiment")
self._get_event("on_experiment_error")
self._get_event("end_experiment")

@property
def on_extension_error(self):
"""Called when an extension is throwing an exception"""
return self._get_event("on_extension_error")

def experiment(self, *args, **kwargs):
"""Initialize a context manager that will call start/error/end events automatically"""
return _DelegateStartEnd(
self._get_event("start_experiment"),
self._get_event("on_experiment_error"),
self._get_event("end_experiment"),
*args,
**kwargs
)

def trial(self, *args, **kwargs):
"""Initialize a context manager that will call start/error/end events automatically"""
return _DelegateStartEnd(
self._get_event("new_trial"),
self._get_event("on_trial_error"),
self._get_event("end_trial"),
*args,
**kwargs
)

def broadcast(self, name, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is indeed not tested nor used yet.

return self._get_event(name).broadcast(*args, **kwargs)

def _get_event(self, key):
"""Retrieve event delegate

Will generate one if not defined already.
"""
delegate = self._events.get(key)

if delegate is None:
delegate = EventDelegate(key)
delegate.manager = self
self._events[key] = delegate

return delegate

def register(self, ext):
"""Register a new extensions

Parameters
----------
ext: ``OrionExtension``
object implementing :class:`OrionExtension` methods

Returns
-------
the number of calls that was registered
"""
registered_callbacks = 0
for name, delegate in self._events.items():
if hasattr(ext, name):
delegate.add(getattr(ext, name))
registered_callbacks += 1

return registered_callbacks

def unregister(self, ext):
"""Remove an extensions if it was already registered"""
unregistered_callbacks = 0
for name, delegate in self._events.items():
if hasattr(ext, name):
delegate.remove(getattr(ext, name))
unregistered_callbacks += 1

return unregistered_callbacks


class OrionExtension:
"""Base orion extension interface you need to implement"""

def on_extension_error(self, name, fun, exception, args):
"""Called when an extension callbakc raise an exception
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should clarify that if the callback does not raise the error, then the execution will continue when the callback is done.


Parameters
----------
fun: callable
handler that raised the error

exception:
raised exception

args: tuple
tuple of the arguments that were used
"""
return

def on_trial_error(
self, trial, exception_type, exception_value, exception_traceback
):
"""Called when a error occur during the optimization process"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is expected of the trial status? Will it be changed to broken already? It should be specified in the docstring.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what will happen afterwards, it this callback supposed to raise an error if we want the worker to stop?

return

def new_trial(self, trial):
"""Called when the trial starts with a new configuration"""
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is when the trial execution start then the name should be start_trial instead. Because it may not be a new trial, it can be a resumed one, so the current name is confusing.


def end_trial(self, trial):
"""Called when the trial finished"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it called anytime, even if the trial was interrupted or crashed? Is it called before a status change occurred?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should also be an event interrupt_trial.

return

def on_experiment_error(
self, experiment, exception_type, exception_value, exception_traceback
):
"""Called when a error occur during the optimization process"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment as for on_trial_error

return

def start_experiment(self, experiment):
"""Called at the begin of the optimization process before the worker starts"""
return

def end_experiment(self, experiment):
"""Called at the end of the optimization process after the worker exits"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment as for end_trial.

return
Loading