diff --git a/kubernetes/requirements.txt b/kubernetes/requirements.txt index 1e3a5f2..2a3547d 100755 --- a/kubernetes/requirements.txt +++ b/kubernetes/requirements.txt @@ -4,3 +4,4 @@ moto[s3] requests_mock pytest coverage +freezegun diff --git a/kubernetes/vault_snapshot/vault_snapshot.py b/kubernetes/vault_snapshot/vault_snapshot.py index 045c94f..ac0aac0 100755 --- a/kubernetes/vault_snapshot/vault_snapshot.py +++ b/kubernetes/vault_snapshot/vault_snapshot.py @@ -6,7 +6,7 @@ import hvac import os from pathlib import Path -import datetime +from datetime import UTC, datetime, timedelta class VaultSnapshot: """ @@ -68,6 +68,13 @@ def __init__(self, **kwargs): else: raise NameError("S3_BUCKET undefined") + if "s3_expire_days" in kwargs: + self.s3_expire_days = kwargs["s3_expire_days"] + elif "S3_EXPIRE_DAYS" in os.environ: + self.s3_expire_days = os.environ["S3_EXPIRE_DAYS"] + else: + self.s3_expire_days = -1 + if "jwt_secret_path" in kwargs: self.jwt_secret_path = kwargs["jwt_secret_path"] elif "JWT_SECRET_PATH" in os.environ: @@ -118,7 +125,7 @@ def snapshot(self): self.logger.info("Raft snapshot status code: %d" % resp.status_code) - date_str = datetime.datetime.now(datetime.UTC).strftime("%F-%H%M") + date_str = datetime.now(UTC).strftime("%F-%H%M") file_name = "vault_%s.snapshot" % (date_str) self.logger.info(f"File name: {file_name}") @@ -140,10 +147,16 @@ def snapshot(self): endpoint_url=self.s3_host, aws_access_key_id=self.s3_access_key_id, aws_secret_access_key=self.s3_secret_access_key) - bucket = s3.Bucket(self.s3_bucket) - for key in bucket.objects.all(): - self.logger.info(key.key) - # todo: do the S3_EXPIRE_DAYS magic + objs = self.s3_client.list_objects(Bucket=self.s3_bucket)["Contents"] + #self.logger.info(objs) + + for o in objs: + self.logger.info(f"LastModified: {o["LastModified"]}") + # expire keys when older than S3_EXPIRE_DAYS + if self.s3_expire_days >= 0: + if o["LastModified"] <= datetime.now(UTC) - timedelta(days=self.s3_expire_days): + self.logger.info(f"Deleting expired snapshot {o["Key"]}") + s3.Object(self.s3_bucket, o["Key"]).delete() return file_name diff --git a/kubernetes/vault_snapshot/vault_snapshot_test.py b/kubernetes/vault_snapshot/vault_snapshot_test.py index cfb0e47..2764e6c 100644 --- a/kubernetes/vault_snapshot/vault_snapshot_test.py +++ b/kubernetes/vault_snapshot/vault_snapshot_test.py @@ -3,6 +3,8 @@ import pytest import requests_mock from moto import mock_aws +from freezegun import freeze_time +from datetime import datetime, timedelta from vault_snapshot import VaultSnapshot @@ -83,3 +85,74 @@ def test_snapshot_with_jwt(self, **kwargs): body = s3obj["Body"] assert body.read() == b"blob" + + @mock_aws + @requests_mock.Mocker(kw="mock") + def test_snapshot_expiration(self, **kwargs): + """ + Test snapshot expiration. + """ + + kwargs['mock'].post("http://127.0.0.1:8200/v1/auth/kubernetes/login", json={ + "auth": { + "client_token": "root" + } + }) + kwargs['mock'].get("http://127.0.0.1:8200/v1/sys/storage/raft/snapshot", text="blob") + kwargs['mock'].get("http://127.0.0.1:8200/v1/auth/token/lookup-self", text="blob") + + bucket_name = "vault-snapshots" + region_name = "us-east-1" + s3_host = f"https://s3.{region_name}.amazonaws.com" + s3_access_key_id = "test" + s3_secret_access_key = "test" + + s3_client = boto3.client(service_name="s3", + endpoint_url=s3_host, + aws_access_key_id=s3_access_key_id, + aws_secret_access_key=s3_secret_access_key) + conn = boto3.resource("s3", region_name=region_name) + # We need to create the bucket since this is all in Moto's 'virtual' AWS account + conn.create_bucket(Bucket=bucket_name ) + + vault_snapshot = VaultSnapshot( + vault_addr="http://127.0.0.1:8200", + # the mock server assumes a "default" role + vault_role="default", + jwt_secret_path="./vault_snapshot/fixtures/jwt", + s3_access_key_id=s3_access_key_id, + s3_secret_access_key=s3_secret_access_key, + s3_host=s3_host, + s3_bucket=bucket_name, + # delete snapshots older than 1 day + s3_expire_days=1 + ) + + # Snapshot day before yesterday + with freeze_time(datetime.today() - timedelta(days=2.5)): + file_name = vault_snapshot.snapshot() + s3obj = conn.Object(bucket_name, file_name).get() + body = s3obj["Body"] + assert body.read() == b"blob" + + objs = s3_client.list_objects(Bucket=bucket_name)["Contents"] + assert len(objs) == 1 + + # Snapshot yesterday + with freeze_time(datetime.today() - timedelta(days=1.5)): + file_name = vault_snapshot.snapshot() + s3obj = conn.Object(bucket_name, file_name).get() + body = s3obj["Body"] + assert body.read() == b"blob" + + objs = s3_client.list_objects(Bucket=bucket_name)["Contents"] + assert len(objs) == 1 + + # Snapshot now + file_name = vault_snapshot.snapshot() + s3obj = conn.Object(bucket_name, file_name).get() + body = s3obj["Body"] + assert body.read() == b"blob" + + objs = s3_client.list_objects(Bucket=bucket_name)["Contents"] + assert len(objs) == 1