diff --git a/ansible_rulebook/app.py b/ansible_rulebook/app.py index 735d35e4..6f5b5b82 100644 --- a/ansible_rulebook/app.py +++ b/ansible_rulebook/app.py @@ -28,6 +28,7 @@ split_collection_name, ) from ansible_rulebook.common import StartupArgs +from ansible_rulebook.conf import settings from ansible_rulebook.engine import run_rulesets, start_source from ansible_rulebook.job_template_runner import job_template_runner from ansible_rulebook.rule_types import RuleSet, RuleSetQueue @@ -106,16 +107,16 @@ async def run(parsed_args: argparse.ArgumentParser) -> None: logger.info("Starting rules") + feedback_task = None if parsed_args.websocket_address: - tasks.append( - asyncio.create_task( - send_event_log_to_websocket( - event_log, - parsed_args.websocket_address, - parsed_args.websocket_ssl_verify, - ) + feedback_task = asyncio.create_task( + send_event_log_to_websocket( + event_log, + parsed_args.websocket_address, + parsed_args.websocket_ssl_verify, ) ) + tasks.append(feedback_task) await run_rulesets( event_log, @@ -126,6 +127,12 @@ async def run(parsed_args: argparse.ArgumentParser) -> None: startup_args.project_data_file, ) + await event_log.put(dict(type="Exit")) + if feedback_task: + await asyncio.wait( + [feedback_task], timeout=settings.max_feedback_timeout + ) + logger.info("Cancelling event source tasks") for task in tasks: task.cancel() @@ -140,7 +147,6 @@ async def run(parsed_args: argparse.ArgumentParser) -> None: error_found = True logger.info("Main complete") - await event_log.put(dict(type="Exit")) await job_template_runner.close_session() if error_found: raise Exception("One of the source plugins failed") diff --git a/ansible_rulebook/conf.py b/ansible_rulebook/conf.py index e790da38..2c49e5b2 100644 --- a/ansible_rulebook/conf.py +++ b/ansible_rulebook/conf.py @@ -20,6 +20,7 @@ def __init__(self): self.identifier = str(uuid.uuid4()) self.gc_after = 1000 self.default_execution_strategy = "sequential" + self.max_feedback_timeout = 5 settings = _Settings() diff --git a/ansible_rulebook/websocket.py b/ansible_rulebook/websocket.py index ee830fc7..bbcdcaed 100644 --- a/ansible_rulebook/websocket.py +++ b/ansible_rulebook/websocket.py @@ -88,14 +88,14 @@ async def request_workload( async def send_event_log_to_websocket( event_log, websocket_address, websocket_ssl_verify ): - logger.info("websocket %s connecting", websocket_address) + logger.info("feedback websocket %s connecting", websocket_address) event = None async for websocket in websockets.connect( websocket_address, logger=logger, ssl=_sslcontext(websocket_address, websocket_ssl_verify), ): - logger.info("websocket %s connected", websocket_address) + logger.info("feedback websocket %s connected", websocket_address) try: if event: logger.info("Resending last event...") @@ -104,22 +104,26 @@ async def send_event_log_to_websocket( while True: event = await event_log.get() - await websocket.send(json.dumps(event)) + logger.debug(f"Event received, {event}") - if event == dict(type="Shutdown"): + if event == dict(type="Exit"): + logger.info("Exiting feedback websocket task") return + await websocket.send(json.dumps(event)) event = None except websockets.exceptions.ConnectionClosed: logger.warning( - "websocket %s connection closed, will retry...", + "feedback websocket %s connection closed, will retry...", websocket_address, ) except CancelledError: - logger.info("closing websocket due to task cancelled") + logger.info("closing feedback websocket due to task cancelled") return except BaseException as err: - logger.error("websocket error on %s err: %s", event, str(err)) + logger.error( + "feedback websocket error on %s err: %s", event, str(err) + ) def _sslcontext(url, ssl_verify) -> ssl.SSLContext: diff --git a/tests/test_app.py b/tests/test_app.py index 9c3625ec..fc334f16 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -159,14 +159,15 @@ async def test_run_with_websocket(create_ruleset): controller_ssl_verify="no", check_controller_connection=True, ) - with patch( - "ansible_rulebook.app.job_template_runner.get_config", - return_value=dict(version="4.4.1"), - ): - await run(cmdline_args) - assert mock_start_source.call_count == 1 - assert mock_run_rulesets.call_count == 1 - assert mock_request_workload.call_count == 1 + with patch("ansible_rulebook.app.send_event_log_to_websocket"): + with patch( + "ansible_rulebook.app.job_template_runner.get_config", + return_value=dict(version="4.4.1"), + ): + await run(cmdline_args) + assert mock_start_source.call_count == 1 + assert mock_run_rulesets.call_count == 1 + assert mock_request_workload.call_count == 1 @pytest.mark.asyncio diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 215cf5c6..d5eff4b0 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -106,7 +106,7 @@ async def test_send_event_log_to_websocket(): queue = asyncio.Queue() queue.put_nowait({"a": 1}) queue.put_nowait({"b": 1}) - queue.put_nowait(dict(type="Shutdown")) + queue.put_nowait(dict(type="Exit")) data_sent = [] @@ -121,7 +121,7 @@ def my_func(data): mo.return_value.__aiter__.side_effect = [mock_object] mo.return_value.send.side_effect = my_func await send_event_log_to_websocket(queue, "dummy", "yes") - assert data_sent == ['{"a": 1}', '{"b": 1}', '{"type": "Shutdown"}'] + assert data_sent == ['{"a": 1}', '{"b": 1}'] @pytest.mark.asyncio @@ -132,7 +132,7 @@ async def test_send_event_log_to_websocket_with_exception( queue = asyncio.Queue() queue.put_nowait({"a": 1}) queue.put_nowait({"b": 2}) - queue.put_nowait(dict(type="Shutdown")) + queue.put_nowait(dict(type="Exit")) data_sent = [] @@ -146,8 +146,7 @@ async def test_send_event_log_to_websocket_with_exception( websockets.exceptions.ConnectionClosed(rcvd=None, sent=None), data_sent.append({"a": 1}), data_sent.append({"b": 2}), - data_sent.append({"type": "Shutdown"}), ] await send_event_log_to_websocket(queue, "dummy", "yes") - assert data_sent == [{"a": 1}, {"b": 2}, {"type": "Shutdown"}] + assert data_sent == [{"a": 1}, {"b": 2}]