Skip to content

Commit

Permalink
feat: handle binary files in file injectors (ansible#1151)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkanoor authored and bzwei committed Dec 6, 2024
1 parent 4fff43c commit bd35080
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 16 deletions.
15 changes: 15 additions & 0 deletions src/aap_eda/core/utils/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,21 @@ def _render_string_or_return_value(value: Any, context: Dict) -> Any:
return value


def extract_variables(template_string: str) -> set[str]:
env = jinja2.Environment(autoescape=True)
ast = env.parse(template_string)
variables = set()

def _extract_variables(node):
if isinstance(node, jinja2.nodes.Name):
variables.add(node.name)
for child in node.iter_child_nodes():
_extract_variables(child)

_extract_variables(ast)
return variables


def substitute_variables(
value: Union[str, int, Dict, List], context: Dict
) -> Union[str, int, Dict, List]:
Expand Down
2 changes: 1 addition & 1 deletion src/aap_eda/core/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

from aap_eda.core import enums, models
from aap_eda.core.utils.credentials import (
validate_registry_host_name,
check_reserved_keys_in_extra_vars,
validate_registry_host_name,
validate_schema,
)
from aap_eda.core.utils.k8s_service_name import is_rfc_1035_compliant
Expand Down
63 changes: 48 additions & 15 deletions src/aap_eda/wsapi/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from aap_eda.core.models.activation import ActivationStatus
from aap_eda.core.utils.credentials import get_secret_fields
from aap_eda.core.utils.strings import substitute_variables
from aap_eda.core.utils.strings import extract_variables, substitute_variables
from aap_eda.tasks import orchestrator

from .messages import (
Expand All @@ -43,6 +43,8 @@

logger = logging.getLogger(__name__)

BINARY_FORMATS = {"binary_base64"}


class MessageType(Enum):
ACTION = "Action"
Expand Down Expand Up @@ -143,14 +145,11 @@ async def handle_workers(self, message: WorkerMessage):
vault_collection = VaultCollection(data=eda_vault_data)
await self.send(text_data=vault_collection.json())

templates = await self.get_file_contents_from_credentials(activation)
for template, contents in templates.items():
data = FileContentMessage(
template_key=template,
data=base64.b64encode(contents.encode()).decode(),
eof=True,
)
await self.send(text_data=data.json())
file_contents = await self.get_file_contents_from_credentials(
activation
)
for file_content in file_contents:
await self.send(text_data=file_content.json())

env_var = await self.get_env_vars_from_credentials(activation)
if env_var:
Expand Down Expand Up @@ -498,20 +497,32 @@ def _get_url(self, message: ActionMessage, inputs: dict) -> str:
@database_sync_to_async
def get_file_contents_from_credentials(
self, activation: models.Activation
) -> tp.Optional[dict]:
file_templates = {}
) -> tp.Optional[list[FileContentMessage]]:
file_template_names = []
file_messages = []
for eda_credential in activation.eda_credentials.all():
inputs = yaml.safe_load(eda_credential.inputs.get_secret_value())
injectors = eda_credential.credential_type.injectors
binary_fields = []
for field in eda_credential.credential_type.inputs.get(
"fields", []
):
if field.get("format") in BINARY_FORMATS:
binary_fields.append(field["id"])

if "file" in injectors:
for template, value in injectors["file"].items():
if template in file_templates:
if template in file_template_names:
raise DuplicateFileTemplateKeyError(
f"{template} already exists"
)
contents = substitute_variables(value, inputs)
file_templates[template] = str(contents)
return file_templates
file_template_names.append(template)
file_messages.append(
self.get_file_content_message(
template, binary_fields, value, inputs
)
)
return file_messages

@database_sync_to_async
def get_env_vars_from_credentials(
Expand Down Expand Up @@ -575,3 +586,25 @@ def get_vault_password_and_id(
vault_inputs = yaml.safe_load(vault_inputs.get_secret_value())
return vault_inputs["vault_password"], vault_inputs["vault_id"]
return None, None

@staticmethod
def get_file_content_message(
template: str, binary_fields: list[str], value: str, inputs: dict
) -> FileContentMessage:
binary_file = any(
attr in binary_fields for attr in extract_variables(value)
)

contents = str(substitute_variables(value, inputs))
if binary_file:
return FileContentMessage(
template_key=template,
data=contents,
data_format="binary",
eof=True,
)
return FileContentMessage(
template_key=template,
data=base64.b64encode(contents.encode()).decode(),
eof=True,
)
27 changes: 27 additions & 0 deletions tests/integration/wsapi/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,33 @@ def _create_event(data, uuid):
},
],
),
(
{
"fields": [
{
"id": "keytab",
"label": "KeyTab",
"format": "binary_base64",
"secret": True,
},
]
},
{
"file": {
"template.keytab_file": "{{ keytab }}",
},
},
{
"keytab": base64.b64encode(bytes([1, 2, 3, 4, 5])).decode(),
},
[
{
"data": base64.b64encode(bytes([1, 2, 3, 4, 5])).decode(),
"template_key": "template.keytab_file",
"data_format": "binary",
},
],
),
],
)
@pytest.mark.django_db(transaction=True)
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytest

from aap_eda.core.utils.strings import extract_variables
from aap_eda.utils import get_eda_version, str_to_bool


Expand Down Expand Up @@ -45,3 +46,26 @@ def test_get_eda_version():
# assert outcome when aap-eda package is not found
with patch("importlib.metadata.version", side_effect=PackageNotFoundError):
assert get_eda_version() == "unknown"


@pytest.mark.parametrize(
"value,expected",
[
("simple", set()),
(
"And this is a {{demo}}",
{
"demo",
},
),
(
"{{var1}} and {{var2}}",
{
"var1",
"var2",
},
),
],
)
def test_extract_variables(value, expected):
assert extract_variables(value) == expected

0 comments on commit bd35080

Please sign in to comment.