diff --git a/ansible_rulebook/app.py b/ansible_rulebook/app.py index 1971af06..444b585e 100644 --- a/ansible_rulebook/app.py +++ b/ansible_rulebook/app.py @@ -107,7 +107,7 @@ async def run(parsed_args: argparse.Namespace) -> None: event_log = NullQueue() logger.info("Starting sources") - tasks, ruleset_queues = spawn_sources( + source_tasks, ruleset_queues = spawn_sources( startup_args.rulesets, startup_args.variables, [parsed_args.source_dir], @@ -121,10 +121,10 @@ async def run(parsed_args: argparse.Namespace) -> None: feedback_task = asyncio.create_task( send_event_log_to_websocket(event_log=event_log) ) - tasks.append(feedback_task) should_reload = await run_rulesets( event_log, + source_tasks, ruleset_queues, startup_args.variables, startup_args.inventory, @@ -139,10 +139,17 @@ async def run(parsed_args: argparse.Namespace) -> None: [feedback_task], timeout=settings.max_feedback_timeout ) + tasks = [] logger.info("Cancelling event source tasks") - for task in tasks: + for task in source_tasks: + tasks.append(task) task.cancel() + if feedback_task: + logger.info("Cancelling feedback task") + tasks.append(feedback_task) + feedback_task.cancel() + error_found = False results = await asyncio.gather(*tasks, return_exceptions=True) for result in results: diff --git a/ansible_rulebook/engine.py b/ansible_rulebook/engine.py index 00bff598..affbf312 100644 --- a/ansible_rulebook/engine.py +++ b/ansible_rulebook/engine.py @@ -257,6 +257,7 @@ async def monitor_rulebook(rulebook_file): async def run_rulesets( event_log: asyncio.Queue, + source_tasks: List[asyncio.Task], ruleset_queues: List[RuleSetQueue], variables: Dict, inventory: str = "", @@ -301,6 +302,7 @@ async def run_rulesets( ruleset_runner = RuleSetRunner( event_log=event_log, ruleset_queue_plan=ruleset_queue_plan, + source_tasks=source_tasks, hosts_facts=hosts_facts, variables=variables, rule_set=rulesets[ruleset_queue_plan.ruleset.name], diff --git a/ansible_rulebook/rule_set_runner.py b/ansible_rulebook/rule_set_runner.py index b1dbad20..f07d45e4 100644 --- a/ansible_rulebook/rule_set_runner.py +++ b/ansible_rulebook/rule_set_runner.py @@ -86,6 +86,7 @@ def __init__( self, event_log: asyncio.Queue, ruleset_queue_plan: EngineRuleSetQueuePlan, + source_tasks: List[asyncio.Task], hosts_facts, variables, rule_set, @@ -96,6 +97,7 @@ def __init__( self.action_loop_task = None self.event_log = event_log self.ruleset_queue_plan = ruleset_queue_plan + self.source_tasks = source_tasks self.name = ruleset_queue_plan.ruleset.name self.rule_set = rule_set self.hosts_facts = hosts_facts @@ -180,6 +182,8 @@ async def _handle_shutdown(self): self.name, str(self.shutdown), ) + for task in self.source_tasks: + task.cancel() if self.shutdown.kind == "now": logger.debug( "ruleset: %s has issued an immediate shutdown", self.name diff --git a/tests/test_engine.py b/tests/test_engine.py index d5459ebf..489f9899 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -248,6 +248,7 @@ async def test_run_rulesets(): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -276,6 +277,7 @@ async def test_run_rules_with_assignment(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -299,6 +301,7 @@ async def test_run_rules_with_assignment2(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -320,7 +323,7 @@ async def test_run_rules_simple(): queue.put_nowait(Shutdown()) await run_rulesets( - event_log, ruleset_queues, dict(), "playbooks/inventory.yml" + event_log, [], ruleset_queues, dict(), "playbooks/inventory.yml" ) assert event_log.get_nowait()["type"] == "Action", "0" @@ -358,6 +361,7 @@ async def test_run_multiple_hosts(): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory1.yml", @@ -390,6 +394,7 @@ async def test_run_multiple_hosts2(): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory1.yml", @@ -417,6 +422,7 @@ async def test_run_multiple_hosts3(): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -438,7 +444,7 @@ async def test_filters(): queue.put_nowait(Shutdown()) await run_rulesets( - event_log, ruleset_queues, dict(), "playbooks/inventory.yml" + event_log, [], ruleset_queues, dict(), "playbooks/inventory.yml" ) assert event_log.get_nowait()["type"] == "Action", "0" @@ -474,6 +480,7 @@ async def test_run_rulesets_on_hosts(): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory1.yml", @@ -504,6 +511,7 @@ async def test_run_assert_facts(): queue.put_nowait(Shutdown()) await run_rulesets( event_log, + [], ruleset_queues, dict(Naboo="naboo"), temp.name, diff --git a/tests/test_examples.py b/tests/test_examples.py index 88d6822a..9e048fa8 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -63,6 +63,7 @@ async def test_01_noop(): with patch("uuid.uuid4", return_value=DUMMY_UUID): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -92,6 +93,7 @@ async def test_02_debug(): with patch("uuid.uuid4", return_value=DUMMY_UUID): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -121,6 +123,7 @@ async def test_03_print_event(): with patch("uuid.uuid4", return_value=DUMMY_UUID): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -150,6 +153,7 @@ async def test_04_set_fact(): with patch("uuid.uuid4", return_value=DUMMY_UUID): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -182,6 +186,7 @@ async def test_05_post_event(): with patch("uuid.uuid4", return_value=DUMMY_UUID): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -214,6 +219,7 @@ async def test_06_retract_fact(): with patch("uuid.uuid4", return_value=DUMMY_UUID): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -248,6 +254,7 @@ async def test_07_and(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -272,6 +279,7 @@ async def test_08_or(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -296,6 +304,7 @@ async def test_09_gt(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -320,6 +329,7 @@ async def test_10_lt(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -344,6 +354,7 @@ async def test_11_le(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -368,6 +379,7 @@ async def test_12_ge(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -392,6 +404,7 @@ async def test_13_add(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -416,6 +429,7 @@ async def test_14_sub(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -443,6 +457,7 @@ async def test_15_multiple_events_all(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -472,6 +487,7 @@ async def test_16_multiple_events_any(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -499,6 +515,7 @@ async def test_17_multiple_sources_any(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -530,6 +547,7 @@ async def test_18_multiple_sources_all(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -558,6 +576,7 @@ async def test_19_is_defined(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -588,6 +607,7 @@ async def test_20_is_not_defined(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -635,6 +655,7 @@ async def test_21_run_playbook(rule, ansible_events): with patch("uuid.uuid4", return_value=DUMMY_UUID): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -673,6 +694,7 @@ async def test_23_nested_data(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -699,6 +721,7 @@ async def test_24_max_attributes(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -726,6 +749,7 @@ async def test_25_max_attributes_nested(): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -750,6 +774,7 @@ async def test_26_print_events(): with patch("uuid.uuid4", return_value=DUMMY_UUID): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -775,6 +800,7 @@ async def test_27_var_root(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -809,6 +835,7 @@ async def test_28_right_side_condition_template(): await run_rulesets( event_log, + [], ruleset_queues, {"custom": {"expected_index": 2}}, dict(), @@ -836,6 +863,7 @@ async def test_29_run_module(): with patch("uuid.uuid4", return_value=DUMMY_UUID): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -877,6 +905,7 @@ async def test_30_run_module_missing(): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -910,6 +939,7 @@ async def test_31_run_module_missing_args(): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -943,6 +973,7 @@ async def test_32_run_module_fail(): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -981,6 +1012,7 @@ async def test_35_multiple_rulesets_1_fired(): await run_rulesets( event_log, + [], ruleset_queues, dict(), ) @@ -1013,6 +1045,7 @@ async def test_36_multiple_rulesets_both_fired(): ): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1038,6 +1071,7 @@ async def test_37_hosts_facts(): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1061,6 +1095,7 @@ async def test_38_shutdown_action(): with patch("uuid.uuid4", return_value=DUMMY_UUID): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1089,6 +1124,7 @@ async def test_40_in(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1111,6 +1147,7 @@ async def test_41_not_in(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1133,6 +1170,7 @@ async def test_42_contains(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1156,6 +1194,7 @@ async def test_43_not_contains(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1178,6 +1217,7 @@ async def test_44_in_and(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1201,6 +1241,7 @@ async def test_45_in_or(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1226,6 +1267,7 @@ async def test_47_generic_plugin(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1250,6 +1292,7 @@ async def test_48_echo(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1282,6 +1325,7 @@ async def test_49_float(): with patch("uuid.uuid4", return_value=DUMMY_UUID): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1327,6 +1371,7 @@ async def test_50_negation(): with patch("uuid.uuid4", return_value=DUMMY_UUID): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1391,6 +1436,7 @@ async def test_51_vars_namespace(): await run_rulesets( event_log, + [], ruleset_queues, person, "playbooks/inventory.yml", @@ -1435,6 +1481,7 @@ async def test_51_vars_namespace_missing_key(): with pytest.raises(VarsKeyMissingException) as exc_info: await run_rulesets( event_log, + [], ruleset_queues, person, "playbooks/inventory.yml", @@ -1451,6 +1498,7 @@ async def test_52_once_within(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1478,6 +1526,7 @@ async def test_53_once_within_multiple_hosts(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1505,6 +1554,7 @@ async def test_54_time_window(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1541,6 +1591,7 @@ async def test_55_not_all(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1568,6 +1619,7 @@ async def test_56_once_after(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1596,6 +1648,7 @@ async def test_57_once_after_multiple(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1616,6 +1669,7 @@ async def test_58_string_search(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1646,6 +1700,7 @@ async def test_59_multiple_actions(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1673,6 +1728,7 @@ async def test_60_json_filter(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1697,6 +1753,7 @@ async def test_61_select_1(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1721,6 +1778,7 @@ async def test_62_select_2(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1747,6 +1805,7 @@ async def test_63_selectattr_1(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1772,6 +1831,7 @@ async def test_64_selectattr_2(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1796,6 +1856,7 @@ async def test_65_selectattr_3(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1828,6 +1889,7 @@ async def test_66_sleepy_playbook(): ): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1860,6 +1922,7 @@ async def test_67_shutdown_now(): ): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1886,6 +1949,7 @@ async def test_68_disabled_rule(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1910,6 +1974,7 @@ async def test_69_enhanced_debug(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1937,6 +2002,7 @@ async def test_70_null(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -1965,6 +2031,7 @@ async def test_72_set_fact_with_type(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(my_bool=True, my_int=2, my_float=3.123), "playbooks/inventory.yml", @@ -1996,6 +2063,7 @@ async def test_73_mix_and_match_list(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -2026,6 +2094,7 @@ async def test_74_self_referential(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -2050,6 +2119,7 @@ async def test_75_all_conditions(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -2074,6 +2144,7 @@ async def test_76_all_conditions(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -2109,6 +2180,7 @@ async def test_46_job_template(): ): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -2145,6 +2217,7 @@ async def test_46_job_template_exception(err_msg, err): ): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -2185,6 +2258,7 @@ async def test_77_default_events_ttl(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), "playbooks/inventory.yml", @@ -2213,6 +2287,7 @@ async def test_78_complete_retract_fact(): with patch("uuid.uuid4", return_value=DUMMY_UUID): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -2265,6 +2340,7 @@ async def test_79_workflow_job_template_exception(err_msg, err): ): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -2316,6 +2392,7 @@ async def test_79_workflow_job_template(): ): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -2341,6 +2418,7 @@ async def test_80_match_multiple_rules(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -2368,6 +2446,7 @@ async def test_81_match_single_rule(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -2393,6 +2472,7 @@ async def test_82_non_alpha_keys(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(), @@ -2419,6 +2499,7 @@ async def test_83_boolean_true(): with SourceTask(rs.sources[0], "sources", {}, queue): await run_rulesets( event_log, + [], ruleset_queues, dict(), dict(),