From be56a0ffe6c5fa87be64f49d77450c84c2db4d46 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:47:08 -0500 Subject: [PATCH] chore: Add `X-Goog-Api-Client` metric header to requests (#826) --- firebase_admin/_http_client.py | 5 + firebase_admin/_utils.py | 3 + firebase_admin/app_check.py | 7 +- firebase_admin/storage.py | 7 +- tests/test_auth_providers.py | 190 +++++++++++++------------------ tests/test_db.py | 118 ++++++++----------- tests/test_functions.py | 4 + tests/test_http_client.py | 14 ++- tests/test_instance_id.py | 21 ++-- tests/test_messaging.py | 71 ++++++------ tests/test_ml.py | 74 +++++------- tests/test_project_management.py | 5 +- tests/test_tenant_mgt.py | 13 +++ tests/test_user_mgt.py | 2 + 14 files changed, 255 insertions(+), 279 deletions(-) diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index d259faddf..f1eccbcf2 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -21,6 +21,7 @@ import requests from requests.packages.urllib3.util import retry # pylint: disable=import-error +from firebase_admin import _utils if hasattr(retry.Retry.DEFAULT, 'allowed_methods'): _ANY_METHOD = {'allowed_methods': None} @@ -36,6 +37,9 @@ DEFAULT_TIMEOUT_SECONDS = 120 +METRICS_HEADERS = { + 'X-GOOG-API-CLIENT': _utils.get_metrics_header(), +} class HttpClient: """Base HTTP client used to make HTTP calls. @@ -72,6 +76,7 @@ def __init__( if headers: self._session.headers.update(headers) + self._session.headers.update(METRICS_HEADERS) if retries: self._session.mount('http://', requests.adapters.HTTPAdapter(max_retries=retries)) self._session.mount('https://', requests.adapters.HTTPAdapter(max_retries=retries)) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index dcfb520d2..b6e292546 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -15,6 +15,7 @@ """Internal utilities common to all modules.""" import json +from platform import python_version import google.auth import requests @@ -75,6 +76,8 @@ 16: exceptions.UNAUTHENTICATED, } +def get_metrics_header(): + return f'gl-python/{python_version()} fire-admin/{firebase_admin.__version__}' def _get_initialized_app(app): """Returns a reference to an initialized App instance.""" diff --git a/firebase_admin/app_check.py b/firebase_admin/app_check.py index 6bc10b2f4..e6b66efc1 100644 --- a/firebase_admin/app_check.py +++ b/firebase_admin/app_check.py @@ -51,6 +51,10 @@ class _AppCheckService: _scoped_project_id = None _jwks_client = None + _APP_CHECK_HEADERS = { + 'X-GOOG-API-CLIENT': _utils.get_metrics_header(), + } + def __init__(self, app): # Validate and store the project_id to validate the JWT claims self._project_id = app.project_id @@ -62,7 +66,8 @@ def __init__(self, app): 'GOOGLE_CLOUD_PROJECT environment variable.') self._scoped_project_id = 'projects/' + app.project_id # Default lifespan is 300 seconds (5 minutes) so we change it to 21600 seconds (6 hours). - self._jwks_client = PyJWKClient(self._JWKS_URL, lifespan=21600) + self._jwks_client = PyJWKClient( + self._JWKS_URL, lifespan=21600, headers=self._APP_CHECK_HEADERS) def verify_token(self, token: str) -> Dict[str, Any]: diff --git a/firebase_admin/storage.py b/firebase_admin/storage.py index f3948371c..46f5f6043 100644 --- a/firebase_admin/storage.py +++ b/firebase_admin/storage.py @@ -55,8 +55,13 @@ def bucket(name=None, app=None) -> storage.Bucket: class _StorageClient: """Holds a Google Cloud Storage client instance.""" + STORAGE_HEADERS = { + 'X-GOOG-API-CLIENT': _utils.get_metrics_header(), + } + def __init__(self, credentials, project, default_bucket): - self._client = storage.Client(credentials=credentials, project=project) + self._client = storage.Client( + credentials=credentials, project=project, extra_headers=self.STORAGE_HEADERS) self._default_bucket = default_bucket @classmethod diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py index a5716266c..48f38a011 100644 --- a/tests/test_auth_providers.py +++ b/tests/test_auth_providers.py @@ -21,6 +21,7 @@ import firebase_admin from firebase_admin import auth from firebase_admin import exceptions +from firebase_admin import _utils from tests import testutils ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v2' @@ -70,6 +71,11 @@ def _instrument_provider_mgt(app, status, payload): testutils.MockAdapter(payload, status, recorder)) return recorder +def _assert_request(request, expected_method, expected_url): + assert request.method == expected_method + assert request.url == expected_url + assert request.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() class TestOIDCProviderConfig: @@ -110,9 +116,8 @@ def test_get(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], '/oauthIdpConfigs/oidc.provider') + _assert_request( + recorder[0], 'GET', f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs/oidc.provider') @pytest.mark.parametrize('invalid_opts', [ {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'saml.provider'}, @@ -140,11 +145,9 @@ def test_create(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'POST' - assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'POST', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider') + got = json.loads(recorder[0].body.decode()) assert got == self.OIDC_CONFIG_REQUEST def test_create_minimal(self, user_mgt_app): @@ -165,11 +168,9 @@ def test_create_minimal(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'POST' - assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'POST', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider') + got = json.loads(recorder[0].body.decode()) assert got == want def test_create_empty_values(self, user_mgt_app): @@ -191,11 +192,9 @@ def test_create_empty_values(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'POST' - assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'POST', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider') + got = json.loads(recorder[0].body.decode()) assert got == want @pytest.mark.parametrize('invalid_opts', [ @@ -225,13 +224,12 @@ def test_update(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'PATCH' mask = ['clientId', 'clientSecret', 'displayName', 'enabled', 'issuer', 'responseType.code', 'responseType.idToken'] - assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask={1}'.format( - USER_MGT_URLS['PREFIX'], ','.join(mask)) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'PATCH', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs/oidc.provider?' + f'updateMask={",".join(mask)}') + got = json.loads(recorder[0].body.decode()) assert got == self.OIDC_CONFIG_REQUEST def test_update_minimal(self, user_mgt_app): @@ -242,11 +240,10 @@ def test_update_minimal(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'PATCH' - assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask=displayName'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'PATCH', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs/oidc.provider?' + f'updateMask=displayName') + got = json.loads(recorder[0].body.decode()) assert got == {'displayName': 'oidcProviderName'} def test_update_empty_values(self, user_mgt_app): @@ -258,12 +255,11 @@ def test_update_empty_values(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'PATCH' mask = ['displayName', 'enabled', 'responseType.idToken'] - assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask={1}'.format( - USER_MGT_URLS['PREFIX'], ','.join(mask)) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'PATCH', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs/oidc.provider?' + f'updateMask={",".join(mask)}') + got = json.loads(recorder[0].body.decode()) assert got == {'displayName': None, 'enabled': False, 'responseType': {'idToken': False}} @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['saml.provider']) @@ -279,9 +275,8 @@ def test_delete(self, user_mgt_app): auth.delete_oidc_provider_config('oidc.provider', app=user_mgt_app) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'DELETE' - assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], '/oauthIdpConfigs/oidc.provider') + _assert_request(recorder[0], 'DELETE', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs/oidc.provider') @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) def test_invalid_max_results(self, user_mgt_app, arg): @@ -302,9 +297,8 @@ def test_list_single_page(self, user_mgt_app): assert len(provider_configs) == 2 assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], '/oauthIdpConfigs?pageSize=100') + _assert_request(recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?pageSize=100') def test_list_multiple_pages(self, user_mgt_app): sample_response = json.loads(OIDC_PROVIDER_CONFIG_RESPONSE) @@ -320,9 +314,8 @@ def test_list_multiple_pages(self, user_mgt_app): self._assert_page(page, next_page_token='token') assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/oauthIdpConfigs?pageSize=10'.format(USER_MGT_URLS['PREFIX']) + _assert_request(recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?pageSize=10') # Page 2 (also the last page) response = {'oauthIdpConfigs': configs[2:]} @@ -331,10 +324,8 @@ def test_list_multiple_pages(self, user_mgt_app): self._assert_page(page, count=1, start=2) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/oauthIdpConfigs?pageSize=10&pageToken=token'.format( - USER_MGT_URLS['PREFIX']) + _assert_request(recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?pageSize=10&pageToken=token') def test_paged_iteration(self, user_mgt_app): sample_response = json.loads(OIDC_PROVIDER_CONFIG_RESPONSE) @@ -353,9 +344,8 @@ def test_paged_iteration(self, user_mgt_app): provider_config = next(iterator) assert provider_config.provider_id == 'oidc.provider{0}'.format(index) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/oauthIdpConfigs?pageSize=100'.format(USER_MGT_URLS['PREFIX']) + _assert_request(recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?pageSize=100') # Page 2 (also the last page) response = {'oauthIdpConfigs': configs[2:]} @@ -364,10 +354,8 @@ def test_paged_iteration(self, user_mgt_app): provider_config = next(iterator) assert provider_config.provider_id == 'oidc.provider2' assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/oauthIdpConfigs?pageSize=100&pageToken=token'.format( - USER_MGT_URLS['PREFIX']) + _assert_request(recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?pageSize=100&pageToken=token') with pytest.raises(StopIteration): next(iterator) @@ -464,10 +452,8 @@ def test_get(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], - '/inboundSamlConfigs/saml.provider') + _assert_request(recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs/saml.provider') @pytest.mark.parametrize('invalid_opts', [ {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'oidc.provider'}, @@ -494,11 +480,10 @@ def test_create(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'POST' - assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'POST', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?' + f'inboundSamlConfigId=saml.provider') + got = json.loads(recorder[0].body.decode()) assert got == self.SAML_CONFIG_REQUEST def test_create_minimal(self, user_mgt_app): @@ -514,11 +499,10 @@ def test_create_minimal(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'POST' - assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'POST', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?' + f'inboundSamlConfigId=saml.provider') + got = json.loads(recorder[0].body.decode()) assert got == want def test_create_empty_values(self, user_mgt_app): @@ -534,11 +518,10 @@ def test_create_empty_values(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'POST' - assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'POST', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?' + f'inboundSamlConfigId=saml.provider') + got = json.loads(recorder[0].body.decode()) assert got == want @pytest.mark.parametrize('invalid_opts', [ @@ -567,15 +550,14 @@ def test_update(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'PATCH' mask = [ 'displayName', 'enabled', 'idpConfig.idpCertificates', 'idpConfig.idpEntityId', 'idpConfig.ssoUrl', 'spConfig.callbackUri', 'spConfig.spEntityId', ] - assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask={1}'.format( - USER_MGT_URLS['PREFIX'], ','.join(mask)) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'PATCH', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs/saml.provider?' + f'updateMask={",".join(mask)}') + got = json.loads(recorder[0].body.decode()) assert got == self.SAML_CONFIG_REQUEST def test_update_minimal(self, user_mgt_app): @@ -586,11 +568,10 @@ def test_update_minimal(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'PATCH' - assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask=displayName'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'PATCH', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs/saml.provider?' + f'updateMask=displayName') + got = json.loads(recorder[0].body.decode()) assert got == {'displayName': 'samlProviderName'} def test_update_empty_values(self, user_mgt_app): @@ -601,12 +582,11 @@ def test_update_empty_values(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'PATCH' mask = ['displayName', 'enabled'] - assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask={1}'.format( - USER_MGT_URLS['PREFIX'], ','.join(mask)) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'PATCH', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs/saml.provider?' + f'updateMask={",".join(mask)}') + got = json.loads(recorder[0].body.decode()) assert got == {'displayName': None, 'enabled': False} @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['oidc.provider']) @@ -622,10 +602,8 @@ def test_delete(self, user_mgt_app): auth.delete_saml_provider_config('saml.provider', app=user_mgt_app) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'DELETE' - assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], - '/inboundSamlConfigs/saml.provider') + _assert_request( + recorder[0], 'DELETE', f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs/saml.provider') def test_config_not_found(self, user_mgt_app): _instrument_provider_mgt(user_mgt_app, 500, CONFIG_NOT_FOUND_RESPONSE) @@ -658,10 +636,8 @@ def test_list_single_page(self, user_mgt_app): assert len(provider_configs) == 2 assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], - '/inboundSamlConfigs?pageSize=100') + _assert_request( + recorder[0], 'GET', f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?pageSize=100') def test_list_multiple_pages(self, user_mgt_app): sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) @@ -677,9 +653,8 @@ def test_list_multiple_pages(self, user_mgt_app): self._assert_page(page, next_page_token='token') assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/inboundSamlConfigs?pageSize=10'.format(USER_MGT_URLS['PREFIX']) + _assert_request( + recorder[0], 'GET', f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?pageSize=10') # Page 2 (also the last page) response = {'inboundSamlConfigs': configs[2:]} @@ -688,10 +663,9 @@ def test_list_multiple_pages(self, user_mgt_app): self._assert_page(page, count=1, start=2) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/inboundSamlConfigs?pageSize=10&pageToken=token'.format( - USER_MGT_URLS['PREFIX']) + _assert_request( + recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?pageSize=10&pageToken=token') def test_paged_iteration(self, user_mgt_app): sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) @@ -710,9 +684,8 @@ def test_paged_iteration(self, user_mgt_app): provider_config = next(iterator) assert provider_config.provider_id == 'saml.provider{0}'.format(index) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/inboundSamlConfigs?pageSize=100'.format(USER_MGT_URLS['PREFIX']) + _assert_request( + recorder[0], 'GET', f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?pageSize=100') # Page 2 (also the last page) response = {'inboundSamlConfigs': configs[2:]} @@ -721,10 +694,9 @@ def test_paged_iteration(self, user_mgt_app): provider_config = next(iterator) assert provider_config.provider_id == 'saml.provider2' assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/inboundSamlConfigs?pageSize=100&pageToken=token'.format( - USER_MGT_URLS['PREFIX']) + _assert_request( + recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?pageSize=100&pageToken=token') with pytest.raises(StopIteration): next(iterator) diff --git a/tests/test_db.py b/tests/test_db.py index aa2c83bd9..4245f65fb 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -193,16 +193,20 @@ def instrument(self, ref, payload, status=200, etag=MockAdapter.ETAG): ref._client.session.mount(self.test_url, adapter) return recorder + def _assert_request(self, request, expected_method, expected_url): + assert request.method == expected_method + assert request.url == expected_url + assert request.headers['Authorization'] == 'Bearer mock-token' + assert request.headers['User-Agent'] == db._USER_AGENT + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + @pytest.mark.parametrize('data', valid_values) def test_get_value(self, data): ref = db.reference('/test') recorder = self.instrument(ref, json.dumps(data)) assert ref.get() == data assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + self._assert_request(recorder[0], 'GET', 'https://test.firebaseio.com/test.json') assert 'X-Firebase-ETag' not in recorder[0].headers @pytest.mark.parametrize('data', valid_values) @@ -211,10 +215,7 @@ def test_get_with_etag(self, data): recorder = self.instrument(ref, json.dumps(data)) assert ref.get(etag=True) == (data, MockAdapter.ETAG) assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + self._assert_request(recorder[0], 'GET', 'https://test.firebaseio.com/test.json') assert recorder[0].headers['X-Firebase-ETag'] == 'true' @pytest.mark.parametrize('data', valid_values) @@ -223,10 +224,8 @@ def test_get_shallow(self, data): recorder = self.instrument(ref, json.dumps(data)) assert ref.get(shallow=True) == data assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?shallow=true' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + self._assert_request( + recorder[0], 'GET', 'https://test.firebaseio.com/test.json?shallow=true') def test_get_with_etag_and_shallow(self): ref = db.reference('/test') @@ -240,14 +239,12 @@ def test_get_if_changed(self, data): assert ref.get_if_changed('invalid-etag') == (True, data, MockAdapter.ETAG) assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' + self._assert_request(recorder[0], 'GET', 'https://test.firebaseio.com/test.json') assert recorder[0].headers['if-none-match'] == 'invalid-etag' assert ref.get_if_changed(MockAdapter.ETAG) == (False, None, None) assert len(recorder) == 2 - assert recorder[1].method == 'GET' - assert recorder[1].url == 'https://test.firebaseio.com/test.json' + self._assert_request(recorder[1], 'GET', 'https://test.firebaseio.com/test.json') assert recorder[1].headers['if-none-match'] == MockAdapter.ETAG @pytest.mark.parametrize('etag', [0, 1, True, False, dict(), list(), tuple()]) @@ -264,9 +261,8 @@ def test_order_by_query(self, data): query_str = 'orderBy=%22foo%22' assert query.get() == data assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + self._assert_request( + recorder[0], 'GET', 'https://test.firebaseio.com/test.json?' + query_str) @pytest.mark.parametrize('data', valid_values) def test_limit_query(self, data): @@ -277,9 +273,8 @@ def test_limit_query(self, data): query_str = 'limitToFirst=100&orderBy=%22foo%22' assert query.get() == data assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + self._assert_request( + recorder[0], 'GET', 'https://test.firebaseio.com/test.json?' + query_str) @pytest.mark.parametrize('data', valid_values) def test_range_query(self, data): @@ -291,9 +286,8 @@ def test_range_query(self, data): query_str = 'endAt=200&orderBy=%22foo%22&startAt=100' assert query.get() == data assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + self._assert_request( + recorder[0], 'GET', 'https://test.firebaseio.com/test.json?' + query_str) @pytest.mark.parametrize('data', valid_values) def test_set_value(self, data): @@ -301,10 +295,9 @@ def test_set_value(self, data): recorder = self.instrument(ref, '') ref.set(data) assert len(recorder) == 1 - assert recorder[0].method == 'PUT' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' + self._assert_request( + recorder[0], 'PUT', 'https://test.firebaseio.com/test.json?print=silent') assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' def test_set_none_value(self): ref = db.reference('/test') @@ -327,10 +320,9 @@ def test_update_children(self, data): recorder = self.instrument(ref, json.dumps(data)) ref.update(data) assert len(recorder) == 1 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' + self._assert_request( + recorder[0], 'PATCH', 'https://test.firebaseio.com/test.json?print=silent') assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' @pytest.mark.parametrize('data', valid_values) def test_set_if_unchanged_success(self, data): @@ -339,10 +331,8 @@ def test_set_if_unchanged_success(self, data): vals = ref.set_if_unchanged(MockAdapter.ETAG, data) assert vals == (True, data, MockAdapter.ETAG) assert len(recorder) == 1 - assert recorder[0].method == 'PUT' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' + self._assert_request(recorder[0], 'PUT', 'https://test.firebaseio.com/test.json') assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['if-match'] == MockAdapter.ETAG @pytest.mark.parametrize('data', valid_values) @@ -352,10 +342,8 @@ def test_set_if_unchanged_failure(self, data): vals = ref.set_if_unchanged('invalid-etag', data) assert vals == (False, {'foo':'bar'}, MockAdapter.ETAG) assert len(recorder) == 1 - assert recorder[0].method == 'PUT' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' + self._assert_request(recorder[0], 'PUT', 'https://test.firebaseio.com/test.json') assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['if-match'] == 'invalid-etag' @pytest.mark.parametrize('etag', [0, 1, True, False, dict(), list(), tuple()]) @@ -397,22 +385,16 @@ def test_push(self, data): assert isinstance(child, db.Reference) assert child.key == 'testkey' assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' + self._assert_request(recorder[0], 'POST', 'https://test.firebaseio.com/test.json') assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT def test_push_default(self): ref = db.reference('/test') recorder = self.instrument(ref, json.dumps({'name' : 'testkey'})) assert ref.push().key == 'testkey' assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' + self._assert_request(recorder[0], 'POST', 'https://test.firebaseio.com/test.json') assert json.loads(recorder[0].body.decode()) == '' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT def test_push_none_value(self): ref = db.reference('/test') @@ -425,10 +407,7 @@ def test_delete(self): recorder = self.instrument(ref, '') ref.delete() assert len(recorder) == 1 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + self._assert_request(recorder[0], 'DELETE', 'https://test.firebaseio.com/test.json') def test_transaction(self): ref = db.reference('/test') @@ -442,8 +421,8 @@ def transaction_update(data): new_value = ref.transaction(transaction_update) assert new_value == {'foo1' : 'bar1', 'foo2' : 'bar2'} assert len(recorder) == 2 - assert recorder[0].method == 'GET' - assert recorder[1].method == 'PUT' + self._assert_request(recorder[0], 'GET', 'https://test.firebaseio.com/test.json') + self._assert_request(recorder[1], 'PUT', 'https://test.firebaseio.com/test.json') assert json.loads(recorder[1].body.decode()) == {'foo1': 'bar1', 'foo2': 'bar2'} def test_transaction_scalar(self): @@ -454,8 +433,8 @@ def test_transaction_scalar(self): new_value = ref.transaction(lambda x: x + 1 if x else 1) assert new_value == 43 assert len(recorder) == 2 - assert recorder[0].method == 'GET' - assert recorder[1].method == 'PUT' + self._assert_request(recorder[0], 'GET', 'https://test.firebaseio.com/test/count.json') + self._assert_request(recorder[1], 'PUT', 'https://test.firebaseio.com/test/count.json') assert json.loads(recorder[1].body.decode()) == 43 def test_transaction_error(self): @@ -471,7 +450,7 @@ def transaction_update(data): ref.transaction(transaction_update) assert str(excinfo.value) == 'test error' assert len(recorder) == 1 - assert recorder[0].method == 'GET' + self._assert_request(recorder[0], 'GET', 'https://test.firebaseio.com/test.json') def test_transaction_abort(self): ref = db.reference('/test/count') @@ -638,16 +617,21 @@ def instrument(self, ref, payload, status=200): ref._client.session.mount(self.test_url, adapter) return recorder + def _assert_request(self, request, expected_method, expected_url): + assert request.method == expected_method + assert request.url == expected_url + assert request.headers['Authorization'] == 'Bearer mock-token' + assert request.headers['User-Agent'] == db._USER_AGENT + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + def test_get_value(self): ref = db.reference('/test') recorder = self.instrument(ref, json.dumps('data')) query_str = 'auth_variable_override={0}'.format(self.encoded_override) assert ref.get() == 'data' assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + self._assert_request( + recorder[0], 'GET', 'https://test.firebaseio.com/test.json?' + query_str) def test_set_value(self): ref = db.reference('/test') @@ -656,11 +640,9 @@ def test_set_value(self): ref.set(data) query_str = 'print=silent&auth_variable_override={0}'.format(self.encoded_override) assert len(recorder) == 1 - assert recorder[0].method == 'PUT' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + self._assert_request( + recorder[0], 'PUT', 'https://test.firebaseio.com/test.json?' + query_str) assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT def test_order_by_query(self): ref = db.reference('/test') @@ -669,10 +651,8 @@ def test_order_by_query(self): query_str = 'orderBy=%22foo%22&auth_variable_override={0}'.format(self.encoded_override) assert query.get() == 'data' assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + self._assert_request( + recorder[0], 'GET', 'https://test.firebaseio.com/test.json?' + query_str) def test_range_query(self): ref = db.reference('/test') @@ -682,10 +662,8 @@ def test_range_query(self): 'auth_variable_override={0}'.format(self.encoded_override)) assert query.get() == 'data' assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + self._assert_request( + recorder[0], 'GET', 'https://test.firebaseio.com/test.json?' + query_str) class TestDatabaseInitialization: diff --git a/tests/test_functions.py b/tests/test_functions.py index 75809c1ad..f8f675890 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -21,6 +21,7 @@ import firebase_admin from firebase_admin import functions +from firebase_admin import _utils from tests import testutils @@ -121,6 +122,7 @@ def test_task_enqueue(self): assert recorder[0].url == _DEFAULT_REQUEST_URL assert recorder[0].headers['Content-Type'] == 'application/json' assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() assert task_id == 'test-task-id' def test_task_enqueue_with_extension(self): @@ -137,6 +139,7 @@ def test_task_enqueue_with_extension(self): assert recorder[0].url == _CLOUD_TASKS_URL + resource_name assert recorder[0].headers['Content-Type'] == 'application/json' assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() assert task_id == 'test-task-id' def test_task_delete(self): @@ -146,6 +149,7 @@ def test_task_delete(self): assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == _DEFAULT_TASK_URL + assert recorder[0].headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() class TestTaskQueueOptions: diff --git a/tests/test_http_client.py b/tests/test_http_client.py index 12ba03b48..cc948b393 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -17,7 +17,7 @@ from pytest_localserver import http import requests -from firebase_admin import _http_client +from firebase_admin import _http_client, _utils from tests import testutils @@ -61,6 +61,18 @@ def test_base_url(): assert recorder[0].method == 'GET' assert recorder[0].url == _TEST_URL + 'foo' +def test_metrics_headers(): + client = _http_client.HttpClient() + assert client.session is not None + recorder = _instrument(client, 'body') + resp = client.request('get', _TEST_URL) + assert resp.status_code == 200 + assert resp.text == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + assert recorder[0].headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + def test_credential(): client = _http_client.HttpClient( credential=testutils.MockGoogleCredential()) diff --git a/tests/test_instance_id.py b/tests/test_instance_id.py index 08b0fe6db..720171cd9 100644 --- a/tests/test_instance_id.py +++ b/tests/test_instance_id.py @@ -20,6 +20,7 @@ from firebase_admin import exceptions from firebase_admin import instance_id from firebase_admin import _http_client +from firebase_admin import _utils from tests import testutils @@ -64,6 +65,11 @@ def _instrument_iid_service(self, app, status=200, payload='True'): testutils.MockAdapter(payload, status, recorder)) return iid_service, recorder + def _assert_request(self, request, expected_method, expected_url): + assert request.method == expected_method + assert request.url == expected_url + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + def _get_url(self, project_id, iid): return instance_id._IID_SERVICE_URL + 'project/{0}/instanceId/{1}'.format(project_id, iid) @@ -86,8 +92,8 @@ def test_delete_instance_id(self): _, recorder = self._instrument_iid_service(app) instance_id.delete_instance_id('test_iid') assert len(recorder) == 1 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == self._get_url('explicit-project-id', 'test_iid') + self._assert_request( + recorder[0], 'DELETE', self._get_url('explicit-project-id', 'test_iid')) def test_delete_instance_id_with_explicit_app(self): cred = testutils.MockCredential() @@ -95,8 +101,8 @@ def test_delete_instance_id_with_explicit_app(self): _, recorder = self._instrument_iid_service(app) instance_id.delete_instance_id('test_iid', app) assert len(recorder) == 1 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == self._get_url('explicit-project-id', 'test_iid') + self._assert_request( + recorder[0], 'DELETE', self._get_url('explicit-project-id', 'test_iid')) @pytest.mark.parametrize('status', http_errors.keys()) def test_delete_instance_id_error(self, status): @@ -114,8 +120,8 @@ def test_delete_instance_id_error(self, status): else: # 401 responses are automatically retried by google-auth assert len(recorder) == 3 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == self._get_url('explicit-project-id', 'test_iid') + self._assert_request( + recorder[0], 'DELETE', self._get_url('explicit-project-id', 'test_iid')) def test_delete_instance_id_unexpected_error(self): cred = testutils.MockCredential() @@ -129,8 +135,7 @@ def test_delete_instance_id_unexpected_error(self): assert excinfo.value.cause is not None assert excinfo.value.http_response is not None assert len(recorder) == 1 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == url + self._assert_request(recorder[0], 'DELETE', url) @pytest.mark.parametrize('iid', [None, '', 0, 1, True, False, list(), dict(), tuple()]) def test_invalid_instance_id(self, iid): diff --git a/tests/test_messaging.py b/tests/test_messaging.py index d482438f5..edb36f53a 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -25,6 +25,7 @@ from firebase_admin import exceptions from firebase_admin import messaging from firebase_admin import _http_client +from firebase_admin import _utils from tests import testutils @@ -1660,6 +1661,18 @@ def _instrument_messaging_service(self, app=None, status=200, payload=_DEFAULT_R testutils.MockAdapter(payload, status, recorder)) return fcm_service, recorder + + def _assert_request(self, request, expected_method, expected_url, expected_body=None): + assert request.method == expected_method + assert request.url == expected_url + assert request.headers['X-GOOG-API-FORMAT-VERSION'] == '2' + assert request.headers['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + if expected_body is None: + assert request.body is None + else: + assert json.loads(request.body.decode()) == expected_body + def _get_url(self, project_id): return messaging._MessagingService.FCM_URL.format(project_id) @@ -1682,15 +1695,11 @@ def test_send_dry_run(self): msg_id = messaging.send(msg, dry_run=True) assert msg_id == 'message-id' assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('explicit-project-id') - assert recorder[0].headers['X-GOOG-API-FORMAT-VERSION'] == '2' - assert recorder[0].headers['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION body = { 'message': messaging._MessagingService.encode_message(msg), 'validate_only': True, } - assert json.loads(recorder[0].body.decode()) == body + self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) def test_send(self): _, recorder = self._instrument_messaging_service() @@ -1698,12 +1707,8 @@ def test_send(self): msg_id = messaging.send(msg) assert msg_id == 'message-id' assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('explicit-project-id') - assert recorder[0].headers['X-GOOG-API-FORMAT-VERSION'] == '2' - assert recorder[0].headers['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION body = {'message': messaging._MessagingService.encode_message(msg)} - assert json.loads(recorder[0].body.decode()) == body + self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) @pytest.mark.parametrize('status,exc_type', HTTP_ERROR_CODES.items()) def test_send_error(self, status, exc_type): @@ -1714,12 +1719,8 @@ def test_send_error(self, status, exc_type): expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) check_exception(excinfo.value, expected, status) assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('explicit-project-id') - assert recorder[0].headers['X-GOOG-API-FORMAT-VERSION'] == '2' - assert recorder[0].headers['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} - assert json.loads(recorder[0].body.decode()) == body + self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_detailed_error(self, status): @@ -1735,10 +1736,8 @@ def test_send_detailed_error(self, status): messaging.send(msg) check_exception(excinfo.value, 'test error', status) assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('explicit-project-id') body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} - assert json.loads(recorder[0].body.decode()) == body + self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_canonical_error_code(self, status): @@ -1754,10 +1753,8 @@ def test_send_canonical_error_code(self, status): messaging.send(msg) check_exception(excinfo.value, 'test error', status) assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('explicit-project-id') body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} - assert json.loads(recorder[0].body.decode()) == body + self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) @pytest.mark.parametrize('status', HTTP_ERROR_CODES) @pytest.mark.parametrize('fcm_error_code, exc_type', FCM_ERROR_CODES.items()) @@ -1780,10 +1777,8 @@ def test_send_fcm_error_code(self, status, fcm_error_code, exc_type): messaging.send(msg) check_exception(excinfo.value, 'test error', status) assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('explicit-project-id') body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} - assert json.loads(recorder[0].body.decode()) == body + self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_unknown_fcm_error_code(self, status): @@ -1805,10 +1800,8 @@ def test_send_unknown_fcm_error_code(self, status): messaging.send(msg) check_exception(excinfo.value, 'test error', status) assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('explicit-project-id') body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} - assert json.loads(recorder[0].body.decode()) == body + self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) class _HttpMockException: @@ -2591,6 +2584,12 @@ def _instrument_iid_service(self, app=None, status=200, payload=_DEFAULT_RESPONS testutils.MockAdapter(payload, status, recorder)) return fcm_service, recorder + def _assert_request(self, request, expected_method, expected_url): + assert request.method == expected_method + assert request.url == expected_url + assert request.headers['access_token_auth'] == 'true' + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + def _get_url(self, path): return '{0}/{1}'.format(messaging._MessagingService.IID_URL, path) @@ -2625,8 +2624,7 @@ def test_subscribe_to_topic(self, args): resp = messaging.subscribe_to_topic(args[0], args[1]) self._check_response(resp) assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('iid/v1:batchAdd') + self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchAdd')) assert json.loads(recorder[0].body.decode()) == args[2] @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) @@ -2637,8 +2635,7 @@ def test_subscribe_to_topic_error(self, status, exc_type): messaging.subscribe_to_topic('foo', 'test-topic') assert str(excinfo.value) == 'Error while calling the IID service: error_reason' assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('iid/v1:batchAdd') + self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchAdd')) @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) def test_subscribe_to_topic_non_json_error(self, status, exc_type): @@ -2648,8 +2645,7 @@ def test_subscribe_to_topic_non_json_error(self, status, exc_type): reason = 'Unexpected HTTP response with status: {0}; body: not json'.format(status) assert str(excinfo.value) == reason assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('iid/v1:batchAdd') + self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchAdd')) @pytest.mark.parametrize('args', _VALID_ARGS) def test_unsubscribe_from_topic(self, args): @@ -2657,8 +2653,7 @@ def test_unsubscribe_from_topic(self, args): resp = messaging.unsubscribe_from_topic(args[0], args[1]) self._check_response(resp) assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('iid/v1:batchRemove') + self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchRemove')) assert json.loads(recorder[0].body.decode()) == args[2] @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) @@ -2669,8 +2664,7 @@ def test_unsubscribe_from_topic_error(self, status, exc_type): messaging.unsubscribe_from_topic('foo', 'test-topic') assert str(excinfo.value) == 'Error while calling the IID service: error_reason' assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('iid/v1:batchRemove') + self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchRemove')) @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) def test_unsubscribe_from_topic_non_json_error(self, status, exc_type): @@ -2680,8 +2674,7 @@ def test_unsubscribe_from_topic_non_json_error(self, status, exc_type): reason = 'Unexpected HTTP response with status: {0}; body: not json'.format(status) assert str(excinfo.value) == reason assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('iid/v1:batchRemove') + self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchRemove')) def _check_response(self, resp): assert resp.success_count == 1 diff --git a/tests/test_ml.py b/tests/test_ml.py index abd6d06f9..137fe4cf6 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -21,12 +21,11 @@ import firebase_admin from firebase_admin import exceptions from firebase_admin import ml +from firebase_admin import _utils from tests import testutils BASE_URL = 'https://firebaseml.googleapis.com/v1beta2/' -HEADER_CLIENT_KEY = 'X-FIREBASE-CLIENT' -HEADER_CLIENT_VALUE = 'fire-admin-python/{0}'.format(firebase_admin.__version__) PROJECT_ID = 'my-project-1' PAGE_TOKEN = 'pageToken' @@ -336,6 +335,12 @@ def instrument_ml_service(status=200, payload=None, operations=False, app=None): session_url, adapter(payload, status, recorder)) return recorder +def _assert_request(request, expected_method, expected_url): + assert request.method == expected_method + assert request.url == expected_url + assert request.headers['X-FIREBASE-CLIENT'] == f'fire-admin-python/{firebase_admin.__version__}' + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + class _TestStorageClient: @staticmethod def upload(bucket_name, model_file_name, app): @@ -599,9 +604,7 @@ def test_wait_for_unlocked(self): model.wait_for_unlocked() assert model == FULL_MODEL_PUBLISHED assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == TestModel._op_url(PROJECT_ID) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'GET', TestModel._op_url(PROJECT_ID)) def test_wait_for_unlocked_timeout(self): recorder = instrument_ml_service( @@ -653,12 +656,8 @@ def test_returns_locked(self): assert model == expected_model assert len(recorder) == 2 - assert recorder[0].method == 'POST' - assert recorder[0].url == TestCreateModel._url(PROJECT_ID) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE - assert recorder[1].method == 'GET' - assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) - assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'POST', TestCreateModel._url(PROJECT_ID)) + _assert_request(recorder[1], 'GET', TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1)) def test_operation_error(self): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) @@ -747,12 +746,8 @@ def test_returns_locked(self): assert model == expected_model assert len(recorder) == 2 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE - assert recorder[1].method == 'GET' - assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) - assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'PATCH', TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)) + _assert_request(recorder[1], 'GET', TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)) def test_operation_error(self): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) @@ -846,9 +841,8 @@ def test_immediate_done(self, publish_function, published): model = publish_function(MODEL_ID_1) assert model == CREATED_UPDATED_MODEL_1 assert len(recorder) == 1 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request( + recorder[0], 'PATCH', TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)) body = json.loads(recorder[0].body.decode()) assert body.get('state', {}).get('published', None) is published @@ -862,12 +856,10 @@ def test_returns_locked(self, publish_function): assert model == expected_model assert len(recorder) == 2 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE - assert recorder[1].method == 'GET' - assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1) - assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request( + recorder[0], 'PATCH', TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)) + _assert_request( + recorder[1], 'GET', TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1)) @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_operation_error(self, publish_function): @@ -918,9 +910,7 @@ def test_get_model(self): recorder = instrument_ml_service(status=200, payload=DEFAULT_GET_RESPONSE) model = ml.get_model(MODEL_ID_1) assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'GET', TestGetModel._url(PROJECT_ID, MODEL_ID_1)) assert model == MODEL_1 assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 @@ -942,9 +932,7 @@ def test_get_model_error(self): ERROR_MSG_NOT_FOUND ) assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'GET', TestGetModel._url(PROJECT_ID, MODEL_ID_1)) def test_no_project_id(self): def evaluate(): @@ -973,9 +961,7 @@ def test_delete_model(self): recorder = instrument_ml_service(status=200, payload=EMPTY_RESPONSE) ml.delete_model(MODEL_ID_1) # no response for delete assert len(recorder) == 1 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'DELETE', TestDeleteModel._url(PROJECT_ID, MODEL_ID_1)) @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_delete_model_validation_errors(self, model_id, exc_type): @@ -994,9 +980,7 @@ def test_delete_model_error(self): ERROR_MSG_NOT_FOUND ) assert len(recorder) == 1 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'DELETE', self._url(PROJECT_ID, MODEL_ID_1)) def test_no_project_id(self): def evaluate(): @@ -1032,9 +1016,7 @@ def test_list_models_no_args(self): recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) models_page = ml.list_models() assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == TestListModels._url(PROJECT_ID) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'GET', TestListModels._url(PROJECT_ID)) TestListModels._check_page(models_page, 2) assert models_page.has_next_page assert models_page.next_page_token == NEXT_PAGE_TOKEN @@ -1048,12 +1030,10 @@ def test_list_models_with_all_args(self): page_size=10, page_token=PAGE_TOKEN) assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == ( + _assert_request(recorder[0], 'GET', ( TestListModels._url(PROJECT_ID) + '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' - .format(PAGE_TOKEN)) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + .format(PAGE_TOKEN))) assert isinstance(models_page, ml.ListModelsPage) assert len(models_page.models) == 1 assert models_page.models[0] == MODEL_3 @@ -1097,9 +1077,7 @@ def test_list_models_error(self): ERROR_MSG_BAD_REQUEST ) assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == TestListModels._url(PROJECT_ID) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'GET', TestListModels._url(PROJECT_ID)) def test_no_project_id(self): def evaluate(): diff --git a/tests/test_project_management.py b/tests/test_project_management.py index 183195510..0a1bf97e5 100644 --- a/tests/test_project_management.py +++ b/tests/test_project_management.py @@ -23,6 +23,7 @@ from firebase_admin import exceptions from firebase_admin import project_management from firebase_admin import _http_client +from firebase_admin import _utils from tests import testutils OPERATION_IN_PROGRESS_RESPONSE = json.dumps({ @@ -521,8 +522,8 @@ def _assert_request_is_correct( self, request, expected_method, expected_url, expected_body=None): assert request.method == expected_method assert request.url == expected_url - client_version = 'Python/Admin/{0}'.format(firebase_admin.__version__) - assert request.headers['X-Client-Version'] == client_version + assert request.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() if expected_body is None: assert request.body is None else: diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py index 53b766239..1da6d938a 100644 --- a/tests/test_tenant_mgt.py +++ b/tests/test_tenant_mgt.py @@ -26,6 +26,7 @@ from firebase_admin import tenant_mgt from firebase_admin import _auth_providers from firebase_admin import _user_mgt +from firebase_admin import _utils from tests import testutils from tests import test_token_gen @@ -195,6 +196,8 @@ def test_get_tenant(self, tenant_mgt_app): req = recorder[0] assert req.method == 'GET' assert req.url == '{0}/tenants/tenant-id'.format(TENANT_MGT_URL_PREFIX) + assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() def test_tenant_not_found(self, tenant_mgt_app): _instrument_tenant_mgt(tenant_mgt_app, 500, TENANT_NOT_FOUND_RESPONSE) @@ -285,6 +288,8 @@ def _assert_request(self, recorder, body): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}/tenants'.format(TENANT_MGT_URL_PREFIX) + assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() got = json.loads(req.body.decode()) assert got == body @@ -383,6 +388,8 @@ def _assert_request(self, recorder, body, mask): assert req.method == 'PATCH' assert req.url == '{0}/tenants/tenant-id?updateMask={1}'.format( TENANT_MGT_URL_PREFIX, ','.join(mask)) + assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() got = json.loads(req.body.decode()) assert got == body @@ -403,6 +410,8 @@ def test_delete_tenant(self, tenant_mgt_app): req = recorder[0] assert req.method == 'DELETE' assert req.url == '{0}/tenants/tenant-id'.format(TENANT_MGT_URL_PREFIX) + assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() def test_tenant_not_found(self, tenant_mgt_app): _instrument_tenant_mgt(tenant_mgt_app, 500, TENANT_NOT_FOUND_RESPONSE) @@ -545,6 +554,8 @@ def _assert_request(self, recorder, expected=None): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' + assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() request = dict(parse.parse_qsl(parse.urlsplit(req.url).query)) assert request == expected @@ -920,6 +931,8 @@ def _assert_request( req = recorder[0] assert req.method == method assert req.url == '{0}/tenants/tenant-id{1}'.format(prefix, want_url) + assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() body = json.loads(req.body.decode()) assert body == want_body diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index ea9c87e6f..604ec9959 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -28,6 +28,7 @@ from firebase_admin import _http_client from firebase_admin import _user_import from firebase_admin import _user_mgt +from firebase_admin import _utils from tests import testutils @@ -135,6 +136,7 @@ def _check_request(recorder, want_url, want_body=None, want_timeout=None): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], want_url) + assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() if want_body: body = json.loads(req.body.decode()) assert body == want_body