diff --git a/src/aap_eda/core/utils/strings.py b/src/aap_eda/core/utils/strings.py index 8f2da212d..a1a243844 100644 --- a/src/aap_eda/core/utils/strings.py +++ b/src/aap_eda/core/utils/strings.py @@ -41,6 +41,21 @@ def _render_string_or_return_value(value: Any, context: Dict) -> Any: return value +def extract_variables(template_string: str) -> set[str]: + env = jinja2.Environment(autoescape=True) + ast = env.parse(template_string) + variables = set() + + def _extract_variables(node): + if isinstance(node, jinja2.nodes.Name): + variables.add(node.name) + for child in node.iter_child_nodes(): + _extract_variables(child) + + _extract_variables(ast) + return variables + + def substitute_variables( value: Union[str, int, Dict, List], context: Dict ) -> Union[str, int, Dict, List]: diff --git a/src/aap_eda/wsapi/consumers.py b/src/aap_eda/wsapi/consumers.py index 5b9612345..19e0363b0 100644 --- a/src/aap_eda/wsapi/consumers.py +++ b/src/aap_eda/wsapi/consumers.py @@ -22,7 +22,7 @@ ) from aap_eda.core.models.activation import ActivationStatus from aap_eda.core.utils.credentials import get_secret_fields -from aap_eda.core.utils.strings import substitute_variables +from aap_eda.core.utils.strings import extract_variables, substitute_variables from aap_eda.tasks import orchestrator from .messages import ( @@ -43,6 +43,8 @@ logger = logging.getLogger(__name__) +BINARY_FORMATS = {"binary_base64"} + class MessageType(Enum): ACTION = "Action" @@ -143,14 +145,11 @@ async def handle_workers(self, message: WorkerMessage): vault_collection = VaultCollection(data=eda_vault_data) await self.send(text_data=vault_collection.json()) - templates = await self.get_file_contents_from_credentials(activation) - for template, contents in templates.items(): - data = FileContentMessage( - template_key=template, - data=base64.b64encode(contents.encode()).decode(), - eof=True, - ) - await self.send(text_data=data.json()) + file_contents = await self.get_file_contents_from_credentials( + activation + ) + for file_content in file_contents: + await self.send(text_data=file_content.json()) env_var = await self.get_env_vars_from_credentials(activation) if env_var: @@ -498,20 +497,32 @@ def _get_url(self, message: ActionMessage, inputs: dict) -> str: @database_sync_to_async def get_file_contents_from_credentials( self, activation: models.Activation - ) -> tp.Optional[dict]: - file_templates = {} + ) -> tp.Optional[list[FileContentMessage]]: + file_template_names = [] + file_messages = [] for eda_credential in activation.eda_credentials.all(): inputs = yaml.safe_load(eda_credential.inputs.get_secret_value()) injectors = eda_credential.credential_type.injectors + binary_fields = [] + for field in eda_credential.credential_type.inputs.get( + "fields", [] + ): + if field.get("format") in BINARY_FORMATS: + binary_fields.append(field["id"]) + if "file" in injectors: for template, value in injectors["file"].items(): - if template in file_templates: + if template in file_template_names: raise DuplicateFileTemplateKeyError( f"{template} already exists" ) - contents = substitute_variables(value, inputs) - file_templates[template] = str(contents) - return file_templates + file_template_names.append(template) + file_messages.append( + self.get_file_content_message( + template, binary_fields, value, inputs + ) + ) + return file_messages @database_sync_to_async def get_env_vars_from_credentials( @@ -575,3 +586,25 @@ def get_vault_password_and_id( vault_inputs = yaml.safe_load(vault_inputs.get_secret_value()) return vault_inputs["vault_password"], vault_inputs["vault_id"] return None, None + + @staticmethod + def get_file_content_message( + template: str, binary_fields: list[str], value: str, inputs: dict + ) -> FileContentMessage: + binary_file = any( + attr in binary_fields for attr in extract_variables(value) + ) + + contents = str(substitute_variables(value, inputs)) + if binary_file: + return FileContentMessage( + template_key=template, + data=contents, + data_format="binary", + eof=True, + ) + return FileContentMessage( + template_key=template, + data=base64.b64encode(contents.encode()).decode(), + eof=True, + ) diff --git a/tests/integration/wsapi/test_consumer.py b/tests/integration/wsapi/test_consumer.py index c213c42f7..7052f6922 100644 --- a/tests/integration/wsapi/test_consumer.py +++ b/tests/integration/wsapi/test_consumer.py @@ -1164,6 +1164,33 @@ def _create_event(data, uuid): }, ], ), + ( + { + "fields": [ + { + "id": "keytab", + "label": "KeyTab", + "format": "binary_base64", + "secret": True, + }, + ] + }, + { + "file": { + "template.keytab_file": "{{ keytab }}", + }, + }, + { + "keytab": base64.b64encode(bytes([1, 2, 3, 4, 5])).decode(), + }, + [ + { + "data": base64.b64encode(bytes([1, 2, 3, 4, 5])).decode(), + "template_key": "template.keytab_file", + "data_format": "binary", + }, + ], + ), ], ) @pytest.mark.django_db(transaction=True) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 70dbffeb9..1e2787ce1 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -17,6 +17,7 @@ import pytest +from aap_eda.core.utils.strings import extract_variables from aap_eda.utils import get_eda_version, str_to_bool @@ -45,3 +46,26 @@ def test_get_eda_version(): # assert outcome when aap-eda package is not found with patch("importlib.metadata.version", side_effect=PackageNotFoundError): assert get_eda_version() == "unknown" + + +@pytest.mark.parametrize( + "value,expected", + [ + ("simple", set()), + ( + "And this is a {{demo}}", + { + "demo", + }, + ), + ( + "{{var1}} and {{var2}}", + { + "var1", + "var2", + }, + ), + ], +) +def test_extract_variables(value, expected): + assert extract_variables(value) == expected