Skip to content

Commit

Permalink
Allow oauth session to restart from scratch
Browse files Browse the repository at this point in the history
  • Loading branch information
HubertJaworski committed Oct 27, 2020
1 parent 7680cf5 commit c808baa
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 37 deletions.
33 changes: 21 additions & 12 deletions neptune/internal/backends/hosted_neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,29 @@ def __init__(self, api_token=None, proxies=None):
ssl_verify = False

self._http_client = RequestsClient(ssl_verify=ssl_verify)
# for session re-creation we need to keep an authenticator-free version of http client
self._http_client_for_token = RequestsClient(ssl_verify=ssl_verify)

update_session_proxies(self._http_client.session, proxies)
update_session_proxies(self._http_client_for_token.session, proxies)

config_api_url = self.credentials.api_url_opt or self.credentials.token_origin_address
# We don't need to be able to resolve Neptune host if we use proxy
if proxies is None:
self._verify_host_resolution(config_api_url, self.credentials.token_origin_address)

backend_client = self._get_swagger_client('{}/api/backend/swagger.json'.format(config_api_url))
# this backend client is used only for initial configuration and session re-creation
backend_client = self._get_swagger_client(
'{}/api/backend/swagger.json'.format(config_api_url),
self._http_client_for_token
)
self._client_config = self._create_client_config(self.credentials.api_token, backend_client)

self._verify_version()

self._set_swagger_clients(self._client_config, config_api_url, backend_client)
self._set_swagger_clients(self._client_config)

self.authenticator = self._create_authenticator(self.credentials.api_token, ssl_verify, proxies)
self.authenticator = self._create_authenticator(self.credentials.api_token, ssl_verify, proxies, backend_client)
self._http_client.authenticator = self.authenticator

user_agent = 'neptune-client/{lib_version} ({system}, python {python_version})'.format(
Expand Down Expand Up @@ -924,7 +931,7 @@ def _upload_tar_data(self, experiment, api_method, data):
return response

@with_api_exceptions_handler
def _get_swagger_client(self, url):
def _get_swagger_client(self, url, http_client):
return SwaggerClient.from_url(
url,
config=dict(
Expand All @@ -933,12 +940,13 @@ def _get_swagger_client(self, url):
validate_responses=False,
formats=[uuid_format]
),
http_client=self._http_client)
http_client=http_client)

@with_api_exceptions_handler
def _create_authenticator(self, api_token, ssl_verify, proxies):
def _create_authenticator(self, api_token, ssl_verify, proxies, backend_client):
return NeptuneAuthenticator(
self.backend_swagger_client.api.exchangeApiToken(X_Neptune_Api_Token=api_token).response().result,
api_token,
backend_client,
ssl_verify,
proxies)

Expand Down Expand Up @@ -992,14 +1000,15 @@ def _verify_version(self):
self._client_config.min_recommended_version, self.client_lib_version),
sys.stderr)

def _set_swagger_clients(self, client_config, client_config_api_addr, client_config_backend_client):
self.backend_swagger_client = (
client_config_backend_client if client_config_api_addr == client_config.api_url
else self._get_swagger_client('{}/api/backend/swagger.json'.format(client_config.api_url))
def _set_swagger_clients(self, client_config):
self.backend_swagger_client = self._get_swagger_client(
'{}/api/backend/swagger.json'.format(client_config.api_url),
self._http_client
)

self.leaderboard_swagger_client = self._get_swagger_client(
'{}/api/leaderboard/swagger.json'.format(client_config.api_url)
'{}/api/leaderboard/swagger.json'.format(client_config.api_url),
self._http_client
)

def _verify_host_resolution(self, api_url, app_url):
Expand Down
69 changes: 44 additions & 25 deletions neptune/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import threading
import time

import jwt
from bravado.requests_client import Authenticator
from oauthlib.oauth2 import TokenExpiredError
from oauthlib.oauth2 import TokenExpiredError, OAuth2Error
from requests.auth import AuthBase
from requests_oauthlib import OAuth2Session

from neptune.utils import with_api_exceptions_handler, update_session_proxies


class NeptuneAuth(AuthBase):
__LOCK = threading.RLock()

def __init__(self, session):
self.session = session
def __init__(self, session_factory):
self.session_factory = session_factory
self.session = session_factory()
self.token_expires_at = 0

def __call__(self, r):
Expand All @@ -51,6 +54,16 @@ def refresh_token_if_needed(self):
self._refresh_token()

def _refresh_token(self):
with self.__LOCK:
try:
self._refresh_session_token()
except OAuth2Error:
# for some reason oauth session is no longer valid. Retry by creating new fresh session
# we can safely ignore this error, as it will be thrown again if it's persistent
self.session = self.session_factory()
self._refresh_session_token()

def _refresh_session_token(self):
self.session.refresh_token(self.session.auto_refresh_url, verify=self.session.verify)
if self.session.token is not None and self.session.token.get('access_token') is not None:
decoded_json_token = jwt.decode(self.session.token.get('access_token'), verify=False)
Expand All @@ -59,29 +72,35 @@ def _refresh_token(self):

class NeptuneAuthenticator(Authenticator):

def __init__(self, auth_tokens, ssl_verify, proxies):
def __init__(self, api_token, backend_client, ssl_verify, proxies):
super(NeptuneAuthenticator, self).__init__(host='')
decoded_json_token = jwt.decode(auth_tokens.accessToken, verify=False)
expires_at = decoded_json_token.get(u'exp')
client_name = decoded_json_token.get(u'azp')
refresh_url = u'{realm_url}/protocol/openid-connect/token'.format(realm_url=decoded_json_token.get(u'iss'))
token = {
u'access_token': auth_tokens.accessToken,
u'refresh_token': auth_tokens.refreshToken,
u'expires_in': expires_at - time.time()
}
session = OAuth2Session(
client_id=client_name,
token=token,
auto_refresh_url=refresh_url,
auto_refresh_kwargs={'client_id': client_name},
token_updater=_no_token_updater
)
session.verify = ssl_verify

update_session_proxies(session, proxies)

self.auth = NeptuneAuth(session)

# We need to pass a lambda to be able to re-create fresh session at any time when needed
def session_factory():
auth_tokens = backend_client.api.exchangeApiToken(X_Neptune_Api_Token=api_token).response().result
decoded_json_token = jwt.decode(auth_tokens.accessToken, verify=False)
expires_at = decoded_json_token.get(u'exp')
client_name = decoded_json_token.get(u'azp')
refresh_url = u'{realm_url}/protocol/openid-connect/token'.format(realm_url=decoded_json_token.get(u'iss'))
token = {
u'access_token': auth_tokens.accessToken,
u'refresh_token': auth_tokens.refreshToken,
u'expires_in': expires_at - time.time()
}

session = OAuth2Session(
client_id=client_name,
token=token,
auto_refresh_url=refresh_url,
auto_refresh_kwargs={'client_id': client_name},
token_updater=_no_token_updater
)
session.verify = ssl_verify

update_session_proxies(session, proxies)
return session

self.auth = NeptuneAuth(session_factory)

def matches(self, url):
return True
Expand Down

0 comments on commit c808baa

Please sign in to comment.