Skip to content

Commit

Permalink
AWS: Allow setting tags to snapshots or AMIs only
Browse files Browse the repository at this point in the history
This commit introduces two new optional properties on
`AWSPublishingMetadata` named `snapshot_tags` and `ami_tags`, which when
set allows to add tags just for snapshot imports or the registered image
respectively.

Rationale: Currently our team is facing sporadic issues on community
AMIs workflow when uploading the image to different billing types
(`access` or `hourly`) after a retry, as the second attempt will look
for the image name, fail to find, search for tags and find an AMI which
concerns the other billing type. This was initially fixed by
release-engineering/pubtools-marketplacesvm#74 by adding the billing
type as tags. However, the fix brings another potential issue: we will
tag everything with the billing tags, including the S3 object, which
could being unecessarily uploaded twice, as well as having a tag which
doesn't concern the S3 RAW object, but the snapshot/image only.

With this change on `cloudimg` we aim to provide the billing type tags
just for the snapshot and AMI, leaving the S3 object with the previous
ones as it shouldn't receive the billing tag.

This change also allows other teams to have a more granular control on
which tags should go where.

Refers to SPSTRAT-451
  • Loading branch information
JAVGan committed Nov 8, 2024
1 parent b416dc0 commit c49b004
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 4 deletions.
25 changes: 21 additions & 4 deletions cloudimg/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ class AWSPublishingMetadata(PublishingMetadata):
billing_products (list, optional): Billing product identifiers
boot_mode (str, optional): The boot mode for booting up the AMI on EC2.
snapshot_tags (dict, optional): Tags to be applied to the snapshot
import only.
ami_tags (dict, optional): Tags to be applied to the registered AMI
only.
"""

def __init__(self, *args, **kwargs):
Expand All @@ -91,6 +97,8 @@ def __init__(self, *args, **kwargs):
self.billing_products = kwargs.pop('billing_products', None)
bmode_str = kwargs.pop('boot_mode', None) or "not_set"
self.boot_mode = AWSBootMode[bmode_str]
self.snapshot_tags = kwargs.pop('snapshot_tags', None)
self.ami_tags = kwargs.pop('ami_tags', None)

super(AWSPublishingMetadata, self).__init__(*args, **kwargs)

Expand Down Expand Up @@ -558,6 +566,13 @@ def publish(self, metadata):
Returns:
An EC2 Image
"""
def add_tags(tag_parameter_name, extra_kwargs):
new_tags = getattr(metadata, tag_parameter_name, None)
if new_tags:
tags = extra_kwargs.get("tags") or {}
new_tags.update(tags)
extra_kwargs.update({"tags": new_tags})

log.info('Searching for image: %s', metadata.image_name)
image = (
self.get_image_by_name(metadata.image_name) or
Expand All @@ -578,8 +593,7 @@ def publish(self, metadata):

# Set tags when they're provided
extra_kwargs = {}
if metadata.tags:
extra_kwargs.update({"tags": metadata.tags})
add_tags("tags", extra_kwargs)

if not obj:
log.info('Object does not exist: %s', metadata.object_name)
Expand All @@ -590,6 +604,7 @@ def publish(self, metadata):
else:
log.info('Object already exists')

add_tags("snapshot_tags", extra_kwargs)
snapshot = self.import_snapshot(obj,
metadata.snapshot_name,
**extra_kwargs)
Expand Down Expand Up @@ -765,8 +780,10 @@ def register_image(self, snapshot, metadata):
BillingProducts=metadata.billing_products,
**optional_kwargs,
)
if metadata.tags:
self.tag_image(image, metadata.tags)
if metadata.tags or metadata.ami_tags:
tags = metadata.tags or {}
tags.update(metadata.ami_tags or {})
self.tag_image(image, tags)
return image

def share_image(self, image, accounts=[], groups=[]):
Expand Down
106 changes: 106 additions & 0 deletions tests/test_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,85 @@ def test_publish_tags(self,
self.md.object_name,
tags=tags)

@patch('cloudimg.aws.AWSService.upload_to_container')
@patch('cloudimg.aws.AWSService.import_snapshot')
@patch('cloudimg.aws.AWSService.register_image')
@patch('cloudimg.aws.AWSService.share_image')
@patch('cloudimg.aws.AWSService.get_image_by_tags')
@patch('cloudimg.aws.AWSService.get_image_by_name')
@patch('cloudimg.aws.AWSService.get_snapshot_by_name')
@patch('cloudimg.aws.AWSService.get_object_by_name')
def test_publish_snapshot_tags(self,
get_object_by_name,
get_snapshot_by_name,
get_image_by_name,
get_image_by_tags,
share_image,
register_image,
import_snapshot,
upload_to_container):
get_image_by_name.return_value = None
get_image_by_tags.return_value = None
get_snapshot_by_name.return_value = None
get_object_by_name.return_value = None
self.md.tags = None
self.md.snapshot_tags = {"snapshot": "tag"}
published = self.svc.publish(self.md)

share_image.assert_called_once_with(published,
accounts=[],
groups=[])
self.assertEqual(register_image.call_count, 1)
import_snapshot.assert_called_once_with(
ANY,
self.md.snapshot_name,
tags={"snapshot": "tag"},
)
upload_to_container.assert_called_once_with(self.md.image_path,
self.md.container,
self.md.object_name)

@patch('cloudimg.aws.AWSService.upload_to_container')
@patch('cloudimg.aws.AWSService.import_snapshot')
@patch('cloudimg.aws.AWSService.register_image')
@patch('cloudimg.aws.AWSService.share_image')
@patch('cloudimg.aws.AWSService.get_image_by_tags')
@patch('cloudimg.aws.AWSService.get_image_by_name')
@patch('cloudimg.aws.AWSService.get_snapshot_by_name')
@patch('cloudimg.aws.AWSService.get_object_by_name')
def test_publish_merged_snapshot_tags(
self,
get_object_by_name,
get_snapshot_by_name,
get_image_by_name,
get_image_by_tags,
share_image,
register_image,
import_snapshot,
upload_to_container,
):
get_image_by_name.return_value = None
get_image_by_tags.return_value = None
get_snapshot_by_name.return_value = None
get_object_by_name.return_value = None
self.md.tags = {"foo": "bar"}
self.md.snapshot_tags = {"snapshot": "tag"}
published = self.svc.publish(self.md)

share_image.assert_called_once_with(published,
accounts=[],
groups=[])
self.assertEqual(register_image.call_count, 1)
import_snapshot.assert_called_once_with(
ANY,
self.md.snapshot_name,
tags={"foo": "bar", "snapshot": "tag"},
)
upload_to_container.assert_called_once_with(self.md.image_path,
self.md.container,
self.md.object_name,
tags={"foo": "bar"})

@patch('cloudimg.aws.AWSService.tag_image')
def test_register_image_no_tags(self, tag_image):
self.mock_register_image.return_value = "fakeimg"
Expand Down Expand Up @@ -741,6 +820,33 @@ def test_register_image_tags(self, tag_image):
tag_image.assert_called_once_with("fakeimg", self.md.tags)
self.assertEqual(res, "fakeimg")

@patch('cloudimg.aws.AWSService.tag_image')
def test_register_image_ami_tags(self, tag_image):
self.md.tags = None
self.md.ami_tags = {"ami": "tag"}
self.mock_register_image.return_value = "fakeimg"

res = self.svc.register_image(MagicMock(), self.md)

self.mock_register_image.assert_called_once()
tag_image.assert_called_once_with("fakeimg", {"ami": "tag"})
self.assertEqual(res, "fakeimg")

@patch('cloudimg.aws.AWSService.tag_image')
def test_register_image_merged_ami_tags(self, tag_image):
self.md.tags = {"tag": "tag"}
self.md.ami_tags = {"ami": "tag"}
self.mock_register_image.return_value = "fakeimg"

res = self.svc.register_image(MagicMock(), self.md)

self.mock_register_image.assert_called_once()
tag_image.assert_called_once_with(
"fakeimg",
{"tag": "tag", "ami": "tag"},
)
self.assertEqual(res, "fakeimg")

@patch('cloudimg.aws.AWSService.tag_image')
def test_register_image_boot_mode(self, tag_image):
self.mock_register_image.return_value = "fakeimg"
Expand Down

0 comments on commit c49b004

Please sign in to comment.