diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c523a3..a28c287 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased](https://github.com/at-gmbh/personio-py/compare/v0.2.3...HEAD) -... +* add support for providing a custom `requests.Session` in client + ([#39](https://github.com/at-gmbh/personio-py/pull/39) ## [0.2.3](https://github.com/at-gmbh/personio-py/tree/v0.2.3) - 2023-05-05 diff --git a/src/personio_py/client.py b/src/personio_py/client.py index b26ce2d..d16191f 100644 --- a/src/personio_py/client.py +++ b/src/personio_py/client.py @@ -40,7 +40,8 @@ class Personio: PROJECT_URL = 'company/attendances/projects' def __init__(self, base_url: str = None, client_id: str = None, client_secret: str = None, - dynamic_fields: List[DynamicMapping] = None): + dynamic_fields: List[DynamicMapping] = None, + session: Optional[requests.Session] = None): self.base_url = base_url or self.BASE_URL self.client_id = client_id or os.getenv('CLIENT_ID') self.client_secret = client_secret or os.getenv('CLIENT_SECRET') @@ -48,6 +49,7 @@ def __init__(self, base_url: str = None, client_id: str = None, client_secret: s self.authenticated = False self.dynamic_fields = dynamic_fields self.search_index = SearchIndex(self) + self.session = session or requests.Session() def authenticate(self): """ @@ -68,7 +70,7 @@ def authenticate(self): url = urljoin(self.base_url, 'auth') logger.debug(f"authenticating to {url} with client_id {self.client_id}") params = {"client_id": self.client_id, "client_secret": self.client_secret} - response = requests.request("POST", url, headers=self.headers, params=params) + response = self.session.request("POST", url, headers=self.headers, params=params) if response.ok: token = response.json()['data']['token'] self.headers['Authorization'] = f"Bearer {token}" @@ -107,7 +109,7 @@ def request(self, path: str, method='GET', params: Dict[str, Any] = None, _headers.update(headers) # make the request url = urljoin(self.base_url, path) - response = requests.request(method, url, headers=_headers, params=params, json=data) + response = self.session.request(method, url, headers=_headers, params=params, json=data) # re-new the authorization header authorization = response.headers.get('Authorization') if authorization: diff --git a/tests/test_mock_api.py b/tests/test_mock_api.py index 7562049..886bcc7 100644 --- a/tests/test_mock_api.py +++ b/tests/test_mock_api.py @@ -3,6 +3,7 @@ from typing import Any, Dict import pytest +import requests import responses from personio_py import DynamicMapping, Employee, Personio, PersonioApiError, PersonioError @@ -10,6 +11,18 @@ iso_date_match = re.compile(r'\d\d\d\d-\d\d-\d\d') +@responses.activate +def test_authenticate_ok_with_custom_requests_session(): + # mock a successful authentication response + resp_json = {'success': True, 'data': {'token': 'dummy_token'}} + responses.add(responses.POST, 'https://api.personio.de/v1/auth', json=resp_json, status=200) + # authenticate + personio = Personio(client_id='test', client_secret='test', session=requests.Session()) + personio.authenticate() + # validate + assert personio.authenticated is True + assert personio.headers['Authorization'] == "Bearer dummy_token" + @responses.activate def test_authenticate_ok():