-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Resilient http adapter and project authentication (#548)
* Port http adapter * Project auth implementation * Restore coverage requirements * Http adapter coverage * Apply hooks * Project auth coverage * Update changelog and bump version * Drop unnecessary changes * Apply hooks * Coverage for token value * Address comments
- Loading branch information
Showing
9 changed files
with
202 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "up42-py" | ||
version = "0.36.0" | ||
version = "0.37.0a0" | ||
description = "Python SDK for UP42, the geospatial marketplace and developer platform." | ||
authors = ["UP42 GmbH <[email protected]>"] | ||
license = "https://github.com/up42/up42-py/blob/master/LICENSE" | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import random | ||
|
||
import pytest | ||
|
||
from up42.http.config import ResilienceSettings | ||
from up42.http.http_adapter import create as create_adapter | ||
|
||
|
||
@pytest.mark.parametrize("include_post", [True, False]) | ||
def test_should_create_adapter(include_post): | ||
total_retries = 5 | ||
backoff_factor = 0.4 | ||
statuses = (random.randint(400, 600),) | ||
settings = ResilienceSettings(total=total_retries, backoff_factor=backoff_factor, statuses=statuses) | ||
adapter = create_adapter(supply_settings=lambda: settings, include_post=include_post) | ||
assert adapter.max_retries.total == total_retries | ||
assert adapter.max_retries.backoff_factor == backoff_factor | ||
assert adapter.max_retries.status_forcelist == statuses | ||
allowed_methods = adapter.max_retries.allowed_methods or [] | ||
if include_post: | ||
assert "POST" in allowed_methods | ||
else: | ||
assert "POST" not in allowed_methods |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import base64 | ||
import dataclasses as dc | ||
import time | ||
|
||
from requests_mock import Mocker | ||
|
||
from up42.http.config import ProjectCredentialsSettings, TokenProviderSettings | ||
from up42.http.oauth import ProjectAuth | ||
|
||
token_url = "https://localhost/oauth/token" | ||
project_credentials = ProjectCredentialsSettings( | ||
client_id="client_id", | ||
client_secret="client_secret", | ||
) | ||
|
||
token_settings = TokenProviderSettings( | ||
token_url=token_url, | ||
duration=2, | ||
timeout=1, | ||
) | ||
|
||
|
||
def basic_auth(username, password): | ||
token = base64.b64encode(f"{username}:{password}".encode("utf-8")) | ||
return f'Basic {token.decode("ascii")}' | ||
|
||
|
||
basic_client_auth = basic_auth(project_credentials.client_id, project_credentials.client_secret) | ||
basic_auth_headers = {"Authorization": basic_client_auth} | ||
|
||
|
||
@dc.dataclass | ||
class FakeRequest: | ||
headers: dict | ||
|
||
|
||
fake_request = FakeRequest(headers={}) | ||
|
||
|
||
def create_project_auth(): | ||
return ProjectAuth( | ||
supply_credentials_settings=lambda: project_credentials, | ||
supply_token_settings=lambda: token_settings, | ||
) | ||
|
||
|
||
class TestProjectAuth: | ||
def test_should_fetch_token_when_created(self, requests_mock: Mocker): | ||
token_value = "some-value" | ||
requests_mock.post(token_url, json={"access_token": token_value}, request_headers=basic_auth_headers) | ||
project_auth = create_project_auth() | ||
project_auth(fake_request) | ||
assert fake_request.headers["Authorization"] == f"Bearer {token_value}" | ||
assert project_auth.token.access_token == token_value | ||
assert requests_mock.called_once | ||
|
||
def test_should_fetch_token_when_expired(self, requests_mock: Mocker): | ||
responses = [{"json": {"access_token": f"token{idx}"}} for idx in range(1, 3)] | ||
requests_mock.post(token_url, response_list=responses, request_headers=basic_auth_headers) | ||
project_auth = create_project_auth() | ||
time.sleep(token_settings.duration + 1) | ||
project_auth(fake_request) | ||
expected_token_value = "token2" | ||
assert fake_request.headers["Authorization"] == f"Bearer {expected_token_value}" | ||
assert project_auth.token.access_token == expected_token_value | ||
assert requests_mock.call_count == 2 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass(eq=True, frozen=True) | ||
class ResilienceSettings: | ||
total: int = 10 | ||
backoff_factor: float = 0.001 | ||
statuses: tuple = tuple(range(500, 600)) + (429,) | ||
|
||
|
||
@dataclass(eq=True, frozen=True) | ||
class TokenProviderSettings: | ||
token_url: str | ||
duration: int = 5 * 60 | ||
timeout: int = 120 | ||
|
||
|
||
@dataclass(eq=True, frozen=True) | ||
class ProjectCredentialsSettings: | ||
client_id: str | ||
client_secret: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import Callable | ||
|
||
from requests import adapters | ||
from urllib3.util import Retry | ||
|
||
from up42.http import config | ||
|
||
|
||
def create( | ||
supply_settings: Callable[[], config.ResilienceSettings] = config.ResilienceSettings, include_post: bool = False | ||
) -> adapters.HTTPAdapter: | ||
settings = supply_settings() | ||
allowed_methods = set(Retry.DEFAULT_ALLOWED_METHODS) | ||
if include_post: | ||
allowed_methods.add("POST") | ||
|
||
retries = Retry( | ||
total=settings.total, | ||
backoff_factor=settings.backoff_factor, | ||
status_forcelist=settings.statuses, | ||
allowed_methods=allowed_methods, | ||
) | ||
return adapters.HTTPAdapter(max_retries=retries) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import datetime as dt | ||
from dataclasses import dataclass | ||
|
||
import requests | ||
from requests import auth | ||
|
||
from up42.http import config, http_adapter | ||
|
||
|
||
@dataclass(eq=True, frozen=True) | ||
class Token: | ||
access_token: str | ||
expires_on: dt.datetime | ||
|
||
@property | ||
def has_expired(self) -> bool: | ||
return self.expires_on <= dt.datetime.now() | ||
|
||
|
||
class ProjectAuth(requests.auth.AuthBase): | ||
def __init__( | ||
self, | ||
supply_credentials_settings=config.ProjectCredentialsSettings, | ||
supply_token_settings=config.TokenProviderSettings, | ||
create_adapter=http_adapter.create, | ||
): | ||
credentials_settings = supply_credentials_settings() | ||
token_settings = supply_token_settings() | ||
self.client_id = credentials_settings.client_id | ||
self.client_secret = credentials_settings.client_secret | ||
self.token_url = token_settings.token_url | ||
self.duration = token_settings.duration | ||
self.timeout = token_settings.timeout | ||
self.adapter = create_adapter(include_post=True) | ||
self._token = self._fetch_token() | ||
|
||
def __call__(self, request): | ||
request.headers["Authorization"] = f"Bearer {self.token.access_token}" | ||
return request | ||
|
||
def _fetch_token(self): | ||
basic_auth = auth.HTTPBasicAuth(self.client_id, self.client_secret) | ||
session = requests.Session() | ||
session.mount("https://", self.adapter) | ||
auth_response = session.post( | ||
url=self.token_url, | ||
auth=basic_auth, | ||
data={"grant_type": "client_credentials"}, | ||
timeout=self.timeout, | ||
) | ||
access_token = auth_response.json()["access_token"] | ||
expires_on = dt.datetime.now() + dt.timedelta(seconds=self.duration) | ||
return Token(access_token=access_token, expires_on=expires_on) | ||
|
||
@property | ||
def token(self) -> Token: | ||
if self._token.has_expired: | ||
self._token = self._fetch_token() | ||
return self._token |