Skip to content

Commit

Permalink
Merge pull request #535 from neptune-ai/dev/websockets
Browse files Browse the repository at this point in the history
Add stop and abort signal handling
  • Loading branch information
HubertJaworski authored May 4, 2021
2 parents 06e7271 + 167725d commit 9867722
Show file tree
Hide file tree
Showing 24 changed files with 561 additions and 29 deletions.
2 changes: 1 addition & 1 deletion neptune/api_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(self, project_identifier):
{end}
Project {python}{project}{end} not found.
Verify if your project's name was not misspelled. You can find proper name after logging into Neptune UI: ui.neptune.ai.
Verify if your project's name was not misspelled. You can find proper name after logging into Neptune UI.
"""
inputs = dict(list({'project': project_identifier}.items()) + list(STYLES.items()))
super(ProjectNotFound, self).__init__(message.format(**inputs))
Expand Down
4 changes: 4 additions & 0 deletions neptune/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def get_leaderboard_entries(self, project,
min_running_time=None):
pass

# pylint: disable=unused-argument
def websockets_factory(self, project_uuid, experiment_id):
return None

@abstractmethod
def get_channel_points_csv(self, experiment, channel_internal_id, channel_name):
pass
Expand Down
2 changes: 1 addition & 1 deletion neptune/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def __init__(self, project_qualified_name):
- {correct}neptune-ai{end}: {underline}WORKSPACE{end} our company organization name
- {correct}credit-default-prediction{end}: {underline}PROJECT_NAME{end} a project name
The URL to this project looks like this: https://ui.neptune.ai/neptune-ai/credit-default-prediction
The URL to this project looks like this: https://app.neptune.ai/neptune-ai/credit-default-prediction
You may also want to check the following docs pages:
- https://docs-legacy.neptune.ai/workspace-project-and-user-management/index.html
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#
import logging
import os
import re
import time
import uuid
from collections import namedtuple
from http.client import NOT_FOUND
Expand All @@ -25,6 +27,7 @@

import six
from bravado.exception import HTTPNotFound
from neptune.internal.websockets.reconnecting_websocket_factory import ReconnectingWebsocketFactory

from neptune.api_exceptions import (ExperimentNotFound, ExperimentOperationErrors, PathInExperimentNotFound,
ProjectNotFound)
Expand All @@ -49,14 +52,15 @@
from neptune.model import ChannelWithLastValue, LeaderboardEntry
from neptune.new import exceptions as alpha_exceptions
from neptune.new.attributes import constants as alpha_consts
from neptune.new.attributes.constants import MONITORING_TRACEBACK_ATTRIBUTE_PATH, SYSTEM_FAILED_ATTRIBUTE_PATH
from neptune.new.internal import operation as alpha_operation
from neptune.new.internal.backends import hosted_file_operations as alpha_hosted_file_operations
from neptune.new.internal.backends.api_model import AttributeType
from neptune.new.internal.backends.operation_api_name_visitor import \
OperationApiNameVisitor as AlphaOperationApiNameVisitor
from neptune.new.internal.backends.operation_api_object_converter import \
OperationApiObjectConverter as AlphaOperationApiObjectConverter
from neptune.new.internal.operation import AssignString, ConfigFloatSeries, LogFloats
from neptune.new.internal.operation import AssignString, ConfigFloatSeries, LogFloats, AssignBool, LogStrings
from neptune.new.internal.utils import base64_decode, base64_encode, paths as alpha_path_utils
from neptune.new.internal.utils.paths import parse_path
from neptune.utils import assure_directory_exists, with_api_exceptions_handler
Expand Down Expand Up @@ -375,7 +379,12 @@ def mark_succeeded(self, experiment):
pass

def mark_failed(self, experiment, traceback):
pass
operations = []
path = parse_path(SYSTEM_FAILED_ATTRIBUTE_PATH)
traceback_values = [LogStrings.ValueType(val, step=None, ts=time.time()) for val in traceback.split("\n")]
operations.append(AssignBool(path=path, value=True))
operations.append(LogStrings(values=traceback_values, path=parse_path(MONITORING_TRACEBACK_ATTRIBUTE_PATH)))
self._execute_operations(experiment, operations)

def ping_experiment(self, experiment):
try:
Expand Down Expand Up @@ -764,6 +773,13 @@ def get_portion(limit, offset):
except HTTPNotFound:
raise ProjectNotFound(project_identifier=project.full_id)

def websockets_factory(self, project_uuid, experiment_id):
base_url = re.sub(r'^http', 'ws', self.api_address) + '/api/notifications/v1'
return ReconnectingWebsocketFactory(
backend=self,
url=base_url + f'/runs/{project_uuid}/{experiment_id}/signal'
)

@staticmethod
def _to_leaderboard_entry_dto(experiment_attributes):
attributes = experiment_attributes.attributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import requests
import six
from bravado.exception import HTTPBadRequest, HTTPNotFound, HTTPUnprocessableEntity, HTTPConflict
from neptune.internal.websockets.reconnecting_websocket_factory import ReconnectingWebsocketFactory

from neptune.api_exceptions import (
ChannelAlreadyExists,
Expand Down Expand Up @@ -151,6 +152,13 @@ def get_portion(limit, offset):
except HTTPNotFound:
raise ProjectNotFound(project_identifier=project.full_id)

def websockets_factory(self, project_uuid, experiment_id):
base_url = re.sub(r'^http', 'ws', self.api_address) + '/api/notifications/v1'
return ReconnectingWebsocketFactory(
backend=self,
url=base_url + '/experiments/' + experiment_id + '/operations'
)

@with_api_exceptions_handler
def get_channel_points_csv(self, experiment, channel_internal_id, channel_name):
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def get_leaderboard_entries(self, project,
min_running_time=None):
return self._client.get_leaderboard_entries(project, entry_types, ids, states, owners, tags, min_running_time)

@with_migration_handling
def websockets_factory(self, project_uuid, experiment_id):
return self._client.websockets_factory(project_uuid, experiment_id)

@with_migration_handling
def get_channel_points_csv(self, experiment, channel_internal_id, channel_name):
return self._client.get_channel_points_csv(experiment, channel_internal_id, channel_name)
Expand Down
11 changes: 7 additions & 4 deletions neptune/internal/execution/execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from neptune.internal.threads.aborting_thread import AbortingThread
from neptune.internal.threads.hardware_metric_reporting_thread import HardwareMetricReportingThread
from neptune.internal.threads.ping_thread import PingThread
from neptune.internal.websockets.reconnecting_websocket_factory import ReconnectingWebsocketFactory
from neptune.utils import is_notebook, in_docker, is_ipython

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -137,14 +136,18 @@ def _run_aborting_thread(self, abort_callback):
else:
return

websocket_factory = ReconnectingWebsocketFactory(
backend=self._backend,
websocket_factory = self._backend.websockets_factory(
# pylint: disable=protected-access
project_uuid=self._experiment._project.internal_id,
experiment_id=self._experiment.internal_id
)
if not websocket_factory:
return

self._aborting_thread = AbortingThread(
websocket_factory=websocket_factory,
abort_impl=abort_impl,
experiment_id=self._experiment.internal_id
experiment=self._experiment
)
self._aborting_thread.start()

Expand Down
14 changes: 9 additions & 5 deletions neptune/internal/threads/aborting_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@


class AbortingThread(NeptuneThread):
def __init__(self, websocket_factory, abort_impl, experiment_id):
def __init__(self, websocket_factory, abort_impl, experiment):
super(AbortingThread, self).__init__(is_daemon=True)
self._abort_message_processor = AbortMessageProcessor(abort_impl, experiment_id)
self._abort_message_processor = AbortMessageProcessor(abort_impl, experiment)
self._ws_client = websocket_factory.create(shutdown_condition=threading.Event())

def run(self):
Expand All @@ -46,14 +46,18 @@ def _is_heartbeat(message):


class AbortMessageProcessor(WebsocketMessageProcessor):
def __init__(self, abort_impl, experiment_id):
def __init__(self, abort_impl, experiment):
super(AbortMessageProcessor, self).__init__()
self._abort_impl = abort_impl
self._experiment_id = experiment_id
self._experiment = experiment
self.received_abort_message = False

def _process_message(self, message):
if message.get_type() == MessageType.ABORT:
if message.get_type() == MessageType.STOP:
self._experiment.stop()
self._abort()
elif message.get_type() == MessageType.ABORT:
self._experiment.stop("Remotely aborted")
self._abort()

def _abort(self):
Expand Down
26 changes: 24 additions & 2 deletions neptune/internal/websockets/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.new.attributes.constants import SIGNAL_TYPE_ABORT, SIGNAL_TYPE_STOP


class Message(object):
def __init__(self):
pass

MESSAGE_TYPE = 'messageType'
MESSAGE_NEW_TYPE = 'type'
MESSAGE_BODY = 'messageBody'
MESSAGE_NEW_BODY = 'body'

@classmethod
def from_json(cls, json_value):
message_type = json_value[Message.MESSAGE_TYPE]
message_body = json_value[Message.MESSAGE_BODY]
message_type = json_value.get(Message.MESSAGE_TYPE) or json_value.get(Message.MESSAGE_NEW_TYPE)
message_body = json_value.get(Message.MESSAGE_BODY) or json_value.get(Message.MESSAGE_NEW_BODY)

if message_type == SIGNAL_TYPE_STOP:
message_type = MessageType.STOP
elif message_type == SIGNAL_TYPE_ABORT:
message_type = MessageType.ABORT

if message_type in MessageClassRegistry.MESSAGE_CLASSES:
return MessageClassRegistry.MESSAGE_CLASSES[message_type].from_json(message_body)
Expand Down Expand Up @@ -53,6 +61,19 @@ def body_to_json(self):
return None


class StopMessage(Message):
@classmethod
def get_type(cls):
return MessageType.STOP

@classmethod
def from_json(cls, json_value):
return StopMessage()

def body_to_json(self):
return None


class ActionInvocationMessage(Message):
_ACTION_ID_JSON_KEY = 'actionId'
_ACTION_INVOCATION_ID_JSON_KEY = 'actionInvocationId'
Expand Down Expand Up @@ -87,6 +108,7 @@ def body_to_json(self):
class MessageType(object):
NEW_CHANNEL_VALUES = 'NewChannelValues'
ABORT = 'Abort'
STOP = 'Stop'
ACTION_INVOCATION = 'InvokeAction'


Expand Down
13 changes: 3 additions & 10 deletions neptune/internal/websockets/reconnecting_websocket_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re

from neptune.internal.websockets.reconnecting_websocket import ReconnectingWebsocket


class ReconnectingWebsocketFactory(object):
def __init__(self, backend, experiment_id):
def __init__(self, backend, url):
self._backend = backend
self._base_address = re.sub(r'^http', 'ws', self._backend.api_address) + '/api/notifications/v1'
self._experiment_id = experiment_id
self._url = url

def create(self, shutdown_condition):
url = self._experiment_url(self._base_address, self._experiment_id)
return ReconnectingWebsocket(
url=url,
url=self._url,
oauth2_session=self._backend.authenticator.auth.session,
shutdown_event=shutdown_condition,
proxies=self._backend.proxies)

@staticmethod
def _experiment_url(base_address, experiment_id):
return base_address + '/experiments/' + experiment_id + '/operations'
5 changes: 5 additions & 0 deletions neptune/new/attributes/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
MONITORING_ATTRIBUTE_SPACE = 'monitoring/'
MONITORING_STDERR_ATTRIBUTE_PATH = f'{MONITORING_ATTRIBUTE_SPACE}stderr'
MONITORING_STDOUT_ATTRIBUTE_PATH = f'{MONITORING_ATTRIBUTE_SPACE}stdout'
MONITORING_TRACEBACK_ATTRIBUTE_PATH = f'{MONITORING_ATTRIBUTE_SPACE}traceback'

PARAMETERS_ATTRIBUTE_SPACE = 'parameters/'

Expand All @@ -36,3 +37,7 @@
SYSTEM_NAME_ATTRIBUTE_PATH = f'{SYSTEM_ATTRIBUTE_SPACE}name'
SYSTEM_STATE_ATTRIBUTE_PATH = f'{SYSTEM_ATTRIBUTE_SPACE}state'
SYSTEM_TAGS_ATTRIBUTE_PATH = f'{SYSTEM_ATTRIBUTE_SPACE}tags'
SYSTEM_FAILED_ATTRIBUTE_PATH = f'{SYSTEM_ATTRIBUTE_SPACE}failed'

SIGNAL_TYPE_STOP = "neptune/stop"
SIGNAL_TYPE_ABORT = "neptune/abort"
4 changes: 2 additions & 2 deletions neptune/new/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, project_id):
Project {python}{project}{end} not found.
Verify if your project's name was not misspelled.
You can find proper name after logging into Neptune UI: ui.neptune.ai.
You can find proper name after logging into Neptune UI.
"""
inputs = dict(list({'project': project_id}.items()) + list(STYLES.items()))
super().__init__(message.format(**inputs))
Expand Down Expand Up @@ -184,7 +184,7 @@ def __init__(self, project):
- {correct}neptune-ai{end}: {underline}WORKSPACE{end} our company organization name
- {correct}credit-default-prediction{end}: {underline}PROJECT_NAME{end} a project name
The URL to this project looks like this: https://ui.neptune.ai/neptune-ai/credit-default-prediction
The URL to this project looks like this: https://app.neptune.ai/neptune-ai/credit-default-prediction
You may also want to check the following docs pages:
- https://docs.neptune.ai/administration/workspace-project-and-user-management
Expand Down
14 changes: 13 additions & 1 deletion neptune/new/internal/backends/hosted_neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import os
import platform
import re
import uuid
from typing import List, Optional, Dict, Iterable

Expand Down Expand Up @@ -90,6 +91,7 @@
)
from neptune.new.internal.utils import verify_type, base64_decode
from neptune.new.internal.utils.paths import path_to_str
from neptune.new.internal.websockets.websockets_factory import WebsocketsFactory
from neptune.new.types.atoms import GitRef
from neptune.new.version import version as neptune_client_version
from neptune.oauth import NeptuneAuthenticator
Expand All @@ -103,6 +105,7 @@ class HostedNeptuneBackend(NeptuneBackend):

def __init__(self, credentials: Credentials, proxies: Optional[Dict[str, str]] = None):
self.credentials = credentials
self.proxies = proxies

ssl_verify = True
if os.getenv(NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE):
Expand Down Expand Up @@ -132,11 +135,12 @@ def __init__(self, credentials: Credentials, proxies: Optional[Dict[str, str]] =
self._http_client)

# TODO: Do not use NeptuneAuthenticator from old_neptune. Move it to new package.
self._http_client.authenticator = NeptuneAuthenticator(
self._authenticator = NeptuneAuthenticator(
self.credentials.api_token,
token_client,
ssl_verify,
proxies)
self._http_client.authenticator = self._authenticator

user_agent = 'neptune-client/{lib_version} ({system}, python {python_version})'.format(
lib_version=neptune_client_version,
Expand All @@ -147,6 +151,14 @@ def __init__(self, credentials: Credentials, proxies: Optional[Dict[str, str]] =
def get_display_address(self) -> str:
return self._client_config.display_url

def websockets_factory(self, project_uuid: uuid.UUID, run_uuid: uuid.UUID) -> Optional[WebsocketsFactory]:
base_url = re.sub(r'^http', 'ws', self._client_config.api_url)
return WebsocketsFactory(
url=base_url + f'/api/notifications/v1/runs/{str(project_uuid)}/{str(run_uuid)}/signal',
session=self._authenticator.auth.session,
proxies=self.proxies
)

@with_api_exceptions_handler
def get_project(self, project_id: str) -> Project:
verify_type("project_id", project_id, str)
Expand Down
6 changes: 6 additions & 0 deletions neptune/new/internal/backends/neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,20 @@
ImageSeriesValues,
)
from neptune.new.internal.operation import Operation
from neptune.new.internal.websockets.websockets_factory import WebsocketsFactory
from neptune.new.types.atoms import GitRef


class NeptuneBackend:

@abc.abstractmethod
def get_display_address(self) -> str:
pass

# pylint: disable=unused-argument
def websockets_factory(self, project_uuid: uuid.UUID, run_uuid: uuid.UUID) -> Optional[WebsocketsFactory]:
return None

@abc.abstractmethod
def get_project(self, project_id: str) -> Project:
pass
Expand Down
1 change: 1 addition & 0 deletions neptune/new/internal/backends/neptune_backend_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def create_run(self,
self._runs[new_run_uuid].set(["sys", "tags"], StringSet(set()))
self._runs[new_run_uuid].set(["sys", "creation_time"], Datetime(datetime.now()))
self._runs[new_run_uuid].set(["sys", "modification_time"], Datetime(datetime.now()))
self._runs[new_run_uuid].set(["sys", "failed"], Boolean(False))
if git_ref:
self._runs[new_run_uuid].set(["source_code", "git"], git_ref)
return ApiRun(new_run_uuid, short_id, 'workspace', 'sandbox', False)
Expand Down
Loading

0 comments on commit 9867722

Please sign in to comment.