From c2739ad0db3fb6c7768c453934de7315a999d260 Mon Sep 17 00:00:00 2001 From: Bill Wei Date: Mon, 8 Jan 2024 14:13:38 -0500 Subject: [PATCH] feat: add websocket authentication using jwt token (#628) The cli expects to receive websocket-access-token, websocket-refresh-token, and websocket-token-address. It does not send the authentication header if above arguments are not provided, so it works with the old eda-server that does not authenticate incomming websocket connecitons. Fixes AAP-17776: ansible-rulebook uses token for authentication --- CHANGELOG.md | 1 + ansible_rulebook/app.py | 18 +-- ansible_rulebook/cli.py | 31 +++- ansible_rulebook/conf.py | 5 + ansible_rulebook/exception.py | 5 + ansible_rulebook/token.py | 50 ++++++ ansible_rulebook/websocket.py | 276 ++++++++++++++++++++++------------ docs/usage.rst | 20 ++- tests/e2e/utils.py | 11 +- tests/test_token.py | 61 ++++++++ tests/test_websocket.py | 17 ++- 11 files changed, 373 insertions(+), 122 deletions(-) create mode 100644 ansible_rulebook/token.py create mode 100644 tests/test_token.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 18a4cd50..9b9b0125 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ - ssl_verify option now also supports "true" or "false" values - Support for standalone boolean in conditions - Add basic auth to controller +- Use token for websocket authentication ### Changed - Generic print as well as printing of events use new banner style diff --git a/ansible_rulebook/app.py b/ansible_rulebook/app.py index a9855330..11a99859 100644 --- a/ansible_rulebook/app.py +++ b/ansible_rulebook/app.py @@ -64,13 +64,9 @@ def qsize(self): async def run(parsed_args: argparse.Namespace) -> None: file_monitor = None - if parsed_args.worker and parsed_args.websocket_address and parsed_args.id: + if parsed_args.worker and parsed_args.websocket_url and parsed_args.id: logger.info("Starting worker mode") - startup_args = await request_workload( - parsed_args.id, - parsed_args.websocket_address, - parsed_args.websocket_ssl_verify, - ) + startup_args = await request_workload(parsed_args.id) if not startup_args: logger.error("Error communicating with web socket server") raise WebSocketExchangeException( @@ -102,7 +98,7 @@ async def run(parsed_args: argparse.Namespace) -> None: if startup_args.check_controller_connection: await validate_controller_params(startup_args) - if parsed_args.websocket_address: + if parsed_args.websocket_url: event_log = asyncio.Queue() else: event_log = NullQueue() @@ -118,13 +114,9 @@ async def run(parsed_args: argparse.Namespace) -> None: logger.info("Starting rules") feedback_task = None - if parsed_args.websocket_address: + if parsed_args.websocket_url: feedback_task = asyncio.create_task( - send_event_log_to_websocket( - event_log, - parsed_args.websocket_address, - parsed_args.websocket_ssl_verify, - ) + send_event_log_to_websocket(event_log=event_log) ) tasks.append(feedback_task) diff --git a/ansible_rulebook/cli.py b/ansible_rulebook/cli.py index c3717580..db16b107 100644 --- a/ansible_rulebook/cli.py +++ b/ansible_rulebook/cli.py @@ -97,16 +97,32 @@ def get_parser() -> argparse.ArgumentParser: ) parser.add_argument( "-W", - "--websocket-address", "--websocket-url", + "--websocket-address", help="Connect the event log to a websocket", + default=os.environ.get("EDA_WEBSOCKET_URL", ""), ) parser.add_argument( "--websocket-ssl-verify", help="How to verify SSL when connecting to the " "websocket: (yes|true) | (no|false) | , " "default to yes for wss connection.", - default="yes", + default=os.environ.get("EDA_WEBSOCKET_SSL_VERIFY", "yes"), + ) + parser.add_argument( + "--websocket-access-token", + help="Token used to autheticate the websocket connection.", + default=os.environ.get("EDA_WEBSOCKET_ACCESS_TOKEN", ""), + ) + parser.add_argument( + "--websocket-refresh-token", + help="Token used to renew a websocket access token.", + default=os.environ.get("EDA_WEBSOCKET_REFRESH_TOKEN", ""), + ) + parser.add_argument( + "--websocket-token-url", + help="Url to renew websocket access token.", + default=os.environ.get("EDA_WEBSOCKET_TOKEN_URL", ""), ) parser.add_argument("--id", help="Identifier") parser.add_argument( @@ -215,10 +231,8 @@ def get_version() -> str: def validate_args(args: argparse.Namespace) -> None: - if args.worker and (not args.id or not args.websocket_address): - raise ValueError( - "Worker mode needs an id and websocket address specfied" - ) + if args.worker and (not args.id or not args.websocket_url): + raise ValueError("Worker mode needs an id and websocket url specfied") if not args.worker and not args.rulebook: raise ValueError("Rulebook must be specified in non worker mode") @@ -255,6 +269,11 @@ def update_settings(args: argparse.Namespace) -> None: settings.default_execution_strategy = args.execution_strategy settings.print_events = args.print_events + settings.websocket_url = args.websocket_url + settings.websocket_ssl_verify = args.websocket_ssl_verify + settings.websocket_token_url = args.websocket_token_url + settings.websocket_access_token = args.websocket_access_token + settings.websocket_refresh_token = args.websocket_refresh_token def main(args: List[str] = None) -> int: diff --git a/ansible_rulebook/conf.py b/ansible_rulebook/conf.py index bfa76f22..15c04110 100644 --- a/ansible_rulebook/conf.py +++ b/ansible_rulebook/conf.py @@ -22,6 +22,11 @@ def __init__(self): self.default_execution_strategy = "sequential" self.max_feedback_timeout = 5 self.print_events = False + self.websocket_url = None + self.websocket_ssl_verify = "yes" + self.websocket_token_url = None + self.websocket_access_token = None + self.websocket_refresh_token = None settings = _Settings() diff --git a/ansible_rulebook/exception.py b/ansible_rulebook/exception.py index 0b6e0da9..9495901c 100644 --- a/ansible_rulebook/exception.py +++ b/ansible_rulebook/exception.py @@ -166,3 +166,8 @@ class InventoryNotFound(Exception): class MissingArtifactKeyException(Exception): pass + + +class TokenNotFound(Exception): + + pass diff --git a/ansible_rulebook/token.py b/ansible_rulebook/token.py new file mode 100644 index 00000000..05417ac2 --- /dev/null +++ b/ansible_rulebook/token.py @@ -0,0 +1,50 @@ +# Copyright 2024 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import ssl +import typing as tp + +import aiohttp + +from ansible_rulebook.conf import settings +from ansible_rulebook.exception import TokenNotFound + +logger = logging.getLogger(__name__) + + +async def renew_token() -> str: + logger.info("Renew websocket token from %s", settings.websocket_token_url) + async with aiohttp.ClientSession() as session: + async with session.post( + settings.websocket_token_url, + data={"refresh": settings.websocket_refresh_token}, + ssl_context=_sslcontext(), + ) as resp: + data = await resp.json() + if "access" not in data: + logger.error(f"Failed to renew token. Error: {str(data)}") + raise TokenNotFound("Response does not contain access token") + return data["access"] + + +def _sslcontext() -> tp.Optional[ssl.SSLContext]: + if settings.websocket_token_url.startswith("https"): + ssl_verify = settings.websocket_ssl_verify.lower() + if ssl_verify in ["yes", "true"]: + return ssl.create_default_context() + if ssl_verify in ["no", "false"]: + return ssl._create_unverified_context() + return ssl.create_default_context(cafile=ssl_verify) + return None diff --git a/ansible_rulebook/websocket.py b/ansible_rulebook/websocket.py index 5cee165f..3d661a84 100644 --- a/ansible_rulebook/websocket.py +++ b/ansible_rulebook/websocket.py @@ -12,126 +12,216 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import base64 import json import logging import os +import random import ssl import tempfile import typing as tp -from asyncio.exceptions import CancelledError +from dataclasses import dataclass, field import websockets import yaml +from websockets.client import WebSocketClientProtocol from ansible_rulebook import rules_parser as rules_parser from ansible_rulebook.common import StartupArgs +from ansible_rulebook.conf import settings +from ansible_rulebook.exception import ShutdownException +from ansible_rulebook.token import renew_token logger = logging.getLogger(__name__) -async def request_workload( - activation_id: str, websocket_address: str, websocket_ssl_verify: str -) -> StartupArgs: - logger.info("websocket %s connecting", websocket_address) - async with websockets.connect( - websocket_address, - ssl=_sslcontext(websocket_address, websocket_ssl_verify), - ) as websocket: +BACKOFF_MIN = 1.92 +BACKOFF_MAX = 60.0 +BACKOFF_FACTOR = 1.618 +BACKOFF_INITIAL = 5 + + +async def _connect_websocket( + handler: tp.Callable[[WebSocketClientProtocol], tp.Awaitable], + retry: bool, + **kwargs: list, +) -> tp.Any: + logger.info("websocket %s connecting", settings.websocket_url) + if settings.websocket_access_token: + extra_headers = { + "Authorization": f"Bearer {settings.websocket_access_token}" + } + else: + extra_headers = {} + + result = None + refresh = True + + while True: + backoff_delay = BACKOFF_MIN try: - logger.info("websocket %s connected", websocket_address) - await websocket.send( - json.dumps(dict(type="Worker", activation_id=activation_id)) + async with websockets.connect( + settings.websocket_url, + ssl=_sslcontext(), + extra_headers=extra_headers, + ) as websocket: + result = await handler(websocket, **kwargs) + if not retry: + break + except asyncio.CancelledError: # pragma: no cover + raise + except ShutdownException: + break + except Exception as e: + status403_legacy = ( + isinstance(e, websockets.exceptions.InvalidStatusCode) + and e.status_code == 403 ) + status403 = ( + isinstance(e, websockets.exceptions.InvalidStatus) + and e.response.status_code == 403 + ) + if status403_legacy or status403: + if refresh and settings.websocket_refresh_token: + new_token = await renew_token() + extra_headers["Authorization"] = f"Bearer {new_token}" + # Only attempt to refresh token once. If a new token cannot + # establish the connection, something else must cause 403 + refresh = False + else: + raise + elif isinstance(e, OSError) and "[Errno 61]" in str(e): + # Sleep and retry implemention duplicated from + # websockets.lagacy.client.Client - project_data_fh = None - response = StartupArgs() - while True: - msg = await websocket.recv() - data = json.loads(msg) - if data.get("type") == "EndOfResponse": - break - if data.get("type") == "ProjectData": - if not project_data_fh: - ( - project_data_fh, - response.project_data_file, - ) = tempfile.mkstemp() - - if data.get("data") and data.get("more"): - os.write( - project_data_fh, base64.b64decode(data.get("data")) - ) - if not data.get("data") and not data.get("more"): - os.close(project_data_fh) - logger.debug("wrote %s", response.project_data_file) - if data.get("type") == "Rulebook": - response.rulesets = rules_parser.parse_rule_sets( - yaml.safe_load(base64.b64decode(data.get("data"))) + # Add a random initial delay between 0 and 5 seconds. + # See 7.2.3. Recovering from Abnormal Closure in RFC 6544. + if backoff_delay == BACKOFF_MIN: + initial_delay = random.random() * BACKOFF_INITIAL + logger.info( + "! connect failed; reconnecting in %.1f seconds", + initial_delay, + exc_info=True, ) - if data.get("type") == "ExtraVars": - response.variables = yaml.safe_load( - base64.b64decode(data.get("data")) + await asyncio.sleep(initial_delay) + else: + logger.info( + "! connect failed again; retrying in %d seconds", + int(backoff_delay), + exc_info=True, ) - if data.get("type") == "ControllerInfo": - response.controller_url = data.get("url") - response.controller_token = data.get("token") - response.controller_ssl_verify = data.get("ssl_verify") - return response - except CancelledError: - logger.info("closing websocket due to task cancelled") - return - except websockets.exceptions.ConnectionClosed: - logger.info("websocket %s closed", websocket_address) - return - - -async def send_event_log_to_websocket( - event_log, websocket_address, websocket_ssl_verify -): - logger.info("feedback websocket %s connecting", websocket_address) - event = None - async for websocket in websockets.connect( - websocket_address, - logger=logger, - ssl=_sslcontext(websocket_address, websocket_ssl_verify), - ): - logger.info("feedback websocket %s connected", websocket_address) - try: - if event: - logger.info("Resending last event...") - await websocket.send(json.dumps(event)) - event = None - - while True: - event = await event_log.get() - logger.debug(f"Event received, {event}") - - if event == dict(type="Exit"): - logger.info("Exiting feedback websocket task") - return - - await websocket.send(json.dumps(event)) - event = None - except websockets.exceptions.ConnectionClosed: - logger.warning( - "feedback websocket %s connection closed, will retry...", - websocket_address, + await asyncio.sleep(int(backoff_delay)) + # Increase delay with truncated exponential backoff. + backoff_delay = backoff_delay * BACKOFF_FACTOR + backoff_delay = min(backoff_delay, BACKOFF_MAX) + continue + else: + # Connection succeeded - reset backoff delay + backoff_delay = BACKOFF_MIN + refresh = True + + return result + + +async def request_workload(activation_id: str) -> StartupArgs: + return await _connect_websocket( + handler=_handle_request_workload, + retry=False, + activation_id=activation_id, + ) + + +async def _handle_request_workload( + websocket: WebSocketClientProtocol, + activation_id: str, +) -> StartupArgs: + logger.info("workload websocket connected") + await websocket.send( + json.dumps(dict(type="Worker", activation_id=activation_id)) + ) + + project_data_fh = None + response = StartupArgs() + while True: + msg = await websocket.recv() + data = json.loads(msg) + if data.get("type") == "EndOfResponse": + break + if data.get("type") == "ProjectData": + if not project_data_fh: + ( + project_data_fh, + response.project_data_file, + ) = tempfile.mkstemp() + + if data.get("data") and data.get("more"): + os.write(project_data_fh, base64.b64decode(data.get("data"))) + if not data.get("data") and not data.get("more"): + os.close(project_data_fh) + logger.debug("wrote %s", response.project_data_file) + if data.get("type") == "Rulebook": + response.rulesets = rules_parser.parse_rule_sets( + yaml.safe_load(base64.b64decode(data.get("data"))) ) - except CancelledError: - logger.info("closing feedback websocket due to task cancelled") - return - except BaseException as err: - logger.error( - "feedback websocket error on %s err: %s", event, str(err) + if data.get("type") == "ExtraVars": + response.variables = yaml.safe_load( + base64.b64decode(data.get("data")) ) + if data.get("type") == "ControllerInfo": + response.controller_url = data.get("url") + response.controller_token = data.get("token") + response.controller_ssl_verify = data.get("ssl_verify") + return response + + +@dataclass +class EventLogQueue: + queue: asyncio.Queue = field(default=None) + event: dict = field(default=None) + + +async def send_event_log_to_websocket(event_log: asyncio.Queue): + logs = EventLogQueue() + logs.queue = event_log + + return await _connect_websocket( + handler=_handle_send_event_log, + retry=True, + logs=logs, + ) + + +async def _handle_send_event_log( + websocket: WebSocketClientProtocol, + logs: EventLogQueue, +): + logger.info("feedback websocket connected") + + if logs.event: + logger.info("Resending last event...") + await websocket.send(json.dumps(logs.event)) + logs.event = None + + while True: + event = await logs.queue.get() + logger.debug(f"Event received, {event}") + + if event == dict(type="Exit"): + logger.info("Exiting feedback websocket task") + raise ShutdownException(shutdown=None) + + logs.event = event + await websocket.send(json.dumps(event)) + logs.event = None -def _sslcontext(url: str, ssl_verify: str) -> tp.Optional[ssl.SSLContext]: - if url.startswith("wss"): - if ssl_verify.lower() in ["yes", "true"]: +def _sslcontext() -> tp.Optional[ssl.SSLContext]: + if settings.websocket_url.startswith("wss"): + ssl_verify = settings.websocket_ssl_verify.lower() + if ssl_verify in ["yes", "true"]: return ssl.create_default_context() - if ssl_verify.lower() in ["no", "false"]: + if ssl_verify in ["no", "false"]: return ssl._create_unverified_context() return ssl.create_default_context(cafile=ssl_verify) return None diff --git a/docs/usage.rst b/docs/usage.rst index 3457a994..32db0a5e 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -6,8 +6,10 @@ The `ansible-rulebook` CLI supports the following options: .. code-block:: console - usage: ansible-rulebook [-h] [-r RULEBOOK] [-e VARS] [-E ENV_VARS] [-v] [--version] [-S SOURCE_DIR] [-i INVENTORY] [-W WEBSOCKET_URL] [--id ID] [-w] [-T PROJECT_TARBALL] [--controller-url CONTROLLER_URL] - [--controller-token CONTROLLER_TOKEN] [--controller-ssl-verify CONTROLLER_SSL_VERIFY] [--print-events] [--heartbeat n] [--execution-strategy sequential|parallel] + usage: ansible-rulebook [-h] [-r RULEBOOK] [-e VARS] [-E ENV_VARS] [-v] [--version] [-S SOURCE_DIR] [-i INVENTORY] [-W WEBSOCKET_URL] [--websocket-ssl-verify WEBSOCKET_SSL_VERIFY] + [--websocket-token-url WEBSOCKET_TOKEN_URL] [--websocket-access-token WEBSOCKET_ACCESS_TOKEN] [--websocket-refresh-token WEBSOCKET_REFRESH_TOKEN] + [--id ID] [-w] [-T PROJECT_TARBALL] [--controller-url CONTROLLER_URL] [--controller-token CONTROLLER_TOKEN] [--controller-ssl-verify CONTROLLER_SSL_VERIFY] + [--print-events] [--heartbeat n] [--execution-strategy sequential|parallel] optional arguments: -h, --help show this help message and exit @@ -22,11 +24,17 @@ The `ansible-rulebook` CLI supports the following options: Source dir -i INVENTORY, --inventory INVENTORY Inventory can be a file or a directory - -W WEBSOCKET_URL, --websocket-url WEBSOCKET_ADDRESS + -W WEBSOCKET_URL, --websocket-url WEBSOCKET_URL Connect the event log to a websocket - --websocket-ssl-verify How to verify the wss connection + --websocket-ssl-verify WEBSOCKET_SSL_VERIFY How to verify SSL when connecting to the websocket api. yes|no|, default to yes for wss connection. Connect the event log to a websocket + --websocket-token-url WEBSOCKET_TOKEN_URL + Fetch a renewed token to authenticate websocket connection + --websocket-access-token WEBSOCKET_ACCESS_TOKEN + Initial token used to authenticate websocket connection + --websocket-refresh-token WEBSOCKET_REFRESH_TOKEN + A token needed to renew an authentication token --id ID Identifier -w, --worker Enable worker mode -T PROJECT_TARBALL, --project-tarball PROJECT_TARBALL @@ -74,9 +82,9 @@ If you are using custom event source plugins use the following: .. note:: Here `sources` is a directory containing your event source plugins. -To run `ansible-rulebook` with worker mode enabled the `--worker` option can be used. The `--id`, and `--websocket-address` options can also be used to expose the event stream data:: +To run `ansible-rulebook` with worker mode enabled the `--worker` option can be used. The `--id`, and `--websocket-url` options can also be used to expose the event stream data:: - ansible-rulebook --rulebook rules.yml --inventory inventory.yml --websocket-address "ws://localhost:8080/api/ws2" --id 1 --worker + ansible-rulebook --rulebook rules.yml --inventory inventory.yml --websocket-url "ws://localhost:8080/api/ws2" --id 1 --worker .. note:: The `id` is the `activation_instance` id which allows the results to be communicated back to the websocket. diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 93eaf669..0b28ade2 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -33,6 +33,9 @@ class Command: verbose: bool = False debug: bool = False websocket: Optional[str] = None + token_url: Optional[str] = None + access_token: Optional[str] = None + refresh_token: Optional[str] = None project_tarball: Optional[Path] = None worker_mode: bool = False verbosity: int = 0 @@ -66,7 +69,13 @@ def to_list(self) -> List: if self.proc_id: result.extend(["--id", str(self.proc_id)]) if self.websocket: - result.extend(["--websocket-address", self.websocket]) + result.extend(["--websocket-url", self.websocket]) + if self.access_token: + result.extend(["--websocket-access-token", self.access_token]) + if self.refresh_token: + result.extend(["--websocket-refresh-token", self.refresh_token]) + if self.token_url: + result.extend(["--websocket-token-url", self.token_url]) if self.project_tarball: result.extend( ["--project-tarball", str(self.project_tarball.absolute())] diff --git a/tests/test_token.py b/tests/test_token.py new file mode 100644 index 00000000..8606566f --- /dev/null +++ b/tests/test_token.py @@ -0,0 +1,61 @@ +# Copyright 2024 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest + +from ansible_rulebook import token +from ansible_rulebook.conf import settings +from ansible_rulebook.exception import TokenNotFound + + +class MockResponse: + def __init__(self, data): + self.data = data + + async def json(self): + return self.data + + async def __aexit__(self, exc_type, exc, tb): + pass + + async def __aenter__(self): + return self + + +def prepare_settings() -> None: + settings.websocket_token_url = "https://dummy.org/xyz" + settings.websocket_access_token = "dummy" + settings.websocket_refresh_token = "dummy" + + +@pytest.mark.asyncio +async def test_renew_token(): + prepare_settings() + with patch("ansible_rulebook.token.aiohttp.ClientSession.post") as mock: + data = {"access": "new_token"} + mock.return_value = MockResponse(data) + renewed = await token.renew_token() + assert renewed == "new_token" + + +@pytest.mark.asyncio +async def test_renew_invalid_token(): + prepare_settings() + with patch("ansible_rulebook.token.aiohttp.ClientSession.post") as mock: + data = {"error": "invalid_token"} + mock.return_value = MockResponse(data) + with pytest.raises(TokenNotFound): + await token.renew_token() diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 9218ac96..90ac47ff 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -10,12 +10,20 @@ import pytest import websockets +from ansible_rulebook.conf import settings from ansible_rulebook.websocket import ( request_workload, send_event_log_to_websocket, ) +def prepare_settings() -> None: + settings.websocket_url = "wss://dummy.org/ws" + settings.websocket_token_url = "https://dummy.org/token" + settings.websocket_access_token = "dummy" + settings.websocket_refresh_token = "dummy" + + def file_sha256(filename: str) -> str: sha256_hash = hashlib.sha256() with open(filename, "rb") as f: @@ -64,6 +72,7 @@ def load_file( @pytest.mark.asyncio async def test_request_workload(): + prepare_settings() os.chdir(HERE) controller_url = "https://www.example.com" controller_token = "abc" @@ -91,7 +100,7 @@ async def test_request_workload(): mo.return_value.__aenter__.return_value.recv.side_effect = test_data mo.return_value.__aenter__.return_value.send.return_value = None - response = await request_workload("dummy", "dummy", "yes") + response = await request_workload("dummy") sha2 = file_sha256(response.project_data_file) assert sha1 == sha2 assert response.controller_url == controller_url @@ -103,6 +112,7 @@ async def test_request_workload(): @pytest.mark.asyncio async def test_send_event_log_to_websocket(): + prepare_settings() queue = asyncio.Queue() queue.put_nowait({"a": 1}) queue.put_nowait({"b": 1}) @@ -120,7 +130,7 @@ def my_func(data): mo.return_value.__anext__.return_value = mock_object mo.return_value.__aiter__.side_effect = [mock_object] mo.return_value.send.side_effect = my_func - await send_event_log_to_websocket(queue, "dummy", "yes") + await send_event_log_to_websocket(queue) assert data_sent == ['{"a": 1}', '{"b": 1}'] @@ -129,6 +139,7 @@ def my_func(data): async def test_send_event_log_to_websocket_with_exception( socket_mock: AsyncMock, ): + prepare_settings() queue = asyncio.Queue() queue.put_nowait({"a": 1}) queue.put_nowait({"b": 2}) @@ -148,5 +159,5 @@ async def test_send_event_log_to_websocket_with_exception( data_sent.append({"b": 2}), ] - await send_event_log_to_websocket(queue, "dummy", "yes") + await send_event_log_to_websocket(queue) assert data_sent == [{"a": 1}, {"b": 2}]