diff --git a/neptune/internal/backends/hosted_neptune_backend.py b/neptune/internal/backends/hosted_neptune_backend.py index 8ecc39436..9b31f2d50 100644 --- a/neptune/internal/backends/hosted_neptune_backend.py +++ b/neptune/internal/backends/hosted_neptune_backend.py @@ -84,22 +84,29 @@ def __init__(self, api_token=None, proxies=None): ssl_verify = False self._http_client = RequestsClient(ssl_verify=ssl_verify) + # for session re-creation we need to keep an authenticator-free version of http client + self._http_client_for_token = RequestsClient(ssl_verify=ssl_verify) update_session_proxies(self._http_client.session, proxies) + update_session_proxies(self._http_client_for_token.session, proxies) config_api_url = self.credentials.api_url_opt or self.credentials.token_origin_address # We don't need to be able to resolve Neptune host if we use proxy if proxies is None: self._verify_host_resolution(config_api_url, self.credentials.token_origin_address) - backend_client = self._get_swagger_client('{}/api/backend/swagger.json'.format(config_api_url)) + # this backend client is used only for initial configuration and session re-creation + backend_client = self._get_swagger_client( + '{}/api/backend/swagger.json'.format(config_api_url), + self._http_client_for_token + ) self._client_config = self._create_client_config(self.credentials.api_token, backend_client) self._verify_version() - self._set_swagger_clients(self._client_config, config_api_url, backend_client) + self._set_swagger_clients(self._client_config) - self.authenticator = self._create_authenticator(self.credentials.api_token, ssl_verify, proxies) + self.authenticator = self._create_authenticator(self.credentials.api_token, ssl_verify, proxies, backend_client) self._http_client.authenticator = self.authenticator user_agent = 'neptune-client/{lib_version} ({system}, python {python_version})'.format( @@ -924,7 +931,7 @@ def _upload_tar_data(self, experiment, api_method, data): return response @with_api_exceptions_handler - def _get_swagger_client(self, url): + def _get_swagger_client(self, url, http_client): return SwaggerClient.from_url( url, config=dict( @@ -933,12 +940,13 @@ def _get_swagger_client(self, url): validate_responses=False, formats=[uuid_format] ), - http_client=self._http_client) + http_client=http_client) @with_api_exceptions_handler - def _create_authenticator(self, api_token, ssl_verify, proxies): + def _create_authenticator(self, api_token, ssl_verify, proxies, backend_client): return NeptuneAuthenticator( - self.backend_swagger_client.api.exchangeApiToken(X_Neptune_Api_Token=api_token).response().result, + api_token, + backend_client, ssl_verify, proxies) @@ -992,14 +1000,15 @@ def _verify_version(self): self._client_config.min_recommended_version, self.client_lib_version), sys.stderr) - def _set_swagger_clients(self, client_config, client_config_api_addr, client_config_backend_client): - self.backend_swagger_client = ( - client_config_backend_client if client_config_api_addr == client_config.api_url - else self._get_swagger_client('{}/api/backend/swagger.json'.format(client_config.api_url)) + def _set_swagger_clients(self, client_config): + self.backend_swagger_client = self._get_swagger_client( + '{}/api/backend/swagger.json'.format(client_config.api_url), + self._http_client ) self.leaderboard_swagger_client = self._get_swagger_client( - '{}/api/leaderboard/swagger.json'.format(client_config.api_url) + '{}/api/leaderboard/swagger.json'.format(client_config.api_url), + self._http_client ) def _verify_host_resolution(self, api_url, app_url): diff --git a/neptune/oauth.py b/neptune/oauth.py index 0d511bb92..f189edd36 100644 --- a/neptune/oauth.py +++ b/neptune/oauth.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import threading import time import jwt from bravado.requests_client import Authenticator -from oauthlib.oauth2 import TokenExpiredError +from oauthlib.oauth2 import TokenExpiredError, OAuth2Error from requests.auth import AuthBase from requests_oauthlib import OAuth2Session @@ -25,9 +26,11 @@ class NeptuneAuth(AuthBase): + __LOCK = threading.RLock() - def __init__(self, session): - self.session = session + def __init__(self, session_factory): + self.session_factory = session_factory + self.session = session_factory() self.token_expires_at = 0 def __call__(self, r): @@ -51,6 +54,16 @@ def refresh_token_if_needed(self): self._refresh_token() def _refresh_token(self): + with self.__LOCK: + try: + self._refresh_session_token() + except OAuth2Error: + # for some reason oauth session is no longer valid. Retry by creating new fresh session + # we can safely ignore this error, as it will be thrown again if it's persistent + self.session = self.session_factory() + self._refresh_session_token() + + def _refresh_session_token(self): self.session.refresh_token(self.session.auto_refresh_url, verify=self.session.verify) if self.session.token is not None and self.session.token.get('access_token') is not None: decoded_json_token = jwt.decode(self.session.token.get('access_token'), verify=False) @@ -59,29 +72,35 @@ def _refresh_token(self): class NeptuneAuthenticator(Authenticator): - def __init__(self, auth_tokens, ssl_verify, proxies): + def __init__(self, api_token, backend_client, ssl_verify, proxies): super(NeptuneAuthenticator, self).__init__(host='') - decoded_json_token = jwt.decode(auth_tokens.accessToken, verify=False) - expires_at = decoded_json_token.get(u'exp') - client_name = decoded_json_token.get(u'azp') - refresh_url = u'{realm_url}/protocol/openid-connect/token'.format(realm_url=decoded_json_token.get(u'iss')) - token = { - u'access_token': auth_tokens.accessToken, - u'refresh_token': auth_tokens.refreshToken, - u'expires_in': expires_at - time.time() - } - session = OAuth2Session( - client_id=client_name, - token=token, - auto_refresh_url=refresh_url, - auto_refresh_kwargs={'client_id': client_name}, - token_updater=_no_token_updater - ) - session.verify = ssl_verify - - update_session_proxies(session, proxies) - - self.auth = NeptuneAuth(session) + + # We need to pass a lambda to be able to re-create fresh session at any time when needed + def session_factory(): + auth_tokens = backend_client.api.exchangeApiToken(X_Neptune_Api_Token=api_token).response().result + decoded_json_token = jwt.decode(auth_tokens.accessToken, verify=False) + expires_at = decoded_json_token.get(u'exp') + client_name = decoded_json_token.get(u'azp') + refresh_url = u'{realm_url}/protocol/openid-connect/token'.format(realm_url=decoded_json_token.get(u'iss')) + token = { + u'access_token': auth_tokens.accessToken, + u'refresh_token': auth_tokens.refreshToken, + u'expires_in': expires_at - time.time() + } + + session = OAuth2Session( + client_id=client_name, + token=token, + auto_refresh_url=refresh_url, + auto_refresh_kwargs={'client_id': client_name}, + token_updater=_no_token_updater + ) + session.verify = ssl_verify + + update_session_proxies(session, proxies) + return session + + self.auth = NeptuneAuth(session_factory) def matches(self, url): return True