diff --git a/ansible_rulebook/action/__init__.py b/ansible_rulebook/action/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ansible_rulebook/action/base_action.py b/ansible_rulebook/action/base_action.py new file mode 100644 index 000000000..6b3d4f7eb --- /dev/null +++ b/ansible_rulebook/action/base_action.py @@ -0,0 +1,89 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import uuid +from typing import Dict + +from ansible_rulebook.conf import settings +from ansible_rulebook.event_filter.insert_meta_info import main as insert_meta +from ansible_rulebook.util import run_at + +from .control import Control +from .metadata import Metadata + +KEY_EDA_VARS = "ansible_eda" +INTERNAL_ACTION_STATUS = "successful" + + +class BaseAction: + def __init__(self, metadata: Metadata, control: Control, action: str): + self.metadata = metadata + self.control = control + self.uuid = str(uuid.uuid4()) + self.action = action + + async def send(self, data: dict, objtype: str = "Action") -> None: + payload = { + "type": objtype, + "action": self.action, + "action_uuid": self.uuid, + "ruleset": self.metadata.rule_set, + "ruleset_uuid": self.metadata.rule_set_uuid, + "rule": self.metadata.rule, + "rule_uuid": self.metadata.rule_uuid, + "rule_run_at": self.metadata.rule_run_at, + "activation_id": settings.identifier, + } + payload.update(data) + await self.control.event_log.put(payload) + + async def send_default_status(self): + await self.send( + { + "run_at": run_at(), + "status": INTERNAL_ACTION_STATUS, + "matching_events": self._get_events(), + } + ) + + def _get_events(self) -> Dict: + if "event" in self.control.variables: + return {"m": self.control.variables["event"]} + if "events" in self.control.variables: + return self.control.variables["events"] + return {} + + def _embellish_internal_event(self, event: Dict) -> Dict: + return insert_meta( + event, **{"source_name": self.action, "source_type": "internal"} + ) + + def set_action(self, action): + self.action = action + + def _collect_extra_vars(self, user_extra_vars: dict) -> dict: + extra_vars = user_extra_vars.copy() if user_extra_vars else {} + + eda_vars = { + "ruleset": self.metadata.rule_set, + "rule": self.metadata.rule, + } + if "events" in self.control.variables: + eda_vars["events"] = self.control.variables["events"] + if "event" in self.control.variables: + eda_vars["event"] = self.control.variables["event"] + + extra_vars[KEY_EDA_VARS] = eda_vars + return extra_vars diff --git a/ansible_rulebook/action/control.py b/ansible_rulebook/action/control.py new file mode 100644 index 000000000..d162c3394 --- /dev/null +++ b/ansible_rulebook/action/control.py @@ -0,0 +1,33 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from dataclasses import dataclass +from typing import List + + +@dataclass(frozen=True) +class Control: + __slots__ = [ + "event_log", + "inventory", + "hosts", + "variables", + "project_data_file", + ] + event_log: asyncio.Queue + inventory: str + hosts: List[str] + variables: dict + project_data_file: str diff --git a/ansible_rulebook/action/debug.py b/ansible_rulebook/action/debug.py new file mode 100644 index 000000000..d8e1cc17e --- /dev/null +++ b/ansible_rulebook/action/debug.py @@ -0,0 +1,71 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +from dataclasses import asdict +from pprint import pprint + +import dpath +from drools import ruleset as lang + +from ansible_rulebook.util import get_horizontal_rule + +from .base_action import BaseAction +from .control import Control +from .metadata import Metadata + +logger = logging.getLogger(__name__) + + +class Debug(BaseAction): + def __init__(self, metadata: Metadata, control: Control, **action_args): + super().__init__(metadata, control, "debug") + self.action_args = action_args + + async def __call__(self): + + if "msg" in self.action_args: + messages = self.action_args.get("msg") + if not isinstance(messages, list): + messages = [messages] + for msg in messages: + print(msg) + elif "var" in self.action_args: + key = self.action_args.get("var") + try: + print(dpath.get(self.control.variables, key, separator=".")) + except KeyError: + logger.error("Key %s not found in variable pool", key) + raise + else: + print(get_horizontal_rule("=")) + print("kwargs:") + args = asdict(self.metadata) + args.update( + { + "inventory": self.control.inventory, + "hosts": self.control.hosts, + "variables": self.control.variables, + "project_data_file": self.control.project_data_file, + } + ) + pprint(args) + print(get_horizontal_rule("=")) + print("facts:") + pprint(lang.get_facts(self.metadata.rule_set)) + print(get_horizontal_rule("=")) + + sys.stdout.flush() + await self.send_default_status() diff --git a/ansible_rulebook/action/metadata.py b/ansible_rulebook/action/metadata.py new file mode 100644 index 000000000..0222ab695 --- /dev/null +++ b/ansible_rulebook/action/metadata.py @@ -0,0 +1,31 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Metadata: + __slots__ = [ + "rule", + "rule_uuid", + "rule_set", + "rule_set_uuid", + "rule_run_at", + ] + rule: str + rule_uuid: str + rule_set: str + rule_set_uuid: str + rule_run_at: str diff --git a/ansible_rulebook/action/noop.py b/ansible_rulebook/action/noop.py new file mode 100644 index 000000000..24a5b6a45 --- /dev/null +++ b/ansible_rulebook/action/noop.py @@ -0,0 +1,30 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from .base_action import BaseAction +from .control import Control +from .metadata import Metadata + +logger = logging.getLogger(__name__) + + +class Noop(BaseAction): + def __init__(self, metadata: Metadata, control: Control, **action_args): + super().__init__(metadata, control, "noop") + self.action_args = action_args + + async def __call__(self): + await self.send_default_status() diff --git a/ansible_rulebook/action/post_event.py b/ansible_rulebook/action/post_event.py new file mode 100644 index 000000000..e2a4c980a --- /dev/null +++ b/ansible_rulebook/action/post_event.py @@ -0,0 +1,36 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from drools import ruleset as lang + +from .base_action import BaseAction +from .control import Control +from .metadata import Metadata + +logger = logging.getLogger(__name__) + + +class PostEvent(BaseAction): + def __init__(self, metadata: Metadata, control: Control, **action_args): + super().__init__(metadata, control, "post_event") + self.action_args = action_args + + async def __call__(self): + lang.post( + self.action_args["ruleset"], + self._embellish_internal_event(self.action_args["event"]), + ) + await self.send_default_status() diff --git a/ansible_rulebook/action/print_event.py b/ansible_rulebook/action/print_event.py new file mode 100644 index 000000000..105f87f75 --- /dev/null +++ b/ansible_rulebook/action/print_event.py @@ -0,0 +1,39 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pprint import pprint +from typing import Callable + +from .base_action import BaseAction +from .control import Control +from .metadata import Metadata + + +class PrintEvent(BaseAction): + def __init__(self, metadata: Metadata, control: Control, **action_args): + super().__init__(metadata, control, "print_event") + + self.action_args = action_args + + async def __call__(self): + print_fn: Callable = print + if "pretty" in self.action_args: + print_fn = pprint + + var_name = "events" if "events" in self.control.variables else "event" + + print_fn(self.control.variables[var_name]) + sys.stdout.flush() + await self.send_default_status() diff --git a/ansible_rulebook/action/retract_fact.py b/ansible_rulebook/action/retract_fact.py new file mode 100644 index 000000000..f95edb881 --- /dev/null +++ b/ansible_rulebook/action/retract_fact.py @@ -0,0 +1,44 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from drools import ruleset as lang + +from .base_action import BaseAction +from .control import Control +from .metadata import Metadata + +logger = logging.getLogger(__name__) + + +class RetractFact(BaseAction): + def __init__(self, metadata: Metadata, control: Control, **action_args): + super().__init__(metadata, control, "retract_fact") + self.action_args = action_args + + async def __call__(self): + partial = self.action_args.get("partial", True) + if not partial: + exclude_keys = ["meta"] + else: + exclude_keys = [] + + lang.retract_matching_facts( + self.action_args["ruleset"], + self.action_args["fact"], + partial, + exclude_keys, + ) + await self.send_default_status() diff --git a/ansible_rulebook/action/run_job_template.py b/ansible_rulebook/action/run_job_template.py new file mode 100644 index 000000000..75c52cfbc --- /dev/null +++ b/ansible_rulebook/action/run_job_template.py @@ -0,0 +1,147 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import uuid + +from drools import ruleset as lang + +from ansible_rulebook.exception import ( + ControllerApiException, + JobTemplateNotFoundException, +) +from ansible_rulebook.job_template_runner import job_template_runner +from ansible_rulebook.util import run_at + +from .base_action import BaseAction +from .control import Control +from .metadata import Metadata + +logger = logging.getLogger(__name__) + + +class RunJobTemplate(BaseAction): + def __init__(self, metadata: Metadata, control: Control, **action_args): + super().__init__(metadata, control, "run_job_template") + self.action_args = action_args + self.name = self.action_args["name"] + self.organization = self.action_args["organization"] + self.job_id = str(uuid.uuid4()) + hosts_limit = ",".join(self.control.hosts) + self.job_args = self.action_args.get("job_args", {}) + self.job_args["limit"] = hosts_limit + self.controller_job = {} + + async def __call__(self): + logger.info( + "running job template: %s, organization: %s", + self.name, + self.organization, + ) + logger.info( + "ruleset: %s, rule %s", self.metadata.rule_set, self.metadata.rule + ) + + self.job_args["extra_vars"] = self._collect_extra_vars( + self.job_args.get("extra_vars", {}) + ) + await self._job_start_event() + await self._run() + + async def _run(self): + retries = self.action_args.get("retries", 0) + if self.action_args.get("retry", False): + retries = max(self.action_args.get("retries", 0), 1) + delay = self.action_args.get("delay", 0) + + try: + for i in range(retries + 1): + if i > 0: + if delay > 0: + await asyncio.sleep(delay) + logger.info( + "Previous run_job_template failed. Retry %d of %d", + i, + retries, + ) + controller_job = await job_template_runner.run_job_template( + self.name, + self.organization, + self.job_args, + ) + if controller_job["status"] != "failed": + break + except (ControllerApiException, JobTemplateNotFoundException) as ex: + logger.error(ex) + controller_job = {} + controller_job["status"] = "failed" + controller_job["created"] = run_at() + controller_job["error"] = str(ex) + + self.controller_job = controller_job + await self._post_process() + + async def _post_process(self) -> None: + a_log = { + "job_template_name": self.name, + "organization": self.organization, + "job_id": self.job_id, + "status": self.controller_job["status"], + "run_at": self.controller_job["created"], + "url": self._controller_job_url(), + "matching_events": self._get_events(), + } + if "error" in self.controller_job: + a_log["message"] = self.controller_job["error"] + a_log["reason"] = {"error": self.controller_job["error"]} + + await self.send(a_log) + set_facts = self.action_args.get("set_facts", False) + post_events = self.action_args.get("post_events", False) + + if set_facts or post_events: + ruleset = self.action_args.get("ruleset", self.metadata.rule_set) + logger.debug("set_facts") + facts = self.controller_job.get("artifacts", {}) + if facts: + facts = self._embellish_internal_event(facts) + logger.debug("facts %s", facts) + if set_facts: + lang.assert_fact(ruleset, facts) + if post_events: + lang.post(ruleset, facts) + else: + logger.debug("Empty facts are not set") + + async def _job_start_event(self): + await self.send( + { + "run_at": run_at(), + "matching_events": self._get_events(), + "action": self.action, + "hosts": ",".join(self.control.hosts), + "name": self.name, + "job_id": self.job_id, + }, + objtype="Job", + ) + + def _controller_job_url(self) -> str: + if "id" in self.controller_job: + return ( + f"{job_template_runner.host}/#/jobs/" + f"{self.controller_job['id']}/details" + ) + return "" diff --git a/ansible_rulebook/action/run_module.py b/ansible_rulebook/action/run_module.py new file mode 100644 index 000000000..6a688df02 --- /dev/null +++ b/ansible_rulebook/action/run_module.py @@ -0,0 +1,44 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from .control import Control +from .metadata import Metadata +from .run_playbook import RunPlaybook + +logger = logging.getLogger(__name__) + + +class RunModule(RunPlaybook): + def __init__(self, metadata: Metadata, control: Control, **action_args): + super().__init__(metadata, control, **action_args) + self.set_action("run_module") + + def _runner_args(self): + module_args_str = "" + module_args = self.action_args.get("module_args", {}) + for key, value in module_args.items(): + if len(module_args_str) > 0: + module_args_str += " " + module_args_str += f"{key}={value!r}" + + return { + "module": self.name, + "host_pattern": ",".join(self.control.hosts), + "module_args": module_args_str, + } + + def _copy_playbook_files(self, project_dir): + pass diff --git a/ansible_rulebook/action/run_playbook.py b/ansible_rulebook/action/run_playbook.py new file mode 100644 index 000000000..bc02c3127 --- /dev/null +++ b/ansible_rulebook/action/run_playbook.py @@ -0,0 +1,306 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import concurrent.futures +import glob +import json +import logging +import os +import shutil +import tempfile +import uuid +from asyncio.exceptions import CancelledError +from functools import partial + +import ansible_runner +import janus +import yaml +from drools import ruleset as lang + +from ansible_rulebook.collection import ( + find_playbook, + has_playbook, + split_collection_name, +) +from ansible_rulebook.conf import settings +from ansible_rulebook.exception import ( + PlaybookNotFoundException, + PlaybookStatusNotFoundException, +) +from ansible_rulebook.util import run_at + +from .base_action import BaseAction +from .control import Control +from .metadata import Metadata + +logger = logging.getLogger(__name__) + +tar = shutil.which("tar") + + +class RunPlaybook(BaseAction): + def __init__(self, metadata: Metadata, control: Control, **action_args): + super().__init__(metadata, control, "run_playbook") + self.action_args = action_args + self.job_id = str(uuid.uuid4()) + self.default_copy_files = True + self.default_check_files = True + self.name = self.action_args["name"] + self.verbosity = self.action_args.get("verbosity", 0) + self.json_mode = self.action_args.get("json_mode", False) + self.private_data_dir = None + + async def __call__(self): + try: + logger.info( + f"ruleset: {self.metadata.rule_set}, " + f"rule: {self.metadata.rule}" + ) + self._create_private_dir() + logger.debug("private data dir %s", self.private_data_dir) + await self._pre_process() + await self._job_start_event() + logger.info("Calling Ansible runner") + await self._run() + finally: + if os.path.exists(self.private_data_dir): + shutil.rmtree(self.private_data_dir) + + def _create_private_dir(self): + self.private_data_dir = tempfile.mkdtemp(prefix=self.action) + + async def _job_start_event(self): + await self.send( + { + "run_at": run_at(), + "matching_events": self._get_events(), + "action": self.action, + "hosts": ",".join(self.control.hosts), + "name": self.name, + "job_id": self.job_id, + "ansible_rulebook_id": settings.identifier, + }, + objtype="Job", + ) + + async def _run(self): + retries = self.action_args.get("retries", 0) + if self.action_args.get("retry", False): + retries = max(self.action_args.get("retries", 0), 1) + + delay = self.action_args.get("delay", 0) + + for i in range(retries + 1): + if i > 0: + if delay > 0: + await asyncio.sleep(delay) + logger.info( + "Previous run_playbook failed. Retry %d of %d", i, retries + ) + + await self._call_runner() + if self._get_latest_artifact("status") != "failed": + break + + await self._post_process() + + def _runner_args(self): + return {"playbook": self.name} + + async def _pre_process(self) -> None: + playbook_extra_vars = self._collect_extra_vars( + self.action_args.get("extra_vars", {}) + ) + + env_dir = os.path.join(self.private_data_dir, "env") + inventory_dir = os.path.join(self.private_data_dir, "inventory") + project_dir = os.path.join(self.private_data_dir, "project") + + os.mkdir(env_dir) + with open(os.path.join(env_dir, "extravars"), "w") as file_handle: + file_handle.write(yaml.dump(playbook_extra_vars)) + os.mkdir(inventory_dir) + with open(os.path.join(inventory_dir, "hosts"), "w") as file_handle: + file_handle.write(self.control.inventory) + os.mkdir(project_dir) + + logger.debug("project_data_file: %s", self.control.project_data_file) + if self.control.project_data_file: + if os.path.exists(self.control.project_data_file): + await self._untar_project( + project_dir, self.control.project_data_file + ) + return + self._copy_playbook_files(project_dir) + + def _copy_playbook_files(self, project_dir): + if self.action_args.get("check_files", self.default_check_files): + if os.path.exists(self.name): + tail_name = os.path.basename(self.name) + shutil.copy(self.name, os.path.join(project_dir, tail_name)) + if self.action_args.get("copy_files", self.default_copy_files): + shutil.copytree( + os.path.dirname(os.path.abspath(self.name)), + project_dir, + dirs_exist_ok=True, + ) + self.name = tail_name + elif has_playbook(*split_collection_name(self.name)): + shutil.copy( + find_playbook(*split_collection_name(self.name)), + os.path.join(project_dir, self.name), + ) + else: + msg = ( + f"Could not find a playbook for {self.name} " + f"from {os.getcwd()}" + ) + logger.error(msg) + raise PlaybookNotFoundException(msg) + + async def _post_process(self): + rc = int(self._get_latest_artifact("rc")) + status = self._get_latest_artifact("status") + logger.info("Ansible runner rc: %d, status: %s", rc, status) + if rc != 0: + error_message = self._get_latest_artifact("stderr") + if not error_message: + error_message = self._get_latest_artifact("stdout") + logger.error(error_message) + + await self.send( + { + "playbook_name": self.name, + "job_id": self.job_id, + "rc": rc, + "status": status, + "run_at": run_at(), + "matching_events": self._get_events(), + } + ) + set_facts = self.action_args.get("set_facts", False) + post_events = self.action_args.get("post_events", False) + + if rc == 0 and (set_facts or post_events): + logger.debug("set_facts") + fact_folder = self._get_latest_artifact("fact_cache", False) + ruleset = self.action_args.get("ruleset", self.metadata.rule_set) + for host_facts in glob.glob(os.path.join(fact_folder, "*")): + with open(host_facts) as file_handle: + fact = json.loads(file_handle.read()) + fact = self._embellish_internal_event(fact) + logger.debug("fact %s", fact) + if set_facts: + lang.assert_fact(ruleset, fact) + if post_events: + lang.post(ruleset, fact) + + def _get_latest_artifact(self, component: str, content: bool = True): + files = glob.glob( + os.path.join(self.private_data_dir, "artifacts", "*", component) + ) + files.sort(key=os.path.getmtime, reverse=True) + if not files: + raise PlaybookStatusNotFoundException(f"No {component} file found") + if content: + with open(files[0], "r") as file_handle: + content = file_handle.read() + return content + return files[0] + + async def _untar_project(self, output_dir, project_data_file): + + cmd = [tar, "zxvf", project_data_file] + proc = await asyncio.create_subprocess_exec( + *cmd, + cwd=output_dir, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + stdout, stderr = await proc.communicate() + + if stdout: + logger.debug(stdout.decode()) + if stderr: + logger.debug(stderr.decode()) + + async def _call_runner(self): + host_limit = ",".join(self.control.hosts) + shutdown = False + + loop = asyncio.get_running_loop() + + queue = janus.Queue() + + # The event_callback is called from the ansible-runner thread + # It needs a thread-safe synchronous queue. + # Janus provides a sync queue connected to an async queue + # Here we push the event into the sync side of janus + def event_callback(event, *_args, **_kwargs): + event["job_id"] = self.job_id + event["ansible_rulebook_id"] = settings.identifier + queue.sync_q.put({"type": "AnsibleEvent", "event": event}) + + # Here we read the async side and push it into the event queue + # which is also async. + # We do this until cancelled at the end of the ansible runner call. + # We might need to drain the queue here before ending. + async def read_queue(): + try: + while True: + val = await queue.async_q.get() + event_data = val.get("event", {}) + val["run_at"] = event_data.get("created") + await self.send(val) + except CancelledError: + logger.info("Ansible runner Queue task cancelled") + + def cancel_callback(): + return shutdown + + tasks = [] + + tasks.append(asyncio.create_task(read_queue())) + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as task_pool: + try: + await loop.run_in_executor( + task_pool, + partial( + ansible_runner.run, + private_data_dir=self.private_data_dir, + limit=host_limit, + verbosity=self.verbosity, + event_handler=event_callback, + cancel_callback=cancel_callback, + json_mode=self.json_mode, + **self._runner_args(), + ), + ) + except CancelledError: + logger.debug( + "Ansible Runner Thread Pool executor task cancelled" + ) + shutdown = True + raise + finally: + # Cancel the queue reading task + for task in tasks: + if not task.done(): + logger.debug("Cancel Queue reading task") + task.cancel() + + await asyncio.gather(*tasks) diff --git a/ansible_rulebook/action/set_fact.py b/ansible_rulebook/action/set_fact.py new file mode 100644 index 000000000..1ae4c2859 --- /dev/null +++ b/ansible_rulebook/action/set_fact.py @@ -0,0 +1,41 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from drools import ruleset as lang + +from .base_action import BaseAction +from .control import Control +from .metadata import Metadata + +logger = logging.getLogger(__name__) + + +class SetFact(BaseAction): + def __init__(self, metadata: Metadata, control: Control, **action_args): + super().__init__(metadata, control, "set_fact") + self.action_args = action_args + + async def __call__(self): + logger.debug( + "set_fact %s %s", + self.action_args["ruleset"], + self.action_args["fact"], + ) + lang.assert_fact( + self.action_args["ruleset"], + self._embellish_internal_event(self.action_args["fact"]), + ) + await self.send_default_status() diff --git a/ansible_rulebook/action/shutdown.py b/ansible_rulebook/action/shutdown.py new file mode 100644 index 000000000..85c01b015 --- /dev/null +++ b/ansible_rulebook/action/shutdown.py @@ -0,0 +1,57 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ansible_rulebook.exception import ShutdownException +from ansible_rulebook.messages import Shutdown as ShutdownMessage +from ansible_rulebook.util import run_at + +from .base_action import INTERNAL_ACTION_STATUS, BaseAction +from .control import Control +from .metadata import Metadata + + +class Shutdown(BaseAction): + def __init__(self, metadata: Metadata, control: Control, **action_args): + super().__init__(metadata, control, "shutdown") + self.action_args = action_args + + async def __call__(self): + delay = self.action_args.get("delay", 60.0) + message = self.action_args.get("message", "Default shutdown message") + kind = self.action_args.get("kind", "graceful") + + await self.send( + { + "run_at": run_at(), + "status": INTERNAL_ACTION_STATUS, + "matching_events": self._get_events(), + "delay": delay, + "message": message, + "kind": kind, + } + ) + print( + "Ruleset: %s rule: %s has initiated shutdown of type: %s. " + "Delay: %.3f seconds, Message: %s" + % ( + self.metadata.rule_set, + self.metadata.rule, + kind, + delay, + message, + ) + ) + raise ShutdownException( + ShutdownMessage(message=message, delay=delay, kind=kind) + ) diff --git a/ansible_rulebook/builtin.py b/ansible_rulebook/builtin.py deleted file mode 100644 index e25fc706c..000000000 --- a/ansible_rulebook/builtin.py +++ /dev/null @@ -1,945 +0,0 @@ -# Copyright 2022 Red Hat, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import concurrent.futures -import glob -import json -import logging -import os -import shutil -import sys -import tempfile -import uuid -from asyncio.exceptions import CancelledError -from functools import partial -from pprint import pprint -from typing import Callable, Dict, List, Optional, Union - -import ansible_runner -import dpath -import janus -import yaml -from drools import ruleset as lang - -from .collection import find_playbook, has_playbook, split_collection_name -from .conf import settings -from .event_filter.insert_meta_info import main as insert_meta -from .exception import ( - ControllerApiException, - JobTemplateNotFoundException, - PlaybookNotFoundException, - PlaybookStatusNotFoundException, - ShutdownException, -) -from .job_template_runner import job_template_runner -from .messages import Shutdown -from .util import get_horizontal_rule, run_at - -logger = logging.getLogger(__name__) - -tar = shutil.which("tar") - -KEY_EDA_VARS = "ansible_eda" -INTERNAL_ACTION_STATUS = "successful" - - -async def none( - event_log, - inventory: str, - hosts: List, - variables: Dict, - project_data_file: str, - source_ruleset_name: str, - source_ruleset_uuid: str, - source_rule_name: str, - source_rule_uuid: str, - rule_run_at: str, - ruleset: str, -): - await event_log.put( - dict( - type="Action", - action="noop", - action_uuid=str(uuid.uuid4()), - ruleset=source_ruleset_name, - ruleset_uuid=source_ruleset_uuid, - rule=source_rule_name, - rule_uuid=source_rule_uuid, - activation_id=settings.identifier, - run_at=run_at(), - status=INTERNAL_ACTION_STATUS, - matching_events=_get_events(variables), - rule_run_at=rule_run_at, - ) - ) - - -async def debug(event_log, **kwargs): - if "msg" in kwargs: - messages = kwargs.get("msg") - if not isinstance(messages, list): - messages = [messages] - for msg in messages: - print(msg) - elif "var" in kwargs: - key = kwargs.get("var") - try: - print(dpath.get(kwargs.get("variables"), key, separator=".")) - except KeyError: - logger.error("Key %s not found in variable pool", key) - return - else: - print(get_horizontal_rule("=")) - print("kwargs:") - pprint(kwargs) - print(get_horizontal_rule("=")) - print("facts:") - pprint(lang.get_facts(kwargs["source_ruleset_name"])) - print(get_horizontal_rule("=")) - sys.stdout.flush() - await event_log.put( - dict( - type="Action", - action="debug", - action_uuid=str(uuid.uuid4()), - playbook_name=kwargs.get("name"), - ruleset=kwargs.get("source_ruleset_name"), - ruleset_uuid=kwargs.get("source_ruleset_uuid"), - rule=kwargs.get("source_rule_name"), - rule_uuid=kwargs.get("source_rule_uuid"), - rule_run_at=kwargs.get("rule_run_at"), - activation_id=settings.identifier, - run_at=run_at(), - status=INTERNAL_ACTION_STATUS, - matching_events=_get_events(kwargs.get("variables")), - ) - ) - - -async def print_event( - event_log, - inventory: str, - hosts: List, - variables: Dict, - project_data_file: str, - source_ruleset_name: str, - source_ruleset_uuid: str, - source_rule_name: str, - source_rule_uuid: str, - rule_run_at: str, - ruleset: str, - name: Optional[str] = None, - pretty: Optional[str] = None, -): - print_fn: Callable = print - if pretty: - print_fn = pprint - - var_name = "events" if "events" in variables else "event" - - print_fn(variables[var_name]) - sys.stdout.flush() - await event_log.put( - dict( - type="Action", - action="print_event", - action_uuid=str(uuid.uuid4()), - activation_id=settings.identifier, - ruleset=source_ruleset_name, - ruleset_uuid=source_ruleset_uuid, - rule=source_rule_name, - rule_uuid=source_rule_uuid, - playbook_name=name, - run_at=run_at(), - status=INTERNAL_ACTION_STATUS, - matching_events=_get_events(variables), - rule_run_at=rule_run_at, - ) - ) - - -async def set_fact( - event_log, - inventory: str, - hosts: List, - variables: Dict, - project_data_file: str, - source_ruleset_name: str, - source_ruleset_uuid: str, - source_rule_name: str, - source_rule_uuid: str, - rule_run_at: str, - ruleset: str, - fact: Dict, - name: Optional[str] = None, -): - logger.debug("set_fact %s %s", ruleset, fact) - lang.assert_fact(ruleset, _embellish_internal_event(fact, "set_fact")) - await event_log.put( - dict( - type="Action", - action="set_fact", - action_uuid=str(uuid.uuid4()), - activation_id=settings.identifier, - ruleset=source_ruleset_name, - ruleset_uuid=source_ruleset_uuid, - rule=source_rule_name, - rule_uuid=source_rule_uuid, - playbook_name=name, - run_at=run_at(), - status=INTERNAL_ACTION_STATUS, - matching_events=_get_events(variables), - rule_run_at=rule_run_at, - ) - ) - - -async def retract_fact( - event_log, - inventory: str, - hosts: List, - variables: Dict, - project_data_file: str, - source_ruleset_name: str, - source_ruleset_uuid: str, - source_rule_name: str, - source_rule_uuid: str, - rule_run_at: str, - ruleset: str, - fact: Dict, - partial: bool = True, - name: Optional[str] = None, -): - - if not partial: - exclude_keys = ["meta"] - else: - exclude_keys = [] - - lang.retract_matching_facts(ruleset, fact, partial, exclude_keys) - await event_log.put( - dict( - type="Action", - action="retract_fact", - action_uuid=str(uuid.uuid4()), - ruleset=source_ruleset_name, - ruleset_uuid=source_ruleset_uuid, - rule=source_rule_name, - rule_uuid=source_rule_uuid, - activation_id=settings.identifier, - playbook_name=name, - run_at=run_at(), - status=INTERNAL_ACTION_STATUS, - matching_events=_get_events(variables), - rule_run_at=rule_run_at, - ) - ) - - -async def post_event( - event_log, - inventory: str, - hosts: List, - variables: Dict, - project_data_file: str, - source_ruleset_name: str, - source_ruleset_uuid: str, - source_rule_name: str, - source_rule_uuid: str, - rule_run_at: str, - ruleset: str, - event: Dict, -): - lang.post(ruleset, _embellish_internal_event(event, "post_event")) - - await event_log.put( - dict( - type="Action", - action="post_event", - action_uuid=str(uuid.uuid4()), - ruleset=source_ruleset_name, - ruleset_uuid=source_ruleset_uuid, - rule=source_rule_name, - rule_uuid=source_rule_uuid, - activation_id=settings.identifier, - run_at=run_at(), - status=INTERNAL_ACTION_STATUS, - matching_events=_get_events(variables), - rule_run_at=rule_run_at, - ) - ) - - -async def run_playbook( - event_log, - inventory: str, - hosts: List, - variables: Dict, - project_data_file: str, - source_ruleset_name: str, - source_ruleset_uuid: str, - source_rule_name: str, - source_rule_uuid: str, - rule_run_at: str, - ruleset: str, - name: str, - set_facts: Optional[bool] = None, - post_events: Optional[bool] = None, - verbosity: int = 0, - copy_files: Optional[bool] = False, - json_mode: Optional[bool] = False, - retries: Optional[int] = 0, - retry: Optional[bool] = False, - delay: Optional[int] = 0, - extra_vars: Optional[Dict] = None, - **kwargs, -): - - logger.info("running Ansible playbook: %s", name) - temp_dir, playbook_name = await pre_process_runner( - event_log, - inventory, - variables, - source_ruleset_name, - source_rule_name, - name, - "run_playbook", - copy_files, - True, - project_data_file, - extra_vars, - **kwargs, - ) - - job_id = str(uuid.uuid4()) - - logger.info(f"ruleset: {source_ruleset_name}, rule: {source_rule_name}") - await event_log.put( - dict( - type="Job", - job_id=job_id, - ansible_rulebook_id=settings.identifier, - name=playbook_name, - ruleset=source_ruleset_name, - ruleset_uuid=source_ruleset_uuid, - rule=source_rule_name, - rule_uuid=source_rule_uuid, - hosts=",".join(hosts), - action="run_playbook", - ) - ) - - logger.info("Calling Ansible runner") - - if retry: - retries = max(retries, 1) - for i in range(retries + 1): - if i > 0: - if delay > 0: - await asyncio.sleep(delay) - logger.info( - "Previous run_playbook failed. Retry %d of %d", i, retries - ) - - action_run_at = run_at() - await call_runner( - event_log, - job_id, - temp_dir, - dict(playbook=playbook_name), - hosts, - verbosity, - json_mode, - ) - if _get_latest_artifact(temp_dir, "status") != "failed": - break - - await post_process_runner( - event_log, - variables, - temp_dir, - ruleset, - source_ruleset_uuid, - source_rule_name, - source_rule_uuid, - rule_run_at, - settings.identifier, - name, - "run_playbook", - job_id, - action_run_at, - set_facts, - post_events, - ) - - shutil.rmtree(temp_dir) - - -async def run_module( - event_log, - inventory: str, - hosts: List, - variables: Dict, - project_data_file: str, - source_ruleset_name: str, - source_ruleset_uuid: str, - source_rule_name: str, - source_rule_uuid: str, - rule_run_at: str, - ruleset: str, - name: str, - set_facts: Optional[bool] = None, - post_events: Optional[bool] = None, - verbosity: int = 0, - copy_files: Optional[bool] = False, - json_mode: Optional[bool] = False, - module_args: Union[Dict, None] = None, - retries: Optional[int] = 0, - retry: Optional[bool] = False, - delay: Optional[int] = 0, - extra_vars: Optional[Dict] = None, - **kwargs, -): - temp_dir, module_name = await pre_process_runner( - event_log, - inventory, - variables, - source_ruleset_name, - source_rule_name, - name, - "run_module", - copy_files, - False, - project_data_file, - extra_vars, - **kwargs, - ) - job_id = str(uuid.uuid4()) - - await event_log.put( - dict( - type="Job", - job_id=job_id, - ansible_rulebook_id=settings.identifier, - name=module_name, - ruleset=source_ruleset_name, - ruleset_uuid=source_ruleset_uuid, - rule=source_rule_name, - rule_uuid=source_rule_uuid, - hosts=",".join(hosts), - action="run_module", - ) - ) - - logger.info("Calling Ansible runner") - module_args_str = "" - if module_args: - for k, v in module_args.items(): - if len(module_args_str) > 0: - module_args_str += " " - module_args_str += f"{k}={v!r}" - - if retry: - retries = max(retries, 1) - for i in range(retries + 1): - if i > 0: - if delay > 0: - await asyncio.sleep(delay) - logger.info( - "Previous run_module failed. Retry %d of %d", i, retries - ) - action_run_at = run_at() - await call_runner( - event_log, - job_id, - temp_dir, - dict( - module=module_name, - host_pattern=",".join(hosts), - module_args=module_args_str, - ), - hosts, - verbosity, - json_mode, - ) - if _get_latest_artifact(temp_dir, "status") != "failed": - break - - await post_process_runner( - event_log, - variables, - temp_dir, - ruleset, - source_ruleset_uuid, - source_rule_name, - source_rule_uuid, - rule_run_at, - settings.identifier, - name, - "run_module", - job_id, - action_run_at, - set_facts, - post_events, - ) - shutil.rmtree(temp_dir) - - -async def call_runner( - event_log, - job_id: str, - private_data_dir: str, - runner_args: Dict, - hosts: List, - verbosity: int = 0, - json_mode: Optional[bool] = False, -): - - host_limit = ",".join(hosts) - shutdown = False - - loop = asyncio.get_running_loop() - - queue = janus.Queue() - - # The event_callback is called from the ansible-runner thread - # It needs a thread-safe synchronous queue. - # Janus provides a sync queue connected to an async queue - # Here we push the event into the sync side of janus - def event_callback(event, *args, **kwargs): - event["job_id"] = job_id - event["ansible_rulebook_id"] = settings.identifier - queue.sync_q.put(dict(type="AnsibleEvent", event=event)) - - # Here we read the async side and push it into the event queue - # which is also async. - # We do this until cancelled at the end of the ansible runner call. - # We might need to drain the queue here before ending. - async def read_queue(): - try: - while True: - val = await queue.async_q.get() - event_data = val.get("event", {}) - val["run_at"] = event_data.get("created") - await event_log.put(val) - except CancelledError: - logger.info("Ansible Runner Queue task cancelled") - - def cancel_callback(): - return shutdown - - tasks = [] - - tasks.append(asyncio.create_task(read_queue())) - - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as task_pool: - try: - await loop.run_in_executor( - task_pool, - partial( - ansible_runner.run, - private_data_dir=private_data_dir, - limit=host_limit, - verbosity=verbosity, - event_handler=event_callback, - cancel_callback=cancel_callback, - json_mode=json_mode, - **runner_args, - ), - ) - except CancelledError: - logger.debug("Ansible Runner Thread Pool executor task cancelled") - shutdown = True - raise - finally: - # Cancel the queue reading task - for task in tasks: - if not task.done(): - logger.debug("Cancel Queue reading task") - task.cancel() - - await asyncio.gather(*tasks) - - -async def untar_project(output_dir, project_data_file): - - cmd = [tar, "zxvf", project_data_file] - proc = await asyncio.create_subprocess_exec( - *cmd, - cwd=output_dir, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - - stdout, stderr = await proc.communicate() - - if stdout: - logger.debug(stdout.decode()) - if stderr: - logger.debug(stderr.decode()) - - -async def pre_process_runner( - event_log, - inventory: str, - variables: Dict, - ruleset: str, - rulename: str, - name: str, - action: str, - copy_files: Optional[bool] = False, - check_files: Optional[bool] = True, - project_data_file: Optional[str] = None, - extra_vars: Optional[Dict] = None, - **kwargs, -): - - private_data_dir = tempfile.mkdtemp(prefix=action) - logger.debug("private data dir %s", private_data_dir) - - playbook_extra_vars = _collect_extra_vars( - variables, extra_vars, ruleset, rulename - ) - - env_dir = os.path.join(private_data_dir, "env") - inventory_dir = os.path.join(private_data_dir, "inventory") - project_dir = os.path.join(private_data_dir, "project") - - playbook_name = name - - os.mkdir(env_dir) - with open(os.path.join(env_dir, "extravars"), "w") as f: - f.write(yaml.dump(playbook_extra_vars)) - os.mkdir(inventory_dir) - with open(os.path.join(inventory_dir, "hosts"), "w") as f: - f.write(inventory) - os.mkdir(project_dir) - - logger.debug("project_data_file: %s", project_data_file) - if project_data_file: - if os.path.exists(project_data_file): - await untar_project(project_dir, project_data_file) - return (private_data_dir, playbook_name) - - if check_files: - if os.path.exists(name): - playbook_name = os.path.basename(name) - shutil.copy(name, os.path.join(project_dir, playbook_name)) - if copy_files: - shutil.copytree( - os.path.dirname(os.path.abspath(name)), - project_dir, - dirs_exist_ok=True, - ) - elif has_playbook(*split_collection_name(name)): - playbook_name = name - shutil.copy( - find_playbook(*split_collection_name(name)), - os.path.join(project_dir, name), - ) - else: - msg = f"Could not find a playbook for {name} from {os.getcwd()}" - logger.error(msg) - raise PlaybookNotFoundException(msg) - - return (private_data_dir, playbook_name) - - -async def post_process_runner( - event_log, - variables: Dict, - private_data_dir: str, - ruleset: str, - ruleset_uuid: str, - rule: str, - rule_uuid: str, - rule_run_at: str, - activation_id: str, - name: str, - action: str, - job_id: str, - run_at: str, - set_facts: Optional[bool] = None, - post_events: Optional[bool] = None, -): - - rc = int(_get_latest_artifact(private_data_dir, "rc")) - status = _get_latest_artifact(private_data_dir, "status") - logger.info("Playbook rc: %d, status: %s", rc, status) - if rc != 0: - error_message = _get_latest_artifact(private_data_dir, "stderr") - if not error_message: - error_message = _get_latest_artifact(private_data_dir, "stdout") - logger.error(error_message) - - result = dict( - type="Action", - action=action, - action_uuid=str(uuid.uuid4()), - activation_id=activation_id, - playbook_name=name, - job_id=job_id, - ruleset=ruleset, - ruleset_uuid=ruleset_uuid, - rule=rule, - rule_uuid=rule_uuid, - rc=rc, - status=status, - run_at=run_at, - matching_events=_get_events(variables), - rule_run_at=rule_run_at, - ) - await event_log.put(result) - - if rc == 0 and (set_facts or post_events): - logger.debug("set_facts") - fact_folder = _get_latest_artifact( - private_data_dir, "fact_cache", False - ) - for host_facts in glob.glob(os.path.join(fact_folder, "*")): - with open(host_facts) as f: - fact = json.loads(f.read()) - fact = _embellish_internal_event(fact, action) - logger.debug("fact %s", fact) - if set_facts: - lang.assert_fact(ruleset, fact) - if post_events: - lang.post(ruleset, fact) - - -async def run_job_template( - event_log, - inventory: str, - hosts: List, - variables: Dict, - project_data_file: str, - source_ruleset_name: str, - source_ruleset_uuid: str, - source_rule_name: str, - source_rule_uuid: str, - rule_run_at: str, - ruleset: str, - name: str, - organization: str, - job_args: Optional[dict] = None, - set_facts: Optional[bool] = None, - post_events: Optional[bool] = None, - verbosity: int = 0, - copy_files: Optional[bool] = False, - json_mode: Optional[bool] = False, - retries: Optional[int] = 0, - retry: Optional[bool] = False, - delay: Optional[int] = 0, - **kwargs, -): - - logger.info( - "running job template: %s, organization: %s", name, organization - ) - logger.info("ruleset: %s, rule %s", source_ruleset_name, source_rule_name) - - hosts_limit = ",".join(hosts) - if not job_args: - job_args = {} - job_args["limit"] = hosts_limit - - job_args["extra_vars"] = _collect_extra_vars( - variables, - job_args.get("extra_vars", {}), - source_ruleset_name, - source_rule_name, - ) - - job_id = str(uuid.uuid4()) - - await event_log.put( - dict( - type="Job", - job_id=job_id, - ansible_rulebook_id=settings.identifier, - name=name, - ruleset=source_ruleset_name, - ruleset_uuid=source_ruleset_uuid, - rule=source_rule_name, - rule_uuid=source_rule_uuid, - hosts=hosts_limit, - action="run_job_template", - ) - ) - - if retry: - retries = max(retries, 1) - - try: - for i in range(retries + 1): - if i > 0: - if delay > 0: - await asyncio.sleep(delay) - logger.info( - "Previous run_job_template failed. Retry %d of %d", - i, - retries, - ) - controller_job = await job_template_runner.run_job_template( - name, - organization, - job_args, - ) - if controller_job["status"] != "failed": - break - except (ControllerApiException, JobTemplateNotFoundException) as ex: - logger.error(ex) - controller_job = {} - controller_job["status"] = "failed" - controller_job["created"] = run_at() - controller_job["error"] = str(ex) - - a_log = dict( - type="Action", - action="run_job_template", - action_uuid=str(uuid.uuid4()), - activation_id=settings.identifier, - job_template_name=name, - organization=organization, - job_id=job_id, - ruleset=ruleset, - ruleset_uuid=source_ruleset_uuid, - rule=source_rule_name, - rule_uuid=source_rule_uuid, - status=controller_job["status"], - run_at=controller_job["created"], - url=_controller_job_url(controller_job), - matching_events=_get_events(variables), - rule_run_at=rule_run_at, - ) - if "error" in controller_job: - a_log["message"] = controller_job["error"] - await event_log.put(a_log) - - if set_facts or post_events: - logger.debug("set_facts") - facts = controller_job["artifacts"] - if facts: - facts = _embellish_internal_event(facts, "run_job_template") - logger.debug("facts %s", facts) - if set_facts: - lang.assert_fact(ruleset, facts) - if post_events: - lang.post(ruleset, facts) - else: - logger.debug("Empty facts are not set") - - -async def shutdown( - event_log, - inventory: str, - hosts: List, - variables: Dict, - project_data_file: str, - source_ruleset_name: str, - source_ruleset_uuid: str, - source_rule_name: str, - source_rule_uuid: str, - rule_run_at: str, - ruleset: str, - delay: float = 60.0, - message: str = "Default shutdown message", - kind: str = "graceful", -): - await event_log.put( - dict( - type="Action", - action="shutdown", - action_uuid=str(uuid.uuid4()), - activation_id=settings.identifier, - ruleset=source_ruleset_name, - ruleset_uuid=source_ruleset_uuid, - rule=source_rule_name, - rule_uuid=source_rule_uuid, - run_at=run_at(), - status=INTERNAL_ACTION_STATUS, - matching_events=_get_events(variables), - delay=delay, - message=message, - kind=kind, - rule_run_at=rule_run_at, - ) - ) - - print( - "Ruleset: %s rule: %s has initiated shutdown of type: %s. " - "Delay: %.3f seconds, Message: %s" - % (source_ruleset_name, source_rule_name, kind, delay, message) - ) - raise ShutdownException(Shutdown(message=message, delay=delay, kind=kind)) - - -actions: Dict[str, Callable] = dict( - none=none, - debug=debug, - print_event=print_event, - set_fact=set_fact, - retract_fact=retract_fact, - post_event=post_event, - run_playbook=run_playbook, - run_module=run_module, - run_job_template=run_job_template, - shutdown=shutdown, -) - - -def _get_latest_artifact(data_dir: str, artifact: str, content: bool = True): - files = glob.glob(os.path.join(data_dir, "artifacts", "*", artifact)) - files.sort(key=os.path.getmtime, reverse=True) - if not files: - raise PlaybookStatusNotFoundException(f"No {artifact} file found") - if content: - with open(files[0], "r") as f: - content = f.read() - return content - return files[0] - - -def _get_events(variables: Dict): - if "event" in variables: - return {"m": variables["event"]} - elif "events" in variables: - return variables["events"] - return {} - - -def _collect_extra_vars( - variables: Dict, user_extra_vars: Dict, ruleset: str, rule: str -): - extra_vars = user_extra_vars.copy() if user_extra_vars else {} - eda_vars = dict(ruleset=ruleset, rule=rule) - if "events" in variables: - eda_vars["events"] = variables["events"] - if "event" in variables: - eda_vars["event"] = variables["event"] - extra_vars[KEY_EDA_VARS] = eda_vars - return extra_vars - - -def _embellish_internal_event(event: Dict, method_name: str) -> Dict: - return insert_meta( - event, **dict(source_name=method_name, source_type="internal") - ) - - -def _controller_job_url(data: dict) -> str: - if "id" in data: - return f"{job_template_runner.host}/#/jobs/{data['id']}/details" - return "" diff --git a/ansible_rulebook/rule_set_runner.py b/ansible_rulebook/rule_set_runner.py index ecf7a28a8..52e4e58ae 100644 --- a/ansible_rulebook/rule_set_runner.py +++ b/ansible_rulebook/rule_set_runner.py @@ -27,7 +27,18 @@ MessageObservedException, ) -from ansible_rulebook.builtin import actions as builtin_actions +from ansible_rulebook.action.control import Control +from ansible_rulebook.action.debug import Debug +from ansible_rulebook.action.metadata import Metadata +from ansible_rulebook.action.noop import Noop +from ansible_rulebook.action.post_event import PostEvent +from ansible_rulebook.action.print_event import PrintEvent +from ansible_rulebook.action.retract_fact import RetractFact +from ansible_rulebook.action.run_job_template import RunJobTemplate +from ansible_rulebook.action.run_module import RunModule +from ansible_rulebook.action.run_playbook import RunPlaybook +from ansible_rulebook.action.set_fact import SetFact +from ansible_rulebook.action.shutdown import Shutdown as ShutdownAction from ansible_rulebook.conf import settings from ansible_rulebook.exception import ( ShutdownException, @@ -49,6 +60,19 @@ logger = logging.getLogger(__name__) +ACTION_CLASSES = { + "debug": Debug, + "print_event": PrintEvent, + "none": Noop, + "set_fact": SetFact, + "post_event": PostEvent, + "retract_fact": RetractFact, + "shutdown": ShutdownAction, + "run_playbook": RunPlaybook, + "run_module": RunModule, + "run_job_template": RunJobTemplate, +} + class RuleSetRunner: def __init__( @@ -286,13 +310,17 @@ def _run_action( f"{action_item.rule}" ) logger.debug("Creating action task %s", task_name) + metadata = Metadata( + rule_set=action_item.ruleset, + rule_set_uuid=action_item.ruleset_uuid, + rule=action_item.rule, + rule_uuid=action_item.rule_uuid, + rule_run_at=rule_run_at, + ) + task = asyncio.create_task( self._call_action( - action_item.ruleset, - action_item.ruleset_uuid, - action_item.rule, - action_item.rule_uuid, - rule_run_at, + metadata, action.action, MappingProxyType(action.action_args), action_item.variables, @@ -308,11 +336,7 @@ def _run_action( async def _call_action( self, - ruleset: str, - ruleset_uuid: str, - rule: str, - rule_uuid: str, - rule_run_at: str, + metadata: Metadata, action: str, immutable_action_args: MappingProxyType, variables: Dict, @@ -324,7 +348,7 @@ async def _call_action( action_args = immutable_action_args.copy() error = None - if action in builtin_actions: + if action in ACTION_CLASSES: try: if action == "run_job_template": limit = dpath.get( @@ -389,21 +413,20 @@ async def _call_action( logger.info("action args: %s", action_args) if "ruleset" not in action_args: - action_args["ruleset"] = ruleset + action_args["ruleset"] = metadata.rule_set - await builtin_actions[action]( + control = Control( event_log=self.event_log, inventory=inventory, hosts=hosts, variables=variables_copy, project_data_file=self.project_data_file, - source_ruleset_name=ruleset, - source_ruleset_uuid=ruleset_uuid, - source_rule_name=rule, - source_rule_uuid=rule_uuid, - rule_run_at=rule_run_at, - **action_args, ) + + await ACTION_CLASSES[action]( + metadata, control, **action_args + )() + except KeyError as e: logger.error( "KeyError %s with variables %s", @@ -451,12 +474,12 @@ async def _call_action( playbook_name=action_args.get("name"), status="failed", run_at=run_at(), - rule_run_at=rule_run_at, + rule_run_at=metadata.rule_run_at, message=str(error), - rule=rule, - ruleset=ruleset, - rule_uuid=rule_uuid, - ruleset_uuid=ruleset_uuid, + rule=metadata.rule, + ruleset=metadata.rule_set, + rule_uuid=metadata.rule_uuid, + ruleset_uuid=metadata.rule_set_uuid, ) ) diff --git a/docs/action_classes.md b/docs/action_classes.md new file mode 100644 index 000000000..f0b24eb60 --- /dev/null +++ b/docs/action_classes.md @@ -0,0 +1,22 @@ + classDiagram + Animal <|-- Duck + Animal <|-- Fish + Animal <|-- Zebra + Animal : +int age + Animal : +String gender + Animal: +isMammal() + Animal: +mate() + class Duck{ + +String beakColor + +swim() + +quack() + } + class Fish{ + -int sizeInFeet + -canEat() + } + class Zebra{ + +bool is_wild + +run() + } + diff --git a/tests/e2e/test_actions.py b/tests/e2e/test_actions.py index 505b0a999..de893601f 100644 --- a/tests/e2e/test_actions.py +++ b/tests/e2e/test_actions.py @@ -102,8 +102,8 @@ def test_actions_sanity(update_environment): "'hosts': ['all']", f"'inventory': {inventory_data}", "'project_data_file': None,", - "'ruleset': 'Test actions sanity'", - "'source_rule_name': 'debug',", + "'rule_set': 'Test actions sanity'", + "'rule': 'debug',", f"'variables': {{'DEFAULT_EVENT_DELAY': '{DEFAULT_EVENT_DELAY}'", f"'DEFAULT_SHUTDOWN_AFTER': '{DEFAULT_SHUTDOWN_AFTER}',", f"'DEFAULT_STARTUP_DELAY': '{DEFAULT_STARTUP_DELAY}'", @@ -123,8 +123,8 @@ def test_actions_sanity(update_environment): r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}" ) expected_debug_regexs = [ - r"'source_rule_uuid':" + f" '{uuid_regex}'", - r"'source_ruleset_uuid':" + f" '{uuid_regex}'", + r"'rule_uuid':" + f" '{uuid_regex}'", + r"'rule_set_uuid':" + f" '{uuid_regex}'", r"'uuid': " + f"'{uuid_regex}'" + r"}}}}", ] @@ -177,7 +177,7 @@ def test_actions_sanity(update_environment): ), "multiple_action action failed" assert ( - len(result.stdout.splitlines()) == 56 + len(result.stdout.splitlines()) == 55 ), "unexpected output from the rulebook" diff --git a/tests/examples/69_enhanced_debug.yml b/tests/examples/69_enhanced_debug.yml index a38a65a59..ab45a9cba 100644 --- a/tests/examples/69_enhanced_debug.yml +++ b/tests/examples/69_enhanced_debug.yml @@ -23,11 +23,6 @@ - "Hello World {{ event }}" - "Hello Java" - "Hello Java again {{ event }}" - - name: r4 - condition: event.i == 3 - action: - debug: - var: event.does_not_exist - name: r5 condition: event.i == 4 action: diff --git a/tests/test_examples.py b/tests/test_examples.py index 79100895b..56190d045 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -2094,7 +2094,8 @@ async def test_46_job_template(): job_url = "https://examples.com/#/jobs/945/details" with SourceTask(rs.sources[0], "sources", {}, queue): with patch( - "ansible_rulebook.builtin.job_template_runner.run_job_template", + "ansible_rulebook.action.run_job_template." + "job_template_runner.run_job_template", return_value=response_obj, ): await run_rulesets( @@ -2129,7 +2130,8 @@ async def test_46_job_template_exception(err_msg, err): rs = ruleset_queues[0][0] with SourceTask(rs.sources[0], "sources", {}, queue): with patch( - "ansible_rulebook.builtin.job_template_runner.run_job_template", + "ansible_rulebook.action.run_job_template." + "job_template_runner.run_job_template", side_effect=err, ): await run_rulesets( diff --git a/tests/unit/action/__init__.py b/tests/unit/action/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/action/playbooks/fail.yml b/tests/unit/action/playbooks/fail.yml new file mode 100644 index 000000000..7d470e067 --- /dev/null +++ b/tests/unit/action/playbooks/fail.yml @@ -0,0 +1,8 @@ +- name: Fail the rule + hosts: all + gather_facts: false + tasks: + - name: Fail if we have a rule name defined + when: ansible_eda.rule is defined + ansible.builtin.fail: + msg: "Failed because of Rule name: {{ ansible_eda.rule }}" diff --git a/tests/unit/action/playbooks/rule_name.yml b/tests/unit/action/playbooks/rule_name.yml new file mode 100644 index 000000000..f41872085 --- /dev/null +++ b/tests/unit/action/playbooks/rule_name.yml @@ -0,0 +1,14 @@ +- name: Print rule name that called this playbook + hosts: all + gather_facts: false + tasks: + - name: Print rule name + when: ansible_eda.rule is defined + ansible.builtin.debug: + msg: "Rule name: {{ ansible_eda.rule }}" + - name: Set the RuleName as a fact + ansible.builtin.set_fact: + results: + my_rule_name: "{{ ansible_eda.rule }}" + my_rule_set_name: "{{ ansible_eda.ruleset }}" + cacheable: true diff --git a/tests/unit/action/test_debug.py b/tests/unit/action/test_debug.py new file mode 100644 index 000000000..29b95fcce --- /dev/null +++ b/tests/unit/action/test_debug.py @@ -0,0 +1,179 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from unittest.mock import patch + +import pytest +from freezegun import freeze_time + +from ansible_rulebook.action.base_action import INTERNAL_ACTION_STATUS +from ansible_rulebook.action.control import Control +from ansible_rulebook.action.debug import Debug +from ansible_rulebook.action.metadata import Metadata +from ansible_rulebook.conf import settings + +DUMMY_UUID = "eb7de03f-6f8f-4943-b69e-3c90db346edf" +RULE_UUID = "abcdef3f-6f8f-4943-b69e-3c90db346edf" +RULE_SET_UUID = "00aabbcc-1111-2222-b69e-3c90db346edf" +RULE_RUN_AT = "2023-06-11T12:13:10Z" +ACTION_RUN_AT = "2023-06-11T12:13:14Z" +REQUIRED_KEYS = { + "action", + "action_uuid", + "activation_id", + "reason", + "rule_run_at", + "run_at", + "rule", + "ruleset", + "rule_uuid", + "ruleset_uuid", + "status", + "type", + "matching_events", +} + + +def _validate(queue, metadata): + while not queue.empty(): + event = queue.get_nowait() + if event["type"] == "Action": + action = event + + assert action["action"] == "debug" + assert action["action_uuid"] == DUMMY_UUID + assert action["activation_id"] == settings.identifier + assert action["run_at"] == ACTION_RUN_AT + assert action["rule_run_at"] == metadata.rule_run_at + assert action["rule"] == metadata.rule + assert action["ruleset"] == metadata.rule_set + assert action["rule_uuid"] == metadata.rule_uuid + assert action["ruleset_uuid"] == metadata.rule_set_uuid + assert action["status"] == INTERNAL_ACTION_STATUS + assert action["type"] == "Action" + assert action["matching_events"] == {"m": {"a": 1}} + assert len(set(action.keys()).difference(REQUIRED_KEYS)) == 0 + + +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.asyncio +async def test_debug(): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"event": {"a": 1}}, + project_data_file="", + ) + action_args = {} + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + with patch( + "ansible_rulebook.action.run_job_template.lang.get_facts", + return_value={"a": 1}, + ) as drools_mock: + await Debug(metadata, control, **action_args)() + drools_mock.assert_called_once() + + _validate(queue, metadata) + + +MSG_DATA = [ + ("msg", "Simple Message"), + ("msg", ["First Message", "Second Message"]), +] + + +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.parametrize("mtype, arg", MSG_DATA) +@pytest.mark.asyncio +async def test_debug_msg(mtype, arg): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"event": {"a": 1}}, + project_data_file="", + ) + action_args = {mtype: arg} + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + await Debug(metadata, control, **action_args)() + + _validate(queue, metadata) + + +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.asyncio +async def test_debug_var(): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"abc": {"xyz": 1}, "event": {"a": 1}}, + project_data_file="", + ) + action_args = {"var": "abc.xyz"} + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + await Debug(metadata, control, **action_args)() + + _validate(queue, metadata) + + +@pytest.mark.asyncio +async def test_debug_var_missing_key(): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"abc": {"xyz": 1}, "event": {"a": 1}}, + project_data_file="", + ) + action_args = {"var": "abc.klm"} + + with pytest.raises(KeyError): + await Debug(metadata, control, **action_args)() diff --git a/tests/unit/action/test_noop.py b/tests/unit/action/test_noop.py new file mode 100644 index 000000000..05e2dd055 --- /dev/null +++ b/tests/unit/action/test_noop.py @@ -0,0 +1,89 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from unittest.mock import patch + +import pytest +from freezegun import freeze_time + +from ansible_rulebook.action.base_action import INTERNAL_ACTION_STATUS +from ansible_rulebook.action.control import Control +from ansible_rulebook.action.metadata import Metadata +from ansible_rulebook.action.noop import Noop +from ansible_rulebook.conf import settings + +DUMMY_UUID = "eb7de03f-6f8f-4943-b69e-3c90db346edf" +RULE_UUID = "abcdef3f-6f8f-4943-b69e-3c90db346edf" +RULE_SET_UUID = "00aabbcc-1111-2222-b69e-3c90db346edf" +RULE_RUN_AT = "2023-06-11T12:13:10Z" +ACTION_RUN_AT = "2023-06-11T12:13:14Z" + + +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.asyncio +async def test_noop(): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"event": {"a": 1}}, + project_data_file="", + ) + action_args = {} + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + await Noop(metadata, control, **action_args)() + + while not queue.empty(): + event = queue.get_nowait() + if event["type"] == "Action": + action = event + + required_keys = { + "action", + "action_uuid", + "activation_id", + "reason", + "rule_run_at", + "run_at", + "rule", + "ruleset", + "rule_uuid", + "ruleset_uuid", + "status", + "type", + "matching_events", + } + assert action["action"] == "noop" + assert action["action_uuid"] == DUMMY_UUID + assert action["activation_id"] == settings.identifier + assert action["run_at"] == ACTION_RUN_AT + assert action["rule_run_at"] == metadata.rule_run_at + assert action["rule"] == metadata.rule + assert action["ruleset"] == metadata.rule_set + assert action["rule_uuid"] == metadata.rule_uuid + assert action["ruleset_uuid"] == metadata.rule_set_uuid + assert action["status"] == INTERNAL_ACTION_STATUS + assert action["type"] == "Action" + assert action["matching_events"] == {"m": {"a": 1}} + + assert len(set(action.keys()).difference(required_keys)) == 0 diff --git a/tests/unit/action/test_post_event.py b/tests/unit/action/test_post_event.py new file mode 100644 index 000000000..57b45e3d3 --- /dev/null +++ b/tests/unit/action/test_post_event.py @@ -0,0 +1,95 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from unittest.mock import patch + +import pytest +from freezegun import freeze_time + +from ansible_rulebook.action.base_action import INTERNAL_ACTION_STATUS +from ansible_rulebook.action.control import Control +from ansible_rulebook.action.metadata import Metadata +from ansible_rulebook.action.post_event import PostEvent +from ansible_rulebook.conf import settings + +DUMMY_UUID = "eb7de03f-6f8f-4943-b69e-3c90db346edf" +RULE_UUID = "abcdef3f-6f8f-4943-b69e-3c90db346edf" +RULE_SET_UUID = "00aabbcc-1111-2222-b69e-3c90db346edf" +RULE_RUN_AT = "2023-06-11T12:13:10Z" +ACTION_RUN_AT = "2023-06-11T12:13:14Z" + + +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.asyncio +async def test_noop(): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"event": {"a": 1}}, + project_data_file="", + ) + action_args = {"event": {"b": 1}, "ruleset": metadata.rule_set} + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + with patch( + "ansible_rulebook.action.run_job_template.lang.post" + ) as drools_mock: + await PostEvent(metadata, control, **action_args)() + drools_mock.assert_called_once_with( + action_args["ruleset"], action_args["event"] + ) + + while not queue.empty(): + event = queue.get_nowait() + if event["type"] == "Action": + action = event + + required_keys = { + "action", + "action_uuid", + "activation_id", + "reason", + "rule_run_at", + "run_at", + "rule", + "ruleset", + "rule_uuid", + "ruleset_uuid", + "status", + "type", + "matching_events", + } + assert action["action"] == "post_event" + assert action["action_uuid"] == DUMMY_UUID + assert action["activation_id"] == settings.identifier + assert action["run_at"] == ACTION_RUN_AT + assert action["rule_run_at"] == metadata.rule_run_at + assert action["rule"] == metadata.rule + assert action["ruleset"] == metadata.rule_set + assert action["rule_uuid"] == metadata.rule_uuid + assert action["ruleset_uuid"] == metadata.rule_set_uuid + assert action["status"] == INTERNAL_ACTION_STATUS + assert action["type"] == "Action" + assert action["matching_events"] == {"m": {"a": 1}} + + assert len(set(action.keys()).difference(required_keys)) == 0 diff --git a/tests/unit/action/test_print_event.py b/tests/unit/action/test_print_event.py new file mode 100644 index 000000000..c8c339d71 --- /dev/null +++ b/tests/unit/action/test_print_event.py @@ -0,0 +1,89 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from unittest.mock import patch + +import pytest +from freezegun import freeze_time + +from ansible_rulebook.action.base_action import INTERNAL_ACTION_STATUS +from ansible_rulebook.action.control import Control +from ansible_rulebook.action.metadata import Metadata +from ansible_rulebook.action.print_event import PrintEvent +from ansible_rulebook.conf import settings + +DUMMY_UUID = "eb7de03f-6f8f-4943-b69e-3c90db346edf" +RULE_UUID = "abcdef3f-6f8f-4943-b69e-3c90db346edf" +RULE_SET_UUID = "00aabbcc-1111-2222-b69e-3c90db346edf" +RULE_RUN_AT = "2023-06-11T12:13:10Z" +ACTION_RUN_AT = "2023-06-11T12:13:14Z" + + +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.asyncio +async def test_print_event(): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"event": {"a": 1}}, + project_data_file="", + ) + action_args = dict(pretty=True) + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + await PrintEvent(metadata, control, **action_args)() + + while not queue.empty(): + event = queue.get_nowait() + if event["type"] == "Action": + action = event + + required_keys = { + "action", + "action_uuid", + "activation_id", + "reason", + "rule_run_at", + "run_at", + "rule", + "ruleset", + "rule_uuid", + "ruleset_uuid", + "status", + "type", + "matching_events", + } + assert action["action"] == "print_event" + assert action["action_uuid"] == DUMMY_UUID + assert action["activation_id"] == settings.identifier + assert action["run_at"] == ACTION_RUN_AT + assert action["rule_run_at"] == metadata.rule_run_at + assert action["rule"] == metadata.rule + assert action["ruleset"] == metadata.rule_set + assert action["rule_uuid"] == metadata.rule_uuid + assert action["ruleset_uuid"] == metadata.rule_set_uuid + assert action["status"] == INTERNAL_ACTION_STATUS + assert action["type"] == "Action" + assert action["matching_events"] == {"m": {"a": 1}} + + assert len(set(action.keys()).difference(required_keys)) == 0 diff --git a/tests/unit/action/test_retract_fact.py b/tests/unit/action/test_retract_fact.py new file mode 100644 index 000000000..7aa175e5f --- /dev/null +++ b/tests/unit/action/test_retract_fact.py @@ -0,0 +1,106 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from unittest.mock import patch + +import pytest +from freezegun import freeze_time + +from ansible_rulebook.action.base_action import INTERNAL_ACTION_STATUS +from ansible_rulebook.action.control import Control +from ansible_rulebook.action.metadata import Metadata +from ansible_rulebook.action.retract_fact import RetractFact +from ansible_rulebook.conf import settings + +DUMMY_UUID = "eb7de03f-6f8f-4943-b69e-3c90db346edf" +RULE_UUID = "abcdef3f-6f8f-4943-b69e-3c90db346edf" +RULE_SET_UUID = "00aabbcc-1111-2222-b69e-3c90db346edf" +RULE_RUN_AT = "2023-06-11T12:13:10Z" +ACTION_RUN_AT = "2023-06-11T12:13:14Z" + +TEST_DATA = [(True, []), (False, ["meta"])] + + +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.parametrize("partial,keys_excluded", TEST_DATA) +@pytest.mark.asyncio +async def test_retract_fact(partial, keys_excluded): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"event": {"a": 1}}, + project_data_file="", + ) + action_args = { + "fact": {"b": 1}, + "ruleset": metadata.rule_set, + "partial": partial, + } + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + with patch( + "ansible_rulebook.action.run_job_template." + "lang.retract_matching_facts" + ) as drools_mock: + await RetractFact(metadata, control, **action_args)() + drools_mock.assert_called_once_with( + action_args["ruleset"], + action_args["fact"], + partial, + keys_excluded, + ) + + while not queue.empty(): + event = queue.get_nowait() + if event["type"] == "Action": + action = event + + required_keys = { + "action", + "action_uuid", + "activation_id", + "reason", + "rule_run_at", + "run_at", + "rule", + "ruleset", + "rule_uuid", + "ruleset_uuid", + "status", + "type", + "matching_events", + } + assert action["action"] == "retract_fact" + assert action["action_uuid"] == DUMMY_UUID + assert action["activation_id"] == settings.identifier + assert action["run_at"] == ACTION_RUN_AT + assert action["rule_run_at"] == metadata.rule_run_at + assert action["rule"] == metadata.rule + assert action["ruleset"] == metadata.rule_set + assert action["rule_uuid"] == metadata.rule_uuid + assert action["ruleset_uuid"] == metadata.rule_set_uuid + assert action["status"] == INTERNAL_ACTION_STATUS + assert action["type"] == "Action" + assert action["matching_events"] == {"m": {"a": 1}} + + assert len(set(action.keys()).difference(required_keys)) == 0 diff --git a/tests/unit/action/test_run_job_template.py b/tests/unit/action/test_run_job_template.py new file mode 100644 index 000000000..9ac54dab5 --- /dev/null +++ b/tests/unit/action/test_run_job_template.py @@ -0,0 +1,252 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from unittest.mock import patch + +import pytest + +from ansible_rulebook.action.control import Control +from ansible_rulebook.action.metadata import Metadata +from ansible_rulebook.action.run_job_template import RunJobTemplate +from ansible_rulebook.exception import ( + ControllerApiException, + JobTemplateNotFoundException, +) + +JOB_TEMPLATE_ERRORS = [ + ("api error", ControllerApiException("api error")), + ("jt does not exist", JobTemplateNotFoundException("jt does not exist")), +] + + +@pytest.mark.parametrize("err_msg,err", JOB_TEMPLATE_ERRORS) +@pytest.mark.asyncio +async def test_run_job_template_exception(err_msg, err): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid="u1", + rule_set_uuid="u2", + rule_run_at="abc", + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"a": 1}, + project_data_file="", + ) + action_args = { + "name": "fred", + "set_facts": True, + "organization": "Default", + "retries": 0, + "retry": True, + "delay": 0, + } + with patch( + "ansible_rulebook.action.run_job_template." + "job_template_runner.run_job_template", + side_effect=err, + ): + await RunJobTemplate(metadata, control, **action_args)() + + while not queue.empty(): + event = queue.get_nowait() + if event["type"] == "Action": + action = event + + assert action["action"] == "run_job_template" + assert action["reason"] == {"error": err_msg} + required_keys = { + "action", + "action_uuid", + "activation_id", + "reason", + "rule_run_at", + "run_at", + "rule", + "ruleset", + "rule_uuid", + "ruleset_uuid", + "status", + "type", + } + assert set(action.keys()).issuperset(required_keys) + + +DROOLS_CALLS = [ + ( + "ansible_rulebook.action.run_job_template.lang.assert_fact", + dict(set_facts=True), + ), + ( + "ansible_rulebook.action.run_job_template.lang.post", + dict(post_events=True), + ), +] + + +@pytest.mark.parametrize("drools_call,additional_args", DROOLS_CALLS) +@pytest.mark.asyncio +async def test_run_job_template(drools_call, additional_args): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid="u1", + rule_set_uuid="u2", + rule_run_at="abc", + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"a": 1}, + project_data_file="", + ) + action_args = { + "name": "fred", + "organization": "Default", + "retries": 0, + "retry": True, + "delay": 0, + } + action_args.update(additional_args) + controller_job = { + "status": "failed", + "rc": 0, + "artifacts": dict(b=1), + "created": "abc", + "id": 10, + } + with patch( + "ansible_rulebook.action.run_job_template." + "job_template_runner.run_job_template", + return_value=controller_job, + ): + with patch(drools_call) as drools_mock: + await RunJobTemplate(metadata, control, **action_args)() + drools_mock.assert_called_once() + + while not queue.empty(): + event = queue.get_nowait() + if event["type"] == "Action": + action = event + + assert action["action"] == "run_job_template" + required_keys = { + "action", + "action_uuid", + "activation_id", + "reason", + "rule_run_at", + "run_at", + "rule", + "ruleset", + "rule_uuid", + "ruleset_uuid", + "status", + "type", + "job_template_name", + "matching_events", + "job_id", + "url", + "organization", + } + x = set(action.keys()).difference(required_keys) + assert len(x) == 0 + + +@pytest.mark.asyncio +async def test_run_job_template_retries(): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid="u1", + rule_set_uuid="u2", + rule_run_at="abc", + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"a": 1}, + project_data_file="", + ) + action_args = { + "name": "fred", + "organization": "Default", + "retries": 1, + "retry": True, + "delay": 1, + "set_facts": True, + } + controller_job = [ + { + "status": "failed", + "rc": 0, + "artifacts": dict(b=1), + "created": "abc", + "id": 10, + }, + { + "status": "success", + "rc": 0, + "artifacts": dict(b=1), + "created": "abc", + "id": 10, + }, + ] + + with patch( + "ansible_rulebook.action.run_job_template." + "job_template_runner.run_job_template", + side_effect=controller_job, + ): + with patch( + "ansible_rulebook.action.run_job_template.lang.assert_fact" + ) as drools_mock: + await RunJobTemplate(metadata, control, **action_args)() + drools_mock.assert_called_once() + + while not queue.empty(): + event = queue.get_nowait() + if event["type"] == "Action": + action = event + + assert action["action"] == "run_job_template" + required_keys = { + "action", + "action_uuid", + "activation_id", + "reason", + "rule_run_at", + "run_at", + "rule", + "ruleset", + "rule_uuid", + "ruleset_uuid", + "status", + "type", + "job_template_name", + "matching_events", + "job_id", + "url", + "organization", + } + x = set(action.keys()).difference(required_keys) + assert len(x) == 0 diff --git a/tests/unit/action/test_run_module.py b/tests/unit/action/test_run_module.py new file mode 100644 index 000000000..019fcc307 --- /dev/null +++ b/tests/unit/action/test_run_module.py @@ -0,0 +1,99 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from unittest.mock import patch + +import pytest +from freezegun import freeze_time + +from ansible_rulebook.action.control import Control +from ansible_rulebook.action.metadata import Metadata +from ansible_rulebook.action.run_module import RunModule +from ansible_rulebook.conf import settings + +DUMMY_UUID = "eb7de03f-6f8f-4943-b69e-3c90db346edf" +RULE_UUID = "abcdef3f-6f8f-4943-b69e-3c90db346edf" +RULE_SET_UUID = "00aabbcc-1111-2222-b69e-3c90db346edf" +RULE_RUN_AT = "2023-06-11T12:13:10Z" +ACTION_RUN_AT = "2023-06-11T12:13:14Z" + + +def _validate(queue, metadata, status, rc): + while not queue.empty(): + event = queue.get_nowait() + if event["type"] == "Action": + action = event + + required_keys = { + "action", + "action_uuid", + "activation_id", + "reason", + "rule_run_at", + "run_at", + "rule", + "ruleset", + "rule_uuid", + "ruleset_uuid", + "status", + "type", + "matching_events", + "job_id", + "playbook_name", + "rc", + } + assert action["action"] == "run_module" + assert action["action_uuid"] == DUMMY_UUID + assert action["activation_id"] == settings.identifier + assert action["run_at"] == ACTION_RUN_AT + assert action["rule_run_at"] == metadata.rule_run_at + assert action["rule"] == metadata.rule + assert action["ruleset"] == metadata.rule_set + assert action["rule_uuid"] == metadata.rule_uuid + assert action["ruleset_uuid"] == metadata.rule_set_uuid + assert action["status"] == status + assert action["rc"] == rc + assert action["type"] == "Action" + assert action["matching_events"] == {"m_0": {"a": 1}, "m_1": {"b": 2}} + + assert len(set(action.keys()).difference(required_keys)) == 0 + + +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.asyncio +async def test_run_module(): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["localhost"], + variables={"events": {"m_0": {"a": 1}, "m_1": {"b": 2}}}, + project_data_file="", + ) + action_args = { + "module_args": {"name": "Fred Flintstone"}, + "name": "ansible.eda.upcase", + } + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + await RunModule(metadata, control, **action_args)() + + _validate(queue, metadata, "successful", 0) diff --git a/tests/unit/action/test_run_playbook.py b/tests/unit/action/test_run_playbook.py new file mode 100644 index 000000000..8ce7dfadb --- /dev/null +++ b/tests/unit/action/test_run_playbook.py @@ -0,0 +1,196 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import os +from unittest.mock import patch + +import pytest +from freezegun import freeze_time + +from ansible_rulebook.action.control import Control +from ansible_rulebook.action.metadata import Metadata +from ansible_rulebook.action.run_playbook import RunPlaybook +from ansible_rulebook.conf import settings +from ansible_rulebook.exception import PlaybookNotFoundException + +DUMMY_UUID = "eb7de03f-6f8f-4943-b69e-3c90db346edf" +RULE_UUID = "abcdef3f-6f8f-4943-b69e-3c90db346edf" +RULE_SET_UUID = "00aabbcc-1111-2222-b69e-3c90db346edf" +RULE_RUN_AT = "2023-06-11T12:13:10Z" +ACTION_RUN_AT = "2023-06-11T12:13:14Z" + + +def _validate(queue, metadata, status, rc): + while not queue.empty(): + event = queue.get_nowait() + if event["type"] == "Action": + action = event + + required_keys = { + "action", + "action_uuid", + "activation_id", + "reason", + "rule_run_at", + "run_at", + "rule", + "ruleset", + "rule_uuid", + "ruleset_uuid", + "status", + "type", + "matching_events", + "job_id", + "playbook_name", + "rc", + } + assert action["action"] == "run_playbook" + assert action["action_uuid"] == DUMMY_UUID + assert action["activation_id"] == settings.identifier + assert action["run_at"] == ACTION_RUN_AT + assert action["rule_run_at"] == metadata.rule_run_at + assert action["rule"] == metadata.rule + assert action["ruleset"] == metadata.rule_set + assert action["rule_uuid"] == metadata.rule_uuid + assert action["ruleset_uuid"] == metadata.rule_set_uuid + assert action["status"] == status + assert action["rc"] == rc + assert action["type"] == "Action" + assert action["matching_events"] == {"m": {"a": 1}} + + assert len(set(action.keys()).difference(required_keys)) == 0 + + +HERE = os.path.dirname(os.path.abspath(__file__)) + +DROOLS_CALLS = [ + ( + "ansible_rulebook.action.run_job_template.lang.assert_fact", + dict(set_facts=True), + ), + ( + "ansible_rulebook.action.run_job_template.lang.post", + dict(post_events=True), + ), +] + + +@pytest.mark.parametrize("drools_call,additional_args", DROOLS_CALLS) +@pytest.mark.asyncio +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.asyncio +async def test_run_playbook(drools_call, additional_args): + os.chdir(HERE) + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"event": {"a": 1}}, + project_data_file="", + ) + action_args = { + "ruleset": metadata.rule_set, + "name": "./playbooks/rule_name.yml", + } + action_args.update(additional_args) + + set_fact_args = { + "results": { + "my_rule_name": metadata.rule, + "my_rule_set_name": metadata.rule_set, + }, + "meta": { + "source": {"name": "run_playbook", "type": "internal"}, + "received_at": ACTION_RUN_AT, + "uuid": DUMMY_UUID, + }, + } + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + with patch(drools_call) as drools_mock: + await RunPlaybook(metadata, control, **action_args)() + drools_mock.assert_called_once_with( + action_args["ruleset"], set_fact_args + ) + + _validate(queue, metadata, "successful", 0) + + +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.asyncio +async def test_run_playbook_missing(): + os.chdir(HERE) + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"event": {"a": 1}}, + project_data_file="", + ) + action_args = { + "ruleset": metadata.rule_set, + "name": "./playbooks/does_not_exist.yml", + "set_facts": True, + } + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + with pytest.raises(PlaybookNotFoundException): + await RunPlaybook(metadata, control, **action_args)() + + +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.asyncio +async def test_run_playbook_fail(): + os.chdir(HERE) + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"event": {"a": 1}}, + project_data_file="", + ) + action_args = { + "ruleset": metadata.rule_set, + "name": "./playbooks/fail.yml", + "set_facts": True, + } + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + await RunPlaybook(metadata, control, **action_args)() + + _validate(queue, metadata, "failed", 2) diff --git a/tests/unit/action/test_set_fact.py b/tests/unit/action/test_set_fact.py new file mode 100644 index 000000000..6f22ec4b6 --- /dev/null +++ b/tests/unit/action/test_set_fact.py @@ -0,0 +1,95 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from unittest.mock import patch + +import pytest +from freezegun import freeze_time + +from ansible_rulebook.action.base_action import INTERNAL_ACTION_STATUS +from ansible_rulebook.action.control import Control +from ansible_rulebook.action.metadata import Metadata +from ansible_rulebook.action.set_fact import SetFact +from ansible_rulebook.conf import settings + +DUMMY_UUID = "eb7de03f-6f8f-4943-b69e-3c90db346edf" +RULE_UUID = "abcdef3f-6f8f-4943-b69e-3c90db346edf" +RULE_SET_UUID = "00aabbcc-1111-2222-b69e-3c90db346edf" +RULE_RUN_AT = "2023-06-11T12:13:10Z" +ACTION_RUN_AT = "2023-06-11T12:13:14Z" + + +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.asyncio +async def test_noop(): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"event": {"a": 1}}, + project_data_file="", + ) + action_args = {"fact": {"b": 1}, "ruleset": metadata.rule_set} + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + with patch( + "ansible_rulebook.action.run_job_template.lang.assert_fact" + ) as drools_mock: + await SetFact(metadata, control, **action_args)() + drools_mock.assert_called_once_with( + action_args["ruleset"], action_args["fact"] + ) + + while not queue.empty(): + event = queue.get_nowait() + if event["type"] == "Action": + action = event + + required_keys = { + "action", + "action_uuid", + "activation_id", + "reason", + "rule_run_at", + "run_at", + "rule", + "ruleset", + "rule_uuid", + "ruleset_uuid", + "status", + "type", + "matching_events", + } + assert action["action"] == "set_fact" + assert action["action_uuid"] == DUMMY_UUID + assert action["activation_id"] == settings.identifier + assert action["run_at"] == ACTION_RUN_AT + assert action["rule_run_at"] == metadata.rule_run_at + assert action["rule"] == metadata.rule + assert action["ruleset"] == metadata.rule_set + assert action["rule_uuid"] == metadata.rule_uuid + assert action["ruleset_uuid"] == metadata.rule_set_uuid + assert action["status"] == INTERNAL_ACTION_STATUS + assert action["type"] == "Action" + assert action["matching_events"] == {"m": {"a": 1}} + + assert len(set(action.keys()).difference(required_keys)) == 0 diff --git a/tests/unit/action/test_shutdown.py b/tests/unit/action/test_shutdown.py new file mode 100644 index 000000000..37844e17b --- /dev/null +++ b/tests/unit/action/test_shutdown.py @@ -0,0 +1,94 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from unittest.mock import patch + +import pytest +from freezegun import freeze_time + +from ansible_rulebook.action.base_action import INTERNAL_ACTION_STATUS +from ansible_rulebook.action.control import Control +from ansible_rulebook.action.metadata import Metadata +from ansible_rulebook.action.shutdown import Shutdown +from ansible_rulebook.conf import settings +from ansible_rulebook.exception import ShutdownException + +DUMMY_UUID = "eb7de03f-6f8f-4943-b69e-3c90db346edf" +RULE_UUID = "abcdef3f-6f8f-4943-b69e-3c90db346edf" +RULE_SET_UUID = "00aabbcc-1111-2222-b69e-3c90db346edf" +RULE_RUN_AT = "2023-06-11T12:13:10Z" +ACTION_RUN_AT = "2023-06-11T12:13:14Z" + + +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.asyncio +async def test_shutdown(): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + control = Control( + event_log=queue, + inventory="abc", + hosts=["all"], + variables={"event": {"a": 1}}, + project_data_file="", + ) + action_args = dict(delay=60, message="Testing Shutdown") + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + with pytest.raises(ShutdownException): + await Shutdown(metadata, control, **action_args)() + + while not queue.empty(): + event = queue.get_nowait() + if event["type"] == "Action": + action = event + + required_keys = { + "action", + "action_uuid", + "activation_id", + "reason", + "rule_run_at", + "run_at", + "rule", + "ruleset", + "rule_uuid", + "ruleset_uuid", + "status", + "type", + "matching_events", + "delay", + "message", + "kind", + } + assert action["action"] == "shutdown" + assert action["action_uuid"] == DUMMY_UUID + assert action["activation_id"] == settings.identifier + assert action["run_at"] == ACTION_RUN_AT + assert action["rule_run_at"] == metadata.rule_run_at + assert action["rule"] == metadata.rule + assert action["ruleset"] == metadata.rule_set + assert action["rule_uuid"] == metadata.rule_uuid + assert action["ruleset_uuid"] == metadata.rule_set_uuid + assert action["status"] == INTERNAL_ACTION_STATUS + assert action["type"] == "Action" + assert action["matching_events"] == {"m": {"a": 1}} + + assert len(set(action.keys()).difference(required_keys)) == 0