Skip to content

Commit

Permalink
fix: wait for the feedback task to end (#562)
Browse files Browse the repository at this point in the history
[AAP-13874](https://issues.redhat.com/browse/AAP-13874)

ansible-rulebook uses a websocket channel to send feedback events for
reporting purposes. The feedback task runs asynchronously and might get
shutdown/cancelled before it can post all of the events. This PR ensures
that the feedback task gets sufficient time to post all of the events
back to the server
  • Loading branch information
mkanoor authored Aug 3, 2023
1 parent 6c3887a commit 1ff8fca
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 28 deletions.
22 changes: 14 additions & 8 deletions ansible_rulebook/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
split_collection_name,
)
from ansible_rulebook.common import StartupArgs
from ansible_rulebook.conf import settings
from ansible_rulebook.engine import run_rulesets, start_source
from ansible_rulebook.job_template_runner import job_template_runner
from ansible_rulebook.rule_types import RuleSet, RuleSetQueue
Expand Down Expand Up @@ -106,16 +107,16 @@ async def run(parsed_args: argparse.ArgumentParser) -> None:

logger.info("Starting rules")

feedback_task = None
if parsed_args.websocket_address:
tasks.append(
asyncio.create_task(
send_event_log_to_websocket(
event_log,
parsed_args.websocket_address,
parsed_args.websocket_ssl_verify,
)
feedback_task = asyncio.create_task(
send_event_log_to_websocket(
event_log,
parsed_args.websocket_address,
parsed_args.websocket_ssl_verify,
)
)
tasks.append(feedback_task)

await run_rulesets(
event_log,
Expand All @@ -126,6 +127,12 @@ async def run(parsed_args: argparse.ArgumentParser) -> None:
startup_args.project_data_file,
)

await event_log.put(dict(type="Exit"))
if feedback_task:
await asyncio.wait(
[feedback_task], timeout=settings.max_feedback_timeout
)

logger.info("Cancelling event source tasks")
for task in tasks:
task.cancel()
Expand All @@ -140,7 +147,6 @@ async def run(parsed_args: argparse.ArgumentParser) -> None:
error_found = True

logger.info("Main complete")
await event_log.put(dict(type="Exit"))
await job_template_runner.close_session()
if error_found:
raise Exception("One of the source plugins failed")
Expand Down
1 change: 1 addition & 0 deletions ansible_rulebook/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self):
self.identifier = str(uuid.uuid4())
self.gc_after = 1000
self.default_execution_strategy = "sequential"
self.max_feedback_timeout = 5


settings = _Settings()
18 changes: 11 additions & 7 deletions ansible_rulebook/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ async def request_workload(
async def send_event_log_to_websocket(
event_log, websocket_address, websocket_ssl_verify
):
logger.info("websocket %s connecting", websocket_address)
logger.info("feedback websocket %s connecting", websocket_address)
event = None
async for websocket in websockets.connect(
websocket_address,
logger=logger,
ssl=_sslcontext(websocket_address, websocket_ssl_verify),
):
logger.info("websocket %s connected", websocket_address)
logger.info("feedback websocket %s connected", websocket_address)
try:
if event:
logger.info("Resending last event...")
Expand All @@ -104,22 +104,26 @@ async def send_event_log_to_websocket(

while True:
event = await event_log.get()
await websocket.send(json.dumps(event))
logger.debug(f"Event received, {event}")

if event == dict(type="Shutdown"):
if event == dict(type="Exit"):
logger.info("Exiting feedback websocket task")
return

await websocket.send(json.dumps(event))
event = None
except websockets.exceptions.ConnectionClosed:
logger.warning(
"websocket %s connection closed, will retry...",
"feedback websocket %s connection closed, will retry...",
websocket_address,
)
except CancelledError:
logger.info("closing websocket due to task cancelled")
logger.info("closing feedback websocket due to task cancelled")
return
except BaseException as err:
logger.error("websocket error on %s err: %s", event, str(err))
logger.error(
"feedback websocket error on %s err: %s", event, str(err)
)


def _sslcontext(url, ssl_verify) -> ssl.SSLContext:
Expand Down
17 changes: 9 additions & 8 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,15 @@ async def test_run_with_websocket(create_ruleset):
controller_ssl_verify="no",
check_controller_connection=True,
)
with patch(
"ansible_rulebook.app.job_template_runner.get_config",
return_value=dict(version="4.4.1"),
):
await run(cmdline_args)
assert mock_start_source.call_count == 1
assert mock_run_rulesets.call_count == 1
assert mock_request_workload.call_count == 1
with patch("ansible_rulebook.app.send_event_log_to_websocket"):
with patch(
"ansible_rulebook.app.job_template_runner.get_config",
return_value=dict(version="4.4.1"),
):
await run(cmdline_args)
assert mock_start_source.call_count == 1
assert mock_run_rulesets.call_count == 1
assert mock_request_workload.call_count == 1


@pytest.mark.asyncio
Expand Down
9 changes: 4 additions & 5 deletions tests/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def test_send_event_log_to_websocket():
queue = asyncio.Queue()
queue.put_nowait({"a": 1})
queue.put_nowait({"b": 1})
queue.put_nowait(dict(type="Shutdown"))
queue.put_nowait(dict(type="Exit"))

data_sent = []

Expand All @@ -121,7 +121,7 @@ def my_func(data):
mo.return_value.__aiter__.side_effect = [mock_object]
mo.return_value.send.side_effect = my_func
await send_event_log_to_websocket(queue, "dummy", "yes")
assert data_sent == ['{"a": 1}', '{"b": 1}', '{"type": "Shutdown"}']
assert data_sent == ['{"a": 1}', '{"b": 1}']


@pytest.mark.asyncio
Expand All @@ -132,7 +132,7 @@ async def test_send_event_log_to_websocket_with_exception(
queue = asyncio.Queue()
queue.put_nowait({"a": 1})
queue.put_nowait({"b": 2})
queue.put_nowait(dict(type="Shutdown"))
queue.put_nowait(dict(type="Exit"))

data_sent = []

Expand All @@ -146,8 +146,7 @@ async def test_send_event_log_to_websocket_with_exception(
websockets.exceptions.ConnectionClosed(rcvd=None, sent=None),
data_sent.append({"a": 1}),
data_sent.append({"b": 2}),
data_sent.append({"type": "Shutdown"}),
]

await send_event_log_to_websocket(queue, "dummy", "yes")
assert data_sent == [{"a": 1}, {"b": 2}, {"type": "Shutdown"}]
assert data_sent == [{"a": 1}, {"b": 2}]

0 comments on commit 1ff8fca

Please sign in to comment.