Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cancel event sources earlier #634

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions ansible_rulebook/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def run(parsed_args: argparse.Namespace) -> None:
event_log = NullQueue()

logger.info("Starting sources")
tasks, ruleset_queues = spawn_sources(
source_tasks, ruleset_queues = spawn_sources(
startup_args.rulesets,
startup_args.variables,
[parsed_args.source_dir],
Expand All @@ -121,10 +121,10 @@ async def run(parsed_args: argparse.Namespace) -> None:
feedback_task = asyncio.create_task(
send_event_log_to_websocket(event_log=event_log)
)
tasks.append(feedback_task)

should_reload = await run_rulesets(
event_log,
source_tasks,
ruleset_queues,
startup_args.variables,
startup_args.inventory,
Expand All @@ -139,10 +139,17 @@ async def run(parsed_args: argparse.Namespace) -> None:
[feedback_task], timeout=settings.max_feedback_timeout
)

tasks = []
logger.info("Cancelling event source tasks")
for task in tasks:
for task in source_tasks:
tasks.append(task)
task.cancel()

if feedback_task:
logger.info("Cancelling feedback task")
tasks.append(feedback_task)
feedback_task.cancel()

error_found = False
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
Expand Down
2 changes: 2 additions & 0 deletions ansible_rulebook/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ async def monitor_rulebook(rulebook_file):

async def run_rulesets(
event_log: asyncio.Queue,
source_tasks: List[asyncio.Task],
ruleset_queues: List[RuleSetQueue],
variables: Dict,
inventory: str = "",
Expand Down Expand Up @@ -301,6 +302,7 @@ async def run_rulesets(
ruleset_runner = RuleSetRunner(
event_log=event_log,
ruleset_queue_plan=ruleset_queue_plan,
source_tasks=source_tasks,
hosts_facts=hosts_facts,
variables=variables,
rule_set=rulesets[ruleset_queue_plan.ruleset.name],
Expand Down
4 changes: 4 additions & 0 deletions ansible_rulebook/rule_set_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
self,
event_log: asyncio.Queue,
ruleset_queue_plan: EngineRuleSetQueuePlan,
source_tasks: List[asyncio.Task],
hosts_facts,
variables,
rule_set,
Expand All @@ -96,6 +97,7 @@ def __init__(
self.action_loop_task = None
self.event_log = event_log
self.ruleset_queue_plan = ruleset_queue_plan
self.source_tasks = source_tasks
self.name = ruleset_queue_plan.ruleset.name
self.rule_set = rule_set
self.hosts_facts = hosts_facts
Expand Down Expand Up @@ -180,6 +182,8 @@ async def _handle_shutdown(self):
self.name,
str(self.shutdown),
)
for task in self.source_tasks:
task.cancel()
if self.shutdown.kind == "now":
logger.debug(
"ruleset: %s has issued an immediate shutdown", self.name
Expand Down
12 changes: 10 additions & 2 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ async def test_run_rulesets():

await run_rulesets(
event_log,
[],
ruleset_queues,
dict(),
"playbooks/inventory.yml",
Expand Down Expand Up @@ -276,6 +277,7 @@ async def test_run_rules_with_assignment():

await run_rulesets(
event_log,
[],
ruleset_queues,
dict(),
dict(),
Expand All @@ -299,6 +301,7 @@ async def test_run_rules_with_assignment2():

await run_rulesets(
event_log,
[],
ruleset_queues,
dict(),
dict(),
Expand All @@ -320,7 +323,7 @@ async def test_run_rules_simple():
queue.put_nowait(Shutdown())

await run_rulesets(
event_log, ruleset_queues, dict(), "playbooks/inventory.yml"
event_log, [], ruleset_queues, dict(), "playbooks/inventory.yml"
)

assert event_log.get_nowait()["type"] == "Action", "0"
Expand Down Expand Up @@ -358,6 +361,7 @@ async def test_run_multiple_hosts():

await run_rulesets(
event_log,
[],
ruleset_queues,
dict(),
"playbooks/inventory1.yml",
Expand Down Expand Up @@ -390,6 +394,7 @@ async def test_run_multiple_hosts2():

await run_rulesets(
event_log,
[],
ruleset_queues,
dict(),
"playbooks/inventory1.yml",
Expand Down Expand Up @@ -417,6 +422,7 @@ async def test_run_multiple_hosts3():

await run_rulesets(
event_log,
[],
ruleset_queues,
dict(),
"playbooks/inventory.yml",
Expand All @@ -438,7 +444,7 @@ async def test_filters():
queue.put_nowait(Shutdown())

await run_rulesets(
event_log, ruleset_queues, dict(), "playbooks/inventory.yml"
event_log, [], ruleset_queues, dict(), "playbooks/inventory.yml"
)

assert event_log.get_nowait()["type"] == "Action", "0"
Expand Down Expand Up @@ -474,6 +480,7 @@ async def test_run_rulesets_on_hosts():

await run_rulesets(
event_log,
[],
ruleset_queues,
dict(),
"playbooks/inventory1.yml",
Expand Down Expand Up @@ -504,6 +511,7 @@ async def test_run_assert_facts():
queue.put_nowait(Shutdown())
await run_rulesets(
event_log,
[],
ruleset_queues,
dict(Naboo="naboo"),
temp.name,
Expand Down
Loading
Loading