Skip to content

Commit

Permalink
Fix oauth session proxy (#186)
Browse files Browse the repository at this point in the history
* Fix oauth session proxy

* Review chagnes
  • Loading branch information
Hubert Jaworski authored Oct 29, 2019
1 parent 509fccf commit 6ec89c4
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 15 deletions.
19 changes: 7 additions & 12 deletions neptune/internal/backends/hosted_neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from neptune.notebook import Notebook
from neptune.oauth import NeptuneAuthenticator
from neptune.projects import Project
from neptune.utils import is_float, with_api_exceptions_handler
from neptune.utils import is_float, with_api_exceptions_handler, update_session_proxies

_logger = logging.getLogger(__name__)

Expand All @@ -63,16 +63,16 @@ def __init__(self, api_token=None, proxies=None):
ssl_verify = False

self._http_client = RequestsClient(ssl_verify=ssl_verify)
if proxies is not None:
self._update_proxies(proxies)

update_session_proxies(self._http_client.session, proxies)

self.backend_swagger_client = self._get_swagger_client('{}/api/backend/swagger.json'
.format(self.api_address))

self.leaderboard_swagger_client = self._get_swagger_client('{}/api/leaderboard/swagger.json'
.format(self.api_address))

self.authenticator = self._create_authenticator(self.credentials.api_token, ssl_verify)
self.authenticator = self._create_authenticator(self.credentials.api_token, ssl_verify, proxies)
self._http_client.authenticator = self.authenticator

# This is not a top-level import because of circular dependencies
Expand Down Expand Up @@ -872,12 +872,6 @@ def _upload_tar_data(self, experiment, api_method, data):
response.raise_for_status()
return response

def _update_proxies(self, proxies):
try:
self._http_client.session.proxies.update(proxies)
except (TypeError, ValueError):
raise ValueError("Wrong proxies format: {}".format(proxies))

@with_api_exceptions_handler
def _get_swagger_client(self, url):
return SwaggerClient.from_url(
Expand All @@ -892,10 +886,11 @@ def _get_swagger_client(self, url):
)

@with_api_exceptions_handler
def _create_authenticator(self, api_token, ssl_verify):
def _create_authenticator(self, api_token, ssl_verify, proxies):
return NeptuneAuthenticator(
self.backend_swagger_client.api.exchangeApiToken(X_Neptune_Api_Token=api_token).response().result,
ssl_verify
ssl_verify,
proxies
)


Expand Down
7 changes: 5 additions & 2 deletions neptune/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from requests.auth import AuthBase
from requests_oauthlib import OAuth2Session

from neptune.utils import with_api_exceptions_handler
from neptune.utils import with_api_exceptions_handler, update_session_proxies


class NeptuneAuth(AuthBase):
Expand Down Expand Up @@ -59,7 +59,7 @@ def _refresh_token(self):

class NeptuneAuthenticator(Authenticator):

def __init__(self, auth_tokens, ssl_verify):
def __init__(self, auth_tokens, ssl_verify, proxies):
super(NeptuneAuthenticator, self).__init__(host='')
decoded_json_token = jwt.decode(auth_tokens.accessToken, verify=False)
expires_at = decoded_json_token.get(u'exp')
Expand All @@ -78,6 +78,9 @@ def __init__(self, auth_tokens, ssl_verify):
token_updater=_no_token_updater
)
session.verify = ssl_verify

update_session_proxies(session, proxies)

self.auth = NeptuneAuth(session)

def matches(self, url):
Expand Down
8 changes: 8 additions & 0 deletions neptune/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ def discover_git_repo_location():
return None


def update_session_proxies(session, proxies):
if proxies is not None:
try:
session.proxies.update(proxies)
except (TypeError, ValueError):
raise ValueError("Wrong proxies format: {}".format(proxies))


def get_git_info(repo_path=None):
"""Retrieve information about git repository.
Expand Down
2 changes: 1 addition & 1 deletion tests/neptune/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_apply_oauth2_session_to_request(self, time_mock, session_mock):
session.token = dict()

# and
neptune_authenticator = NeptuneAuthenticator(auth_tokens, False)
neptune_authenticator = NeptuneAuthenticator(auth_tokens, False, None)
request = a_request()

# when
Expand Down

0 comments on commit 6ec89c4

Please sign in to comment.